10
10
#include < ATen/cpu/vec/vec.h>
11
11
#include < executorch/kernels/optimized/cpu/binary_ops.h>
12
12
#include < executorch/kernels/portable/cpu/scalar_utils.h>
13
- #include < executorch/kernels/portable/cpu/util/broadcast_util .h>
13
+ #include < executorch/kernels/portable/cpu/util/elementwise_util .h>
14
14
#include < executorch/runtime/kernel/kernel_includes.h>
15
15
#include < executorch/runtime/platform/assert.h>
16
16
@@ -20,7 +20,7 @@ namespace native {
20
20
21
21
namespace {
22
22
23
- ScalarType get_compute_type (ScalarType a_type, ScalarType b_type) {
23
+ ScalarType get_common_type (ScalarType a_type, ScalarType b_type) {
24
24
ET_CHECK (
25
25
!isComplexType (a_type) && !isQIntType (a_type) && !isBitsType (a_type));
26
26
ET_CHECK (
@@ -43,14 +43,27 @@ Tensor& opt_div_out(
43
43
const Tensor& a,
44
44
const Tensor& b,
45
45
Tensor& out) {
46
- (void )ctx;
46
+ // Check Dim Order
47
+ ET_KERNEL_CHECK (
48
+ ctx, tensors_have_same_dim_order (a, b, out), InvalidArgument, out);
49
+
50
+ // Resize
51
+ ET_KERNEL_CHECK (
52
+ ctx,
53
+ resize_to_broadcast_target_size (a, b, out) == Error::Ok,
54
+ InvalidArgument,
55
+ out);
56
+
57
+ // @lint-ignore CLANGTIDY facebook-hte-CArray
58
+ static constexpr const char op_name[] = " div.out" ;
47
59
48
60
ScalarType a_type = a.scalar_type ();
49
61
ScalarType b_type = b.scalar_type ();
50
62
ScalarType out_type = out.scalar_type ();
51
63
52
64
if (a.numel () == 1 || b.numel () == 1 ) {
53
- if (a_type == b_type && a_type == out_type && a_type != ScalarType::Half) {
65
+ if (a_type == b_type && a_type == out_type && a_type != ScalarType::Half &&
66
+ a_type != ScalarType::BFloat16) {
54
67
const Tensor* tensor;
55
68
const Tensor* scalar;
56
69
ScalarType tensor_type;
@@ -66,13 +79,8 @@ Tensor& opt_div_out(
66
79
scalar = &b;
67
80
scalar_type = b_type;
68
81
}
69
- ET_KERNEL_CHECK (
70
- ctx,
71
- resize_to_broadcast_target_size (a, b, out) == Error::Ok,
72
- InvalidArgument,
73
- out);
74
- ET_SWITCH_REALB_TYPES (tensor_type, ctx, " div.out" , CTYPE, [&]() {
75
- ET_SWITCH_REALB_TYPES (scalar_type, ctx, " div.out" , CTYPE_SCALAR, [&]() {
82
+ ET_SWITCH_REALB_TYPES (tensor_type, ctx, op_name, CTYPE, [&]() {
83
+ ET_SWITCH_REALB_TYPES (scalar_type, ctx, op_name, CTYPE_SCALAR, [&]() {
76
84
CTYPE_SCALAR scalar_val = *scalar->const_data_ptr <CTYPE_SCALAR>();
77
85
CTYPE scalar_casted = static_cast <CTYPE>(scalar_val);
78
86
@@ -101,16 +109,7 @@ Tensor& opt_div_out(
101
109
102
110
auto selected_optimized_path = select_optimized_path (a, b, out);
103
111
if (selected_optimized_path == ElementwiseOptimizedPath::kTreatAs1d ) {
104
- // Resize for dynamic shape
105
- auto error = resize_tensor (out, a.sizes ());
106
- ET_KERNEL_CHECK_MSG (
107
- ctx,
108
- error == Error::Ok,
109
- InvalidArgument,
110
- out,
111
- " Failed to resize output tensor." );
112
-
113
- ET_SWITCH_REAL_TYPES_AND (Bool, out_type, ctx, " div.out" , CTYPE, [&]() {
112
+ ET_SWITCH_REALB_TYPES (out_type, ctx, op_name, CTYPE, [&]() {
114
113
using Vec = at::vec::Vectorized<CTYPE>;
115
114
at::vec::map2<CTYPE>(
116
115
[](Vec x, Vec y) { return x / y; },
@@ -122,7 +121,7 @@ Tensor& opt_div_out(
122
121
} else if (selected_optimized_path != ElementwiseOptimizedPath::kNone ) {
123
122
// Reason for using alpha is becasuse handle_broadcast_elementwise
124
123
// is used for add and sub as well:
125
- ET_SWITCH_REALB_TYPES (out_type, ctx, " div.out " , CTYPE, [&]() {
124
+ ET_SWITCH_REALB_TYPES (out_type, ctx, op_name , CTYPE, [&]() {
126
125
if (selected_optimized_path ==
127
126
ElementwiseOptimizedPath::kBroadcast2dBy1dReverseArguments ||
128
127
selected_optimized_path ==
@@ -139,33 +138,21 @@ Tensor& opt_div_out(
139
138
}
140
139
});
141
140
} else {
142
- ScalarType common_type = get_compute_type (a_type, b_type);
143
- ET_KERNEL_CHECK (ctx, canCast (common_type, out_type), InvalidArgument, out);
144
-
145
- ET_KERNEL_CHECK (
146
- ctx,
147
- resize_to_broadcast_target_size (a, b, out) == Error::Ok,
148
- InvalidArgument,
149
- out);
150
-
151
- ET_SWITCH_REALB_TYPES (a_type, ctx, " div.out" , CTYPE_A, [&]() {
152
- ET_SWITCH_REALB_TYPES (b_type, ctx, " div.out" , CTYPE_B, [&]() {
153
- ET_SWITCH_REALB_TYPES (common_type, ctx, " div.out" , CTYPE_IN, [&]() {
154
- ET_SWITCH_REALB_TYPES (out_type, ctx, " div.out" , CTYPE_OUT, [&]() {
155
- apply_binary_elementwise_fn<CTYPE_A, CTYPE_B, CTYPE_OUT>(
156
- [](const CTYPE_A val_a, const CTYPE_B val_b) {
157
- CTYPE_IN a_casted = static_cast <CTYPE_IN>(val_a);
158
- CTYPE_IN b_casted = static_cast <CTYPE_IN>(val_b);
159
- CTYPE_IN value = a_casted / b_casted;
160
-
161
- return static_cast <CTYPE_OUT>(value);
162
- },
163
- a,
164
- b,
165
- out);
166
- });
167
- });
168
- });
141
+ ScalarType common_type = get_common_type (a.scalar_type (), b.scalar_type ());
142
+ ScalarType compute_type = utils::get_compute_type (common_type);
143
+
144
+ ET_SWITCH_FLOAT_TYPES (compute_type, ctx, op_name, CTYPE_COMPUTE, [&]() {
145
+ utils::apply_bitensor_elementwise_fn<
146
+ CTYPE_COMPUTE,
147
+ op_name,
148
+ utils::SupportedTensorDtypes::FLOATHBF16>(
149
+ [](const auto val_a, const auto val_b) { return val_a / val_b; },
150
+ ctx,
151
+ a,
152
+ utils::SupportedTensorDtypes::REALHBBF16,
153
+ b,
154
+ utils::SupportedTensorDtypes::REALHBBF16,
155
+ out);
169
156
});
170
157
}
171
158
@@ -177,63 +164,57 @@ Tensor& opt_div_scalar_out(
177
164
const Tensor& a,
178
165
const Scalar& b,
179
166
Tensor& out) {
180
- (void )ctx;
181
-
182
167
ScalarType a_type = a.scalar_type ();
183
168
ScalarType b_type = utils::get_scalar_dtype (b);
184
169
ScalarType common_type = isFloatingType (a_type) ? a_type : ScalarType::Float;
185
170
ScalarType out_type = out.scalar_type ();
186
171
187
- ET_CHECK (common_type == out_type);
188
-
189
- // Resize for dynamic shape
190
- auto error = resize_tensor (out, a.sizes ());
191
- ET_CHECK_MSG (error == Error::Ok, " Failed to resize output tensor." );
192
-
193
- if (a_type == common_type && a_type == out_type) {
194
- ET_SWITCH_REAL_TYPES (a_type, ctx, " div.Scalar_out" , CTYPE, [&]() {
195
- ET_SWITCH_REAL_TYPES_AND (
196
- Bool, b_type, ctx, " div.Scalar_out" , CTYPE_B, [&]() {
197
- CTYPE_B b_val;
198
- ET_EXTRACT_SCALAR (b, b_val);
199
- CTYPE b_casted = static_cast <CTYPE>(b_val);
200
-
201
- using Vec = at::vec::Vectorized<CTYPE>;
202
- Vec inv_b_casted_vec (CTYPE (1 ) / b_casted);
203
- at::vec::map<CTYPE>(
204
- [inv_b_casted_vec](Vec x) { return x * inv_b_casted_vec; },
205
- out.mutable_data_ptr <CTYPE>(),
206
- a.const_data_ptr <CTYPE>(),
207
- out.numel ());
208
- });
172
+ // Check Common Dtype
173
+ ET_KERNEL_CHECK (ctx, common_type == out_type, InvalidArgument, out);
174
+
175
+ // Check Dim Order
176
+ ET_KERNEL_CHECK (
177
+ ctx, tensors_have_same_dim_order (a, out), InvalidArgument, out);
178
+
179
+ // Resize
180
+ ET_KERNEL_CHECK (
181
+ ctx, resize_tensor (out, a.sizes ()) == Error::Ok, InvalidArgument, out);
182
+
183
+ // @lint-ignore CLANGTIDY facebook-hte-CArray
184
+ static constexpr const char op_name[] = " div.Scalar_out" ;
185
+
186
+ if (a_type == common_type && a_type == out_type &&
187
+ a_type != ScalarType::Half && a_type != ScalarType::BFloat16) {
188
+ ET_SWITCH_REAL_TYPES (a_type, ctx, op_name, CTYPE, [&]() {
189
+ ET_SWITCH_REALB_TYPES (b_type, ctx, op_name, CTYPE_B, [&]() {
190
+ CTYPE_B b_val;
191
+ ET_EXTRACT_SCALAR (b, b_val);
192
+ CTYPE b_casted = static_cast <CTYPE>(b_val);
193
+
194
+ using Vec = at::vec::Vectorized<CTYPE>;
195
+ Vec inv_b_casted_vec (CTYPE (1 ) / b_casted);
196
+ at::vec::map<CTYPE>(
197
+ [inv_b_casted_vec](Vec x) { return x * inv_b_casted_vec; },
198
+ out.mutable_data_ptr <CTYPE>(),
199
+ a.const_data_ptr <CTYPE>(),
200
+ out.numel ());
201
+ });
209
202
});
210
203
} else {
211
- ET_SWITCH_REAL_TYPES_AND (
212
- Bool, a_type, ctx, " div.Scalar_out" , CTYPE_A, [&]() {
213
- ET_SWITCH_REAL_TYPES_AND (
214
- Bool, b_type, ctx, " div.Scalar_out" , CTYPE_B, [&]() {
215
- ET_SWITCH_REAL_TYPES (
216
- common_type, ctx, " div.Scalar_out" , CTYPE_IN, [&]() {
217
- ET_SWITCH_REAL_TYPES (
218
- out_type, ctx, " div.Scalar_out" , CTYPE_OUT, [&]() {
219
- CTYPE_B b_val;
220
- ET_EXTRACT_SCALAR (b, b_val);
221
- CTYPE_IN b_casted = static_cast <CTYPE_IN>(b_val);
222
- CTYPE_IN inv_b_casted = CTYPE_IN (1 ) / b_casted;
223
-
224
- const size_t n = a.numel ();
225
- const CTYPE_A* a_data = a.const_data_ptr <CTYPE_A>();
226
- CTYPE_OUT* out_data =
227
- out.mutable_data_ptr <CTYPE_OUT>();
228
- for (auto i = 0 ; i < n; ++i) {
229
- out_data[i] = static_cast <CTYPE_OUT>(
230
- static_cast <CTYPE_IN>(a_data[i]) *
231
- inv_b_casted);
232
- }
233
- });
234
- });
235
- });
236
- });
204
+ ScalarType compute_type = utils::get_compute_type (common_type);
205
+
206
+ ET_SWITCH_FLOAT_TYPES (compute_type, ctx, op_name, CTYPE_COMPUTE, [&]() {
207
+ const CTYPE_COMPUTE val_b = utils::scalar_to<CTYPE_COMPUTE>(b);
208
+ utils::apply_unitensor_elementwise_fn<
209
+ CTYPE_COMPUTE,
210
+ op_name,
211
+ utils::SupportedTensorDtypes::SAME_AS_COMMON>(
212
+ [val_b](const auto val_a) { return val_a / val_b; },
213
+ ctx,
214
+ a,
215
+ utils::SupportedTensorDtypes::REALHBBF16,
216
+ out);
217
+ });
237
218
}
238
219
239
220
return out;
0 commit comments