-
Notifications
You must be signed in to change notification settings - Fork 712
[ExecuTorch] Add broadcast support for optimized add op #8205
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 5 commits
27a79c4
dbe3e8a
bf761db
0e1cfc7
0ce8fd7
00e54b8
7ea55eb
ffb6903
e9fe6af
e53eb97
a91eef8
f565c3b
656873f
8ecbd04
2804f70
f3406bf
132d2f5
216c4be
bde7998
110a932
7ebd165
5fb4107
9e0855b
0d19ade
8955d90
6f2f01a
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -8,6 +8,8 @@ | |
|
|
||
| #pragma once | ||
|
|
||
| #include <executorch/kernels/optimized/vec/functional.h> | ||
| #include <executorch/kernels/portable/cpu/scalar_utils.h> | ||
| #include <executorch/runtime/kernel/kernel_includes.h> | ||
|
|
||
| namespace torch { | ||
|
|
@@ -47,8 +49,38 @@ enum class ElementwiseOptimizedPath { | |
| kBroadcastLastDimReverseArguments, | ||
| }; | ||
|
|
||
| enum class BinaryOpType { | ||
| kAdd, | ||
| kSub, | ||
| kMul, | ||
| kDiv, | ||
| }; | ||
|
|
||
| namespace internal { | ||
|
|
||
| template <BinaryOpType op_type> | ||
| struct BinaryOpTypeName; | ||
|
|
||
| template <> | ||
| struct BinaryOpTypeName<BinaryOpType::kAdd> { | ||
| static constexpr char kName[] = "add.out"; | ||
| }; | ||
|
|
||
| template <> | ||
| struct BinaryOpTypeName<BinaryOpType::kSub> { | ||
| static constexpr char kName[] = "sub.out"; | ||
| }; | ||
|
|
||
| template <> | ||
| struct BinaryOpTypeName<BinaryOpType::kMul> { | ||
| static constexpr char kName[] = "mul.out"; | ||
| }; | ||
|
|
||
| template <> | ||
| struct BinaryOpTypeName<BinaryOpType::kDiv> { | ||
| static constexpr char kName[] = "div.out"; | ||
| }; | ||
|
|
||
| /* | ||
| Given two tensors, this function returns the broadcast dim if it exists. | ||
| Returns 0 if no broadcast dim is found. | ||
|
|
@@ -190,5 +222,145 @@ std::array<int32_t, 3> inline get_normalized_tensor_size( | |
| return normalized_tensor_size; | ||
| } | ||
|
|
||
| template <BinaryOpType op_type, typename Op> | ||
| Tensor& handle_last_dim_broadcast_elementwise( | ||
| KernelRuntimeContext& ctx, | ||
| const Op& vec_fun, | ||
| const Tensor& a, | ||
| const Tensor& b, | ||
| Tensor& out, | ||
| const ElementwiseOptimizedPath selected_optimized_path, | ||
| executorch::aten::optional<Scalar>& alpha = {}) { | ||
|
||
| ScalarType out_type = out.scalar_type(); | ||
| const Tensor* lhs; | ||
| const Tensor* rhs; | ||
| if (selected_optimized_path == | ||
| ElementwiseOptimizedPath::kBroadcastLastDimReverseArguments) { | ||
| lhs = &b; | ||
| rhs = &a; | ||
| } else { | ||
| lhs = &a; | ||
| rhs = &b; | ||
| } | ||
| auto error = resize_tensor(out, lhs->sizes()); | ||
| ET_KERNEL_CHECK_MSG( | ||
| ctx, | ||
| error == Error::Ok, | ||
| InvalidArgument, | ||
| out, | ||
| "Failed to resize output tensor."); | ||
| const size_t outer_size = getLeadingDims(out, out.dim() - 1); | ||
| const auto broadcast_size = out.size(out.dim() - 1); | ||
| ET_SWITCH_REALB_TYPES(out_type, ctx, internal::BinaryOpTypeName<op_type>::kName, CTYPE, [&]() { | ||
| using Vec = executorch::vec::Vectorized<CTYPE>; | ||
| CTYPE alpha_val; | ||
| Vec alpha_val_vec(alpha_val); | ||
|
||
| if (alpha.has_value()) { | ||
| ET_KERNEL_CHECK( | ||
| ctx, | ||
| native::utils::extract_scalar(alpha.value(), &alpha_val), | ||
| InvalidArgument, ); | ||
| alpha_val_vec = Vec(alpha_val); | ||
| } | ||
| auto vec_fun_alpha = [vec_fun, alpha_val_vec](const Vec& a, const Vec& b) { | ||
| return vec_fun(a, b, alpha_val_vec); | ||
| }; | ||
| executorch::vec::broadcasting_map_broadcast_last_dim<CTYPE>( | ||
| vec_fun_alpha, | ||
| out.mutable_data_ptr<CTYPE>(), | ||
| lhs->const_data_ptr<CTYPE>(), | ||
| rhs->const_data_ptr<CTYPE>(), | ||
| outer_size, | ||
| broadcast_size); | ||
| }); | ||
| return out; | ||
| } | ||
|
|
||
| template <BinaryOpType op_type, typename Op> | ||
| Tensor& handle_broadcast_elementwise( | ||
| KernelRuntimeContext& ctx, | ||
| const Op& vec_fun, | ||
| const Tensor& a, | ||
| const Tensor& b, | ||
| Tensor& out, | ||
| const ElementwiseOptimizedPath selected_optimized_path, | ||
| executorch::aten::optional<Scalar> alpha = {}) { | ||
|
||
| if ((selected_optimized_path == | ||
| ElementwiseOptimizedPath::kBroadcastLastDim) || | ||
| (selected_optimized_path == | ||
| ElementwiseOptimizedPath::kBroadcastLastDimReverseArguments)) { | ||
| return handle_last_dim_broadcast_elementwise<op_type>( | ||
| ctx, vec_fun, a, b, out, selected_optimized_path, alpha); | ||
| } | ||
|
|
||
| ScalarType out_type = out.scalar_type(); | ||
| const Tensor* lhs; | ||
| const Tensor* rhs; | ||
| if ((selected_optimized_path == | ||
| ElementwiseOptimizedPath::kBroadcast2dBy1dReverseArguments) || | ||
| (selected_optimized_path == | ||
| ElementwiseOptimizedPath::kBroadcastNdByNdReverseArguments)) { | ||
| lhs = &b; | ||
| rhs = &a; | ||
| } else { | ||
| // Catch failure to update logic when adding new broadcasting possibility. | ||
| ET_DCHECK( | ||
| (selected_optimized_path == | ||
| ElementwiseOptimizedPath::kBroadcast2dBy1d) || | ||
| (selected_optimized_path == | ||
| ElementwiseOptimizedPath::kBroadcastNdByNd)); | ||
| lhs = &a; | ||
| rhs = &b; | ||
| } | ||
| auto error = resize_tensor(out, lhs->sizes()); | ||
| ET_KERNEL_CHECK_MSG( | ||
| ctx, | ||
| error == Error::Ok, | ||
| InvalidArgument, | ||
| out, | ||
| "Failed to resize output tensor."); | ||
| int64_t outer_size = 1; | ||
| int64_t broadcast_size; | ||
| int64_t inner_size; | ||
| if ((selected_optimized_path == ElementwiseOptimizedPath::kBroadcastNdByNd) || | ||
| (selected_optimized_path == | ||
| ElementwiseOptimizedPath::kBroadcastNdByNdReverseArguments)) { | ||
| int32_t broadcast_dim = internal::get_broadcast_dim(*lhs, *rhs); | ||
| int32_t broadcast_dim_lhs = lhs->dim() + broadcast_dim; | ||
| auto normalized_tensor_size_lhs = | ||
| get_normalized_tensor_size(*lhs, broadcast_dim_lhs); | ||
| outer_size = normalized_tensor_size_lhs[0]; | ||
| broadcast_size = normalized_tensor_size_lhs[1]; | ||
| inner_size = normalized_tensor_size_lhs[2]; | ||
| } else { | ||
| broadcast_size = lhs->sizes()[lhs->dim() - 2]; | ||
| inner_size = lhs->sizes()[lhs->dim() - 1]; | ||
| } | ||
| ET_SWITCH_REALB_TYPES(out_type, ctx, internal::BinaryOpTypeName<op_type>::kName, CTYPE, [&]() { | ||
| using Vec = executorch::vec::Vectorized<CTYPE>; | ||
| CTYPE alpha_val; | ||
|
||
| Vec alpha_val_vec; | ||
| if (alpha.has_value()) { | ||
| ET_KERNEL_CHECK( | ||
| ctx, | ||
| native::utils::extract_scalar(alpha.value(), &alpha_val), | ||
| InvalidArgument, ); | ||
| alpha_val_vec = Vec(alpha_val); | ||
| } | ||
| auto vec_fun_alpha = [vec_fun, alpha_val_vec](const Vec& a, const Vec& b) { | ||
| return vec_fun(a, b, alpha_val_vec); | ||
| }; | ||
| executorch::vec:: | ||
| broadcasting_map_3d_and_unsqueezed_3d<CTYPE, decltype(vec_fun_alpha)>( | ||
| vec_fun_alpha, | ||
| out.mutable_data_ptr<CTYPE>(), | ||
| lhs->const_data_ptr<CTYPE>(), | ||
| rhs->const_data_ptr<CTYPE>(), | ||
| outer_size, | ||
| broadcast_size, | ||
| inner_size); | ||
| }); | ||
| return out; | ||
| } | ||
| } // namespace executor | ||
| } // namespace torch | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -140,41 +140,31 @@ Tensor& opt_add_out( | |
| out.numel()); | ||
| }); | ||
| } else if (selected_optimized_path != ElementwiseOptimizedPath::kNone) { | ||
| const Tensor* lhs; | ||
| const Tensor* rhs; | ||
| if (selected_optimized_path == | ||
| ElementwiseOptimizedPath::kBroadcast2dBy1dReverseArguments) { | ||
| lhs = &b; | ||
| rhs = &a; | ||
| ElementwiseOptimizedPath::kBroadcast2dBy1dReverseArguments || | ||
| selected_optimized_path == | ||
| ElementwiseOptimizedPath::kBroadcastLastDimReverseArguments || | ||
| selected_optimized_path == | ||
| ElementwiseOptimizedPath::kBroadcastNdByNdReverseArguments) { | ||
| // This behavior is a bit confusing. | ||
|
||
| // Reason we swap out args here is because handle_broadcast_elementwise | ||
| // handles this selected_optimized_path option a bit differently. | ||
| // This should really be resoled in handle_broadcast_elementwise. | ||
|
||
| // However, the current blocker is that handle_broadcast_elementwise tries to | ||
| // be agnostic of op. This should be fixed, likely by moving lambda creation | ||
| // to handle_broadcast_elementwise and it be aware of which op is being executed. | ||
| auto add_lambda = [](auto x, auto y, auto alpha_val) { | ||
| return y + alpha_val * x; | ||
| }; | ||
| return torch::executor::handle_broadcast_elementwise<BinaryOpType::kAdd>( | ||
| ctx, add_lambda, a, b, out, selected_optimized_path, alpha); | ||
| } else { | ||
| // Catch failure to update logic when adding new broadcasting possibility. | ||
| ET_DCHECK( | ||
| selected_optimized_path == | ||
| ElementwiseOptimizedPath::kBroadcast2dBy1d); | ||
| lhs = &a; | ||
| rhs = &b; | ||
| auto add_lambda = [](auto x, auto y, auto alpha_val) { | ||
| return x + alpha_val * y; | ||
| }; | ||
| return torch::executor::handle_broadcast_elementwise<BinaryOpType::kAdd>( | ||
| ctx, add_lambda, a, b, out, selected_optimized_path, alpha); | ||
| } | ||
| auto error = resize_tensor(out, lhs->sizes()); | ||
| ET_KERNEL_CHECK_MSG( | ||
| ctx, | ||
| error == Error::Ok, | ||
| InvalidArgument, | ||
| out, | ||
| "Failed to resize output tensor."); | ||
| ET_SWITCH_REALB_TYPES(out_type, ctx, "add.out", CTYPE, [&]() { | ||
| CTYPE alpha_val; | ||
| ET_KERNEL_CHECK( | ||
| ctx, utils::extract_scalar(alpha, &alpha_val), InvalidArgument, ); | ||
|
|
||
| using Vec = executorch::vec::Vectorized<CTYPE>; | ||
| executorch::vec::broadcasting_map_2d_by_1d<CTYPE>( | ||
| [alpha_val](Vec x, Vec y) { return x + Vec(alpha_val) * y; }, | ||
| out.mutable_data_ptr<CTYPE>(), | ||
| lhs->const_data_ptr<CTYPE>(), | ||
| rhs->const_data_ptr<CTYPE>(), | ||
| lhs->sizes()[lhs->dim() - 2], | ||
| lhs->sizes()[lhs->dim() - 1]); | ||
| }); | ||
| } else { | ||
| ScalarType common_type = | ||
| promoteTypes(a_type, b_type, /*half_to_float*/ true); | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -68,114 +68,6 @@ template < | |
| struct MulInner<false, CTYPE_A, CTYPE_B, CTYPE_IN, CTYPE_OUT> | ||
| : public ReportCanCastBug {}; | ||
|
|
||
| Tensor& handle_last_dim_broadcast( | ||
| KernelRuntimeContext& ctx, | ||
| const Tensor& a, | ||
| const Tensor& b, | ||
| Tensor& out, | ||
| const ElementwiseOptimizedPath selected_optimized_path) { | ||
| ScalarType out_type = out.scalar_type(); | ||
| const Tensor* lhs; | ||
| const Tensor* rhs; | ||
| if (selected_optimized_path == | ||
| ElementwiseOptimizedPath::kBroadcastLastDimReverseArguments) { | ||
| lhs = &b; | ||
| rhs = &a; | ||
| } else { | ||
| lhs = &a; | ||
| rhs = &b; | ||
| } | ||
| auto error = resize_tensor(out, lhs->sizes()); | ||
| ET_KERNEL_CHECK_MSG( | ||
| ctx, | ||
| error == Error::Ok, | ||
| InvalidArgument, | ||
| out, | ||
| "Failed to resize output tensor."); | ||
| const size_t outer_size = getLeadingDims(out, out.dim() - 1); | ||
| const auto broadcast_size = out.size(out.dim() - 1); | ||
| ET_SWITCH_REALB_TYPES(out_type, ctx, "mul.out", CTYPE, [&]() { | ||
| using Vec = executorch::vec::Vectorized<CTYPE>; | ||
| executorch::vec::broadcasting_map_broadcast_last_dim<CTYPE>( | ||
| [](Vec x, Vec y) { return x * y; }, | ||
| out.mutable_data_ptr<CTYPE>(), | ||
| lhs->const_data_ptr<CTYPE>(), | ||
| rhs->const_data_ptr<CTYPE>(), | ||
| outer_size, | ||
| broadcast_size); | ||
| }); | ||
| return out; | ||
| } | ||
|
|
||
| Tensor& handle_broadcast_mul( | ||
| KernelRuntimeContext& ctx, | ||
| const Tensor& a, | ||
| const Tensor& b, | ||
| Tensor& out, | ||
| const ElementwiseOptimizedPath selected_optimized_path) { | ||
| if ((selected_optimized_path == | ||
| ElementwiseOptimizedPath::kBroadcastLastDim) || | ||
| (selected_optimized_path == | ||
| ElementwiseOptimizedPath::kBroadcastLastDimReverseArguments)) { | ||
| return handle_last_dim_broadcast(ctx, a, b, out, selected_optimized_path); | ||
| } | ||
|
|
||
| ScalarType out_type = out.scalar_type(); | ||
| const Tensor* lhs; | ||
| const Tensor* rhs; | ||
| if ((selected_optimized_path == | ||
| ElementwiseOptimizedPath::kBroadcast2dBy1dReverseArguments) || | ||
| (selected_optimized_path == | ||
| ElementwiseOptimizedPath::kBroadcastNdByNdReverseArguments)) { | ||
| lhs = &b; | ||
| rhs = &a; | ||
| } else { | ||
| // Catch failure to update logic when adding new broadcasting possibility. | ||
| ET_DCHECK( | ||
| (selected_optimized_path == | ||
| ElementwiseOptimizedPath::kBroadcast2dBy1d) || | ||
| (selected_optimized_path == | ||
| ElementwiseOptimizedPath::kBroadcastNdByNd)); | ||
| lhs = &a; | ||
| rhs = &b; | ||
| } | ||
| auto error = resize_tensor(out, lhs->sizes()); | ||
| ET_KERNEL_CHECK_MSG( | ||
| ctx, | ||
| error == Error::Ok, | ||
| InvalidArgument, | ||
| out, | ||
| "Failed to resize output tensor."); | ||
| int64_t outer_size = 1; | ||
| int64_t broadcast_size; | ||
| int64_t inner_size; | ||
| if ((selected_optimized_path == ElementwiseOptimizedPath::kBroadcastNdByNd) || | ||
| (selected_optimized_path == | ||
| ElementwiseOptimizedPath::kBroadcastNdByNdReverseArguments)) { | ||
| int32_t broadcast_dim = internal::get_broadcast_dim(*lhs, *rhs); | ||
| int32_t broadcast_dim_lhs = lhs->dim() + broadcast_dim; | ||
| auto normalized_tensor_size_lhs = | ||
| get_normalized_tensor_size(*lhs, broadcast_dim_lhs); | ||
| outer_size = normalized_tensor_size_lhs[0]; | ||
| broadcast_size = normalized_tensor_size_lhs[1]; | ||
| inner_size = normalized_tensor_size_lhs[2]; | ||
| } else { | ||
| broadcast_size = lhs->sizes()[lhs->dim() - 2]; | ||
| inner_size = lhs->sizes()[lhs->dim() - 1]; | ||
| } | ||
| ET_SWITCH_REALB_TYPES(out_type, ctx, "mul.out", CTYPE, [&]() { | ||
| using Vec = executorch::vec::Vectorized<CTYPE>; | ||
| executorch::vec::broadcasting_map_3d_and_unsqueezed_3d<CTYPE>( | ||
| [](Vec x, Vec y) { return x * y; }, | ||
| out.mutable_data_ptr<CTYPE>(), | ||
| lhs->const_data_ptr<CTYPE>(), | ||
| rhs->const_data_ptr<CTYPE>(), | ||
| outer_size, | ||
| broadcast_size, | ||
| inner_size); | ||
| }); | ||
| return out; | ||
| } | ||
| } // namespace | ||
|
|
||
| Tensor& opt_mul_out( | ||
|
|
@@ -238,7 +130,13 @@ Tensor& opt_mul_out( | |
| out.numel()); | ||
| }); | ||
| } else if (selected_optimized_path != ElementwiseOptimizedPath::kNone) { | ||
| return handle_broadcast_mul(ctx, a, b, out, selected_optimized_path); | ||
| // Reason for using alpha: | ||
|
||
| auto mul_lambda = [](auto x, auto y, auto alpha) { | ||
| (void)alpha; | ||
|
||
| return x * y; | ||
| }; | ||
| return torch::executor::handle_broadcast_elementwise<BinaryOpType::kMul>( | ||
| ctx, mul_lambda, a, b, out, selected_optimized_path); | ||
| } else { | ||
| ScalarType common_type = | ||
| promoteTypes(a_type, b_type, /*half_to_float*/ true); | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
you don't need to do this. see existing example:
executorch/kernels/portable/cpu/op_rsub.cpp
Lines 50 to 55 in c82a7df
the secret sauce is that the string literal has to be a static constexpr const char [] and then you can pass it to a
const char*template argument directly.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks. I was hoping you would point me to something better for this