diff --git a/kernels/portable/cpu/op_add.cpp b/kernels/portable/cpu/op_add.cpp index adb9d4ea723..66944ae0c4f 100644 --- a/kernels/portable/cpu/op_add.cpp +++ b/kernels/portable/cpu/op_add.cpp @@ -6,11 +6,9 @@ * LICENSE file in the root directory of this source tree. */ -#include -#include -#include #include -#include + +#include namespace torch { namespace executor { @@ -22,50 +20,7 @@ Tensor& add_out( const Tensor& b, const Scalar& alpha, Tensor& out) { - // Common Dtype - ScalarType common_type = promoteTypes(a.scalar_type(), b.scalar_type()); - - // Check Common Dtype - ET_KERNEL_CHECK( - ctx, - (canCast(common_type, out.scalar_type()) && - check_alpha_type(utils::get_scalar_dtype(alpha), common_type)), - InvalidArgument, - out); - - // Check Dim Order - ET_KERNEL_CHECK( - ctx, tensors_have_same_dim_order(a, b, out), InvalidArgument, out); - - // Resize - ET_KERNEL_CHECK( - ctx, - resize_to_broadcast_target_size(a, b, out) == Error::Ok, - InvalidArgument, - out); - - // Compute Dtype - ScalarType compute_type = utils::get_compute_type(common_type); - - // @lint-ignore CLANGTIDY facebook-hte-CArray - static constexpr const char op_name[] = "add.out"; - - ET_SWITCH_REALB_TYPES(compute_type, ctx, op_name, CTYPE_COMPUTE, [&]() { - const CTYPE_COMPUTE val_alpha = utils::scalar_to(alpha); - utils::apply_bitensor_elementwise_fn( - [val_alpha](const CTYPE_COMPUTE val_a, const CTYPE_COMPUTE val_b) { - return val_a + val_alpha * val_b; - }, - ctx, - a, - utils::SupportedTensorDtypes::REALHBBF16, - b, - utils::SupportedTensorDtypes::REALHBBF16, - out, - utils::SupportedTensorDtypes::REALHBBF16); - }); - - return out; + return add_out_impl(ctx, a, b, alpha, out); } Tensor& add_scalar_out( @@ -74,46 +29,7 @@ Tensor& add_scalar_out( const Scalar& b, const Scalar& alpha, Tensor& out) { - // Common Dtype - ScalarType common_type = utils::promote_type_with_scalar(a.scalar_type(), b); - - // Check Common Dtype - ET_KERNEL_CHECK( - ctx, - (common_type == out.scalar_type() && - check_alpha_type(utils::get_scalar_dtype(alpha), common_type)), - InvalidArgument, - out); - - // Check Dim Order - ET_KERNEL_CHECK( - ctx, tensors_have_same_dim_order(a, out), InvalidArgument, out); - - // Resize - ET_KERNEL_CHECK( - ctx, resize_tensor(out, a.sizes()) == Error::Ok, InvalidArgument, out); - - // Compute Dtype - ScalarType compute_type = utils::get_compute_type(common_type); - - // @lint-ignore CLANGTIDY facebook-hte-CArray - static constexpr const char op_name[] = "add.Scalar_out"; - - ET_SWITCH_REALB_TYPES(compute_type, ctx, op_name, CTYPE_COMPUTE, [&]() { - utils::apply_unitensor_elementwise_fn( - [b, alpha](const CTYPE_COMPUTE val_a) { - CTYPE_COMPUTE val_b = utils::scalar_to(b); - CTYPE_COMPUTE val_alpha = utils::scalar_to(alpha); - return val_a + val_alpha * val_b; - }, - ctx, - a, - utils::SupportedTensorDtypes::REALHBBF16, - out, - utils::SupportedTensorDtypes::SAME_AS_COMMON); - }); - - return out; + return add_scalar_out_impl(ctx, a, b, alpha, out); } } // namespace native diff --git a/kernels/portable/cpu/op_add_impl.cpp b/kernels/portable/cpu/op_add_impl.cpp new file mode 100644 index 00000000000..6d5a8075818 --- /dev/null +++ b/kernels/portable/cpu/op_add_impl.cpp @@ -0,0 +1,123 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include +#include +#include +#include + +#include + +namespace torch { +namespace executor { +namespace native { + +Tensor& add_out_impl( + KernelRuntimeContext& ctx, + const Tensor& a, + const Tensor& b, + const Scalar& alpha, + Tensor& out) { + // Common Dtype + ScalarType common_type = promoteTypes(a.scalar_type(), b.scalar_type()); + + // Check Common Dtype + ET_KERNEL_CHECK( + ctx, + (canCast(common_type, out.scalar_type()) && + check_alpha_type(utils::get_scalar_dtype(alpha), common_type)), + InvalidArgument, + out); + + // Check Dim Order + ET_KERNEL_CHECK( + ctx, tensors_have_same_dim_order(a, b, out), InvalidArgument, out); + + // Resize + ET_KERNEL_CHECK( + ctx, + resize_to_broadcast_target_size(a, b, out) == Error::Ok, + InvalidArgument, + out); + + // Compute Dtype + ScalarType compute_type = utils::get_compute_type(common_type); + + // @lint-ignore CLANGTIDY facebook-hte-CArray + static constexpr const char op_name[] = "add.out"; + + ET_SWITCH_REALB_TYPES(compute_type, ctx, op_name, CTYPE_COMPUTE, [&]() { + const CTYPE_COMPUTE val_alpha = utils::scalar_to(alpha); + utils::apply_bitensor_elementwise_fn( + [val_alpha](const CTYPE_COMPUTE val_a, const CTYPE_COMPUTE val_b) { + return val_a + val_alpha * val_b; + }, + ctx, + a, + utils::SupportedTensorDtypes::REALHBBF16, + b, + utils::SupportedTensorDtypes::REALHBBF16, + out, + utils::SupportedTensorDtypes::REALHBBF16); + }); + + return out; +} + +Tensor& add_scalar_out_impl( + KernelRuntimeContext& ctx, + const Tensor& a, + const Scalar& b, + const Scalar& alpha, + Tensor& out) { + // Common Dtype + ScalarType common_type = utils::promote_type_with_scalar(a.scalar_type(), b); + + // Check Common Dtype + ET_KERNEL_CHECK( + ctx, + (common_type == out.scalar_type() && + check_alpha_type(utils::get_scalar_dtype(alpha), common_type)), + InvalidArgument, + out); + + // Check Dim Order + ET_KERNEL_CHECK( + ctx, tensors_have_same_dim_order(a, out), InvalidArgument, out); + + // Resize + ET_KERNEL_CHECK( + ctx, resize_tensor(out, a.sizes()) == Error::Ok, InvalidArgument, out); + + // Compute Dtype + ScalarType compute_type = utils::get_compute_type(common_type); + + // @lint-ignore CLANGTIDY facebook-hte-CArray + static constexpr const char op_name[] = "add.Scalar_out"; + + ET_SWITCH_REALB_TYPES(compute_type, ctx, op_name, CTYPE_COMPUTE, [&]() { + utils::apply_unitensor_elementwise_fn( + [b, alpha](const CTYPE_COMPUTE val_a) { + CTYPE_COMPUTE val_b = utils::scalar_to(b); + CTYPE_COMPUTE val_alpha = utils::scalar_to(alpha); + return val_a + val_alpha * val_b; + }, + ctx, + a, + utils::SupportedTensorDtypes::REALHBBF16, + out, + utils::SupportedTensorDtypes::SAME_AS_COMMON); + }); + + return out; +} + +} // namespace native +} // namespace executor +} // namespace torch diff --git a/kernels/portable/cpu/op_add_impl.h b/kernels/portable/cpu/op_add_impl.h new file mode 100644 index 00000000000..f764f48f7e6 --- /dev/null +++ b/kernels/portable/cpu/op_add_impl.h @@ -0,0 +1,31 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include + +namespace torch { +namespace executor { +namespace native { + +Tensor& add_out_impl( + KernelRuntimeContext& ctx, + const Tensor& a, + const Tensor& b, + const Scalar& alpha, + Tensor& out); + +Tensor& add_scalar_out_impl( + KernelRuntimeContext& ctx, + const Tensor& a, + const Scalar& b, + const Scalar& alpha, + Tensor& out); + +} // namespace native +} // namespace executor +} // namespace torch diff --git a/kernels/portable/cpu/targets.bzl b/kernels/portable/cpu/targets.bzl index e6cb5c02ff4..afd98188a3f 100644 --- a/kernels/portable/cpu/targets.bzl +++ b/kernels/portable/cpu/targets.bzl @@ -131,6 +131,26 @@ def define_common_targets(): ], ) + runtime.cxx_library( + name = "op_add_impl", + srcs = ["op_add_impl.cpp"], + exported_headers = ["op_add_impl.h"], + visibility = [ + "//executorch/kernels/portable/cpu/...", + "//executorch/kernels/optimized/cpu/...", + "//executorch/kernels/portable/test/...", + "@EXECUTORCH_CLIENTS", + ], + exported_deps = [ + "//executorch/kernels/portable/cpu/util:broadcast_util", + "//executorch/kernels/portable/cpu/util:dtype_util", + "//executorch/kernels/portable/cpu/util:elementwise_util", + "//executorch/kernels/portable/cpu/util:kernel_ops_util", + "//executorch/kernels/portable/cpu:scalar_utils", + "//executorch/runtime/kernel:kernel_includes", + ], + ) + # The following will not participate in dtype selective build because # they are refactored such to be used in optimized op implementations as well # and we have not enabled selective build for optimized ops. @@ -142,6 +162,7 @@ def define_common_targets(): runtime.cxx_library( name = "all_impl_deps", deps = [ + "//executorch/kernels/portable/cpu:op_add_impl", "//executorch/kernels/portable/cpu:op_div_impl", "//executorch/kernels/portable/cpu:op_mul_impl", ], diff --git a/shim/xplat/executorch/kernels/portable/op_registration_util.bzl b/shim/xplat/executorch/kernels/portable/op_registration_util.bzl index b1a612a7e74..e7686784833 100644 --- a/shim/xplat/executorch/kernels/portable/op_registration_util.bzl +++ b/shim/xplat/executorch/kernels/portable/op_registration_util.bzl @@ -220,11 +220,7 @@ ATEN_OPS = ( op_target( name = "op_add", deps = [ - "//executorch/kernels/portable/cpu/util:broadcast_util", - "//executorch/kernels/portable/cpu/util:dtype_util", - "//executorch/kernels/portable/cpu/util:elementwise_util", - "//executorch/kernels/portable/cpu/util:kernel_ops_util", - ":scalar_utils", + "//executorch/kernels/portable/cpu:op_add_impl", ], ), op_target( @@ -1268,4 +1264,4 @@ def portable_source_list(): def portable_header_list(): """All the header file names from //executorch/kernels/portable/cpu/""" - return ["selective_build.h", "scalar_utils.h", "math_constants.h", "vec_ops.h", "op_div_impl.h", "op_mul_impl.h"] + return ["selective_build.h", "scalar_utils.h", "math_constants.h", "vec_ops.h", "op_div_impl.h", "op_mul_impl.h", "op_add_impl.h"]