10
10
#include < ATen/cpu/vec/vec.h>
11
11
#include < executorch/kernels/optimized/cpu/binary_ops.h>
12
12
#include < executorch/kernels/portable/cpu/scalar_utils.h>
13
- #include < executorch/kernels/portable/cpu/util/broadcast_util .h>
13
+ #include < executorch/kernels/portable/cpu/util/elementwise_util .h>
14
14
#include < executorch/runtime/core/exec_aten/util/tensor_util.h> // IWYU pragma: export
15
15
#include < executorch/runtime/kernel/kernel_includes.h>
16
16
#include < executorch/runtime/platform/assert.h>
@@ -22,76 +22,35 @@ namespace native {
22
22
using Tensor = executorch::aten::Tensor;
23
23
using ScalarType = executorch::aten::ScalarType;
24
24
25
- namespace {
26
-
27
- template <
28
- bool can_cast,
29
- typename CTYPE_A,
30
- typename CTYPE_B,
31
- typename CTYPE_IN,
32
- typename CTYPE_OUT>
33
- struct MulInner ;
34
-
35
- template <
36
- typename CTYPE_A,
37
- typename CTYPE_B,
38
- typename CTYPE_IN,
39
- typename CTYPE_OUT>
40
- struct MulInner <true , CTYPE_A, CTYPE_B, CTYPE_IN, CTYPE_OUT> {
41
- static void run (const Tensor& a, const Tensor& b, Tensor& out) {
42
- apply_binary_elementwise_fn<CTYPE_A, CTYPE_B, CTYPE_OUT>(
43
- // NOLINTNEXTLINE(facebook-hte-ConstantArgumentPassByValue)
44
- [](const CTYPE_A val_a, const CTYPE_B val_b) {
45
- CTYPE_IN a_casted = static_cast <CTYPE_IN>(val_a);
46
- CTYPE_IN b_casted = static_cast <CTYPE_IN>(val_b);
47
- CTYPE_IN value = a_casted * b_casted;
48
-
49
- return static_cast <CTYPE_OUT>(value);
50
- },
51
- a,
52
- b,
53
- out);
54
- }
55
- };
56
-
57
- struct ReportCanCastBug {
58
- static void run (const Tensor&, const Tensor&, Tensor&) {
59
- ET_DCHECK_MSG (false , " BUG: canCast should have been checked above" );
60
- }
61
- };
62
-
63
- template <
64
- typename CTYPE_A,
65
- typename CTYPE_B,
66
- typename CTYPE_IN,
67
- typename CTYPE_OUT>
68
- struct MulInner <false , CTYPE_A, CTYPE_B, CTYPE_IN, CTYPE_OUT>
69
- : public ReportCanCastBug {};
70
-
71
- } // namespace
72
-
73
25
Tensor& opt_mul_out (
74
26
KernelRuntimeContext& ctx,
75
27
const Tensor& a,
76
28
const Tensor& b,
77
29
Tensor& out) {
78
- (void )ctx;
79
-
80
30
ScalarType a_type = a.scalar_type ();
81
31
ScalarType b_type = b.scalar_type ();
82
32
ScalarType out_type = out.scalar_type ();
33
+ ScalarType common_type = promoteTypes (a_type, b_type);
34
+
35
+ ET_KERNEL_CHECK (ctx, canCast (common_type, out_type), InvalidArgument, out);
36
+
37
+ ET_KERNEL_CHECK (
38
+ ctx, tensors_have_same_dim_order (a, b, out), InvalidArgument, out);
39
+
40
+ ET_KERNEL_CHECK (
41
+ ctx,
42
+ resize_to_broadcast_target_size (a, b, out) == Error::Ok,
43
+ InvalidArgument,
44
+ out);
45
+
46
+ // @lint-ignore CLANGTIDY facebook-hte-CArray
47
+ static constexpr const char op_name[] = " mul.out" ;
83
48
84
49
if (b.numel () == 1 ) {
85
50
if (a_type == b_type && a_type == out_type && a_type != ScalarType::Half &&
86
51
a_type != ScalarType::BFloat16) {
87
- ET_KERNEL_CHECK (
88
- ctx,
89
- resize_to_broadcast_target_size (a, b, out) == Error::Ok,
90
- InvalidArgument,
91
- out);
92
-
93
- ET_SWITCH_REALB_TYPES (a_type, ctx, " mul.out" , CTYPE, [&]() {
94
- ET_SWITCH_REALB_TYPES (b_type, ctx, " mul.out" , CTYPE_B, [&]() {
52
+ ET_SWITCH_REALB_TYPES (a_type, ctx, op_name, CTYPE, [&]() {
53
+ ET_SWITCH_REALB_TYPES (b_type, ctx, op_name, CTYPE_B, [&]() {
95
54
CTYPE_B b_val = *b.const_data_ptr <CTYPE_B>();
96
55
CTYPE b_casted = static_cast <CTYPE>(b_val);
97
56
@@ -111,17 +70,11 @@ Tensor& opt_mul_out(
111
70
112
71
auto selected_optimized_path = select_optimized_path (a, b, out);
113
72
if (selected_optimized_path == ElementwiseOptimizedPath::kTreatAs1d ) {
114
- ET_KERNEL_CHECK (
115
- ctx,
116
- resize_to_broadcast_target_size (a, b, out) == Error::Ok,
117
- InvalidArgument,
118
- out);
119
-
120
73
if (executorch::runtime::isComplexType (out_type)) {
121
74
ET_KERNEL_CHECK (
122
75
ctx, a_type == b_type && a_type == out_type, InvalidArgument, out);
123
76
124
- ET_SWITCH_COMPLEXH_TYPES (out_type, ctx, " mul.out " , CTYPE, [&]() {
77
+ ET_SWITCH_COMPLEXH_TYPES (out_type, ctx, op_name , CTYPE, [&]() {
125
78
using Vec = at::vec::Vectorized<CTYPE>;
126
79
at::vec::map2<CTYPE>(
127
80
[](Vec x, Vec y) { return x * y; },
@@ -131,7 +84,7 @@ Tensor& opt_mul_out(
131
84
out.numel ());
132
85
});
133
86
} else {
134
- ET_SWITCH_REALB_TYPES (out_type, ctx, " mul.out " , CTYPE, [&]() {
87
+ ET_SWITCH_REALB_TYPES (out_type, ctx, op_name , CTYPE, [&]() {
135
88
using Vec = at::vec::Vectorized<CTYPE>;
136
89
at::vec::map2<CTYPE>(
137
90
[](Vec x, Vec y) { return x * y; },
@@ -146,63 +99,47 @@ Tensor& opt_mul_out(
146
99
ET_KERNEL_CHECK (
147
100
ctx, a_type == b_type && a_type == out_type, InvalidArgument, out);
148
101
149
- ET_SWITCH_COMPLEXH_TYPES (out_type, ctx, " mul.out " , CTYPE, [&]() {
102
+ ET_SWITCH_COMPLEXH_TYPES (out_type, ctx, op_name , CTYPE, [&]() {
150
103
auto mul_lambda = [](auto x, auto y) { return x * y; };
151
104
torch::executor::handle_broadcast_elementwise<CTYPE>(
152
105
ctx, mul_lambda, a, b, out, selected_optimized_path);
153
106
});
154
107
} else {
155
- ET_SWITCH_REALB_TYPES (out_type, ctx, " mul.out " , CTYPE, [&]() {
108
+ ET_SWITCH_REALB_TYPES (out_type, ctx, op_name , CTYPE, [&]() {
156
109
auto mul_lambda = [](auto x, auto y) { return x * y; };
157
110
torch::executor::handle_broadcast_elementwise<CTYPE>(
158
111
ctx, mul_lambda, a, b, out, selected_optimized_path);
159
112
});
160
113
}
161
114
} else {
162
- ScalarType common_type =
163
- promoteTypes (a_type, b_type, /* half_to_float*/ true );
164
- ET_KERNEL_CHECK (ctx, canCast (common_type, out_type), InvalidArgument, out);
165
-
166
- ET_KERNEL_CHECK (
167
- ctx,
168
- resize_to_broadcast_target_size (a, b, out) == Error::Ok,
169
- InvalidArgument,
170
- out);
171
-
172
115
if (executorch::runtime::isComplexType (a_type) ||
173
116
executorch::runtime::isComplexType (b_type) ||
174
117
executorch::runtime::isComplexType (out_type)) {
175
118
ET_KERNEL_CHECK (
176
119
ctx, a_type == b_type && a_type == out_type, InvalidArgument, out);
177
120
178
- ET_SWITCH_COMPLEXH_TYPES (out_type, ctx, " mul.out " , CTYPE, [&]() {
121
+ ET_SWITCH_COMPLEXH_TYPES (out_type, ctx, op_name , CTYPE, [&]() {
179
122
apply_binary_elementwise_fn<CTYPE, CTYPE, CTYPE>(
180
123
[](const CTYPE val_a, const CTYPE val_b) { return val_a * val_b; },
181
124
a,
182
125
b,
183
126
out);
184
127
});
185
128
} else {
186
- ET_SWITCH_REALHBBF16_TYPES (a_type, ctx, " mul.out" , CTYPE_A, [&]() {
187
- ET_SWITCH_REALHBBF16_TYPES (b_type, ctx, " mul.out" , CTYPE_B, [&]() {
188
- using CTYPE_IN = typename torch::executor::
189
- promote_types<CTYPE_A, CTYPE_B, /* half_to_float*/ true >::type;
190
- ET_DCHECK (CppTypeToScalarType<CTYPE_IN>::value == common_type);
191
- ET_SWITCH_REALHBBF16_TYPES (
192
- out_type, ctx, " mul.out" , CTYPE_OUT, [&]() {
193
- apply_binary_elementwise_fn<CTYPE_A, CTYPE_B, CTYPE_OUT>(
194
- [](const CTYPE_A val_a, const CTYPE_B val_b) {
195
- CTYPE_IN a_casted = static_cast <CTYPE_IN>(val_a);
196
- CTYPE_IN b_casted = static_cast <CTYPE_IN>(val_b);
197
- CTYPE_IN value = a_casted * b_casted;
198
-
199
- return static_cast <CTYPE_OUT>(value);
200
- },
201
- a,
202
- b,
203
- out);
204
- });
205
- });
129
+ ScalarType compute_type = utils::internal::get_compute_type (common_type);
130
+
131
+ ET_SWITCH_REALB_TYPES (compute_type, ctx, op_name, CTYPE_COMPUTE, [&]() {
132
+ utils::apply_bitensor_elementwise_fn<
133
+ CTYPE_COMPUTE,
134
+ op_name,
135
+ utils::SupportedTensorDtypes::REALHBBF16>(
136
+ [](const auto val_a, const auto val_b) { return val_a * val_b; },
137
+ ctx,
138
+ a,
139
+ utils::SupportedTensorDtypes::REALHBBF16,
140
+ b,
141
+ utils::SupportedTensorDtypes::REALHBBF16,
142
+ out);
206
143
});
207
144
}
208
145
}
@@ -215,26 +152,24 @@ Tensor& opt_mul_scalar_out(
215
152
const Tensor& a,
216
153
const Scalar& b,
217
154
Tensor& out) {
218
- (void )ctx;
219
-
220
155
ScalarType a_type = a.scalar_type ();
221
- ScalarType common_type =
222
- utils::promote_type_with_scalar (a_type, b, /* half_to_float*/ false );
156
+ ScalarType common_type = utils::promote_type_with_scalar (a_type, b);
223
157
ScalarType out_type = out.scalar_type ();
224
158
225
- ET_CHECK ( common_type == out_type);
159
+ ET_KERNEL_CHECK (ctx, common_type == out_type, InvalidArgument, out );
226
160
227
- if (common_type == ScalarType::Half || common_type == ScalarType::BFloat16) {
228
- common_type = ScalarType::Float;
229
- }
161
+ ET_KERNEL_CHECK (
162
+ ctx, tensors_have_same_dim_order (a, out), InvalidArgument, out);
230
163
231
- // Resize for dynamic shape
232
- auto error = resize_tensor (out, a.sizes ());
233
- ET_CHECK_MSG (error == Error::Ok, " Failed to resize output tensor." );
164
+ ET_KERNEL_CHECK (
165
+ ctx, resize_tensor (out, a.sizes ()) == Error::Ok, InvalidArgument, out);
166
+
167
+ // @lint-ignore CLANGTIDY facebook-hte-CArray
168
+ static constexpr const char op_name[] = " mul.Scalar_out" ;
234
169
235
170
if (a_type == common_type && a_type == out_type &&
236
171
a_type != ScalarType::Half && a_type != ScalarType::BFloat16) {
237
- ET_SWITCH_REALB_TYPES (a_type, ctx, " mul.Scalar_out " , CTYPE, [&]() {
172
+ ET_SWITCH_REALB_TYPES (a_type, ctx, op_name , CTYPE, [&]() {
238
173
CTYPE b_casted = utils::scalar_to<CTYPE>(b);
239
174
240
175
using Vec = at::vec::Vectorized<CTYPE>;
@@ -245,22 +180,19 @@ Tensor& opt_mul_scalar_out(
245
180
out.numel ());
246
181
});
247
182
} else {
248
- ET_SWITCH_REALHBBF16_TYPES (a_type, ctx, " mul.Scalar_out" , CTYPE_A, [&]() {
249
- ET_SWITCH_REALB_TYPES (
250
- common_type, ctx, " mul.Scalar_out" , CTYPE_IN, [&]() {
251
- ET_SWITCH_REALHBBF16_TYPES (
252
- out_type, ctx, " mul.Scalar_out" , CTYPE_OUT, [&]() {
253
- CTYPE_IN b_casted = utils::scalar_to<CTYPE_IN>(b);
254
-
255
- const size_t n = a.numel ();
256
- const CTYPE_A* a_data = a.const_data_ptr <CTYPE_A>();
257
- CTYPE_OUT* out_data = out.mutable_data_ptr <CTYPE_OUT>();
258
- for (auto i = 0 ; i < n; ++i) {
259
- out_data[i] = static_cast <CTYPE_OUT>(
260
- static_cast <CTYPE_IN>(a_data[i]) * b_casted);
261
- }
262
- });
263
- });
183
+ ScalarType compute_type = utils::internal::get_compute_type (common_type);
184
+
185
+ ET_SWITCH_REALB_TYPES (compute_type, ctx, op_name, CTYPE_COMPUTE, [&]() {
186
+ const CTYPE_COMPUTE val_b = utils::scalar_to<CTYPE_COMPUTE>(b);
187
+ utils::apply_unitensor_elementwise_fn<
188
+ CTYPE_COMPUTE,
189
+ op_name,
190
+ utils::SupportedTensorDtypes::SAME_AS_COMMON>(
191
+ [val_b](const auto val_a) { return val_a * val_b; },
192
+ ctx,
193
+ a,
194
+ utils::SupportedTensorDtypes::REALHBBF16,
195
+ out);
264
196
});
265
197
}
266
198
0 commit comments