Skip to content

Commit d5c76af

Browse files
committed
Refactor op_div for shreability with optimized op_div
Pull Request resolved: #6713 ghstack-source-id: 252372852 @exported-using-ghexport Differential Revision: [D65606664](https://our.internmc.facebook.com/intern/diff/D65606664/)
1 parent 03b1ef2 commit d5c76af

File tree

5 files changed

+363
-240
lines changed

5 files changed

+363
-240
lines changed

kernels/portable/cpu/op_div.cpp

Lines changed: 5 additions & 233 deletions
Original file line numberDiff line numberDiff line change
@@ -6,72 +6,18 @@
66
* LICENSE file in the root directory of this source tree.
77
*/
88

9-
#include <executorch/kernels/portable/cpu/scalar_utils.h>
10-
#include <executorch/kernels/portable/cpu/util/elementwise_util.h>
11-
#include <executorch/kernels/portable/cpu/util/math_util.h>
12-
#include <executorch/runtime/kernel/kernel_includes.h>
13-
#include <executorch/runtime/platform/assert.h>
14-
#include <cmath>
9+
#include <executorch/kernels/portable/cpu/op_div_impl.h>
1510

1611
namespace torch {
1712
namespace executor {
1813
namespace native {
1914

20-
namespace {
21-
22-
ScalarType get_common_type(ScalarType a_type, ScalarType b_type) {
23-
if (isFloatingType(a_type) && isFloatingType(b_type)) {
24-
return promoteTypes(a_type, b_type);
25-
} else if (isFloatingType(a_type)) {
26-
return a_type;
27-
} else if (isFloatingType(b_type)) {
28-
return b_type;
29-
}
30-
return ScalarType::Float;
31-
}
32-
33-
} // namespace
34-
3515
Tensor& div_out(
3616
KernelRuntimeContext& ctx,
3717
const Tensor& a,
3818
const Tensor& b,
3919
Tensor& out) {
40-
// Common Dtype
41-
ScalarType common_type = get_common_type(a.scalar_type(), b.scalar_type());
42-
43-
// Check Dim Order
44-
ET_KERNEL_CHECK(
45-
ctx, tensors_have_same_dim_order(a, b, out), InvalidArgument, out);
46-
47-
// Resize
48-
ET_KERNEL_CHECK(
49-
ctx,
50-
resize_to_broadcast_target_size(a, b, out) == Error::Ok,
51-
InvalidArgument,
52-
out);
53-
54-
// Compute Dtype
55-
ScalarType compute_type = utils::get_compute_type(common_type);
56-
57-
// @lint-ignore CLANGTIDY facebook-hte-CArray
58-
static constexpr const char op_name[] = "div.out";
59-
60-
ET_SWITCH_FLOAT_TYPES(compute_type, ctx, op_name, CTYPE_COMPUTE, [&]() {
61-
utils::apply_bitensor_elementwise_fn<CTYPE_COMPUTE, op_name>(
62-
[](const CTYPE_COMPUTE val_a, const CTYPE_COMPUTE val_b) {
63-
return val_a / val_b;
64-
},
65-
ctx,
66-
a,
67-
utils::SupportedTensorDtypes::REALHBBF16,
68-
b,
69-
utils::SupportedTensorDtypes::REALHBBF16,
70-
out,
71-
utils::SupportedTensorDtypes::FLOATHBF16);
72-
});
73-
74-
return out;
20+
return div_out_impl(ctx, a, b, out);
7521
}
7622

7723
Tensor& div_out_mode(
@@ -80,124 +26,15 @@ Tensor& div_out_mode(
8026
const Tensor& b,
8127
exec_aten::optional<exec_aten::string_view> mode,
8228
Tensor& out) {
83-
if (!mode.has_value()) {
84-
return div_out(ctx, a, b, out);
85-
}
86-
87-
auto mode_val = mode.value();
88-
89-
// Check mode
90-
ET_KERNEL_CHECK(
91-
ctx, mode_val == "trunc" || mode_val == "floor", InvalidArgument, out);
92-
93-
// Common Dtype
94-
ScalarType common_type = promoteTypes(a.scalar_type(), b.scalar_type());
95-
96-
// Check Common Dtype
97-
ET_KERNEL_CHECK(
98-
ctx,
99-
(canCast(common_type, out.scalar_type()) &&
100-
common_type != ScalarType::Bool),
101-
InvalidArgument,
102-
out);
103-
104-
// Check Dim Order
105-
ET_KERNEL_CHECK(
106-
ctx, tensors_have_same_dim_order(a, b, out), InvalidArgument, out);
107-
108-
// Resize
109-
ET_KERNEL_CHECK(
110-
ctx,
111-
resize_to_broadcast_target_size(a, b, out) == Error::Ok,
112-
InvalidArgument,
113-
out);
114-
115-
// Compute Dtype
116-
ScalarType compute_type = utils::get_compute_type(common_type);
117-
118-
// @lint-ignore CLANGTIDY facebook-hte-CArray
119-
static constexpr const char op_name[] = "div.out_mode";
120-
121-
const bool mode_is_trunc = mode_val == "trunc";
122-
bool div_by_zero_error = false;
123-
124-
ET_SWITCH_REAL_TYPES(compute_type, ctx, op_name, CTYPE_COMPUTE, [&]() {
125-
utils::apply_bitensor_elementwise_fn<CTYPE_COMPUTE, op_name>(
126-
[mode_is_trunc, &div_by_zero_error](
127-
const CTYPE_COMPUTE val_a, const CTYPE_COMPUTE val_b) {
128-
if (is_integral_type<CTYPE_COMPUTE, /*includeBool=*/true>::value) {
129-
if (val_b == 0) {
130-
div_by_zero_error = true;
131-
return static_cast<CTYPE_COMPUTE>(0);
132-
}
133-
}
134-
CTYPE_COMPUTE value = val_a / val_b;
135-
if (mode_is_trunc) {
136-
value = std::trunc(value);
137-
} else {
138-
// We established above that the mode is either trunc or floor, so
139-
// it must be floor.
140-
value = utils::floor_divide(val_a, val_b);
141-
}
142-
return value;
143-
},
144-
ctx,
145-
a,
146-
utils::SupportedTensorDtypes::REALHBBF16,
147-
b,
148-
utils::SupportedTensorDtypes::REALHBBF16,
149-
out,
150-
utils::SupportedTensorDtypes::REALHBF16);
151-
});
152-
153-
ET_KERNEL_CHECK_MSG(
154-
ctx,
155-
!div_by_zero_error,
156-
InvalidArgument,
157-
out,
158-
"Div mode operation encountered integer division by zero");
159-
160-
return out;
29+
return div_out_mode_impl(ctx, a, b, mode, out);
16130
}
16231

16332
Tensor& div_scalar_out(
16433
KernelRuntimeContext& ctx,
16534
const Tensor& a,
16635
const Scalar& b,
16736
Tensor& out) {
168-
// Common Dtype
169-
ScalarType common_type =
170-
isFloatingType(a.scalar_type()) ? a.scalar_type() : ScalarType::Float;
171-
172-
// Check Common Dtype
173-
ET_KERNEL_CHECK(ctx, common_type == out.scalar_type(), InvalidArgument, out);
174-
175-
// Check Dim Order
176-
ET_KERNEL_CHECK(
177-
ctx, tensors_have_same_dim_order(a, out), InvalidArgument, out);
178-
179-
// Resize
180-
ET_KERNEL_CHECK(
181-
ctx, resize_tensor(out, a.sizes()) == Error::Ok, InvalidArgument, out);
182-
183-
// Compute Dtype
184-
ScalarType compute_type = utils::get_compute_type(common_type);
185-
186-
// @lint-ignore CLANGTIDY facebook-hte-CArray
187-
static constexpr const char op_name[] = "div.Scalar_out";
188-
189-
ET_SWITCH_FLOAT_TYPES(compute_type, ctx, op_name, CTYPE_COMPUTE, [&]() {
190-
const CTYPE_COMPUTE val_b = utils::scalar_to<CTYPE_COMPUTE>(b);
191-
utils::apply_unitensor_elementwise_fn<CTYPE_COMPUTE, op_name>(
192-
[val_b](const CTYPE_COMPUTE val_a) { return val_a / val_b; },
193-
ctx,
194-
a,
195-
utils::SupportedTensorDtypes::REALHBBF16,
196-
out,
197-
utils::SupportedTensorDtypes::SAME_AS_COMMON);
198-
});
199-
200-
return out;
37+
return div_scalar_out_impl(ctx, a, b, out);
20138
}
20239

20340
Tensor& div_scalar_mode_out(
@@ -206,72 +43,7 @@ Tensor& div_scalar_mode_out(
20643
const Scalar& b,
20744
exec_aten::optional<exec_aten::string_view> mode,
20845
Tensor& out) {
209-
if (!mode.has_value()) {
210-
return div_scalar_out(ctx, a, b, out);
211-
}
212-
213-
auto mode_val = mode.value();
214-
215-
// Check mode
216-
ET_KERNEL_CHECK(
217-
ctx, mode_val == "trunc" || mode_val == "floor", InvalidArgument, out);
218-
219-
// Common Dtype
220-
ScalarType common_type = utils::promote_type_with_scalar(a.scalar_type(), b);
221-
222-
// Check Common Dtype
223-
ET_KERNEL_CHECK(
224-
ctx,
225-
(canCast(common_type, out.scalar_type()) &&
226-
common_type != ScalarType::Bool),
227-
InvalidArgument,
228-
out);
229-
230-
// Check for intergral division by zero
231-
ET_KERNEL_CHECK_MSG(
232-
ctx,
233-
!(executorch::runtime::isIntegralType(common_type, true) &&
234-
utils::scalar_to<double>(b) == 0),
235-
InvalidArgument,
236-
out,
237-
"Div mode operation encountered integer division by zero");
238-
239-
// Check Dim Order
240-
ET_KERNEL_CHECK(
241-
ctx, tensors_have_same_dim_order(a, out), InvalidArgument, out);
242-
243-
// Resize
244-
ET_KERNEL_CHECK(
245-
ctx, resize_tensor(out, a.sizes()) == Error::Ok, InvalidArgument, out);
246-
247-
// Compute Dtype
248-
ScalarType compute_type = utils::get_compute_type(common_type);
249-
250-
const bool mode_is_trunc = mode_val == "trunc";
251-
252-
// @lint-ignore CLANGTIDY facebook-hte-CArray
253-
static constexpr const char op_name[] = "div.Scalar_mode_out";
254-
255-
ET_SWITCH_REAL_TYPES(compute_type, ctx, op_name, CTYPE_COMPUTE, [&]() {
256-
const CTYPE_COMPUTE val_b = utils::scalar_to<CTYPE_COMPUTE>(b);
257-
utils::apply_unitensor_elementwise_fn<CTYPE_COMPUTE, op_name>(
258-
[val_b, mode_is_trunc](const CTYPE_COMPUTE val_a) {
259-
CTYPE_COMPUTE value = val_a / val_b;
260-
if (mode_is_trunc) {
261-
value = std::trunc(value);
262-
} else {
263-
value = utils::floor_divide(val_a, val_b);
264-
}
265-
return value;
266-
},
267-
ctx,
268-
a,
269-
utils::SupportedTensorDtypes::REALHBBF16,
270-
out,
271-
utils::SupportedTensorDtypes::REALHBF16);
272-
});
273-
274-
return out;
46+
return div_scalar_mode_out_impl(ctx, a, b, mode, out);
27547
}
27648

27749
} // namespace native

0 commit comments

Comments
 (0)