|
7 | 7 | */ |
8 | 8 |
|
9 | 9 | #include <executorch/kernels/portable/cpu/scalar_utils.h> |
10 | | -#include <executorch/kernels/portable/cpu/util/broadcast_util.h> |
| 10 | +#include <executorch/kernels/portable/cpu/util/elementwise_util.h> |
11 | 11 | #include <executorch/kernels/portable/cpu/util/math_util.h> |
12 | 12 | #include <executorch/runtime/kernel/kernel_includes.h> |
13 | 13 |
|
14 | 14 | namespace torch { |
15 | 15 | namespace executor { |
16 | 16 | namespace native { |
17 | | -namespace { |
18 | | - |
19 | | -template < |
20 | | - bool can_cast, |
21 | | - typename CTYPE_A, |
22 | | - typename CTYPE_B, |
23 | | - typename CTYPE_IN, |
24 | | - typename CTYPE_OUT> |
25 | | -struct MaximumInner; |
26 | | - |
27 | | -template < |
28 | | - typename CTYPE_A, |
29 | | - typename CTYPE_B, |
30 | | - typename CTYPE_IN, |
31 | | - typename CTYPE_OUT> |
32 | | -struct MaximumInner<true, CTYPE_A, CTYPE_B, CTYPE_IN, CTYPE_OUT> { |
33 | | - static void run(const Tensor& a, const Tensor& b, Tensor& out) { |
34 | | - apply_binary_elementwise_fn<CTYPE_A, CTYPE_B, CTYPE_OUT>( |
35 | | - // NOLINTNEXTLINE(facebook-hte-ConstantArgumentPassByValue) |
36 | | - [](const CTYPE_A val_a, const CTYPE_B val_b) { |
37 | | - CTYPE_IN a_casted = static_cast<CTYPE_IN>(val_a); |
38 | | - CTYPE_IN b_casted = static_cast<CTYPE_IN>(val_b); |
39 | | - CTYPE_IN value = utils::max_override(a_casted, b_casted); |
40 | | - |
41 | | - return static_cast<CTYPE_OUT>(value); |
42 | | - }, |
43 | | - a, |
44 | | - b, |
45 | | - out); |
46 | | - } |
47 | | -}; |
48 | | - |
49 | | -struct ReportCanCastBug { |
50 | | - static void run(const Tensor&, const Tensor&, Tensor&) { |
51 | | - ET_DCHECK_MSG(false, "BUG: canCast should have been checked above"); |
52 | | - } |
53 | | -}; |
54 | | - |
55 | | -template < |
56 | | - typename CTYPE_A, |
57 | | - typename CTYPE_B, |
58 | | - typename CTYPE_IN, |
59 | | - typename CTYPE_OUT> |
60 | | -struct MaximumInner<false, CTYPE_A, CTYPE_B, CTYPE_IN, CTYPE_OUT> |
61 | | - : public ReportCanCastBug {}; |
62 | | - |
63 | | -} // namespace |
64 | 17 |
|
65 | 18 | Tensor& maximum_out( |
66 | 19 | KernelRuntimeContext& ctx, |
67 | 20 | const Tensor& a, |
68 | 21 | const Tensor& b, |
69 | 22 | Tensor& out) { |
70 | | - (void)ctx; |
| 23 | + // Common Dtype |
| 24 | + ScalarType common_type = promoteTypes(a.scalar_type(), b.scalar_type()); |
| 25 | + |
| 26 | + // Check Common Dtype |
| 27 | + ET_KERNEL_CHECK( |
| 28 | + ctx, canCast(common_type, out.scalar_type()), InvalidArgument, out); |
71 | 29 |
|
| 30 | + // Check Dim Order |
| 31 | + ET_KERNEL_CHECK( |
| 32 | + ctx, tensors_have_same_dim_order(a, b, out), InvalidArgument, out); |
| 33 | + |
| 34 | + // Resize |
72 | 35 | ET_KERNEL_CHECK( |
73 | 36 | ctx, |
74 | 37 | resize_to_broadcast_target_size(a, b, out) == Error::Ok, |
75 | 38 | InvalidArgument, |
76 | 39 | out); |
77 | 40 |
|
78 | | - ET_KERNEL_CHECK( |
79 | | - ctx, tensors_have_same_dim_order(a, b, out), InvalidArgument, out); |
80 | | - |
81 | | - ScalarType a_type = a.scalar_type(); |
82 | | - ScalarType b_type = b.scalar_type(); |
83 | | - ScalarType common_type = promoteTypes(a_type, b_type, /*half_to_float*/ true); |
84 | | - ScalarType out_type = out.scalar_type(); |
| 41 | + // Compute Dtype |
| 42 | + ScalarType compute_type = utils::get_compute_type(common_type); |
85 | 43 |
|
86 | | - ET_KERNEL_CHECK(ctx, canCast(common_type, out_type), InvalidArgument, out); |
| 44 | + static constexpr const char op_name[] = "maximum.out"; |
87 | 45 |
|
88 | | - ET_SWITCH_REALHB_TYPES(a_type, ctx, "maximum.out", CTYPE_A, [&]() { |
89 | | - ET_SWITCH_REALHB_TYPES(b_type, ctx, "maximum.out", CTYPE_B, [&]() { |
90 | | - using CTYPE_IN = typename torch::executor:: |
91 | | - promote_types<CTYPE_A, CTYPE_B, /*half_to_float*/ true>::type; |
92 | | - ET_DCHECK(CppTypeToScalarType<CTYPE_IN>::value == common_type); |
93 | | - ET_SWITCH_REALHB_TYPES(out_type, ctx, "maximum.out", CTYPE_OUT, [&]() { |
94 | | - MaximumInner< |
95 | | - can_cast<CTYPE_IN, CTYPE_OUT>::value, |
96 | | - CTYPE_A, |
97 | | - CTYPE_B, |
98 | | - CTYPE_IN, |
99 | | - CTYPE_OUT>::run(a, b, out); |
100 | | - }); |
101 | | - }); |
| 46 | + ET_SWITCH_REALB_TYPES(compute_type, ctx, op_name, CTYPE_COMPUTE, [&]() { |
| 47 | + utils::apply_bitensor_elementwise_fn<CTYPE_COMPUTE, op_name>( |
| 48 | + [](const CTYPE_COMPUTE val_a, const CTYPE_COMPUTE val_b) { |
| 49 | + return utils::max_override(val_a, val_b); |
| 50 | + }, |
| 51 | + ctx, |
| 52 | + a, |
| 53 | + utils::SupportedTensorDtypes::REALHBBF16, |
| 54 | + b, |
| 55 | + utils::SupportedTensorDtypes::REALHBBF16, |
| 56 | + out, |
| 57 | + utils::SupportedTensorDtypes::REALHBBF16); |
102 | 58 | }); |
103 | 59 |
|
104 | 60 | return out; |
|
0 commit comments