Skip to content

Commit 7198f9f

Browse files
[ET][Portable][Build Size] Introduce FLOATHBF16. Binary ops: atan2, div
- div: 1.44 M -> 23 K - atan2: 164 K -> 4 K Differential Revision: [D63909725](https://our.internmc.facebook.com/intern/diff/D63909725/) [ghstack-poisoned]
1 parent e51b898 commit 7198f9f

File tree

5 files changed

+236
-171
lines changed

5 files changed

+236
-171
lines changed

kernels/portable/cpu/op_atan2.cpp

Lines changed: 39 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -6,50 +6,65 @@
66
* LICENSE file in the root directory of this source tree.
77
*/
88

9-
#include <executorch/kernels/portable/cpu/util/broadcast_util.h>
9+
#include <executorch/kernels/portable/cpu/util/elementwise_util.h>
1010
#include <executorch/runtime/kernel/kernel_includes.h>
1111
#include <cmath>
1212

1313
namespace torch {
1414
namespace executor {
1515
namespace native {
1616

17-
using Tensor = exec_aten::Tensor;
18-
using ScalarType = exec_aten::ScalarType;
17+
namespace {
18+
19+
ScalarType get_common_type(ScalarType a_type, ScalarType b_type) {
20+
if (isFloatingType(a_type) && isFloatingType(b_type)) {
21+
return promoteTypes(a_type, b_type);
22+
} else if (isFloatingType(a_type)) {
23+
return a_type;
24+
} else if (isFloatingType(b_type)) {
25+
return b_type;
26+
}
27+
return ScalarType::Float;
28+
}
29+
30+
} // namespace
1931

2032
Tensor& atan2_out(
2133
KernelRuntimeContext& ctx,
2234
const Tensor& a,
2335
const Tensor& b,
2436
Tensor& out) {
25-
// Determine output size and resize for dynamic shapes
37+
// Common Dtype
38+
ScalarType common_type = get_common_type(a.scalar_type(), b.scalar_type());
39+
40+
// Check Dim Order
41+
ET_KERNEL_CHECK(
42+
ctx, tensors_have_same_dim_order(a, b, out), InvalidArgument, out);
43+
44+
// Resize
2645
ET_KERNEL_CHECK(
2746
ctx,
2847
resize_to_broadcast_target_size(a, b, out) == Error::Ok,
2948
InvalidArgument,
3049
out);
3150

32-
ET_KERNEL_CHECK(
33-
ctx, tensors_have_same_dim_order(a, b, out), InvalidArgument, out);
51+
// Compute Dtype
52+
ScalarType compute_type = utils::get_compute_type(common_type);
53+
54+
static constexpr const char op_name[] = "atan2.out";
3455

35-
ScalarType a_type = a.scalar_type();
36-
ScalarType b_type = b.scalar_type();
37-
ScalarType out_type = out.scalar_type();
38-
39-
ET_SWITCH_REALHB_TYPES(a_type, ctx, "atan2.out", CTYPE_A, [&]() {
40-
ET_SWITCH_REALHB_TYPES(b_type, ctx, "atan2.out", CTYPE_B, [&]() {
41-
ET_SWITCH_FLOATH_TYPES(out_type, ctx, "atan2.out", CTYPE_OUT, [&]() {
42-
apply_binary_elementwise_fn<CTYPE_A, CTYPE_B, CTYPE_OUT>(
43-
[](const CTYPE_A val_a, const CTYPE_B val_b) {
44-
CTYPE_OUT casted_a = static_cast<CTYPE_OUT>(val_a);
45-
CTYPE_OUT casted_b = static_cast<CTYPE_OUT>(val_b);
46-
return static_cast<CTYPE_OUT>(std::atan2(casted_a, casted_b));
47-
},
48-
a,
49-
b,
50-
out);
51-
});
52-
});
56+
ET_SWITCH_FLOAT_TYPES(compute_type, ctx, op_name, CTYPE_COMPUTE, [&]() {
57+
utils::apply_bitensor_elementwise_fn<CTYPE_COMPUTE, op_name>(
58+
[](const CTYPE_COMPUTE val_a, const CTYPE_COMPUTE val_b) {
59+
return std::atan2(val_a, val_b);
60+
},
61+
ctx,
62+
a,
63+
utils::SupportedTensorDtypes::REALHBBF16,
64+
b,
65+
utils::SupportedTensorDtypes::REALHBBF16,
66+
out,
67+
utils::SupportedTensorDtypes::FLOATHBF16);
5368
});
5469

5570
return out;

0 commit comments

Comments
 (0)