|
9 | 9 | #include <executorch/kernels/optimized/cpu/binary_ops.h> |
10 | 10 | #include <executorch/kernels/optimized/vec/functional.h> |
11 | 11 | #include <executorch/kernels/optimized/vec/vec.h> |
| 12 | +#include <executorch/kernels/portable/cpu/op_mul_impl.h> |
12 | 13 | #include <executorch/kernels/portable/cpu/scalar_utils.h> |
13 | 14 | #include <executorch/kernels/portable/cpu/util/broadcast_util.h> |
14 | 15 | #include <executorch/runtime/core/exec_aten/util/tensor_util.h> // IWYU pragma: export |
@@ -240,36 +241,7 @@ Tensor& opt_mul_out( |
240 | 241 | } else if (selected_optimized_path != ElementwiseOptimizedPath::kNone) { |
241 | 242 | return handle_broadcast_mul(ctx, a, b, out, selected_optimized_path); |
242 | 243 | } else { |
243 | | - ScalarType common_type = |
244 | | - promoteTypes(a_type, b_type, /*half_to_float*/ true); |
245 | | - ET_KERNEL_CHECK(ctx, canCast(common_type, out_type), InvalidArgument, out); |
246 | | - |
247 | | - ET_KERNEL_CHECK( |
248 | | - ctx, |
249 | | - resize_to_broadcast_target_size(a, b, out) == Error::Ok, |
250 | | - InvalidArgument, |
251 | | - out); |
252 | | - |
253 | | - ET_SWITCH_REALHBBF16_TYPES(a_type, ctx, "mul.out", CTYPE_A, [&]() { |
254 | | - ET_SWITCH_REALHBBF16_TYPES(b_type, ctx, "mul.out", CTYPE_B, [&]() { |
255 | | - using CTYPE_IN = typename torch::executor:: |
256 | | - promote_types<CTYPE_A, CTYPE_B, /*half_to_float*/ true>::type; |
257 | | - ET_DCHECK(CppTypeToScalarType<CTYPE_IN>::value == common_type); |
258 | | - ET_SWITCH_REALHBBF16_TYPES(out_type, ctx, "mul.out", CTYPE_OUT, [&]() { |
259 | | - apply_binary_elementwise_fn<CTYPE_A, CTYPE_B, CTYPE_OUT>( |
260 | | - [](const CTYPE_A val_a, const CTYPE_B val_b) { |
261 | | - CTYPE_IN a_casted = static_cast<CTYPE_IN>(val_a); |
262 | | - CTYPE_IN b_casted = static_cast<CTYPE_IN>(val_b); |
263 | | - CTYPE_IN value = a_casted * b_casted; |
264 | | - |
265 | | - return static_cast<CTYPE_OUT>(value); |
266 | | - }, |
267 | | - a, |
268 | | - b, |
269 | | - out); |
270 | | - }); |
271 | | - }); |
272 | | - }); |
| 244 | + mul_out_impl(ctx, a, b, out); |
273 | 245 | } |
274 | 246 |
|
275 | 247 | return out; |
@@ -315,27 +287,7 @@ Tensor& opt_mul_scalar_out( |
315 | 287 | }); |
316 | 288 | }); |
317 | 289 | } else { |
318 | | - ET_SWITCH_REALHBBF16_TYPES(a_type, ctx, "mul.Scalar_out", CTYPE_A, [&]() { |
319 | | - ET_SWITCH_SCALAR_OBJ_TYPES(b_type, ctx, "mul.Scalar_out", CTYPE_B, [&]() { |
320 | | - ET_SWITCH_REALB_TYPES( |
321 | | - common_type, ctx, "mul.Scalar_out", CTYPE_IN, [&]() { |
322 | | - ET_SWITCH_REALHBBF16_TYPES( |
323 | | - out_type, ctx, "mul.Scalar_out", CTYPE_OUT, [&]() { |
324 | | - CTYPE_B b_val; |
325 | | - ET_EXTRACT_SCALAR(b, b_val); |
326 | | - CTYPE_IN b_casted = static_cast<CTYPE_IN>(b_val); |
327 | | - |
328 | | - const size_t n = a.numel(); |
329 | | - const CTYPE_A* a_data = a.const_data_ptr<CTYPE_A>(); |
330 | | - CTYPE_OUT* out_data = out.mutable_data_ptr<CTYPE_OUT>(); |
331 | | - for (auto i = 0; i < n; ++i) { |
332 | | - out_data[i] = static_cast<CTYPE_OUT>( |
333 | | - static_cast<CTYPE_IN>(a_data[i]) * b_casted); |
334 | | - } |
335 | | - }); |
336 | | - }); |
337 | | - }); |
338 | | - }); |
| 290 | + mul_scalar_out_impl(ctx, a, b, out); |
339 | 291 | } |
340 | 292 |
|
341 | 293 | return out; |
|
0 commit comments