|
9 | 9 | #include <executorch/kernels/optimized/cpu/binary_ops.h> |
10 | 10 | #include <executorch/kernels/optimized/vec/functional.h> |
11 | 11 | #include <executorch/kernels/optimized/vec/vec.h> |
| 12 | +#include <executorch/kernels/portable/cpu/op_sub_impl.h> |
12 | 13 | #include <executorch/kernels/portable/cpu/scalar_utils.h> |
13 | 14 | #include <executorch/kernels/portable/cpu/util/broadcast_util.h> |
14 | 15 | #include <executorch/runtime/core/exec_aten/util/tensor_util.h> |
@@ -210,35 +211,7 @@ Tensor& opt_sub_out( |
210 | 211 | } |
211 | 212 | }); |
212 | 213 | } else { |
213 | | - ScalarType common_type = |
214 | | - promoteTypes(a_type, b_type, /*half_to_float*/ true); |
215 | | - ET_KERNEL_CHECK(ctx, canCast(common_type, out_type), InvalidArgument, out); |
216 | | - |
217 | | - ET_KERNEL_CHECK( |
218 | | - ctx, |
219 | | - resize_to_broadcast_target_size(a, b, out) == Error::Ok, |
220 | | - InvalidArgument, |
221 | | - out); |
222 | | - |
223 | | - ET_SWITCH_REALH_TYPES(a_type, ctx, "sub.out", CTYPE_A, [&]() { |
224 | | - ET_SWITCH_REALH_TYPES(b_type, ctx, "sub.out", CTYPE_B, [&]() { |
225 | | - using CTYPE_IN = typename torch::executor:: |
226 | | - promote_types<CTYPE_A, CTYPE_B, /*half_to_float*/ true>::type; |
227 | | - ET_DCHECK(CppTypeToScalarType<CTYPE_IN>::value == common_type); |
228 | | - ET_SWITCH_REALH_TYPES(out_type, ctx, "sub.out", CTYPE_OUT, [&]() { |
229 | | - CTYPE_IN alpha_val; |
230 | | - ET_KERNEL_CHECK( |
231 | | - ctx, utils::extract_scalar(alpha, &alpha_val), InvalidArgument, ); |
232 | | - |
233 | | - SubInner< |
234 | | - can_cast<CTYPE_IN, CTYPE_OUT>::value, |
235 | | - CTYPE_A, |
236 | | - CTYPE_B, |
237 | | - CTYPE_IN, |
238 | | - CTYPE_OUT>::run(a, b, alpha_val, out); |
239 | | - }); |
240 | | - }); |
241 | | - }); |
| 214 | + sub_out_impl(ctx, a, b, alpha, out); |
242 | 215 | } |
243 | 216 |
|
244 | 217 | return out; |
@@ -290,31 +263,7 @@ Tensor& opt_sub_scalar_out( |
290 | 263 | }); |
291 | 264 | }); |
292 | 265 | } else { |
293 | | - ET_SWITCH_REALH_TYPES(a_type, ctx, "sub.Scalar_out", CTYPE_A, [&]() { |
294 | | - ET_SWITCH_SCALAR_OBJ_REAL_TYPES( |
295 | | - b_type, ctx, "sub.Scalar_out", CTYPE_B, [&]() { |
296 | | - ET_SWITCH_REAL_TYPES( |
297 | | - common_type, ctx, "sub.Scalar_out", CTYPE_IN, [&]() { |
298 | | - ET_SWITCH_REALH_TYPES( |
299 | | - out_type, ctx, "sub.Scalar_out", CTYPE_OUT, [&]() { |
300 | | - CTYPE_B b_val; |
301 | | - ET_EXTRACT_SCALAR(b, b_val); |
302 | | - CTYPE_IN b_casted = static_cast<CTYPE_IN>(b_val); |
303 | | - CTYPE_IN alpha_val; |
304 | | - ET_EXTRACT_SCALAR(alpha, alpha_val); |
305 | | - |
306 | | - const size_t n = a.numel(); |
307 | | - const CTYPE_A* a_data = a.const_data_ptr<CTYPE_A>(); |
308 | | - CTYPE_OUT* out_data = out.mutable_data_ptr<CTYPE_OUT>(); |
309 | | - for (auto i = 0; i < n; ++i) { |
310 | | - out_data[i] = static_cast<CTYPE_OUT>( |
311 | | - static_cast<CTYPE_IN>(a_data[i]) - |
312 | | - alpha_val * b_casted); |
313 | | - } |
314 | | - }); |
315 | | - }); |
316 | | - }); |
317 | | - }); |
| 266 | + sub_scalar_out_impl(ctx, a, b, alpha, out); |
318 | 267 | } |
319 | 268 |
|
320 | 269 | return out; |
|
0 commit comments