Skip to content

Commit 87886a5

Browse files
committed
[Exeuctorch][Optimized] Use portable path for fallback in op_mul
Remove earlier copy paste from portable impl for fallack path. By relying directly on portable impl we can benefit from future optimizations in portable impl and avoid missing things like we did here. Differential Revision: [D65628858](https://our.internmc.facebook.com/intern/diff/D65628858/) [ghstack-poisoned]
1 parent e81749d commit 87886a5

File tree

2 files changed

+4
-51
lines changed

2 files changed

+4
-51
lines changed

kernels/optimized/cpu/op_mul.cpp

Lines changed: 3 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
#include <executorch/kernels/optimized/cpu/binary_ops.h>
1010
#include <executorch/kernels/optimized/vec/functional.h>
1111
#include <executorch/kernels/optimized/vec/vec.h>
12+
#include <executorch/kernels/portable/cpu/op_mul_impl.h>
1213
#include <executorch/kernels/portable/cpu/scalar_utils.h>
1314
#include <executorch/kernels/portable/cpu/util/broadcast_util.h>
1415
#include <executorch/runtime/core/exec_aten/util/tensor_util.h> // IWYU pragma: export
@@ -240,36 +241,7 @@ Tensor& opt_mul_out(
240241
} else if (selected_optimized_path != ElementwiseOptimizedPath::kNone) {
241242
return handle_broadcast_mul(ctx, a, b, out, selected_optimized_path);
242243
} else {
243-
ScalarType common_type =
244-
promoteTypes(a_type, b_type, /*half_to_float*/ true);
245-
ET_KERNEL_CHECK(ctx, canCast(common_type, out_type), InvalidArgument, out);
246-
247-
ET_KERNEL_CHECK(
248-
ctx,
249-
resize_to_broadcast_target_size(a, b, out) == Error::Ok,
250-
InvalidArgument,
251-
out);
252-
253-
ET_SWITCH_REALHBBF16_TYPES(a_type, ctx, "mul.out", CTYPE_A, [&]() {
254-
ET_SWITCH_REALHBBF16_TYPES(b_type, ctx, "mul.out", CTYPE_B, [&]() {
255-
using CTYPE_IN = typename torch::executor::
256-
promote_types<CTYPE_A, CTYPE_B, /*half_to_float*/ true>::type;
257-
ET_DCHECK(CppTypeToScalarType<CTYPE_IN>::value == common_type);
258-
ET_SWITCH_REALHBBF16_TYPES(out_type, ctx, "mul.out", CTYPE_OUT, [&]() {
259-
apply_binary_elementwise_fn<CTYPE_A, CTYPE_B, CTYPE_OUT>(
260-
[](const CTYPE_A val_a, const CTYPE_B val_b) {
261-
CTYPE_IN a_casted = static_cast<CTYPE_IN>(val_a);
262-
CTYPE_IN b_casted = static_cast<CTYPE_IN>(val_b);
263-
CTYPE_IN value = a_casted * b_casted;
264-
265-
return static_cast<CTYPE_OUT>(value);
266-
},
267-
a,
268-
b,
269-
out);
270-
});
271-
});
272-
});
244+
mul_out_impl(ctx, a, b, out);
273245
}
274246

275247
return out;
@@ -315,27 +287,7 @@ Tensor& opt_mul_scalar_out(
315287
});
316288
});
317289
} else {
318-
ET_SWITCH_REALHBBF16_TYPES(a_type, ctx, "mul.Scalar_out", CTYPE_A, [&]() {
319-
ET_SWITCH_SCALAR_OBJ_TYPES(b_type, ctx, "mul.Scalar_out", CTYPE_B, [&]() {
320-
ET_SWITCH_REALB_TYPES(
321-
common_type, ctx, "mul.Scalar_out", CTYPE_IN, [&]() {
322-
ET_SWITCH_REALHBBF16_TYPES(
323-
out_type, ctx, "mul.Scalar_out", CTYPE_OUT, [&]() {
324-
CTYPE_B b_val;
325-
ET_EXTRACT_SCALAR(b, b_val);
326-
CTYPE_IN b_casted = static_cast<CTYPE_IN>(b_val);
327-
328-
const size_t n = a.numel();
329-
const CTYPE_A* a_data = a.const_data_ptr<CTYPE_A>();
330-
CTYPE_OUT* out_data = out.mutable_data_ptr<CTYPE_OUT>();
331-
for (auto i = 0; i < n; ++i) {
332-
out_data[i] = static_cast<CTYPE_OUT>(
333-
static_cast<CTYPE_IN>(a_data[i]) * b_casted);
334-
}
335-
});
336-
});
337-
});
338-
});
290+
mul_scalar_out_impl(ctx, a, b, out);
339291
}
340292

341293
return out;

kernels/optimized/cpu/targets.bzl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,7 @@ _OPTIMIZED_ATEN_OPS = (
6969
name = "op_mul",
7070
deps = [
7171
":binary_ops",
72+
"//executorch/kernels/portable/cpu:op_mul_impl",
7273
"//executorch/kernels/portable/cpu:scalar_utils",
7374
"//executorch/kernels/portable/cpu/util:broadcast_util",
7475
"//executorch/runtime/core/exec_aten/util:tensor_util",

0 commit comments

Comments
 (0)