|
9 | 9 | #include <executorch/kernels/optimized/cpu/binary_ops.h> |
10 | 10 | #include <executorch/kernels/optimized/vec/functional.h> |
11 | 11 | #include <executorch/kernels/optimized/vec/vec.h> |
12 | | -#include <executorch/kernels/portable/cpu/scalar_utils.h> |
13 | | -#include <executorch/kernels/portable/cpu/util/broadcast_util.h> |
14 | | -#include <executorch/runtime/kernel/kernel_includes.h> |
| 12 | +#include <executorch/kernels/portable/cpu/op_div_impl.h> |
15 | 13 | #include <executorch/runtime/platform/assert.h> |
16 | 14 |
|
17 | 15 | namespace torch { |
18 | 16 | namespace executor { |
19 | 17 | namespace native { |
20 | 18 |
|
21 | | -namespace { |
22 | | - |
23 | | -ScalarType get_compute_type(ScalarType a_type, ScalarType b_type) { |
24 | | - ET_CHECK( |
25 | | - !isComplexType(a_type) && !isQIntType(a_type) && !isBitsType(a_type)); |
26 | | - ET_CHECK( |
27 | | - !isComplexType(b_type) && !isQIntType(b_type) && !isBitsType(b_type)); |
28 | | - |
29 | | - if (isFloatingType(a_type) && isFloatingType(b_type)) { |
30 | | - return promoteTypes(a_type, b_type); |
31 | | - } else if (isFloatingType(a_type)) { |
32 | | - return a_type; |
33 | | - } else if (isFloatingType(b_type)) { |
34 | | - return b_type; |
35 | | - } |
36 | | - return ScalarType::Float; |
37 | | -} |
38 | | - |
39 | | -} // namespace |
40 | | - |
41 | 19 | Tensor& opt_div_out( |
42 | 20 | KernelRuntimeContext& ctx, |
43 | 21 | const Tensor& a, |
@@ -163,34 +141,7 @@ Tensor& opt_div_out( |
163 | 141 | } |
164 | 142 | }); |
165 | 143 | } else { |
166 | | - ScalarType common_type = get_compute_type(a_type, b_type); |
167 | | - ET_KERNEL_CHECK(ctx, canCast(common_type, out_type), InvalidArgument, out); |
168 | | - |
169 | | - ET_KERNEL_CHECK( |
170 | | - ctx, |
171 | | - resize_to_broadcast_target_size(a, b, out) == Error::Ok, |
172 | | - InvalidArgument, |
173 | | - out); |
174 | | - |
175 | | - ET_SWITCH_REALB_TYPES(a_type, ctx, "div.out", CTYPE_A, [&]() { |
176 | | - ET_SWITCH_REALB_TYPES(b_type, ctx, "div.out", CTYPE_B, [&]() { |
177 | | - ET_SWITCH_REALB_TYPES(common_type, ctx, "div.out", CTYPE_IN, [&]() { |
178 | | - ET_SWITCH_REALB_TYPES(out_type, ctx, "div.out", CTYPE_OUT, [&]() { |
179 | | - apply_binary_elementwise_fn<CTYPE_A, CTYPE_B, CTYPE_OUT>( |
180 | | - [](const CTYPE_A val_a, const CTYPE_B val_b) { |
181 | | - CTYPE_IN a_casted = static_cast<CTYPE_IN>(val_a); |
182 | | - CTYPE_IN b_casted = static_cast<CTYPE_IN>(val_b); |
183 | | - CTYPE_IN value = a_casted / b_casted; |
184 | | - |
185 | | - return static_cast<CTYPE_OUT>(value); |
186 | | - }, |
187 | | - a, |
188 | | - b, |
189 | | - out); |
190 | | - }); |
191 | | - }); |
192 | | - }); |
193 | | - }); |
| 144 | + div_out_impl(ctx, a, b, out); |
194 | 145 | } |
195 | 146 |
|
196 | 147 | return out; |
@@ -232,32 +183,7 @@ Tensor& opt_div_scalar_out( |
232 | 183 | }); |
233 | 184 | }); |
234 | 185 | } else { |
235 | | - ET_SWITCH_REAL_TYPES_AND( |
236 | | - Bool, a_type, ctx, "div.Scalar_out", CTYPE_A, [&]() { |
237 | | - ET_SWITCH_REAL_TYPES_AND( |
238 | | - Bool, b_type, ctx, "div.Scalar_out", CTYPE_B, [&]() { |
239 | | - ET_SWITCH_REAL_TYPES( |
240 | | - common_type, ctx, "div.Scalar_out", CTYPE_IN, [&]() { |
241 | | - ET_SWITCH_REAL_TYPES( |
242 | | - out_type, ctx, "div.Scalar_out", CTYPE_OUT, [&]() { |
243 | | - CTYPE_B b_val; |
244 | | - ET_EXTRACT_SCALAR(b, b_val); |
245 | | - CTYPE_IN b_casted = static_cast<CTYPE_IN>(b_val); |
246 | | - CTYPE_IN inv_b_casted = CTYPE_IN(1) / b_casted; |
247 | | - |
248 | | - const size_t n = a.numel(); |
249 | | - const CTYPE_A* a_data = a.const_data_ptr<CTYPE_A>(); |
250 | | - CTYPE_OUT* out_data = |
251 | | - out.mutable_data_ptr<CTYPE_OUT>(); |
252 | | - for (auto i = 0; i < n; ++i) { |
253 | | - out_data[i] = static_cast<CTYPE_OUT>( |
254 | | - static_cast<CTYPE_IN>(a_data[i]) * |
255 | | - inv_b_casted); |
256 | | - } |
257 | | - }); |
258 | | - }); |
259 | | - }); |
260 | | - }); |
| 186 | + div_scalar_out_impl(ctx, a, b, out); |
261 | 187 | } |
262 | 188 |
|
263 | 189 | return out; |
|
0 commit comments