diff --git a/kernels/optimized/cpu/op_mul.cpp b/kernels/optimized/cpu/op_mul.cpp index 7feee4e1566..2240a41c838 100644 --- a/kernels/optimized/cpu/op_mul.cpp +++ b/kernels/optimized/cpu/op_mul.cpp @@ -9,6 +9,7 @@ #include #include #include +#include #include #include #include // IWYU pragma: export @@ -240,36 +241,7 @@ Tensor& opt_mul_out( } else if (selected_optimized_path != ElementwiseOptimizedPath::kNone) { return handle_broadcast_mul(ctx, a, b, out, selected_optimized_path); } else { - ScalarType common_type = - promoteTypes(a_type, b_type, /*half_to_float*/ true); - ET_KERNEL_CHECK(ctx, canCast(common_type, out_type), InvalidArgument, out); - - ET_KERNEL_CHECK( - ctx, - resize_to_broadcast_target_size(a, b, out) == Error::Ok, - InvalidArgument, - out); - - ET_SWITCH_REALHBBF16_TYPES(a_type, ctx, "mul.out", CTYPE_A, [&]() { - ET_SWITCH_REALHBBF16_TYPES(b_type, ctx, "mul.out", CTYPE_B, [&]() { - using CTYPE_IN = typename torch::executor:: - promote_types::type; - ET_DCHECK(CppTypeToScalarType::value == common_type); - ET_SWITCH_REALHBBF16_TYPES(out_type, ctx, "mul.out", CTYPE_OUT, [&]() { - apply_binary_elementwise_fn( - [](const CTYPE_A val_a, const CTYPE_B val_b) { - CTYPE_IN a_casted = static_cast(val_a); - CTYPE_IN b_casted = static_cast(val_b); - CTYPE_IN value = a_casted * b_casted; - - return static_cast(value); - }, - a, - b, - out); - }); - }); - }); + mul_out_impl(ctx, a, b, out); } return out; @@ -315,27 +287,7 @@ Tensor& opt_mul_scalar_out( }); }); } else { - ET_SWITCH_REALHBBF16_TYPES(a_type, ctx, "mul.Scalar_out", CTYPE_A, [&]() { - ET_SWITCH_SCALAR_OBJ_TYPES(b_type, ctx, "mul.Scalar_out", CTYPE_B, [&]() { - ET_SWITCH_REALB_TYPES( - common_type, ctx, "mul.Scalar_out", CTYPE_IN, [&]() { - ET_SWITCH_REALHBBF16_TYPES( - out_type, ctx, "mul.Scalar_out", CTYPE_OUT, [&]() { - CTYPE_B b_val; - ET_EXTRACT_SCALAR(b, b_val); - CTYPE_IN b_casted = static_cast(b_val); - - const size_t n = a.numel(); - const CTYPE_A* a_data = a.const_data_ptr(); - CTYPE_OUT* out_data = out.mutable_data_ptr(); - for (auto i = 0; i < n; ++i) { - out_data[i] = static_cast( - static_cast(a_data[i]) * b_casted); - } - }); - }); - }); - }); + mul_scalar_out_impl(ctx, a, b, out); } return out; diff --git a/kernels/optimized/cpu/targets.bzl b/kernels/optimized/cpu/targets.bzl index dc186d70a8b..c3dc61269b5 100644 --- a/kernels/optimized/cpu/targets.bzl +++ b/kernels/optimized/cpu/targets.bzl @@ -69,6 +69,7 @@ _OPTIMIZED_ATEN_OPS = ( name = "op_mul", deps = [ ":binary_ops", + "//executorch/kernels/portable/cpu:op_mul_impl", "//executorch/kernels/portable/cpu:scalar_utils", "//executorch/kernels/portable/cpu/util:broadcast_util", "//executorch/runtime/core/exec_aten/util:tensor_util",