From c2183fb318ef02c99e2b32a697eea66a5ce01df4 Mon Sep 17 00:00:00 2001 From: Kimish Patel Date: Fri, 8 Nov 2024 07:14:11 -0800 Subject: [PATCH] [Exeuctorch][Portale] Refactor op_add to be reused by optimized fallback If we can use portable for optimized's fallback then we can remove copy pasted stuff and take advantage of size and build reduction efforts in portable. Differential Revision: [D65632373](https://our.internmc.facebook.com/intern/diff/D65632373/) [ghstack-poisoned] --- kernels/portable/cpu/op_add.cpp | 92 +------------ kernels/portable/cpu/op_add_impl.cpp | 123 ++++++++++++++++++ kernels/portable/cpu/op_add_impl.h | 31 +++++ kernels/portable/cpu/targets.bzl | 21 +++ .../kernels/portable/op_registration_util.bzl | 8 +- 5 files changed, 181 insertions(+), 94 deletions(-) create mode 100644 kernels/portable/cpu/op_add_impl.cpp create mode 100644 kernels/portable/cpu/op_add_impl.h diff --git a/kernels/portable/cpu/op_add.cpp b/kernels/portable/cpu/op_add.cpp index adb9d4ea723..66944ae0c4f 100644 --- a/kernels/portable/cpu/op_add.cpp +++ b/kernels/portable/cpu/op_add.cpp @@ -6,11 +6,9 @@ * LICENSE file in the root directory of this source tree. */ -#include -#include -#include #include -#include + +#include namespace torch { namespace executor { @@ -22,50 +20,7 @@ Tensor& add_out( const Tensor& b, const Scalar& alpha, Tensor& out) { - // Common Dtype - ScalarType common_type = promoteTypes(a.scalar_type(), b.scalar_type()); - - // Check Common Dtype - ET_KERNEL_CHECK( - ctx, - (canCast(common_type, out.scalar_type()) && - check_alpha_type(utils::get_scalar_dtype(alpha), common_type)), - InvalidArgument, - out); - - // Check Dim Order - ET_KERNEL_CHECK( - ctx, tensors_have_same_dim_order(a, b, out), InvalidArgument, out); - - // Resize - ET_KERNEL_CHECK( - ctx, - resize_to_broadcast_target_size(a, b, out) == Error::Ok, - InvalidArgument, - out); - - // Compute Dtype - ScalarType compute_type = utils::get_compute_type(common_type); - - // @lint-ignore CLANGTIDY facebook-hte-CArray - static constexpr const char op_name[] = "add.out"; - - ET_SWITCH_REALB_TYPES(compute_type, ctx, op_name, CTYPE_COMPUTE, [&]() { - const CTYPE_COMPUTE val_alpha = utils::scalar_to(alpha); - utils::apply_bitensor_elementwise_fn( - [val_alpha](const CTYPE_COMPUTE val_a, const CTYPE_COMPUTE val_b) { - return val_a + val_alpha * val_b; - }, - ctx, - a, - utils::SupportedTensorDtypes::REALHBBF16, - b, - utils::SupportedTensorDtypes::REALHBBF16, - out, - utils::SupportedTensorDtypes::REALHBBF16); - }); - - return out; + return add_out_impl(ctx, a, b, alpha, out); } Tensor& add_scalar_out( @@ -74,46 +29,7 @@ Tensor& add_scalar_out( const Scalar& b, const Scalar& alpha, Tensor& out) { - // Common Dtype - ScalarType common_type = utils::promote_type_with_scalar(a.scalar_type(), b); - - // Check Common Dtype - ET_KERNEL_CHECK( - ctx, - (common_type == out.scalar_type() && - check_alpha_type(utils::get_scalar_dtype(alpha), common_type)), - InvalidArgument, - out); - - // Check Dim Order - ET_KERNEL_CHECK( - ctx, tensors_have_same_dim_order(a, out), InvalidArgument, out); - - // Resize - ET_KERNEL_CHECK( - ctx, resize_tensor(out, a.sizes()) == Error::Ok, InvalidArgument, out); - - // Compute Dtype - ScalarType compute_type = utils::get_compute_type(common_type); - - // @lint-ignore CLANGTIDY facebook-hte-CArray - static constexpr const char op_name[] = "add.Scalar_out"; - - ET_SWITCH_REALB_TYPES(compute_type, ctx, op_name, CTYPE_COMPUTE, [&]() { - utils::apply_unitensor_elementwise_fn( - [b, alpha](const CTYPE_COMPUTE val_a) { - CTYPE_COMPUTE val_b = utils::scalar_to(b); - CTYPE_COMPUTE val_alpha = utils::scalar_to(alpha); - return val_a + val_alpha * val_b; - }, - ctx, - a, - utils::SupportedTensorDtypes::REALHBBF16, - out, - utils::SupportedTensorDtypes::SAME_AS_COMMON); - }); - - return out; + return add_scalar_out_impl(ctx, a, b, alpha, out); } } // namespace native diff --git a/kernels/portable/cpu/op_add_impl.cpp b/kernels/portable/cpu/op_add_impl.cpp new file mode 100644 index 00000000000..6d5a8075818 --- /dev/null +++ b/kernels/portable/cpu/op_add_impl.cpp @@ -0,0 +1,123 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include +#include +#include +#include + +#include + +namespace torch { +namespace executor { +namespace native { + +Tensor& add_out_impl( + KernelRuntimeContext& ctx, + const Tensor& a, + const Tensor& b, + const Scalar& alpha, + Tensor& out) { + // Common Dtype + ScalarType common_type = promoteTypes(a.scalar_type(), b.scalar_type()); + + // Check Common Dtype + ET_KERNEL_CHECK( + ctx, + (canCast(common_type, out.scalar_type()) && + check_alpha_type(utils::get_scalar_dtype(alpha), common_type)), + InvalidArgument, + out); + + // Check Dim Order + ET_KERNEL_CHECK( + ctx, tensors_have_same_dim_order(a, b, out), InvalidArgument, out); + + // Resize + ET_KERNEL_CHECK( + ctx, + resize_to_broadcast_target_size(a, b, out) == Error::Ok, + InvalidArgument, + out); + + // Compute Dtype + ScalarType compute_type = utils::get_compute_type(common_type); + + // @lint-ignore CLANGTIDY facebook-hte-CArray + static constexpr const char op_name[] = "add.out"; + + ET_SWITCH_REALB_TYPES(compute_type, ctx, op_name, CTYPE_COMPUTE, [&]() { + const CTYPE_COMPUTE val_alpha = utils::scalar_to(alpha); + utils::apply_bitensor_elementwise_fn( + [val_alpha](const CTYPE_COMPUTE val_a, const CTYPE_COMPUTE val_b) { + return val_a + val_alpha * val_b; + }, + ctx, + a, + utils::SupportedTensorDtypes::REALHBBF16, + b, + utils::SupportedTensorDtypes::REALHBBF16, + out, + utils::SupportedTensorDtypes::REALHBBF16); + }); + + return out; +} + +Tensor& add_scalar_out_impl( + KernelRuntimeContext& ctx, + const Tensor& a, + const Scalar& b, + const Scalar& alpha, + Tensor& out) { + // Common Dtype + ScalarType common_type = utils::promote_type_with_scalar(a.scalar_type(), b); + + // Check Common Dtype + ET_KERNEL_CHECK( + ctx, + (common_type == out.scalar_type() && + check_alpha_type(utils::get_scalar_dtype(alpha), common_type)), + InvalidArgument, + out); + + // Check Dim Order + ET_KERNEL_CHECK( + ctx, tensors_have_same_dim_order(a, out), InvalidArgument, out); + + // Resize + ET_KERNEL_CHECK( + ctx, resize_tensor(out, a.sizes()) == Error::Ok, InvalidArgument, out); + + // Compute Dtype + ScalarType compute_type = utils::get_compute_type(common_type); + + // @lint-ignore CLANGTIDY facebook-hte-CArray + static constexpr const char op_name[] = "add.Scalar_out"; + + ET_SWITCH_REALB_TYPES(compute_type, ctx, op_name, CTYPE_COMPUTE, [&]() { + utils::apply_unitensor_elementwise_fn( + [b, alpha](const CTYPE_COMPUTE val_a) { + CTYPE_COMPUTE val_b = utils::scalar_to(b); + CTYPE_COMPUTE val_alpha = utils::scalar_to(alpha); + return val_a + val_alpha * val_b; + }, + ctx, + a, + utils::SupportedTensorDtypes::REALHBBF16, + out, + utils::SupportedTensorDtypes::SAME_AS_COMMON); + }); + + return out; +} + +} // namespace native +} // namespace executor +} // namespace torch diff --git a/kernels/portable/cpu/op_add_impl.h b/kernels/portable/cpu/op_add_impl.h new file mode 100644 index 00000000000..f764f48f7e6 --- /dev/null +++ b/kernels/portable/cpu/op_add_impl.h @@ -0,0 +1,31 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include + +namespace torch { +namespace executor { +namespace native { + +Tensor& add_out_impl( + KernelRuntimeContext& ctx, + const Tensor& a, + const Tensor& b, + const Scalar& alpha, + Tensor& out); + +Tensor& add_scalar_out_impl( + KernelRuntimeContext& ctx, + const Tensor& a, + const Scalar& b, + const Scalar& alpha, + Tensor& out); + +} // namespace native +} // namespace executor +} // namespace torch diff --git a/kernels/portable/cpu/targets.bzl b/kernels/portable/cpu/targets.bzl index e6cb5c02ff4..afd98188a3f 100644 --- a/kernels/portable/cpu/targets.bzl +++ b/kernels/portable/cpu/targets.bzl @@ -131,6 +131,26 @@ def define_common_targets(): ], ) + runtime.cxx_library( + name = "op_add_impl", + srcs = ["op_add_impl.cpp"], + exported_headers = ["op_add_impl.h"], + visibility = [ + "//executorch/kernels/portable/cpu/...", + "//executorch/kernels/optimized/cpu/...", + "//executorch/kernels/portable/test/...", + "@EXECUTORCH_CLIENTS", + ], + exported_deps = [ + "//executorch/kernels/portable/cpu/util:broadcast_util", + "//executorch/kernels/portable/cpu/util:dtype_util", + "//executorch/kernels/portable/cpu/util:elementwise_util", + "//executorch/kernels/portable/cpu/util:kernel_ops_util", + "//executorch/kernels/portable/cpu:scalar_utils", + "//executorch/runtime/kernel:kernel_includes", + ], + ) + # The following will not participate in dtype selective build because # they are refactored such to be used in optimized op implementations as well # and we have not enabled selective build for optimized ops. @@ -142,6 +162,7 @@ def define_common_targets(): runtime.cxx_library( name = "all_impl_deps", deps = [ + "//executorch/kernels/portable/cpu:op_add_impl", "//executorch/kernels/portable/cpu:op_div_impl", "//executorch/kernels/portable/cpu:op_mul_impl", ], diff --git a/shim/xplat/executorch/kernels/portable/op_registration_util.bzl b/shim/xplat/executorch/kernels/portable/op_registration_util.bzl index b1a612a7e74..e7686784833 100644 --- a/shim/xplat/executorch/kernels/portable/op_registration_util.bzl +++ b/shim/xplat/executorch/kernels/portable/op_registration_util.bzl @@ -220,11 +220,7 @@ ATEN_OPS = ( op_target( name = "op_add", deps = [ - "//executorch/kernels/portable/cpu/util:broadcast_util", - "//executorch/kernels/portable/cpu/util:dtype_util", - "//executorch/kernels/portable/cpu/util:elementwise_util", - "//executorch/kernels/portable/cpu/util:kernel_ops_util", - ":scalar_utils", + "//executorch/kernels/portable/cpu:op_add_impl", ], ), op_target( @@ -1268,4 +1264,4 @@ def portable_source_list(): def portable_header_list(): """All the header file names from //executorch/kernels/portable/cpu/""" - return ["selective_build.h", "scalar_utils.h", "math_constants.h", "vec_ops.h", "op_div_impl.h", "op_mul_impl.h"] + return ["selective_build.h", "scalar_utils.h", "math_constants.h", "vec_ops.h", "op_div_impl.h", "op_mul_impl.h", "op_add_impl.h"]