|
11 | 11 | #include <executorch/kernels/optimized/vec/vec.h> |
12 | 12 | #include <executorch/kernels/portable/cpu/scalar_utils.h> |
13 | 13 | #include <executorch/kernels/portable/cpu/util/broadcast_util.h> |
| 14 | +#include <executorch/runtime/core/exec_aten/util/tensor_util.h> // IWYU pragma: export |
14 | 15 | #include <executorch/runtime/kernel/kernel_includes.h> |
15 | 16 | #include <executorch/runtime/platform/assert.h> |
16 | 17 |
|
@@ -66,6 +67,117 @@ template < |
66 | 67 | typename CTYPE_OUT> |
67 | 68 | struct MulInner<false, CTYPE_A, CTYPE_B, CTYPE_IN, CTYPE_OUT> |
68 | 69 | : public ReportCanCastBug {}; |
| 70 | + |
| 71 | +Tensor& handle_last_dim_broadcast( |
| 72 | + KernelRuntimeContext& ctx, |
| 73 | + const Tensor& a, |
| 74 | + const Tensor& b, |
| 75 | + Tensor& out, |
| 76 | + const ElementwiseOptimizedPath selected_optimized_path) { |
| 77 | + ScalarType out_type = out.scalar_type(); |
| 78 | + const Tensor* lhs; |
| 79 | + const Tensor* rhs; |
| 80 | + if (selected_optimized_path == |
| 81 | + ElementwiseOptimizedPath::kBroadcastLastDimReverseArguments) { |
| 82 | + lhs = &b; |
| 83 | + rhs = &a; |
| 84 | + } else { |
| 85 | + lhs = &a; |
| 86 | + rhs = &b; |
| 87 | + } |
| 88 | + auto error = resize_tensor(out, lhs->sizes()); |
| 89 | + ET_KERNEL_CHECK_MSG( |
| 90 | + ctx, |
| 91 | + error == Error::Ok, |
| 92 | + InvalidArgument, |
| 93 | + out, |
| 94 | + "Failed to resize output tensor."); |
| 95 | + const size_t outer_size = getLeadingDims(out, out.dim() - 1); |
| 96 | + const auto broadcast_size = out.size(out.dim() - 1); |
| 97 | + ET_SWITCH_REALB_TYPES(out_type, ctx, "mul.out", CTYPE, [&]() { |
| 98 | + using Vec = executorch::vec::Vectorized<CTYPE>; |
| 99 | + executorch::vec::broadcasting_map_broadcast_last_dim<CTYPE>( |
| 100 | + [](Vec x, Vec y) { return x * y; }, |
| 101 | + out.mutable_data_ptr<CTYPE>(), |
| 102 | + lhs->const_data_ptr<CTYPE>(), |
| 103 | + rhs->const_data_ptr<CTYPE>(), |
| 104 | + outer_size, |
| 105 | + broadcast_size); |
| 106 | + }); |
| 107 | + return out; |
| 108 | +} |
| 109 | + |
| 110 | +Tensor& handle_broadcast_mul( |
| 111 | + KernelRuntimeContext& ctx, |
| 112 | + const Tensor& a, |
| 113 | + const Tensor& b, |
| 114 | + Tensor& out, |
| 115 | + const ElementwiseOptimizedPath selected_optimized_path) { |
| 116 | + |
| 117 | + if ((selected_optimized_path == |
| 118 | + ElementwiseOptimizedPath::kBroadcastLastDim) || |
| 119 | + (selected_optimized_path == ElementwiseOptimizedPath::kBroadcastLastDimReverseArguments)) { |
| 120 | + return handle_last_dim_broadcast( |
| 121 | + ctx, a, b, out, selected_optimized_path); |
| 122 | + } |
| 123 | + |
| 124 | + ScalarType out_type = out.scalar_type(); |
| 125 | + const Tensor* lhs; |
| 126 | + const Tensor* rhs; |
| 127 | + if ((selected_optimized_path == |
| 128 | + ElementwiseOptimizedPath::kBroadcast2dBy1dReverseArguments) || |
| 129 | + (selected_optimized_path == |
| 130 | + ElementwiseOptimizedPath::kBroadcastNdByNdReverseArguments)) { |
| 131 | + lhs = &b; |
| 132 | + rhs = &a; |
| 133 | + } else { |
| 134 | + // Catch failure to update logic when adding new broadcasting possibility. |
| 135 | + ET_DCHECK( |
| 136 | + (selected_optimized_path == |
| 137 | + ElementwiseOptimizedPath::kBroadcast2dBy1d) || |
| 138 | + (selected_optimized_path == |
| 139 | + ElementwiseOptimizedPath::kBroadcastNdByNd)); |
| 140 | + lhs = &a; |
| 141 | + rhs = &b; |
| 142 | + } |
| 143 | + auto error = resize_tensor(out, lhs->sizes()); |
| 144 | + ET_KERNEL_CHECK_MSG( |
| 145 | + ctx, |
| 146 | + error == Error::Ok, |
| 147 | + InvalidArgument, |
| 148 | + out, |
| 149 | + "Failed to resize output tensor."); |
| 150 | + int64_t outer_size = 1; |
| 151 | + int64_t broadcast_size; |
| 152 | + int64_t inner_size; |
| 153 | + if ((selected_optimized_path == ElementwiseOptimizedPath::kBroadcastNdByNd) || |
| 154 | + (selected_optimized_path == |
| 155 | + ElementwiseOptimizedPath::kBroadcastNdByNdReverseArguments)) { |
| 156 | + int32_t broadcast_dim = internal::get_broadcast_dim(*lhs, *rhs); |
| 157 | + int32_t broadcast_dim_lhs = lhs->dim() + broadcast_dim; |
| 158 | + int32_t broadcast_dim_rhs = rhs->dim() + broadcast_dim; |
| 159 | + auto normalized_tensor_size_lhs = |
| 160 | + get_normalized_tensor_size(*lhs, broadcast_dim_lhs); |
| 161 | + outer_size = normalized_tensor_size_lhs[0]; |
| 162 | + broadcast_size = normalized_tensor_size_lhs[1]; |
| 163 | + inner_size = normalized_tensor_size_lhs[2]; |
| 164 | + } else { |
| 165 | + broadcast_size = lhs->sizes()[lhs->dim() - 2]; |
| 166 | + inner_size = lhs->sizes()[lhs->dim() - 1]; |
| 167 | + } |
| 168 | + ET_SWITCH_REALB_TYPES(out_type, ctx, "mul.out", CTYPE, [&]() { |
| 169 | + using Vec = executorch::vec::Vectorized<CTYPE>; |
| 170 | + executorch::vec::broadcasting_map_3d_and_unsqueezed_3d<CTYPE>( |
| 171 | + [](Vec x, Vec y) { return x * y; }, |
| 172 | + out.mutable_data_ptr<CTYPE>(), |
| 173 | + lhs->const_data_ptr<CTYPE>(), |
| 174 | + rhs->const_data_ptr<CTYPE>(), |
| 175 | + outer_size, |
| 176 | + broadcast_size, |
| 177 | + inner_size); |
| 178 | + }); |
| 179 | + return out; |
| 180 | +} |
69 | 181 | } // namespace |
70 | 182 |
|
71 | 183 | Tensor& opt_mul_out( |
@@ -128,56 +240,7 @@ Tensor& opt_mul_out( |
128 | 240 | out.numel()); |
129 | 241 | }); |
130 | 242 | } else if (selected_optimized_path != ElementwiseOptimizedPath::kNone) { |
131 | | - const Tensor* lhs; |
132 | | - const Tensor* rhs; |
133 | | - if ((selected_optimized_path == |
134 | | - ElementwiseOptimizedPath::kBroadcast2dBy1dReverseArguments) || |
135 | | - (selected_optimized_path == ElementwiseOptimizedPath::kBroadcastNdByNdReverseArguments)) { |
136 | | - lhs = &b; |
137 | | - rhs = &a; |
138 | | - } else { |
139 | | - // Catch failure to update logic when adding new broadcasting possibility. |
140 | | - ET_DCHECK( |
141 | | - (selected_optimized_path == |
142 | | - ElementwiseOptimizedPath::kBroadcast2dBy1d) || |
143 | | - (selected_optimized_path == ElementwiseOptimizedPath::kBroadcastNdByNd)); |
144 | | - lhs = &a; |
145 | | - rhs = &b; |
146 | | - } |
147 | | - auto error = resize_tensor(out, lhs->sizes()); |
148 | | - ET_KERNEL_CHECK_MSG( |
149 | | - ctx, |
150 | | - error == Error::Ok, |
151 | | - InvalidArgument, |
152 | | - out, |
153 | | - "Failed to resize output tensor."); |
154 | | - int64_t outer_size = 1; |
155 | | - int64_t broadcast_size; |
156 | | - int64_t inner_size; |
157 | | - if ((selected_optimized_path == ElementwiseOptimizedPath::kBroadcastNdByNd) || |
158 | | - (selected_optimized_path == ElementwiseOptimizedPath::kBroadcastNdByNdReverseArguments)) { |
159 | | - int32_t broadcast_dim = internal::get_broadcast_dim(*lhs, *rhs); |
160 | | - int32_t broadcast_dim_lhs = lhs->dim() + broadcast_dim; |
161 | | - int32_t broadcast_dim_rhs = rhs->dim() + broadcast_dim; |
162 | | - auto normalized_tensor_size_lhs = get_normalized_tensor_size(*lhs, broadcast_dim_lhs); |
163 | | - outer_size = normalized_tensor_size_lhs[0]; |
164 | | - broadcast_size = normalized_tensor_size_lhs[1]; |
165 | | - inner_size = normalized_tensor_size_lhs[2]; |
166 | | - } else { |
167 | | - broadcast_size = lhs->sizes()[lhs->dim() - 2]; |
168 | | - inner_size = lhs->sizes()[lhs->dim() - 1]; |
169 | | - } |
170 | | - ET_SWITCH_REALB_TYPES(out_type, ctx, "mul.out", CTYPE, [&]() { |
171 | | - using Vec = executorch::vec::Vectorized<CTYPE>; |
172 | | - executorch::vec::broadcasting_map_3d_and_unsqueezed_3d<CTYPE>( |
173 | | - [](Vec x, Vec y) { return x * y; }, |
174 | | - out.mutable_data_ptr<CTYPE>(), |
175 | | - lhs->const_data_ptr<CTYPE>(), |
176 | | - rhs->const_data_ptr<CTYPE>(), |
177 | | - outer_size, |
178 | | - broadcast_size, |
179 | | - inner_size); |
180 | | - }); |
| 243 | + return handle_broadcast_mul(ctx, a, b, out, selected_optimized_path); |
181 | 244 | } else { |
182 | 245 | ScalarType common_type = |
183 | 246 | promoteTypes(a_type, b_type, /*half_to_float*/ true); |
|
0 commit comments