Skip to content

Commit 9ff34e2

Browse files
[ET][Portable][Build Size] Introduce FLOATHBF16. Binary ops: atan2, div
Pull Request resolved: pytorch/executorch#6011 - div: 1.44 M -> 23 K - atan2: 164 K -> 4 K ghstack-source-id: 246985123 @exported-using-ghexport Differential Revision: [D63909725](https://our.internmc.facebook.com/intern/diff/D63909725/)
1 parent 9468540 commit 9ff34e2

File tree

5 files changed

+241
-171
lines changed

5 files changed

+241
-171
lines changed

kernels/portable/cpu/op_atan2.cpp

Lines changed: 40 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -6,50 +6,66 @@
66
* LICENSE file in the root directory of this source tree.
77
*/
88

9-
#include <executorch/kernels/portable/cpu/util/broadcast_util.h>
9+
#include <executorch/kernels/portable/cpu/util/elementwise_util.h>
1010
#include <executorch/runtime/kernel/kernel_includes.h>
1111
#include <cmath>
1212

1313
namespace torch {
1414
namespace executor {
1515
namespace native {
1616

17-
using Tensor = exec_aten::Tensor;
18-
using ScalarType = exec_aten::ScalarType;
17+
namespace {
18+
19+
ScalarType get_common_type(ScalarType a_type, ScalarType b_type) {
20+
if (isFloatingType(a_type) && isFloatingType(b_type)) {
21+
return promoteTypes(a_type, b_type);
22+
} else if (isFloatingType(a_type)) {
23+
return a_type;
24+
} else if (isFloatingType(b_type)) {
25+
return b_type;
26+
}
27+
return ScalarType::Float;
28+
}
29+
30+
} // namespace
1931

2032
Tensor& atan2_out(
2133
KernelRuntimeContext& ctx,
2234
const Tensor& a,
2335
const Tensor& b,
2436
Tensor& out) {
25-
// Determine output size and resize for dynamic shapes
37+
// Common Dtype
38+
ScalarType common_type = get_common_type(a.scalar_type(), b.scalar_type());
39+
40+
// Check Dim Order
41+
ET_KERNEL_CHECK(
42+
ctx, tensors_have_same_dim_order(a, b, out), InvalidArgument, out);
43+
44+
// Resize
2645
ET_KERNEL_CHECK(
2746
ctx,
2847
resize_to_broadcast_target_size(a, b, out) == Error::Ok,
2948
InvalidArgument,
3049
out);
3150

32-
ET_KERNEL_CHECK(
33-
ctx, tensors_have_same_dim_order(a, b, out), InvalidArgument, out);
51+
// Compute Dtype
52+
ScalarType compute_type = utils::get_compute_type(common_type);
53+
54+
// @lint-ignore CLANGTIDY facebook-hte-CArray
55+
static constexpr const char op_name[] = "atan2.out";
3456

35-
ScalarType a_type = a.scalar_type();
36-
ScalarType b_type = b.scalar_type();
37-
ScalarType out_type = out.scalar_type();
38-
39-
ET_SWITCH_REALHB_TYPES(a_type, ctx, "atan2.out", CTYPE_A, [&]() {
40-
ET_SWITCH_REALHB_TYPES(b_type, ctx, "atan2.out", CTYPE_B, [&]() {
41-
ET_SWITCH_FLOATH_TYPES(out_type, ctx, "atan2.out", CTYPE_OUT, [&]() {
42-
apply_binary_elementwise_fn<CTYPE_A, CTYPE_B, CTYPE_OUT>(
43-
[](const CTYPE_A val_a, const CTYPE_B val_b) {
44-
CTYPE_OUT casted_a = static_cast<CTYPE_OUT>(val_a);
45-
CTYPE_OUT casted_b = static_cast<CTYPE_OUT>(val_b);
46-
return static_cast<CTYPE_OUT>(std::atan2(casted_a, casted_b));
47-
},
48-
a,
49-
b,
50-
out);
51-
});
52-
});
57+
ET_SWITCH_FLOAT_TYPES(compute_type, ctx, op_name, CTYPE_COMPUTE, [&]() {
58+
utils::apply_bitensor_elementwise_fn<CTYPE_COMPUTE, op_name>(
59+
[](const CTYPE_COMPUTE val_a, const CTYPE_COMPUTE val_b) {
60+
return std::atan2(val_a, val_b);
61+
},
62+
ctx,
63+
a,
64+
utils::SupportedTensorDtypes::REALHBBF16,
65+
b,
66+
utils::SupportedTensorDtypes::REALHBBF16,
67+
out,
68+
utils::SupportedTensorDtypes::FLOATHBF16);
5369
});
5470

5571
return out;

0 commit comments

Comments
 (0)