From 33f75f24ad339429ff8c405124caa78de7ccf65a Mon Sep 17 00:00:00 2001 From: Kimish Patel Date: Fri, 8 Nov 2024 07:14:15 -0800 Subject: [PATCH] [Executorch][Optimized] Use portable's impl for optimized op_add's fallback This wy we can get advantage of size and build opt efforts in portable kernels Differential Revision: [D65632375](https://our.internmc.facebook.com/intern/diff/D65632375/) [ghstack-poisoned] --- kernels/optimized/cpu/op_add.cpp | 56 ++----------------------------- kernels/optimized/cpu/targets.bzl | 1 + 2 files changed, 4 insertions(+), 53 deletions(-) diff --git a/kernels/optimized/cpu/op_add.cpp b/kernels/optimized/cpu/op_add.cpp index 2b31a8d5db9..535eaf2a8cb 100644 --- a/kernels/optimized/cpu/op_add.cpp +++ b/kernels/optimized/cpu/op_add.cpp @@ -9,6 +9,7 @@ #include #include #include +#include #include #include #include @@ -176,35 +177,7 @@ Tensor& opt_add_out( lhs->sizes()[lhs->dim() - 1]); }); } else { - ScalarType common_type = - promoteTypes(a_type, b_type, /*half_to_float*/ true); - ET_KERNEL_CHECK(ctx, canCast(common_type, out_type), InvalidArgument, out); - - ET_KERNEL_CHECK( - ctx, - resize_to_broadcast_target_size(a, b, out) == Error::Ok, - InvalidArgument, - out); - - ET_SWITCH_REALHBBF16_TYPES(a_type, ctx, "add.out", CTYPE_A, [&]() { - ET_SWITCH_REALHBBF16_TYPES(b_type, ctx, "add.out", CTYPE_B, [&]() { - using CTYPE_IN = typename torch::executor:: - promote_types::type; - ET_DCHECK(CppTypeToScalarType::value == common_type); - ET_SWITCH_REALHBBF16_TYPES(out_type, ctx, "add.out", CTYPE_OUT, [&]() { - CTYPE_IN alpha_val; - ET_KERNEL_CHECK( - ctx, utils::extract_scalar(alpha, &alpha_val), InvalidArgument, ); - - AddInner< - can_cast::value, - CTYPE_A, - CTYPE_B, - CTYPE_IN, - CTYPE_OUT>::run(a, b, alpha_val, out); - }); - }); - }); + add_out_impl(ctx, a, b, alpha, out); } return out; @@ -255,30 +228,7 @@ Tensor& opt_add_scalar_out( }); }); } else { - ET_SWITCH_REALHBBF16_TYPES(a_type, ctx, "add.Scalar_out", CTYPE_A, [&]() { - ET_SWITCH_SCALAR_OBJ_TYPES(b_type, ctx, "add.Scalar_out", CTYPE_B, [&]() { - ET_SWITCH_REALB_TYPES( - common_type, ctx, "add.Scalar_out", CTYPE_IN, [&]() { - ET_SWITCH_REALHBBF16_TYPES( - out_type, ctx, "add.Scalar_out", CTYPE_OUT, [&]() { - CTYPE_B b_val; - ET_EXTRACT_SCALAR(b, b_val); - CTYPE_IN b_casted = static_cast(b_val); - CTYPE_IN alpha_val; - ET_EXTRACT_SCALAR(alpha, alpha_val); - - const size_t n = a.numel(); - const CTYPE_A* a_data = a.const_data_ptr(); - CTYPE_OUT* out_data = out.mutable_data_ptr(); - for (auto i = 0; i < n; ++i) { - out_data[i] = static_cast( - static_cast(a_data[i]) + - alpha_val * b_casted); - } - }); - }); - }); - }); + add_scalar_out_impl(ctx, a, b, alpha, out); } return out; diff --git a/kernels/optimized/cpu/targets.bzl b/kernels/optimized/cpu/targets.bzl index c3dc61269b5..c4ed508fa60 100644 --- a/kernels/optimized/cpu/targets.bzl +++ b/kernels/optimized/cpu/targets.bzl @@ -6,6 +6,7 @@ _OPTIMIZED_ATEN_OPS = ( name = "op_add", deps = [ ":binary_ops", + "//executorch/kernels/portable/cpu:op_add_impl", "//executorch/kernels/portable/cpu:scalar_utils", "//executorch/kernels/portable/cpu/util:broadcast_util", ],