Skip to content

Commit 010e68a

Browse files
Refactor op_div: fix bug + enable Half/Bfloat16
Differential Revision: D81169893 Pull Request resolved: #13740
1 parent e2a6538 commit 010e68a

File tree

3 files changed

+101
-102
lines changed

3 files changed

+101
-102
lines changed

kernels/optimized/cpu/op_div.cpp

Lines changed: 80 additions & 99 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
#include <ATen/cpu/vec/vec.h>
1111
#include <executorch/kernels/optimized/cpu/binary_ops.h>
1212
#include <executorch/kernels/portable/cpu/scalar_utils.h>
13-
#include <executorch/kernels/portable/cpu/util/broadcast_util.h>
13+
#include <executorch/kernels/portable/cpu/util/elementwise_util.h>
1414
#include <executorch/runtime/kernel/kernel_includes.h>
1515
#include <executorch/runtime/platform/assert.h>
1616

@@ -20,7 +20,7 @@ namespace native {
2020

2121
namespace {
2222

23-
ScalarType get_compute_type(ScalarType a_type, ScalarType b_type) {
23+
ScalarType get_common_type(ScalarType a_type, ScalarType b_type) {
2424
ET_CHECK(
2525
!isComplexType(a_type) && !isQIntType(a_type) && !isBitsType(a_type));
2626
ET_CHECK(
@@ -43,14 +43,27 @@ Tensor& opt_div_out(
4343
const Tensor& a,
4444
const Tensor& b,
4545
Tensor& out) {
46-
(void)ctx;
46+
// Check Dim Order
47+
ET_KERNEL_CHECK(
48+
ctx, tensors_have_same_dim_order(a, b, out), InvalidArgument, out);
49+
50+
// Resize
51+
ET_KERNEL_CHECK(
52+
ctx,
53+
resize_to_broadcast_target_size(a, b, out) == Error::Ok,
54+
InvalidArgument,
55+
out);
56+
57+
// @lint-ignore CLANGTIDY facebook-hte-CArray
58+
static constexpr const char op_name[] = "div.out";
4759

4860
ScalarType a_type = a.scalar_type();
4961
ScalarType b_type = b.scalar_type();
5062
ScalarType out_type = out.scalar_type();
5163

5264
if (a.numel() == 1 || b.numel() == 1) {
53-
if (a_type == b_type && a_type == out_type && a_type != ScalarType::Half) {
65+
if (a_type == b_type && a_type == out_type && a_type != ScalarType::Half &&
66+
a_type != ScalarType::BFloat16) {
5467
const Tensor* tensor;
5568
const Tensor* scalar;
5669
ScalarType tensor_type;
@@ -66,13 +79,8 @@ Tensor& opt_div_out(
6679
scalar = &b;
6780
scalar_type = b_type;
6881
}
69-
ET_KERNEL_CHECK(
70-
ctx,
71-
resize_to_broadcast_target_size(a, b, out) == Error::Ok,
72-
InvalidArgument,
73-
out);
74-
ET_SWITCH_REALB_TYPES(tensor_type, ctx, "div.out", CTYPE, [&]() {
75-
ET_SWITCH_REALB_TYPES(scalar_type, ctx, "div.out", CTYPE_SCALAR, [&]() {
82+
ET_SWITCH_REALB_TYPES(tensor_type, ctx, op_name, CTYPE, [&]() {
83+
ET_SWITCH_REALB_TYPES(scalar_type, ctx, op_name, CTYPE_SCALAR, [&]() {
7684
CTYPE_SCALAR scalar_val = *scalar->const_data_ptr<CTYPE_SCALAR>();
7785
CTYPE scalar_casted = static_cast<CTYPE>(scalar_val);
7886

@@ -101,16 +109,7 @@ Tensor& opt_div_out(
101109

102110
auto selected_optimized_path = select_optimized_path(a, b, out);
103111
if (selected_optimized_path == ElementwiseOptimizedPath::kTreatAs1d) {
104-
// Resize for dynamic shape
105-
auto error = resize_tensor(out, a.sizes());
106-
ET_KERNEL_CHECK_MSG(
107-
ctx,
108-
error == Error::Ok,
109-
InvalidArgument,
110-
out,
111-
"Failed to resize output tensor.");
112-
113-
ET_SWITCH_REAL_TYPES_AND(Bool, out_type, ctx, "div.out", CTYPE, [&]() {
112+
ET_SWITCH_REALB_TYPES(out_type, ctx, op_name, CTYPE, [&]() {
114113
using Vec = at::vec::Vectorized<CTYPE>;
115114
at::vec::map2<CTYPE>(
116115
[](Vec x, Vec y) { return x / y; },
@@ -122,7 +121,7 @@ Tensor& opt_div_out(
122121
} else if (selected_optimized_path != ElementwiseOptimizedPath::kNone) {
123122
// Reason for using alpha is becasuse handle_broadcast_elementwise
124123
// is used for add and sub as well:
125-
ET_SWITCH_REALB_TYPES(out_type, ctx, "div.out", CTYPE, [&]() {
124+
ET_SWITCH_REALB_TYPES(out_type, ctx, op_name, CTYPE, [&]() {
126125
if (selected_optimized_path ==
127126
ElementwiseOptimizedPath::kBroadcast2dBy1dReverseArguments ||
128127
selected_optimized_path ==
@@ -139,33 +138,21 @@ Tensor& opt_div_out(
139138
}
140139
});
141140
} else {
142-
ScalarType common_type = get_compute_type(a_type, b_type);
143-
ET_KERNEL_CHECK(ctx, canCast(common_type, out_type), InvalidArgument, out);
144-
145-
ET_KERNEL_CHECK(
146-
ctx,
147-
resize_to_broadcast_target_size(a, b, out) == Error::Ok,
148-
InvalidArgument,
149-
out);
150-
151-
ET_SWITCH_REALB_TYPES(a_type, ctx, "div.out", CTYPE_A, [&]() {
152-
ET_SWITCH_REALB_TYPES(b_type, ctx, "div.out", CTYPE_B, [&]() {
153-
ET_SWITCH_REALB_TYPES(common_type, ctx, "div.out", CTYPE_IN, [&]() {
154-
ET_SWITCH_REALB_TYPES(out_type, ctx, "div.out", CTYPE_OUT, [&]() {
155-
apply_binary_elementwise_fn<CTYPE_A, CTYPE_B, CTYPE_OUT>(
156-
[](const CTYPE_A val_a, const CTYPE_B val_b) {
157-
CTYPE_IN a_casted = static_cast<CTYPE_IN>(val_a);
158-
CTYPE_IN b_casted = static_cast<CTYPE_IN>(val_b);
159-
CTYPE_IN value = a_casted / b_casted;
160-
161-
return static_cast<CTYPE_OUT>(value);
162-
},
163-
a,
164-
b,
165-
out);
166-
});
167-
});
168-
});
141+
ScalarType common_type = get_common_type(a.scalar_type(), b.scalar_type());
142+
ScalarType compute_type = utils::get_compute_type(common_type);
143+
144+
ET_SWITCH_FLOAT_TYPES(compute_type, ctx, op_name, CTYPE_COMPUTE, [&]() {
145+
utils::apply_bitensor_elementwise_fn<
146+
CTYPE_COMPUTE,
147+
op_name,
148+
utils::SupportedTensorDtypes::FLOATHBF16>(
149+
[](const auto val_a, const auto val_b) { return val_a / val_b; },
150+
ctx,
151+
a,
152+
utils::SupportedTensorDtypes::REALHBBF16,
153+
b,
154+
utils::SupportedTensorDtypes::REALHBBF16,
155+
out);
169156
});
170157
}
171158

@@ -177,63 +164,57 @@ Tensor& opt_div_scalar_out(
177164
const Tensor& a,
178165
const Scalar& b,
179166
Tensor& out) {
180-
(void)ctx;
181-
182167
ScalarType a_type = a.scalar_type();
183168
ScalarType b_type = utils::get_scalar_dtype(b);
184169
ScalarType common_type = isFloatingType(a_type) ? a_type : ScalarType::Float;
185170
ScalarType out_type = out.scalar_type();
186171

187-
ET_CHECK(common_type == out_type);
188-
189-
// Resize for dynamic shape
190-
auto error = resize_tensor(out, a.sizes());
191-
ET_CHECK_MSG(error == Error::Ok, "Failed to resize output tensor.");
192-
193-
if (a_type == common_type && a_type == out_type) {
194-
ET_SWITCH_REAL_TYPES(a_type, ctx, "div.Scalar_out", CTYPE, [&]() {
195-
ET_SWITCH_REAL_TYPES_AND(
196-
Bool, b_type, ctx, "div.Scalar_out", CTYPE_B, [&]() {
197-
CTYPE_B b_val;
198-
ET_EXTRACT_SCALAR(b, b_val);
199-
CTYPE b_casted = static_cast<CTYPE>(b_val);
200-
201-
using Vec = at::vec::Vectorized<CTYPE>;
202-
Vec inv_b_casted_vec(CTYPE(1) / b_casted);
203-
at::vec::map<CTYPE>(
204-
[inv_b_casted_vec](Vec x) { return x * inv_b_casted_vec; },
205-
out.mutable_data_ptr<CTYPE>(),
206-
a.const_data_ptr<CTYPE>(),
207-
out.numel());
208-
});
172+
// Check Common Dtype
173+
ET_KERNEL_CHECK(ctx, common_type == out_type, InvalidArgument, out);
174+
175+
// Check Dim Order
176+
ET_KERNEL_CHECK(
177+
ctx, tensors_have_same_dim_order(a, out), InvalidArgument, out);
178+
179+
// Resize
180+
ET_KERNEL_CHECK(
181+
ctx, resize_tensor(out, a.sizes()) == Error::Ok, InvalidArgument, out);
182+
183+
// @lint-ignore CLANGTIDY facebook-hte-CArray
184+
static constexpr const char op_name[] = "div.Scalar_out";
185+
186+
if (a_type == common_type && a_type == out_type &&
187+
a_type != ScalarType::Half && a_type != ScalarType::BFloat16) {
188+
ET_SWITCH_REAL_TYPES(a_type, ctx, op_name, CTYPE, [&]() {
189+
ET_SWITCH_REALB_TYPES(b_type, ctx, op_name, CTYPE_B, [&]() {
190+
CTYPE_B b_val;
191+
ET_EXTRACT_SCALAR(b, b_val);
192+
CTYPE b_casted = static_cast<CTYPE>(b_val);
193+
194+
using Vec = at::vec::Vectorized<CTYPE>;
195+
Vec inv_b_casted_vec(CTYPE(1) / b_casted);
196+
at::vec::map<CTYPE>(
197+
[inv_b_casted_vec](Vec x) { return x * inv_b_casted_vec; },
198+
out.mutable_data_ptr<CTYPE>(),
199+
a.const_data_ptr<CTYPE>(),
200+
out.numel());
201+
});
209202
});
210203
} else {
211-
ET_SWITCH_REAL_TYPES_AND(
212-
Bool, a_type, ctx, "div.Scalar_out", CTYPE_A, [&]() {
213-
ET_SWITCH_REAL_TYPES_AND(
214-
Bool, b_type, ctx, "div.Scalar_out", CTYPE_B, [&]() {
215-
ET_SWITCH_REAL_TYPES(
216-
common_type, ctx, "div.Scalar_out", CTYPE_IN, [&]() {
217-
ET_SWITCH_REAL_TYPES(
218-
out_type, ctx, "div.Scalar_out", CTYPE_OUT, [&]() {
219-
CTYPE_B b_val;
220-
ET_EXTRACT_SCALAR(b, b_val);
221-
CTYPE_IN b_casted = static_cast<CTYPE_IN>(b_val);
222-
CTYPE_IN inv_b_casted = CTYPE_IN(1) / b_casted;
223-
224-
const size_t n = a.numel();
225-
const CTYPE_A* a_data = a.const_data_ptr<CTYPE_A>();
226-
CTYPE_OUT* out_data =
227-
out.mutable_data_ptr<CTYPE_OUT>();
228-
for (auto i = 0; i < n; ++i) {
229-
out_data[i] = static_cast<CTYPE_OUT>(
230-
static_cast<CTYPE_IN>(a_data[i]) *
231-
inv_b_casted);
232-
}
233-
});
234-
});
235-
});
236-
});
204+
ScalarType compute_type = utils::get_compute_type(common_type);
205+
206+
ET_SWITCH_FLOAT_TYPES(compute_type, ctx, op_name, CTYPE_COMPUTE, [&]() {
207+
const CTYPE_COMPUTE val_b = utils::scalar_to<CTYPE_COMPUTE>(b);
208+
utils::apply_unitensor_elementwise_fn<
209+
CTYPE_COMPUTE,
210+
op_name,
211+
utils::SupportedTensorDtypes::SAME_AS_COMMON>(
212+
[val_b](const auto val_a) { return val_a / val_b; },
213+
ctx,
214+
a,
215+
utils::SupportedTensorDtypes::REALHBBF16,
216+
out);
217+
});
237218
}
238219

239220
return out;

kernels/test/op_div_test.cpp

Lines changed: 19 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ class OpDivOutTest : public OperatorTest {
5454
#define ENUMERATE_TEST_ENTRY(ctype, dtype) \
5555
test_div<DTYPE_A, DTYPE_B, ScalarType::dtype>();
5656

57-
ET_FORALL_FLOAT_TYPES(ENUMERATE_TEST_ENTRY)
57+
ET_FORALL_FLOATHBF16_TYPES(ENUMERATE_TEST_ENTRY)
5858

5959
#undef ENUMERATE_TEST_ENTRY
6060
}
@@ -64,7 +64,7 @@ class OpDivOutTest : public OperatorTest {
6464
#define ENUMERATE_TEST_ENTRY(ctype, dtype) \
6565
test_div_enumerate_out_types<DTYPE_A, ScalarType::dtype>();
6666

67-
ET_FORALL_REAL_TYPES(ENUMERATE_TEST_ENTRY)
67+
ET_FORALL_REALHBF16_TYPES(ENUMERATE_TEST_ENTRY)
6868

6969
#undef ENUMERATE_TEST_ENTRY
7070
}
@@ -183,7 +183,7 @@ void OpDivOutTest::test_div_enumerate_a_types() {
183183
#define ENUMERATE_TEST_ENTRY(ctype, dtype) \
184184
test_div_enumerate_b_types<ScalarType::dtype>();
185185

186-
ET_FORALL_REAL_TYPES(ENUMERATE_TEST_ENTRY)
186+
ET_FORALL_REALHBF16_TYPES(ENUMERATE_TEST_ENTRY)
187187

188188
test_div<ScalarType::Bool, ScalarType::Float, ScalarType::Float>();
189189

@@ -283,6 +283,22 @@ TEST_F(OpDivOutTest, BroadcastScalarSupported2) {
283283
EXPECT_TENSOR_EQ(out, ret);
284284
}
285285

286+
TEST_F(OpDivOutTest, BroadcastSupported3) {
287+
TensorFactory<ScalarType::Float> tf;
288+
289+
Tensor a = tf.make({5}, {2, 3, 4, 5, 6});
290+
Tensor b = tf.make({1, 5}, {2, 1, 2, 2, 3});
291+
292+
// Destination for the broadcasting div. Follow the broadcasting rules in
293+
// https://fburl.com/n9wl4d0o
294+
Tensor out = tf.zeros({1, 5});
295+
296+
op_div_out(a, b, out);
297+
298+
Tensor ret = tf.make({1, 5}, {1, 3, 2, 2.5, 2});
299+
EXPECT_TENSOR_EQ(out, ret);
300+
}
301+
286302
TEST_F(OpDivOutTest, BroadcastScalarRank0Supported) {
287303
TensorFactory<ScalarType::Float> tf;
288304

shim_et/xplat/executorch/kernels/optimized/op_registration_util.bzl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -180,6 +180,8 @@ OPTIMIZED_ATEN_OPS = (
180180
":binary_ops",
181181
"//executorch/kernels/portable/cpu:scalar_utils",
182182
"//executorch/kernels/portable/cpu/util:broadcast_util",
183+
"//executorch/kernels/portable/cpu/util:dtype_util",
184+
"//executorch/kernels/portable/cpu/util:elementwise_util",
183185
"//executorch/runtime/core/portable_type/c10/c10:aten_headers_for_executorch",
184186
],
185187
),

0 commit comments

Comments
 (0)