diff --git a/kernels/portable/cpu/op_sub.cpp b/kernels/portable/cpu/op_sub.cpp index 6217f82c3b1..15bf4ec1473 100644 --- a/kernels/portable/cpu/op_sub.cpp +++ b/kernels/portable/cpu/op_sub.cpp @@ -11,6 +11,8 @@ #include #include +#include + namespace torch { namespace executor { namespace native { @@ -21,55 +23,7 @@ Tensor& sub_out( const Tensor& b, const Scalar& alpha, Tensor& out) { - ScalarType alpha_type = utils::get_scalar_dtype(alpha); - - // Check alpha type - ET_KERNEL_CHECK(ctx, alpha_type != ScalarType::Bool, InvalidArgument, out); - - // Common Dtype - ScalarType common_type = promoteTypes(a.scalar_type(), b.scalar_type()); - - // Check Common Dtype - ET_KERNEL_CHECK( - ctx, - (canCast(common_type, out.scalar_type()) && - canCast(alpha_type, common_type)), - InvalidArgument, - out); - - // Check Dim Order - ET_KERNEL_CHECK( - ctx, tensors_have_same_dim_order(a, b, out), InvalidArgument, out); - - // Resize - ET_KERNEL_CHECK( - ctx, - resize_to_broadcast_target_size(a, b, out) == Error::Ok, - InvalidArgument, - out); - - // Compute Dtype - ScalarType compute_type = utils::get_compute_type(common_type); - - // @lint-ignore CLANGTIDY facebook-hte-CArray - static constexpr const char op_name[] = "sub.out"; - - ET_SWITCH_REAL_TYPES(compute_type, ctx, op_name, CTYPE_COMPUTE, [&]() { - const CTYPE_COMPUTE val_alpha = utils::scalar_to(alpha); - utils::apply_bitensor_elementwise_fn( - [val_alpha](const CTYPE_COMPUTE val_a, const CTYPE_COMPUTE val_b) { - return val_a - val_alpha * val_b; - }, - ctx, - a, - utils::SupportedTensorDtypes::REALHBF16, - b, - utils::SupportedTensorDtypes::REALHBF16, - out, - utils::SupportedTensorDtypes::REALHBF16); - }); - - return out; + return sub_out_impl(ctx, a, b, alpha, out); } Tensor& sub_scalar_out( @@ -78,50 +32,7 @@ Tensor& sub_scalar_out( const Scalar& b, const Scalar& alpha, Tensor& out) { - ScalarType alpha_type = utils::get_scalar_dtype(alpha); - - // Check alpha type - ET_KERNEL_CHECK(ctx, alpha_type != ScalarType::Bool, InvalidArgument, out); - - // Common Dtype - ScalarType common_type = utils::promote_type_with_scalar(a.scalar_type(), b); - - // Check Common Dtype - ET_KERNEL_CHECK( - ctx, - (common_type == out.scalar_type() && canCast(alpha_type, common_type)), - InvalidArgument, - out); - - // Check Dim Order - ET_KERNEL_CHECK( - ctx, tensors_have_same_dim_order(a, out), InvalidArgument, out); - - // Resize - ET_KERNEL_CHECK( - ctx, resize_tensor(out, a.sizes()) == Error::Ok, InvalidArgument, out); - - // Compute Dtype - ScalarType compute_type = utils::get_compute_type(common_type); - - // @lint-ignore CLANGTIDY facebook-hte-CArray - static constexpr const char op_name[] = "sub.Scalar_out"; - - ET_SWITCH_REAL_TYPES(compute_type, ctx, op_name, CTYPE_COMPUTE, [&]() { - const CTYPE_COMPUTE val_b = utils::scalar_to(b); - const CTYPE_COMPUTE val_alpha = utils::scalar_to(alpha); - utils::apply_unitensor_elementwise_fn( - [val_b, val_alpha](const CTYPE_COMPUTE val_a) { - return val_a - val_alpha * val_b; - }, - ctx, - a, - utils::SupportedTensorDtypes::REALHBF16, - out, - utils::SupportedTensorDtypes::SAME_AS_COMMON); - }); - - return out; + return sub_scalar_out_impl(ctx, a, b, alpha, out); } } // namespace native diff --git a/kernels/portable/cpu/op_sub_impl.cpp b/kernels/portable/cpu/op_sub_impl.cpp new file mode 100644 index 00000000000..e18be372cd6 --- /dev/null +++ b/kernels/portable/cpu/op_sub_impl.cpp @@ -0,0 +1,131 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include +#include +#include + +#include + +namespace torch { +namespace executor { +namespace native { + +Tensor& sub_out_impl( + KernelRuntimeContext& ctx, + const Tensor& a, + const Tensor& b, + const Scalar& alpha, + Tensor& out) { + ScalarType alpha_type = utils::get_scalar_dtype(alpha); + + // Check alpha type + ET_KERNEL_CHECK(ctx, alpha_type != ScalarType::Bool, InvalidArgument, out); + + // Common Dtype + ScalarType common_type = promoteTypes(a.scalar_type(), b.scalar_type()); + + // Check Common Dtype + ET_KERNEL_CHECK( + ctx, + (canCast(common_type, out.scalar_type()) && + canCast(alpha_type, common_type)), + InvalidArgument, + out); + + // Check Dim Order + ET_KERNEL_CHECK( + ctx, tensors_have_same_dim_order(a, b, out), InvalidArgument, out); + + // Resize + ET_KERNEL_CHECK( + ctx, + resize_to_broadcast_target_size(a, b, out) == Error::Ok, + InvalidArgument, + out); + + // Compute Dtype + ScalarType compute_type = utils::get_compute_type(common_type); + + // @lint-ignore CLANGTIDY facebook-hte-CArray + static constexpr const char op_name[] = "sub.out"; + + ET_SWITCH_REAL_TYPES(compute_type, ctx, op_name, CTYPE_COMPUTE, [&]() { + const CTYPE_COMPUTE val_alpha = utils::scalar_to(alpha); + utils::apply_bitensor_elementwise_fn( + [val_alpha](const CTYPE_COMPUTE val_a, const CTYPE_COMPUTE val_b) { + return val_a - val_alpha * val_b; + }, + ctx, + a, + utils::SupportedTensorDtypes::REALHBF16, + b, + utils::SupportedTensorDtypes::REALHBF16, + out, + utils::SupportedTensorDtypes::REALHBF16); + }); + + return out; +} + +Tensor& sub_scalar_out_impl( + KernelRuntimeContext& ctx, + const Tensor& a, + const Scalar& b, + const Scalar& alpha, + Tensor& out) { + ScalarType alpha_type = utils::get_scalar_dtype(alpha); + + // Check alpha type + ET_KERNEL_CHECK(ctx, alpha_type != ScalarType::Bool, InvalidArgument, out); + + // Common Dtype + ScalarType common_type = utils::promote_type_with_scalar(a.scalar_type(), b); + + // Check Common Dtype + ET_KERNEL_CHECK( + ctx, + (common_type == out.scalar_type() && canCast(alpha_type, common_type)), + InvalidArgument, + out); + + // Check Dim Order + ET_KERNEL_CHECK( + ctx, tensors_have_same_dim_order(a, out), InvalidArgument, out); + + // Resize + ET_KERNEL_CHECK( + ctx, resize_tensor(out, a.sizes()) == Error::Ok, InvalidArgument, out); + + // Compute Dtype + ScalarType compute_type = utils::get_compute_type(common_type); + + // @lint-ignore CLANGTIDY facebook-hte-CArray + static constexpr const char op_name[] = "sub.Scalar_out"; + + ET_SWITCH_REAL_TYPES(compute_type, ctx, op_name, CTYPE_COMPUTE, [&]() { + const CTYPE_COMPUTE val_b = utils::scalar_to(b); + const CTYPE_COMPUTE val_alpha = utils::scalar_to(alpha); + utils::apply_unitensor_elementwise_fn( + [val_b, val_alpha](const CTYPE_COMPUTE val_a) { + return val_a - val_alpha * val_b; + }, + ctx, + a, + utils::SupportedTensorDtypes::REALHBF16, + out, + utils::SupportedTensorDtypes::SAME_AS_COMMON); + }); + + return out; +} + +} // namespace native +} // namespace executor +} // namespace torch diff --git a/kernels/portable/cpu/op_sub_impl.h b/kernels/portable/cpu/op_sub_impl.h new file mode 100644 index 00000000000..ddff2e2a39d --- /dev/null +++ b/kernels/portable/cpu/op_sub_impl.h @@ -0,0 +1,34 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include +#include +#include + +namespace torch { +namespace executor { +namespace native { + +Tensor& sub_out_impl( + KernelRuntimeContext& ctx, + const Tensor& a, + const Tensor& b, + const Scalar& alpha, + Tensor& out); + +Tensor& sub_scalar_out_impl( + KernelRuntimeContext& ctx, + const Tensor& a, + const Scalar& b, + const Scalar& alpha, + Tensor& out); + +} // namespace native +} // namespace executor +} // namespace torch diff --git a/kernels/portable/cpu/targets.bzl b/kernels/portable/cpu/targets.bzl index afd98188a3f..e3f69e089ee 100644 --- a/kernels/portable/cpu/targets.bzl +++ b/kernels/portable/cpu/targets.bzl @@ -151,6 +151,26 @@ def define_common_targets(): ], ) + runtime.cxx_library( + name = "op_sub_impl", + srcs = ["op_sub_impl.cpp"], + exported_headers = ["op_sub_impl.h"], + visibility = [ + "//executorch/kernels/portable/cpu/...", + "//executorch/kernels/optimized/cpu/...", + "//executorch/kernels/portable/test/...", + "@EXECUTORCH_CLIENTS", + ], + exported_deps = [ + "//executorch/kernels/portable/cpu/util:broadcast_util", + "//executorch/kernels/portable/cpu/util:dtype_util", + "//executorch/kernels/portable/cpu/util:elementwise_util", + "//executorch/kernels/portable/cpu/util:kernel_ops_util", + "//executorch/kernels/portable/cpu:scalar_utils", + "//executorch/runtime/kernel:kernel_includes", + ], + ) + # The following will not participate in dtype selective build because # they are refactored such to be used in optimized op implementations as well # and we have not enabled selective build for optimized ops. @@ -165,6 +185,7 @@ def define_common_targets(): "//executorch/kernels/portable/cpu:op_add_impl", "//executorch/kernels/portable/cpu:op_div_impl", "//executorch/kernels/portable/cpu:op_mul_impl", + "//executorch/kernels/portable/cpu:op_sub_impl", ], visibility = ["//executorch/...", "@EXECUTORCH_CLIENTS"], ) diff --git a/shim/xplat/executorch/kernels/portable/op_registration_util.bzl b/shim/xplat/executorch/kernels/portable/op_registration_util.bzl index e7686784833..9c0e0eba3c9 100644 --- a/shim/xplat/executorch/kernels/portable/op_registration_util.bzl +++ b/shim/xplat/executorch/kernels/portable/op_registration_util.bzl @@ -1138,10 +1138,7 @@ ATEN_OPS = ( op_target( name = "op_sub", deps = [ - ":scalar_utils", - "//executorch/kernels/portable/cpu/util:broadcast_util", - "//executorch/kernels/portable/cpu/util:dtype_util", - "//executorch/kernels/portable/cpu/util:elementwise_util", + "//executorch/kernels/portable/cpu:op_sub_impl", ], ), op_target( @@ -1264,4 +1261,4 @@ def portable_source_list(): def portable_header_list(): """All the header file names from //executorch/kernels/portable/cpu/""" - return ["selective_build.h", "scalar_utils.h", "math_constants.h", "vec_ops.h", "op_div_impl.h", "op_mul_impl.h", "op_add_impl.h"] + return ["selective_build.h", "scalar_utils.h", "math_constants.h", "vec_ops.h", "op_div_impl.h", "op_mul_impl.h", "op_add_impl.h", "op_sub_impl.h"]