diff --git a/kernels/optimized/cpu/op_sub.cpp b/kernels/optimized/cpu/op_sub.cpp index 51ff4fbd571..b4c9d1e9752 100644 --- a/kernels/optimized/cpu/op_sub.cpp +++ b/kernels/optimized/cpu/op_sub.cpp @@ -9,6 +9,7 @@ #include #include #include +#include #include #include #include @@ -210,35 +211,7 @@ Tensor& opt_sub_out( } }); } else { - ScalarType common_type = - promoteTypes(a_type, b_type, /*half_to_float*/ true); - ET_KERNEL_CHECK(ctx, canCast(common_type, out_type), InvalidArgument, out); - - ET_KERNEL_CHECK( - ctx, - resize_to_broadcast_target_size(a, b, out) == Error::Ok, - InvalidArgument, - out); - - ET_SWITCH_REALH_TYPES(a_type, ctx, "sub.out", CTYPE_A, [&]() { - ET_SWITCH_REALH_TYPES(b_type, ctx, "sub.out", CTYPE_B, [&]() { - using CTYPE_IN = typename torch::executor:: - promote_types::type; - ET_DCHECK(CppTypeToScalarType::value == common_type); - ET_SWITCH_REALH_TYPES(out_type, ctx, "sub.out", CTYPE_OUT, [&]() { - CTYPE_IN alpha_val; - ET_KERNEL_CHECK( - ctx, utils::extract_scalar(alpha, &alpha_val), InvalidArgument, ); - - SubInner< - can_cast::value, - CTYPE_A, - CTYPE_B, - CTYPE_IN, - CTYPE_OUT>::run(a, b, alpha_val, out); - }); - }); - }); + sub_out_impl(ctx, a, b, alpha, out); } return out; @@ -290,31 +263,7 @@ Tensor& opt_sub_scalar_out( }); }); } else { - ET_SWITCH_REALH_TYPES(a_type, ctx, "sub.Scalar_out", CTYPE_A, [&]() { - ET_SWITCH_SCALAR_OBJ_REAL_TYPES( - b_type, ctx, "sub.Scalar_out", CTYPE_B, [&]() { - ET_SWITCH_REAL_TYPES( - common_type, ctx, "sub.Scalar_out", CTYPE_IN, [&]() { - ET_SWITCH_REALH_TYPES( - out_type, ctx, "sub.Scalar_out", CTYPE_OUT, [&]() { - CTYPE_B b_val; - ET_EXTRACT_SCALAR(b, b_val); - CTYPE_IN b_casted = static_cast(b_val); - CTYPE_IN alpha_val; - ET_EXTRACT_SCALAR(alpha, alpha_val); - - const size_t n = a.numel(); - const CTYPE_A* a_data = a.const_data_ptr(); - CTYPE_OUT* out_data = out.mutable_data_ptr(); - for (auto i = 0; i < n; ++i) { - out_data[i] = static_cast( - static_cast(a_data[i]) - - alpha_val * b_casted); - } - }); - }); - }); - }); + sub_scalar_out_impl(ctx, a, b, alpha, out); } return out; diff --git a/kernels/optimized/cpu/targets.bzl b/kernels/optimized/cpu/targets.bzl index c4ed508fa60..04efec9946c 100644 --- a/kernels/optimized/cpu/targets.bzl +++ b/kernels/optimized/cpu/targets.bzl @@ -88,6 +88,7 @@ _OPTIMIZED_ATEN_OPS = ( name = "op_sub", deps = [ ":binary_ops", + "//executorch/kernels/portable/cpu:op_sub_impl", "//executorch/kernels/portable/cpu:scalar_utils", "//executorch/kernels/portable/cpu/util:broadcast_util", ],