Skip to content

Commit 15b9c7d

Browse files
Clean up optimized op_mul
Differential Revision: D81199584 Pull Request resolved: #13763
1 parent 82e7249 commit 15b9c7d

File tree

2 files changed

+62
-128
lines changed

2 files changed

+62
-128
lines changed

kernels/optimized/cpu/op_mul.cpp

Lines changed: 60 additions & 128 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
#include <ATen/cpu/vec/vec.h>
1111
#include <executorch/kernels/optimized/cpu/binary_ops.h>
1212
#include <executorch/kernels/portable/cpu/scalar_utils.h>
13-
#include <executorch/kernels/portable/cpu/util/broadcast_util.h>
13+
#include <executorch/kernels/portable/cpu/util/elementwise_util.h>
1414
#include <executorch/runtime/core/exec_aten/util/tensor_util.h> // IWYU pragma: export
1515
#include <executorch/runtime/kernel/kernel_includes.h>
1616
#include <executorch/runtime/platform/assert.h>
@@ -22,76 +22,35 @@ namespace native {
2222
using Tensor = executorch::aten::Tensor;
2323
using ScalarType = executorch::aten::ScalarType;
2424

25-
namespace {
26-
27-
template <
28-
bool can_cast,
29-
typename CTYPE_A,
30-
typename CTYPE_B,
31-
typename CTYPE_IN,
32-
typename CTYPE_OUT>
33-
struct MulInner;
34-
35-
template <
36-
typename CTYPE_A,
37-
typename CTYPE_B,
38-
typename CTYPE_IN,
39-
typename CTYPE_OUT>
40-
struct MulInner<true, CTYPE_A, CTYPE_B, CTYPE_IN, CTYPE_OUT> {
41-
static void run(const Tensor& a, const Tensor& b, Tensor& out) {
42-
apply_binary_elementwise_fn<CTYPE_A, CTYPE_B, CTYPE_OUT>(
43-
// NOLINTNEXTLINE(facebook-hte-ConstantArgumentPassByValue)
44-
[](const CTYPE_A val_a, const CTYPE_B val_b) {
45-
CTYPE_IN a_casted = static_cast<CTYPE_IN>(val_a);
46-
CTYPE_IN b_casted = static_cast<CTYPE_IN>(val_b);
47-
CTYPE_IN value = a_casted * b_casted;
48-
49-
return static_cast<CTYPE_OUT>(value);
50-
},
51-
a,
52-
b,
53-
out);
54-
}
55-
};
56-
57-
struct ReportCanCastBug {
58-
static void run(const Tensor&, const Tensor&, Tensor&) {
59-
ET_DCHECK_MSG(false, "BUG: canCast should have been checked above");
60-
}
61-
};
62-
63-
template <
64-
typename CTYPE_A,
65-
typename CTYPE_B,
66-
typename CTYPE_IN,
67-
typename CTYPE_OUT>
68-
struct MulInner<false, CTYPE_A, CTYPE_B, CTYPE_IN, CTYPE_OUT>
69-
: public ReportCanCastBug {};
70-
71-
} // namespace
72-
7325
Tensor& opt_mul_out(
7426
KernelRuntimeContext& ctx,
7527
const Tensor& a,
7628
const Tensor& b,
7729
Tensor& out) {
78-
(void)ctx;
79-
8030
ScalarType a_type = a.scalar_type();
8131
ScalarType b_type = b.scalar_type();
8232
ScalarType out_type = out.scalar_type();
33+
ScalarType common_type = promoteTypes(a_type, b_type);
34+
35+
ET_KERNEL_CHECK(ctx, canCast(common_type, out_type), InvalidArgument, out);
36+
37+
ET_KERNEL_CHECK(
38+
ctx, tensors_have_same_dim_order(a, b, out), InvalidArgument, out);
39+
40+
ET_KERNEL_CHECK(
41+
ctx,
42+
resize_to_broadcast_target_size(a, b, out) == Error::Ok,
43+
InvalidArgument,
44+
out);
45+
46+
// @lint-ignore CLANGTIDY facebook-hte-CArray
47+
static constexpr const char op_name[] = "mul.out";
8348

8449
if (b.numel() == 1) {
8550
if (a_type == b_type && a_type == out_type && a_type != ScalarType::Half &&
8651
a_type != ScalarType::BFloat16) {
87-
ET_KERNEL_CHECK(
88-
ctx,
89-
resize_to_broadcast_target_size(a, b, out) == Error::Ok,
90-
InvalidArgument,
91-
out);
92-
93-
ET_SWITCH_REALB_TYPES(a_type, ctx, "mul.out", CTYPE, [&]() {
94-
ET_SWITCH_REALB_TYPES(b_type, ctx, "mul.out", CTYPE_B, [&]() {
52+
ET_SWITCH_REALB_TYPES(a_type, ctx, op_name, CTYPE, [&]() {
53+
ET_SWITCH_REALB_TYPES(b_type, ctx, op_name, CTYPE_B, [&]() {
9554
CTYPE_B b_val = *b.const_data_ptr<CTYPE_B>();
9655
CTYPE b_casted = static_cast<CTYPE>(b_val);
9756

@@ -111,17 +70,11 @@ Tensor& opt_mul_out(
11170

11271
auto selected_optimized_path = select_optimized_path(a, b, out);
11372
if (selected_optimized_path == ElementwiseOptimizedPath::kTreatAs1d) {
114-
ET_KERNEL_CHECK(
115-
ctx,
116-
resize_to_broadcast_target_size(a, b, out) == Error::Ok,
117-
InvalidArgument,
118-
out);
119-
12073
if (executorch::runtime::isComplexType(out_type)) {
12174
ET_KERNEL_CHECK(
12275
ctx, a_type == b_type && a_type == out_type, InvalidArgument, out);
12376

124-
ET_SWITCH_COMPLEXH_TYPES(out_type, ctx, "mul.out", CTYPE, [&]() {
77+
ET_SWITCH_COMPLEXH_TYPES(out_type, ctx, op_name, CTYPE, [&]() {
12578
using Vec = at::vec::Vectorized<CTYPE>;
12679
at::vec::map2<CTYPE>(
12780
[](Vec x, Vec y) { return x * y; },
@@ -131,7 +84,7 @@ Tensor& opt_mul_out(
13184
out.numel());
13285
});
13386
} else {
134-
ET_SWITCH_REALB_TYPES(out_type, ctx, "mul.out", CTYPE, [&]() {
87+
ET_SWITCH_REALB_TYPES(out_type, ctx, op_name, CTYPE, [&]() {
13588
using Vec = at::vec::Vectorized<CTYPE>;
13689
at::vec::map2<CTYPE>(
13790
[](Vec x, Vec y) { return x * y; },
@@ -146,63 +99,47 @@ Tensor& opt_mul_out(
14699
ET_KERNEL_CHECK(
147100
ctx, a_type == b_type && a_type == out_type, InvalidArgument, out);
148101

149-
ET_SWITCH_COMPLEXH_TYPES(out_type, ctx, "mul.out", CTYPE, [&]() {
102+
ET_SWITCH_COMPLEXH_TYPES(out_type, ctx, op_name, CTYPE, [&]() {
150103
auto mul_lambda = [](auto x, auto y) { return x * y; };
151104
torch::executor::handle_broadcast_elementwise<CTYPE>(
152105
ctx, mul_lambda, a, b, out, selected_optimized_path);
153106
});
154107
} else {
155-
ET_SWITCH_REALB_TYPES(out_type, ctx, "mul.out", CTYPE, [&]() {
108+
ET_SWITCH_REALB_TYPES(out_type, ctx, op_name, CTYPE, [&]() {
156109
auto mul_lambda = [](auto x, auto y) { return x * y; };
157110
torch::executor::handle_broadcast_elementwise<CTYPE>(
158111
ctx, mul_lambda, a, b, out, selected_optimized_path);
159112
});
160113
}
161114
} else {
162-
ScalarType common_type =
163-
promoteTypes(a_type, b_type, /*half_to_float*/ true);
164-
ET_KERNEL_CHECK(ctx, canCast(common_type, out_type), InvalidArgument, out);
165-
166-
ET_KERNEL_CHECK(
167-
ctx,
168-
resize_to_broadcast_target_size(a, b, out) == Error::Ok,
169-
InvalidArgument,
170-
out);
171-
172115
if (executorch::runtime::isComplexType(a_type) ||
173116
executorch::runtime::isComplexType(b_type) ||
174117
executorch::runtime::isComplexType(out_type)) {
175118
ET_KERNEL_CHECK(
176119
ctx, a_type == b_type && a_type == out_type, InvalidArgument, out);
177120

178-
ET_SWITCH_COMPLEXH_TYPES(out_type, ctx, "mul.out", CTYPE, [&]() {
121+
ET_SWITCH_COMPLEXH_TYPES(out_type, ctx, op_name, CTYPE, [&]() {
179122
apply_binary_elementwise_fn<CTYPE, CTYPE, CTYPE>(
180123
[](const CTYPE val_a, const CTYPE val_b) { return val_a * val_b; },
181124
a,
182125
b,
183126
out);
184127
});
185128
} else {
186-
ET_SWITCH_REALHBBF16_TYPES(a_type, ctx, "mul.out", CTYPE_A, [&]() {
187-
ET_SWITCH_REALHBBF16_TYPES(b_type, ctx, "mul.out", CTYPE_B, [&]() {
188-
using CTYPE_IN = typename torch::executor::
189-
promote_types<CTYPE_A, CTYPE_B, /*half_to_float*/ true>::type;
190-
ET_DCHECK(CppTypeToScalarType<CTYPE_IN>::value == common_type);
191-
ET_SWITCH_REALHBBF16_TYPES(
192-
out_type, ctx, "mul.out", CTYPE_OUT, [&]() {
193-
apply_binary_elementwise_fn<CTYPE_A, CTYPE_B, CTYPE_OUT>(
194-
[](const CTYPE_A val_a, const CTYPE_B val_b) {
195-
CTYPE_IN a_casted = static_cast<CTYPE_IN>(val_a);
196-
CTYPE_IN b_casted = static_cast<CTYPE_IN>(val_b);
197-
CTYPE_IN value = a_casted * b_casted;
198-
199-
return static_cast<CTYPE_OUT>(value);
200-
},
201-
a,
202-
b,
203-
out);
204-
});
205-
});
129+
ScalarType compute_type = utils::internal::get_compute_type(common_type);
130+
131+
ET_SWITCH_REALB_TYPES(compute_type, ctx, op_name, CTYPE_COMPUTE, [&]() {
132+
utils::apply_bitensor_elementwise_fn<
133+
CTYPE_COMPUTE,
134+
op_name,
135+
utils::SupportedTensorDtypes::REALHBBF16>(
136+
[](const auto val_a, const auto val_b) { return val_a * val_b; },
137+
ctx,
138+
a,
139+
utils::SupportedTensorDtypes::REALHBBF16,
140+
b,
141+
utils::SupportedTensorDtypes::REALHBBF16,
142+
out);
206143
});
207144
}
208145
}
@@ -215,26 +152,24 @@ Tensor& opt_mul_scalar_out(
215152
const Tensor& a,
216153
const Scalar& b,
217154
Tensor& out) {
218-
(void)ctx;
219-
220155
ScalarType a_type = a.scalar_type();
221-
ScalarType common_type =
222-
utils::promote_type_with_scalar(a_type, b, /*half_to_float*/ false);
156+
ScalarType common_type = utils::promote_type_with_scalar(a_type, b);
223157
ScalarType out_type = out.scalar_type();
224158

225-
ET_CHECK(common_type == out_type);
159+
ET_KERNEL_CHECK(ctx, common_type == out_type, InvalidArgument, out);
226160

227-
if (common_type == ScalarType::Half || common_type == ScalarType::BFloat16) {
228-
common_type = ScalarType::Float;
229-
}
161+
ET_KERNEL_CHECK(
162+
ctx, tensors_have_same_dim_order(a, out), InvalidArgument, out);
230163

231-
// Resize for dynamic shape
232-
auto error = resize_tensor(out, a.sizes());
233-
ET_CHECK_MSG(error == Error::Ok, "Failed to resize output tensor.");
164+
ET_KERNEL_CHECK(
165+
ctx, resize_tensor(out, a.sizes()) == Error::Ok, InvalidArgument, out);
166+
167+
// @lint-ignore CLANGTIDY facebook-hte-CArray
168+
static constexpr const char op_name[] = "mul.Scalar_out";
234169

235170
if (a_type == common_type && a_type == out_type &&
236171
a_type != ScalarType::Half && a_type != ScalarType::BFloat16) {
237-
ET_SWITCH_REALB_TYPES(a_type, ctx, "mul.Scalar_out", CTYPE, [&]() {
172+
ET_SWITCH_REALB_TYPES(a_type, ctx, op_name, CTYPE, [&]() {
238173
CTYPE b_casted = utils::scalar_to<CTYPE>(b);
239174

240175
using Vec = at::vec::Vectorized<CTYPE>;
@@ -245,22 +180,19 @@ Tensor& opt_mul_scalar_out(
245180
out.numel());
246181
});
247182
} else {
248-
ET_SWITCH_REALHBBF16_TYPES(a_type, ctx, "mul.Scalar_out", CTYPE_A, [&]() {
249-
ET_SWITCH_REALB_TYPES(
250-
common_type, ctx, "mul.Scalar_out", CTYPE_IN, [&]() {
251-
ET_SWITCH_REALHBBF16_TYPES(
252-
out_type, ctx, "mul.Scalar_out", CTYPE_OUT, [&]() {
253-
CTYPE_IN b_casted = utils::scalar_to<CTYPE_IN>(b);
254-
255-
const size_t n = a.numel();
256-
const CTYPE_A* a_data = a.const_data_ptr<CTYPE_A>();
257-
CTYPE_OUT* out_data = out.mutable_data_ptr<CTYPE_OUT>();
258-
for (auto i = 0; i < n; ++i) {
259-
out_data[i] = static_cast<CTYPE_OUT>(
260-
static_cast<CTYPE_IN>(a_data[i]) * b_casted);
261-
}
262-
});
263-
});
183+
ScalarType compute_type = utils::internal::get_compute_type(common_type);
184+
185+
ET_SWITCH_REALB_TYPES(compute_type, ctx, op_name, CTYPE_COMPUTE, [&]() {
186+
const CTYPE_COMPUTE val_b = utils::scalar_to<CTYPE_COMPUTE>(b);
187+
utils::apply_unitensor_elementwise_fn<
188+
CTYPE_COMPUTE,
189+
op_name,
190+
utils::SupportedTensorDtypes::SAME_AS_COMMON>(
191+
[val_b](const auto val_a) { return val_a * val_b; },
192+
ctx,
193+
a,
194+
utils::SupportedTensorDtypes::REALHBBF16,
195+
out);
264196
});
265197
}
266198

shim_et/xplat/executorch/kernels/optimized/op_registration_util.bzl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -261,6 +261,8 @@ OPTIMIZED_ATEN_OPS = (
261261
":binary_ops",
262262
"//executorch/kernels/portable/cpu:scalar_utils",
263263
"//executorch/kernels/portable/cpu/util:broadcast_util",
264+
"//executorch/kernels/portable/cpu/util:dtype_util",
265+
"//executorch/kernels/portable/cpu/util:elementwise_util",
264266
"//executorch/runtime/core/exec_aten/util:tensor_util",
265267
"//executorch/runtime/core/portable_type/c10/c10:aten_headers_for_executorch",
266268
],

0 commit comments

Comments
 (0)