|
6 | 6 | * LICENSE file in the root directory of this source tree. |
7 | 7 | */ |
8 | 8 |
|
9 | | -#include <executorch/kernels/portable/cpu/util/broadcast_util.h> |
| 9 | +#include <executorch/kernels/portable/cpu/util/elementwise_util.h> |
10 | 10 | #include <executorch/runtime/kernel/kernel_includes.h> |
11 | 11 | #include <cmath> |
12 | 12 |
|
13 | 13 | namespace torch { |
14 | 14 | namespace executor { |
15 | 15 | namespace native { |
16 | 16 |
|
17 | | -using Tensor = exec_aten::Tensor; |
18 | | -using ScalarType = exec_aten::ScalarType; |
| 17 | +namespace { |
| 18 | + |
| 19 | +ScalarType get_common_type(ScalarType a_type, ScalarType b_type) { |
| 20 | + if (isFloatingType(a_type) && isFloatingType(b_type)) { |
| 21 | + return promoteTypes(a_type, b_type); |
| 22 | + } else if (isFloatingType(a_type)) { |
| 23 | + return a_type; |
| 24 | + } else if (isFloatingType(b_type)) { |
| 25 | + return b_type; |
| 26 | + } |
| 27 | + return ScalarType::Float; |
| 28 | +} |
| 29 | + |
| 30 | +} // namespace |
19 | 31 |
|
20 | 32 | Tensor& atan2_out( |
21 | 33 | KernelRuntimeContext& ctx, |
22 | 34 | const Tensor& a, |
23 | 35 | const Tensor& b, |
24 | 36 | Tensor& out) { |
25 | | - // Determine output size and resize for dynamic shapes |
| 37 | + // Common Dtype |
| 38 | + ScalarType common_type = get_common_type(a.scalar_type(), b.scalar_type()); |
| 39 | + |
| 40 | + // Check Dim Order |
| 41 | + ET_KERNEL_CHECK( |
| 42 | + ctx, tensors_have_same_dim_order(a, b, out), InvalidArgument, out); |
| 43 | + |
| 44 | + // Resize |
26 | 45 | ET_KERNEL_CHECK( |
27 | 46 | ctx, |
28 | 47 | resize_to_broadcast_target_size(a, b, out) == Error::Ok, |
29 | 48 | InvalidArgument, |
30 | 49 | out); |
31 | 50 |
|
32 | | - ET_KERNEL_CHECK( |
33 | | - ctx, tensors_have_same_dim_order(a, b, out), InvalidArgument, out); |
| 51 | + // Compute Dtype |
| 52 | + ScalarType compute_type = utils::get_compute_type(common_type); |
| 53 | + |
| 54 | + // @lint-ignore CLANGTIDY facebook-hte-CArray |
| 55 | + static constexpr const char op_name[] = "atan2.out"; |
34 | 56 |
|
35 | | - ScalarType a_type = a.scalar_type(); |
36 | | - ScalarType b_type = b.scalar_type(); |
37 | | - ScalarType out_type = out.scalar_type(); |
38 | | - |
39 | | - ET_SWITCH_REALHB_TYPES(a_type, ctx, "atan2.out", CTYPE_A, [&]() { |
40 | | - ET_SWITCH_REALHB_TYPES(b_type, ctx, "atan2.out", CTYPE_B, [&]() { |
41 | | - ET_SWITCH_FLOATH_TYPES(out_type, ctx, "atan2.out", CTYPE_OUT, [&]() { |
42 | | - apply_binary_elementwise_fn<CTYPE_A, CTYPE_B, CTYPE_OUT>( |
43 | | - [](const CTYPE_A val_a, const CTYPE_B val_b) { |
44 | | - CTYPE_OUT casted_a = static_cast<CTYPE_OUT>(val_a); |
45 | | - CTYPE_OUT casted_b = static_cast<CTYPE_OUT>(val_b); |
46 | | - return static_cast<CTYPE_OUT>(std::atan2(casted_a, casted_b)); |
47 | | - }, |
48 | | - a, |
49 | | - b, |
50 | | - out); |
51 | | - }); |
52 | | - }); |
| 57 | + ET_SWITCH_FLOAT_TYPES(compute_type, ctx, op_name, CTYPE_COMPUTE, [&]() { |
| 58 | + utils::apply_bitensor_elementwise_fn<CTYPE_COMPUTE, op_name>( |
| 59 | + [](const CTYPE_COMPUTE val_a, const CTYPE_COMPUTE val_b) { |
| 60 | + return std::atan2(val_a, val_b); |
| 61 | + }, |
| 62 | + ctx, |
| 63 | + a, |
| 64 | + utils::SupportedTensorDtypes::REALHBBF16, |
| 65 | + b, |
| 66 | + utils::SupportedTensorDtypes::REALHBBF16, |
| 67 | + out, |
| 68 | + utils::SupportedTensorDtypes::FLOATHBF16); |
53 | 69 | }); |
54 | 70 |
|
55 | 71 | return out; |
|
0 commit comments