|
6 | 6 | * LICENSE file in the root directory of this source tree. |
7 | 7 | */ |
8 | 8 |
|
9 | | -// patternlint-disable-next-line executorch-cpp-nostdinc |
10 | | -#include <functional> |
11 | | - |
12 | 9 | #include <executorch/kernels/portable/cpu/pattern/bitwise_op.h> |
13 | | -#include <executorch/kernels/portable/cpu/scalar_utils.h> |
14 | | -#include <executorch/kernels/portable/cpu/util/broadcast_util.h> |
15 | | -#include <executorch/kernels/portable/cpu/util/functional_util.h> |
16 | | -#include <executorch/runtime/kernel/kernel_includes.h> |
17 | 10 |
|
18 | 11 | namespace torch { |
19 | 12 | namespace executor { |
20 | 13 | namespace native { |
21 | 14 |
|
22 | | -using Tensor = exec_aten::Tensor; |
23 | | - |
24 | 15 | Tensor& bitwise_or_Tensor_out( |
25 | 16 | KernelRuntimeContext& ctx, |
26 | 17 | const Tensor& a, |
27 | 18 | const Tensor& b, |
28 | 19 | Tensor& out) { |
29 | | - ET_KERNEL_CHECK( |
30 | | - ctx, |
31 | | - resize_to_broadcast_target_size(a, b, out) == Error::Ok, |
32 | | - InvalidArgument, |
33 | | - out); |
34 | | - |
35 | | - ET_KERNEL_CHECK( |
36 | | - ctx, tensors_have_same_dim_order(a, b, out), InvalidArgument, out); |
37 | | - |
38 | | - ScalarType a_type = a.scalar_type(); |
39 | | - ScalarType b_type = b.scalar_type(); |
40 | | - ScalarType common_type = promoteTypes(a_type, b_type); |
41 | | - ScalarType out_type = out.scalar_type(); |
42 | | - |
43 | | - ET_KERNEL_CHECK(ctx, canCast(common_type, out_type), InvalidArgument, out); |
44 | | - |
45 | | - ET_SWITCH_INT_TYPES_AND( |
46 | | - Bool, a_type, ctx, "bitwise_or.Tensor_out", CTYPE_A, [&]() { |
47 | | - ET_SWITCH_INT_TYPES_AND( |
48 | | - Bool, b_type, ctx, "bitwise_or.Tensor_out", CTYPE_B, [&]() { |
49 | | - using CTYPE_IN = typename torch::executor:: |
50 | | - promote_types<CTYPE_A, CTYPE_B>::type; |
51 | | - ET_DCHECK(CppTypeToScalarType<CTYPE_IN>::value == common_type); |
52 | | - ET_SWITCH_REAL_TYPES_AND( |
53 | | - Bool, |
54 | | - out_type, |
55 | | - ctx, |
56 | | - "bitwise_or.Tensor_out", |
57 | | - CTYPE_OUT, |
58 | | - [&]() { |
59 | | - internal::BitwiseOpInner< |
60 | | - can_cast<CTYPE_IN, CTYPE_OUT>::value, |
61 | | - std::bit_or, |
62 | | - CTYPE_A, |
63 | | - CTYPE_B, |
64 | | - CTYPE_IN, |
65 | | - CTYPE_OUT>::run(a, b, out); |
66 | | - }); |
67 | | - }); |
68 | | - }); |
69 | | - |
70 | | - return out; |
| 20 | + static constexpr const char op_name[] = "bitwise_or.Tensor_out"; |
| 21 | + return internal::bitwise_tensor_out<op_name>(ctx, a, b, out); |
71 | 22 | } |
72 | 23 |
|
73 | 24 | Tensor& bitwise_or_Scalar_out( |
74 | 25 | KernelRuntimeContext& ctx, |
75 | 26 | const Tensor& a, |
76 | 27 | const Scalar& b, |
77 | 28 | Tensor& out) { |
78 | | - (void)ctx; |
79 | | - |
80 | | - ET_KERNEL_CHECK( |
81 | | - ctx, tensors_have_same_dim_order(a, out), InvalidArgument, out); |
82 | | - |
83 | | - // Resize for dynamic shape |
84 | | - ET_KERNEL_CHECK_MSG( |
85 | | - ctx, |
86 | | - resize_tensor(out, a.sizes()) == Error::Ok, |
87 | | - InvalidArgument, |
88 | | - out, |
89 | | - "Failed to resize output tensor."); |
90 | | - |
91 | | - ScalarType a_type = a.scalar_type(); |
92 | | - ScalarType b_type = utils::get_scalar_dtype(b); |
93 | | - ScalarType common_type = utils::promote_type_with_scalar(a_type, b); |
94 | | - ScalarType out_type = out.scalar_type(); |
95 | | - |
96 | | - ET_KERNEL_CHECK(ctx, canCast(common_type, out_type), InvalidArgument, out); |
97 | | - |
98 | | - ET_SWITCH_INT_TYPES_AND( |
99 | | - Bool, a_type, ctx, "bitwise_or.Scalar_out", CTYPE_A, [&]() { |
100 | | - ET_SWITCH_SCALAR_OBJ_INTB_TYPES( |
101 | | - b_type, ctx, "bitwise_or.Scalar_out", CTYPE_B, [&]() { |
102 | | - CTYPE_B val_b = 0; |
103 | | - utils::extract_scalar(b, &val_b); |
104 | | - ET_SWITCH_INT_TYPES_AND( |
105 | | - Bool, |
106 | | - common_type, |
107 | | - ctx, |
108 | | - "bitwise_or.Scalar_out", |
109 | | - CTYPE_IN, |
110 | | - [&]() { |
111 | | - ET_SWITCH_REAL_TYPES_AND( |
112 | | - Bool, |
113 | | - out_type, |
114 | | - ctx, |
115 | | - "bitwise_or.Scalar_out", |
116 | | - CTYPE_OUT, |
117 | | - [&]() { |
118 | | - apply_unary_map_fn( |
119 | | - [val_b](const CTYPE_A val_a) { |
120 | | - CTYPE_IN a_casted = |
121 | | - static_cast<CTYPE_IN>(val_a); |
122 | | - CTYPE_IN b_casted = |
123 | | - static_cast<CTYPE_IN>(val_b); |
124 | | - CTYPE_IN value = |
125 | | - std::bit_or<CTYPE_IN>()(a_casted, b_casted); |
126 | | - |
127 | | - return static_cast<CTYPE_OUT>(value); |
128 | | - }, |
129 | | - a.const_data_ptr<CTYPE_A>(), |
130 | | - out.mutable_data_ptr<CTYPE_OUT>(), |
131 | | - out.numel()); |
132 | | - }); |
133 | | - }); |
134 | | - }); |
135 | | - }); |
136 | | - |
137 | | - return out; |
| 29 | + static constexpr const char op_name[] = "bitwise_or.Scalar_out"; |
| 30 | + return internal::bitwise_scalar_out<op_name>(ctx, a, b, out); |
138 | 31 | } |
139 | 32 |
|
140 | 33 | } // namespace native |
|
0 commit comments