-
Notifications
You must be signed in to change notification settings - Fork 713
[Executorch] Refactor op_mul's broadcasting utils #8204
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 3 commits
27a79c4
cd9a0d7
e814bb7
ed79e8c
7d9494f
be44fb4
aedea37
ebf62fe
3029ca6
77eb1f3
53f8a14
f25833f
c7f9e88
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -8,6 +8,7 @@ | |
|
|
||
| #pragma once | ||
|
|
||
| #include <executorch/kernels/optimized/vec/functional.h> | ||
| #include <executorch/runtime/kernel/kernel_includes.h> | ||
|
|
||
| namespace torch { | ||
|
|
@@ -190,5 +191,116 @@ std::array<int32_t, 3> inline get_normalized_tensor_size( | |
| return normalized_tensor_size; | ||
| } | ||
|
|
||
| template <typename Op> | ||
| Tensor& handle_last_dim_broadcast_elementwise( | ||
| KernelRuntimeContext& ctx, | ||
| const Op& vec_fun, | ||
| const Tensor& a, | ||
| const Tensor& b, | ||
| Tensor& out, | ||
| const ElementwiseOptimizedPath selected_optimized_path) { | ||
| ScalarType out_type = out.scalar_type(); | ||
| const Tensor* lhs; | ||
| const Tensor* rhs; | ||
| if (selected_optimized_path == | ||
| ElementwiseOptimizedPath::kBroadcastLastDimReverseArguments) { | ||
| lhs = &b; | ||
| rhs = &a; | ||
| } else { | ||
| lhs = &a; | ||
| rhs = &b; | ||
| } | ||
| auto error = resize_tensor(out, lhs->sizes()); | ||
| ET_KERNEL_CHECK_MSG( | ||
| ctx, | ||
| error == Error::Ok, | ||
| InvalidArgument, | ||
| out, | ||
| "Failed to resize output tensor."); | ||
| const size_t outer_size = getLeadingDims(out, out.dim() - 1); | ||
| const auto broadcast_size = out.size(out.dim() - 1); | ||
| ET_SWITCH_REALB_TYPES(out_type, ctx, "mul.out", CTYPE, [&]() { | ||
| executorch::vec::broadcasting_map_broadcast_last_dim<CTYPE, Op>( | ||
| vec_fun, | ||
| out.mutable_data_ptr<CTYPE>(), | ||
| lhs->const_data_ptr<CTYPE>(), | ||
| rhs->const_data_ptr<CTYPE>(), | ||
| outer_size, | ||
| broadcast_size); | ||
| }); | ||
| return out; | ||
| } | ||
|
|
||
| template <typename Op> | ||
| Tensor& handle_broadcast_elementwise( | ||
| KernelRuntimeContext& ctx, | ||
| const Op& vec_fun, | ||
| const Tensor& a, | ||
| const Tensor& b, | ||
| Tensor& out, | ||
| const ElementwiseOptimizedPath selected_optimized_path) { | ||
| if ((selected_optimized_path == | ||
| ElementwiseOptimizedPath::kBroadcastLastDim) || | ||
| (selected_optimized_path == | ||
| ElementwiseOptimizedPath::kBroadcastLastDimReverseArguments)) { | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. it occurs to me that we should separate the selected algorithm from whether to reverse arguments or not to make this read nicer, but that definitely doesn't have to go in this PR |
||
| return handle_last_dim_broadcast_elementwise( | ||
| ctx, vec_fun, a, b, out, selected_optimized_path); | ||
| } | ||
|
|
||
| ScalarType out_type = out.scalar_type(); | ||
| const Tensor* lhs; | ||
| const Tensor* rhs; | ||
| if ((selected_optimized_path == | ||
| ElementwiseOptimizedPath::kBroadcast2dBy1dReverseArguments) || | ||
| (selected_optimized_path == | ||
| ElementwiseOptimizedPath::kBroadcastNdByNdReverseArguments)) { | ||
| lhs = &b; | ||
| rhs = &a; | ||
| } else { | ||
| // Catch failure to update logic when adding new broadcasting possibility. | ||
| ET_DCHECK( | ||
| (selected_optimized_path == | ||
| ElementwiseOptimizedPath::kBroadcast2dBy1d) || | ||
| (selected_optimized_path == | ||
| ElementwiseOptimizedPath::kBroadcastNdByNd)); | ||
| lhs = &a; | ||
| rhs = &b; | ||
| } | ||
| auto error = resize_tensor(out, lhs->sizes()); | ||
| ET_KERNEL_CHECK_MSG( | ||
| ctx, | ||
| error == Error::Ok, | ||
| InvalidArgument, | ||
| out, | ||
| "Failed to resize output tensor."); | ||
| int64_t outer_size = 1; | ||
| int64_t broadcast_size; | ||
| int64_t inner_size; | ||
| if ((selected_optimized_path == ElementwiseOptimizedPath::kBroadcastNdByNd) || | ||
| (selected_optimized_path == | ||
| ElementwiseOptimizedPath::kBroadcastNdByNdReverseArguments)) { | ||
| int32_t broadcast_dim = internal::get_broadcast_dim(*lhs, *rhs); | ||
| int32_t broadcast_dim_lhs = lhs->dim() + broadcast_dim; | ||
| auto normalized_tensor_size_lhs = | ||
| get_normalized_tensor_size(*lhs, broadcast_dim_lhs); | ||
| outer_size = normalized_tensor_size_lhs[0]; | ||
| broadcast_size = normalized_tensor_size_lhs[1]; | ||
| inner_size = normalized_tensor_size_lhs[2]; | ||
| } else { | ||
| broadcast_size = lhs->sizes()[lhs->dim() - 2]; | ||
| inner_size = lhs->sizes()[lhs->dim() - 1]; | ||
| } | ||
| ET_SWITCH_REALB_TYPES(out_type, ctx, "mul.out", CTYPE, [&]() { | ||
| executorch::vec::broadcasting_map_3d_and_unsqueezed_3d<CTYPE, Op>( | ||
| vec_fun, | ||
| out.mutable_data_ptr<CTYPE>(), | ||
| lhs->const_data_ptr<CTYPE>(), | ||
| rhs->const_data_ptr<CTYPE>(), | ||
| outer_size, | ||
| broadcast_size, | ||
| inner_size); | ||
| }); | ||
| return out; | ||
| } | ||
| } // namespace executor | ||
| } // namespace torch | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -68,114 +68,6 @@ template < | |
| struct MulInner<false, CTYPE_A, CTYPE_B, CTYPE_IN, CTYPE_OUT> | ||
| : public ReportCanCastBug {}; | ||
|
|
||
| Tensor& handle_last_dim_broadcast( | ||
| KernelRuntimeContext& ctx, | ||
| const Tensor& a, | ||
| const Tensor& b, | ||
| Tensor& out, | ||
| const ElementwiseOptimizedPath selected_optimized_path) { | ||
| ScalarType out_type = out.scalar_type(); | ||
| const Tensor* lhs; | ||
| const Tensor* rhs; | ||
| if (selected_optimized_path == | ||
| ElementwiseOptimizedPath::kBroadcastLastDimReverseArguments) { | ||
| lhs = &b; | ||
| rhs = &a; | ||
| } else { | ||
| lhs = &a; | ||
| rhs = &b; | ||
| } | ||
| auto error = resize_tensor(out, lhs->sizes()); | ||
| ET_KERNEL_CHECK_MSG( | ||
| ctx, | ||
| error == Error::Ok, | ||
| InvalidArgument, | ||
| out, | ||
| "Failed to resize output tensor."); | ||
| const size_t outer_size = getLeadingDims(out, out.dim() - 1); | ||
| const auto broadcast_size = out.size(out.dim() - 1); | ||
| ET_SWITCH_REALB_TYPES(out_type, ctx, "mul.out", CTYPE, [&]() { | ||
| using Vec = executorch::vec::Vectorized<CTYPE>; | ||
| executorch::vec::broadcasting_map_broadcast_last_dim<CTYPE>( | ||
| [](Vec x, Vec y) { return x * y; }, | ||
| out.mutable_data_ptr<CTYPE>(), | ||
| lhs->const_data_ptr<CTYPE>(), | ||
| rhs->const_data_ptr<CTYPE>(), | ||
| outer_size, | ||
| broadcast_size); | ||
| }); | ||
| return out; | ||
| } | ||
|
|
||
| Tensor& handle_broadcast_mul( | ||
| KernelRuntimeContext& ctx, | ||
| const Tensor& a, | ||
| const Tensor& b, | ||
| Tensor& out, | ||
| const ElementwiseOptimizedPath selected_optimized_path) { | ||
| if ((selected_optimized_path == | ||
| ElementwiseOptimizedPath::kBroadcastLastDim) || | ||
| (selected_optimized_path == | ||
| ElementwiseOptimizedPath::kBroadcastLastDimReverseArguments)) { | ||
| return handle_last_dim_broadcast(ctx, a, b, out, selected_optimized_path); | ||
| } | ||
|
|
||
| ScalarType out_type = out.scalar_type(); | ||
| const Tensor* lhs; | ||
| const Tensor* rhs; | ||
| if ((selected_optimized_path == | ||
| ElementwiseOptimizedPath::kBroadcast2dBy1dReverseArguments) || | ||
| (selected_optimized_path == | ||
| ElementwiseOptimizedPath::kBroadcastNdByNdReverseArguments)) { | ||
| lhs = &b; | ||
| rhs = &a; | ||
| } else { | ||
| // Catch failure to update logic when adding new broadcasting possibility. | ||
| ET_DCHECK( | ||
| (selected_optimized_path == | ||
| ElementwiseOptimizedPath::kBroadcast2dBy1d) || | ||
| (selected_optimized_path == | ||
| ElementwiseOptimizedPath::kBroadcastNdByNd)); | ||
| lhs = &a; | ||
| rhs = &b; | ||
| } | ||
| auto error = resize_tensor(out, lhs->sizes()); | ||
| ET_KERNEL_CHECK_MSG( | ||
| ctx, | ||
| error == Error::Ok, | ||
| InvalidArgument, | ||
| out, | ||
| "Failed to resize output tensor."); | ||
| int64_t outer_size = 1; | ||
| int64_t broadcast_size; | ||
| int64_t inner_size; | ||
| if ((selected_optimized_path == ElementwiseOptimizedPath::kBroadcastNdByNd) || | ||
| (selected_optimized_path == | ||
| ElementwiseOptimizedPath::kBroadcastNdByNdReverseArguments)) { | ||
| int32_t broadcast_dim = internal::get_broadcast_dim(*lhs, *rhs); | ||
| int32_t broadcast_dim_lhs = lhs->dim() + broadcast_dim; | ||
| auto normalized_tensor_size_lhs = | ||
| get_normalized_tensor_size(*lhs, broadcast_dim_lhs); | ||
| outer_size = normalized_tensor_size_lhs[0]; | ||
| broadcast_size = normalized_tensor_size_lhs[1]; | ||
| inner_size = normalized_tensor_size_lhs[2]; | ||
| } else { | ||
| broadcast_size = lhs->sizes()[lhs->dim() - 2]; | ||
| inner_size = lhs->sizes()[lhs->dim() - 1]; | ||
| } | ||
| ET_SWITCH_REALB_TYPES(out_type, ctx, "mul.out", CTYPE, [&]() { | ||
| using Vec = executorch::vec::Vectorized<CTYPE>; | ||
| executorch::vec::broadcasting_map_3d_and_unsqueezed_3d<CTYPE>( | ||
| [](Vec x, Vec y) { return x * y; }, | ||
| out.mutable_data_ptr<CTYPE>(), | ||
| lhs->const_data_ptr<CTYPE>(), | ||
| rhs->const_data_ptr<CTYPE>(), | ||
| outer_size, | ||
| broadcast_size, | ||
| inner_size); | ||
| }); | ||
| return out; | ||
| } | ||
| } // namespace | ||
|
|
||
| Tensor& opt_mul_out( | ||
|
|
@@ -238,7 +130,9 @@ Tensor& opt_mul_out( | |
| out.numel()); | ||
| }); | ||
| } else if (selected_optimized_path != ElementwiseOptimizedPath::kNone) { | ||
| return handle_broadcast_mul(ctx, a, b, out, selected_optimized_path); | ||
| auto mul_lambda = [](auto x, auto y) { return x * y; }; | ||
| return torch::executor::handle_broadcast_elementwise( | ||
| ctx, mul_lambda, a, b, out, selected_optimized_path); | ||
|
||
| } else { | ||
| ScalarType common_type = | ||
| promoteTypes(a_type, b_type, /*half_to_float*/ true); | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this isn't in mul anymore, but I see it's fixed in the next PR, close enough