diff --git a/kernels/optimized/cpu/op_add.cpp b/kernels/optimized/cpu/op_add.cpp index dbf828e5882..5f164f1eb13 100644 --- a/kernels/optimized/cpu/op_add.cpp +++ b/kernels/optimized/cpu/op_add.cpp @@ -14,59 +14,11 @@ #include #include +#include + namespace torch { namespace executor { namespace native { -namespace { - -template < - bool can_cast, - typename CTYPE_A, - typename CTYPE_B, - typename CTYPE_IN, - typename CTYPE_OUT> -struct AddInner; - -template < - typename CTYPE_A, - typename CTYPE_B, - typename CTYPE_IN, - typename CTYPE_OUT> -struct AddInner { - static void - run(const Tensor& a, const Tensor& b, CTYPE_IN alpha_val, Tensor& out) { - apply_binary_elementwise_fn( - // NOLINTNEXTLINE(facebook-hte-ConstantArgumentPassByValue) - [alpha_val](const CTYPE_A val_a, const CTYPE_B val_b) { - CTYPE_IN a_casted = static_cast(val_a); - CTYPE_IN b_casted = static_cast(val_b); - CTYPE_IN value = a_casted + alpha_val * b_casted; - - return static_cast(value); - }, - a, - b, - out); - } -}; - -template -struct ReportCanCastBug { - static void run(const Tensor&, const Tensor&, CTYPE_IN, Tensor&) { - ET_DCHECK_MSG(false, "BUG: canCast should have been checked above"); - } -}; - -template < - typename CTYPE_A, - typename CTYPE_B, - typename CTYPE_IN, - typename CTYPE_OUT> -struct AddInner - : public ReportCanCastBug {}; - -} // namespace - using Tensor = executorch::aten::Tensor; using ScalarType = executorch::aten::ScalarType; @@ -76,8 +28,6 @@ Tensor& opt_add_out( const Tensor& b, const Scalar& alpha, Tensor& out) { - (void)ctx; - ScalarType a_type = a.scalar_type(); ScalarType b_type = b.scalar_type(); ScalarType out_type = out.scalar_type(); @@ -95,7 +45,9 @@ Tensor& opt_add_out( ET_SWITCH_REALB_TYPES(b_type, ctx, "add.out", CTYPE_B, [&]() { CTYPE alpha_val; ET_KERNEL_CHECK( - ctx, utils::extract_scalar(alpha, &alpha_val), InvalidArgument, ); + ctx, + torch::executor::native::utils::extract_scalar(alpha, &alpha_val), + InvalidArgument, ); CTYPE_B b_val = *b.const_data_ptr(); CTYPE b_casted = static_cast(b_val); @@ -115,100 +67,9 @@ Tensor& opt_add_out( return opt_add_out(ctx, b, a, alpha, out); } - auto selected_optimized_path = select_optimized_path(a, b, out); - if (selected_optimized_path == ElementwiseOptimizedPath::kTreatAs1d) { - // Resize for dynamic shape - auto error = resize_tensor(out, a.sizes()); - ET_KERNEL_CHECK_MSG( - ctx, - error == Error::Ok, - InvalidArgument, - out, - "Failed to resize output tensor."); - - ET_SWITCH_REALB_TYPES(a_type, ctx, "add.out", CTYPE, [&]() { - CTYPE alpha_val; - ET_KERNEL_CHECK( - ctx, utils::extract_scalar(alpha, &alpha_val), InvalidArgument, ); - - using Vec = executorch::vec::Vectorized; - executorch::vec::map2( - [alpha_val](Vec x, Vec y) { return x + Vec(alpha_val) * y; }, - out.mutable_data_ptr(), - a.const_data_ptr(), - b.const_data_ptr(), - out.numel()); - }); - } else if (selected_optimized_path != ElementwiseOptimizedPath::kNone) { - ET_SWITCH_REALB_TYPES(out_type, ctx, "add.out", CTYPE, [&]() { - CTYPE alpha_val; - ET_KERNEL_CHECK_MSG( - ctx, - utils::extract_scalar(alpha, &alpha_val), - InvalidArgument, - out, - "Failed to extract scalar alpha."); - using Vec = executorch::vec::Vectorized; - Vec alpha_val_vec(alpha_val); - if (selected_optimized_path == - ElementwiseOptimizedPath::kBroadcast2dBy1dReverseArguments || - selected_optimized_path == - ElementwiseOptimizedPath::kBroadcastLastDimReverseArguments || - selected_optimized_path == - ElementwiseOptimizedPath::kBroadcastNdByNdReverseArguments) { - // Reason we swap out args here is because handle_broadcast_elementwise - // handles this selected_optimized_path option a bit differently. - // This should really be resolved in handle_broadcast_elementwise. - // However, the current blocker is that handle_broadcast_elementwise - // tries to be agnostic of op. This should be fixed, likely by moving - // lambda creation to handle_broadcast_elementwise and it be aware of - // which op is being executed. - auto add_lambda = [&alpha_val_vec](auto x, auto y) { - return y + alpha_val_vec * x; - }; - return torch::executor::handle_broadcast_elementwise( - ctx, add_lambda, a, b, out, selected_optimized_path, alpha); - } else { - auto add_lambda = [&alpha_val_vec](auto x, auto y) { - return x + alpha_val_vec * y; - }; - return torch::executor::handle_broadcast_elementwise( - ctx, add_lambda, a, b, out, selected_optimized_path, alpha); - } - }); - } else { - ScalarType common_type = - promoteTypes(a_type, b_type, /*half_to_float*/ true); - ET_KERNEL_CHECK(ctx, canCast(common_type, out_type), InvalidArgument, out); - - ET_KERNEL_CHECK( - ctx, - resize_to_broadcast_target_size(a, b, out) == Error::Ok, - InvalidArgument, - out); - - ET_SWITCH_REALHBBF16_TYPES(a_type, ctx, "add.out", CTYPE_A, [&]() { - ET_SWITCH_REALHBBF16_TYPES(b_type, ctx, "add.out", CTYPE_B, [&]() { - using CTYPE_IN = typename torch::executor:: - promote_types::type; - ET_DCHECK(CppTypeToScalarType::value == common_type); - ET_SWITCH_REALHBBF16_TYPES(out_type, ctx, "add.out", CTYPE_OUT, [&]() { - CTYPE_IN alpha_val; - ET_KERNEL_CHECK( - ctx, utils::extract_scalar(alpha, &alpha_val), InvalidArgument, ); - - AddInner< - can_cast::value, - CTYPE_A, - CTYPE_B, - CTYPE_IN, - CTYPE_OUT>::run(a, b, alpha_val, out); - }); - }); - }); - } - - return out; + static constexpr const char op_name[] = "add.out"; + return torch::executor::kernels::impl::opt_add_sub_out_impl( + ctx, a, b, alpha, out); } Tensor& opt_add_scalar_out( diff --git a/kernels/optimized/cpu/op_add_sub_impl.h b/kernels/optimized/cpu/op_add_sub_impl.h new file mode 100644 index 00000000000..6fb8574688b --- /dev/null +++ b/kernels/optimized/cpu/op_add_sub_impl.h @@ -0,0 +1,218 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include +#include +#include +#include +#include +#include + +namespace torch { +namespace executor { +namespace kernels { +namespace impl { + +namespace { +template < + bool can_cast, + typename CTYPE_A, + typename CTYPE_B, + typename CTYPE_IN, + typename CTYPE_OUT> +struct AddInner; + +template < + typename CTYPE_A, + typename CTYPE_B, + typename CTYPE_IN, + typename CTYPE_OUT> +struct AddInner { + static void + run(const Tensor& a, const Tensor& b, CTYPE_IN alpha_val, Tensor& out) { + apply_binary_elementwise_fn( + // NOLINTNEXTLINE(facebook-hte-ConstantArgumentPassByValue) + [alpha_val](const CTYPE_A val_a, const CTYPE_B val_b) { + CTYPE_IN a_casted = static_cast(val_a); + CTYPE_IN b_casted = static_cast(val_b); + CTYPE_IN value = a_casted + alpha_val * b_casted; + + return static_cast(value); + }, + a, + b, + out); + } +}; + +template +struct ReportCanCastBug { + static void run(const Tensor&, const Tensor&, CTYPE_IN, Tensor&) { + ET_DCHECK_MSG(false, "BUG: canCast should have been checked above"); + } +}; + +template < + typename CTYPE_A, + typename CTYPE_B, + typename CTYPE_IN, + typename CTYPE_OUT> +struct AddInner + : public ReportCanCastBug {}; + +} // namespace + +using Tensor = executorch::aten::Tensor; +using ScalarType = executorch::aten::ScalarType; + +template +Tensor& opt_add_sub_out_impl( + KernelRuntimeContext& ctx, + const Tensor& a, + const Tensor& b, + const Scalar& alpha, + Tensor& out) { + (void)ctx; + + ScalarType a_type = a.scalar_type(); + ScalarType b_type = b.scalar_type(); + ScalarType out_type = out.scalar_type(); + + auto selected_optimized_path = select_optimized_path(a, b, out); + if (selected_optimized_path == ElementwiseOptimizedPath::kTreatAs1d) { + // Resize for dynamic shape + auto error = resize_tensor(out, a.sizes()); + ET_KERNEL_CHECK_MSG( + ctx, + error == Error::Ok, + InvalidArgument, + out, + "Failed to resize output tensor."); + + ET_SWITCH_REALB_TYPES(a_type, ctx, op_name, CTYPE, [&]() { + CTYPE alpha_val; + ET_KERNEL_CHECK( + ctx, + torch::executor::native::utils::extract_scalar(alpha, &alpha_val), + InvalidArgument, ); + if constexpr (is_sub) { + alpha_val = -alpha_val; + } + using Vec = executorch::vec::Vectorized; + executorch::vec::map2( + [alpha_val](Vec x, Vec y) { return x + Vec(alpha_val) * y; }, + out.mutable_data_ptr(), + a.const_data_ptr(), + b.const_data_ptr(), + out.numel()); + }); + } else if (selected_optimized_path != ElementwiseOptimizedPath::kNone) { + // Cannot apply the trick of -alpha here because alpha is Scalar without + // support for - operator. At least not right now. + ET_SWITCH_REALB_TYPES(out_type, ctx, op_name, CTYPE, [&]() { + CTYPE alpha_val; + ET_KERNEL_CHECK_MSG( + ctx, + torch::executor::native::utils::extract_scalar(alpha, &alpha_val), + InvalidArgument, + out, + "Failed to extract scalar alpha."); + using Vec = executorch::vec::Vectorized; + Vec alpha_val_vec(alpha_val); + if constexpr (is_sub) { + if (selected_optimized_path == + ElementwiseOptimizedPath::kBroadcast2dBy1dReverseArguments || + selected_optimized_path == + ElementwiseOptimizedPath::kBroadcastLastDimReverseArguments || + selected_optimized_path == + ElementwiseOptimizedPath::kBroadcastNdByNdReverseArguments) { + auto add_lambda = [&alpha_val_vec](auto x, auto y) { + return y - alpha_val_vec * x; + }; + return torch::executor::handle_broadcast_elementwise( + ctx, add_lambda, a, b, out, selected_optimized_path, alpha); + } else { + auto add_lambda = [&alpha_val_vec](auto x, auto y) { + return x - alpha_val_vec * y; + }; + return torch::executor::handle_broadcast_elementwise( + ctx, add_lambda, a, b, out, selected_optimized_path, alpha); + } + } else { + if (selected_optimized_path == + ElementwiseOptimizedPath::kBroadcast2dBy1dReverseArguments || + selected_optimized_path == + ElementwiseOptimizedPath::kBroadcastLastDimReverseArguments || + selected_optimized_path == + ElementwiseOptimizedPath::kBroadcastNdByNdReverseArguments) { + // Reason we swap out args here is because + // handle_broadcast_elementwise handles this selected_optimized_path + // option a bit differently. This should really be resolved in + // handle_broadcast_elementwise. However, the current blocker is that + // handle_broadcast_elementwise tries to be agnostic of op. This + // should be fixed, likely by moving lambda creation to + // handle_broadcast_elementwise and it be aware of which op is being + // executed. + auto add_lambda = [&alpha_val_vec](auto x, auto y) { + return y + alpha_val_vec * x; + }; + return torch::executor::handle_broadcast_elementwise( + ctx, add_lambda, a, b, out, selected_optimized_path, alpha); + } else { + auto add_lambda = [&alpha_val_vec](auto x, auto y) { + return x + alpha_val_vec * y; + }; + return torch::executor::handle_broadcast_elementwise( + ctx, add_lambda, a, b, out, selected_optimized_path, alpha); + } + } + }); + } else { + ScalarType common_type = + promoteTypes(a_type, b_type, /*half_to_float*/ true); + ET_KERNEL_CHECK(ctx, canCast(common_type, out_type), InvalidArgument, out); + + ET_KERNEL_CHECK( + ctx, + resize_to_broadcast_target_size(a, b, out) == Error::Ok, + InvalidArgument, + out); + + ET_SWITCH_REALHBBF16_TYPES(a_type, ctx, op_name, CTYPE_A, [&]() { + ET_SWITCH_REALHBBF16_TYPES(b_type, ctx, op_name, CTYPE_B, [&]() { + using CTYPE_IN = typename torch::executor:: + promote_types::type; + ET_DCHECK(CppTypeToScalarType::value == common_type); + ET_SWITCH_REALHBBF16_TYPES(out_type, ctx, op_name, CTYPE_OUT, [&]() { + CTYPE_IN alpha_val; + ET_KERNEL_CHECK( + ctx, + torch::executor::native::utils::extract_scalar(alpha, &alpha_val), + InvalidArgument, ); + if constexpr (is_sub) { + alpha_val = -alpha_val; + } + + AddInner< + can_cast::value, + CTYPE_A, + CTYPE_B, + CTYPE_IN, + CTYPE_OUT>::run(a, b, alpha_val, out); + }); + }); + }); + } + + return out; +} +} // namespace impl +} // namespace kernels +} // namespace executor +} // namespace torch diff --git a/kernels/optimized/cpu/targets.bzl b/kernels/optimized/cpu/targets.bzl index 1c62b683b8f..94ceb1f4dc1 100644 --- a/kernels/optimized/cpu/targets.bzl +++ b/kernels/optimized/cpu/targets.bzl @@ -6,6 +6,7 @@ _OPTIMIZED_ATEN_OPS = ( name = "op_add", deps = [ ":binary_ops", + ":add_sub_impl", "//executorch/kernels/portable/cpu:scalar_utils", "//executorch/kernels/portable/cpu/util:broadcast_util", ], @@ -123,6 +124,14 @@ def define_common_targets(): aten_op_targets = [":{}".format(op["name"]) for op in enabled_ops] all_op_targets = aten_op_targets + runtime.cxx_library( + name = "add_sub_impl", + srcs = [], + exported_headers = ["op_add_sub_impl.h"], + visibility = ["//executorch/kernels/optimized/cpu/..."], + exported_deps = ["//executorch/runtime/core:core"], + ) + runtime.cxx_library( name = "binary_ops", exported_headers = ["binary_ops.h"],