diff --git a/kernels/optimized/cpu/op_add.cpp b/kernels/optimized/cpu/op_add.cpp index 88b102b5650..66524e3ccf5 100644 --- a/kernels/optimized/cpu/op_add.cpp +++ b/kernels/optimized/cpu/op_add.cpp @@ -9,6 +9,7 @@ #include #include #include +#include #include #include #include @@ -137,29 +138,7 @@ Tensor& opt_add_scalar_out( out.numel()); }); } else { - ET_SWITCH_REALHBBF16_TYPES(a_type, ctx, "add.Scalar_out", CTYPE_A, [&]() { - ET_SWITCH_REALB_TYPES( - common_type, ctx, "add.Scalar_out", CTYPE_IN, [&]() { - ET_SWITCH_REALHBBF16_TYPES( - out_type, ctx, "add.Scalar_out", CTYPE_OUT, [&]() { - CTYPE_IN b_casted = utils::scalar_to(b); - CTYPE_IN alpha_val; - ET_KERNEL_CHECK( - ctx, - utils::extract_scalar(alpha, &alpha_val), - InvalidArgument, ); - - const size_t n = a.numel(); - const CTYPE_A* a_data = a.const_data_ptr(); - CTYPE_OUT* out_data = out.mutable_data_ptr(); - for (auto i = 0; i < n; ++i) { - out_data[i] = static_cast( - static_cast(a_data[i]) + - alpha_val * b_casted); - } - }); - }); - }); + utils::add_scalar_out(ctx, a, b, alpha, out); } return out; diff --git a/kernels/optimized/cpu/op_add_sub_impl.h b/kernels/optimized/cpu/op_add_sub_impl.h index 37761b44c9b..d22dde68a43 100644 --- a/kernels/optimized/cpu/op_add_sub_impl.h +++ b/kernels/optimized/cpu/op_add_sub_impl.h @@ -9,6 +9,8 @@ #include #include #include +#include +#include #include #include #include @@ -19,55 +21,6 @@ namespace executor { namespace kernels { namespace impl { -namespace { -template < - bool can_cast, - typename CTYPE_A, - typename CTYPE_B, - typename CTYPE_IN, - typename CTYPE_OUT> -struct AddInner; - -template < - typename CTYPE_A, - typename CTYPE_B, - typename CTYPE_IN, - typename CTYPE_OUT> -struct AddInner { - static void - run(const Tensor& a, const Tensor& b, CTYPE_IN alpha_val, Tensor& out) { - apply_binary_elementwise_fn( - // NOLINTNEXTLINE(facebook-hte-ConstantArgumentPassByValue) - [alpha_val](const CTYPE_A val_a, const CTYPE_B val_b) { - CTYPE_IN a_casted = static_cast(val_a); - CTYPE_IN b_casted = static_cast(val_b); - CTYPE_IN value = a_casted + alpha_val * b_casted; - - return static_cast(value); - }, - a, - b, - out); - } -}; - -template -struct ReportCanCastBug { - static void run(const Tensor&, const Tensor&, CTYPE_IN, Tensor&) { - ET_DCHECK_MSG(false, "BUG: canCast should have been checked above"); - } -}; - -template < - typename CTYPE_A, - typename CTYPE_B, - typename CTYPE_IN, - typename CTYPE_OUT> -struct AddInner - : public ReportCanCastBug {}; - -} // namespace - using Tensor = executorch::aten::Tensor; using ScalarType = executorch::aten::ScalarType; @@ -203,40 +156,11 @@ Tensor& opt_add_sub_out_impl( } }); } else { - ScalarType common_type = - promoteTypes(a_type, b_type, /*half_to_float*/ true); - ET_KERNEL_CHECK(ctx, canCast(common_type, out_type), InvalidArgument, out); - - ET_KERNEL_CHECK( - ctx, - resize_to_broadcast_target_size(a, b, out) == Error::Ok, - InvalidArgument, - out); - - ET_SWITCH_REALHBBF16_TYPES(a_type, ctx, op_name, CTYPE_A, [&]() { - ET_SWITCH_REALHBBF16_TYPES(b_type, ctx, op_name, CTYPE_B, [&]() { - using CTYPE_IN = typename torch::executor:: - promote_types::type; - ET_DCHECK(CppTypeToScalarType::value == common_type); - ET_SWITCH_REALHBBF16_TYPES(out_type, ctx, op_name, CTYPE_OUT, [&]() { - CTYPE_IN alpha_val; - ET_KERNEL_CHECK( - ctx, - torch::executor::native::utils::extract_scalar(alpha, &alpha_val), - InvalidArgument, ); - if constexpr (is_sub) { - alpha_val = -alpha_val; - } - - AddInner< - can_cast::value, - CTYPE_A, - CTYPE_B, - CTYPE_IN, - CTYPE_OUT>::run(a, b, alpha_val, out); - }); - }); - }); + if constexpr (is_sub) { + native::utils::sub_out(ctx, a, b, alpha, out); + } else { + native::utils::add_out(ctx, a, b, alpha, out); + } } return out; diff --git a/kernels/optimized/cpu/op_div.cpp b/kernels/optimized/cpu/op_div.cpp index 7af2b4b4695..c1ca946156a 100644 --- a/kernels/optimized/cpu/op_div.cpp +++ b/kernels/optimized/cpu/op_div.cpp @@ -9,6 +9,7 @@ #include #include #include +#include #include #include #include @@ -18,26 +19,6 @@ namespace torch { namespace executor { namespace native { -namespace { - -ScalarType get_compute_type(ScalarType a_type, ScalarType b_type) { - ET_CHECK( - !isComplexType(a_type) && !isQIntType(a_type) && !isBitsType(a_type)); - ET_CHECK( - !isComplexType(b_type) && !isQIntType(b_type) && !isBitsType(b_type)); - - if (isFloatingType(a_type) && isFloatingType(b_type)) { - return promoteTypes(a_type, b_type); - } else if (isFloatingType(a_type)) { - return a_type; - } else if (isFloatingType(b_type)) { - return b_type; - } - return ScalarType::Float; -} - -} // namespace - Tensor& opt_div_out( KernelRuntimeContext& ctx, const Tensor& a, @@ -139,34 +120,7 @@ Tensor& opt_div_out( } }); } else { - ScalarType common_type = get_compute_type(a_type, b_type); - ET_KERNEL_CHECK(ctx, canCast(common_type, out_type), InvalidArgument, out); - - ET_KERNEL_CHECK( - ctx, - resize_to_broadcast_target_size(a, b, out) == Error::Ok, - InvalidArgument, - out); - - ET_SWITCH_REALB_TYPES(a_type, ctx, "div.out", CTYPE_A, [&]() { - ET_SWITCH_REALB_TYPES(b_type, ctx, "div.out", CTYPE_B, [&]() { - ET_SWITCH_REALB_TYPES(common_type, ctx, "div.out", CTYPE_IN, [&]() { - ET_SWITCH_REALB_TYPES(out_type, ctx, "div.out", CTYPE_OUT, [&]() { - apply_binary_elementwise_fn( - [](const CTYPE_A val_a, const CTYPE_B val_b) { - CTYPE_IN a_casted = static_cast(val_a); - CTYPE_IN b_casted = static_cast(val_b); - CTYPE_IN value = a_casted / b_casted; - - return static_cast(value); - }, - a, - b, - out); - }); - }); - }); - }); + utils::div_out(ctx, a, b, out); } return out; @@ -208,32 +162,7 @@ Tensor& opt_div_scalar_out( }); }); } else { - ET_SWITCH_REAL_TYPES_AND( - Bool, a_type, ctx, "div.Scalar_out", CTYPE_A, [&]() { - ET_SWITCH_REAL_TYPES_AND( - Bool, b_type, ctx, "div.Scalar_out", CTYPE_B, [&]() { - ET_SWITCH_REAL_TYPES( - common_type, ctx, "div.Scalar_out", CTYPE_IN, [&]() { - ET_SWITCH_REAL_TYPES( - out_type, ctx, "div.Scalar_out", CTYPE_OUT, [&]() { - CTYPE_B b_val; - ET_EXTRACT_SCALAR(b, b_val); - CTYPE_IN b_casted = static_cast(b_val); - CTYPE_IN inv_b_casted = CTYPE_IN(1) / b_casted; - - const size_t n = a.numel(); - const CTYPE_A* a_data = a.const_data_ptr(); - CTYPE_OUT* out_data = - out.mutable_data_ptr(); - for (auto i = 0; i < n; ++i) { - out_data[i] = static_cast( - static_cast(a_data[i]) * - inv_b_casted); - } - }); - }); - }); - }); + utils::div_scalar_out(ctx, a, b, out); } return out; diff --git a/kernels/optimized/cpu/op_le.cpp b/kernels/optimized/cpu/op_le.cpp index 51fca9b0063..8a23c94e419 100644 --- a/kernels/optimized/cpu/op_le.cpp +++ b/kernels/optimized/cpu/op_le.cpp @@ -9,7 +9,7 @@ #include #include #include -#include +#include #include #include #include @@ -61,10 +61,7 @@ Tensor& opt_le_tensor_out( ctx, le_lambda, a, b, out, selected_optimized_path); }); } else { - // @lint-ignore CLANGTIDY facebook-hte-CArray - static constexpr const char op_name[] = "le.Tensor_out"; - return internal::comparison_tensor_out( - ctx, a, b, out); + utils::le_tensor_out(ctx, a, b, out); } return out; @@ -107,34 +104,7 @@ Tensor& opt_le_scalar_out( }); }); } else { - ET_SWITCH_REAL_TYPES_AND( - Bool, a_type, ctx, "le.Scalar_out", CTYPE_A, [&]() { - ET_SWITCH_REAL_TYPES_AND( - Bool, b_type, ctx, "le.Scalar_out", CTYPE_B, [&]() { - ET_SWITCH_REAL_TYPES_AND( - Bool, common_type, ctx, "le.Scalar_out", CTYPE_IN, [&]() { - ET_SWITCH_REAL_TYPES_AND( - Bool, - out_type, - ctx, - "le.Scalar_out", - CTYPE_OUT, - [&]() { - CTYPE_B b_val = 0; - ET_EXTRACT_SCALAR(b, b_val); - CTYPE_IN b_casted = static_cast(b_val); - const size_t n = a.numel(); - const CTYPE_A* a_data = a.const_data_ptr(); - CTYPE_OUT* out_data = - out.mutable_data_ptr(); - for (auto i = 0; i < n; ++i) { - out_data[i] = static_cast( - static_cast(a_data[i]) <= b_casted); - } - }); - }); - }); - }); + utils::le_scalar_out(ctx, a, b, out); } return out; diff --git a/kernels/optimized/cpu/op_mul.cpp b/kernels/optimized/cpu/op_mul.cpp index 0d132ab1e03..c2577edd207 100644 --- a/kernels/optimized/cpu/op_mul.cpp +++ b/kernels/optimized/cpu/op_mul.cpp @@ -9,6 +9,7 @@ #include #include #include +#include #include #include #include // IWYU pragma: export @@ -22,54 +23,6 @@ namespace native { using Tensor = executorch::aten::Tensor; using ScalarType = executorch::aten::ScalarType; -namespace { - -template < - bool can_cast, - typename CTYPE_A, - typename CTYPE_B, - typename CTYPE_IN, - typename CTYPE_OUT> -struct MulInner; - -template < - typename CTYPE_A, - typename CTYPE_B, - typename CTYPE_IN, - typename CTYPE_OUT> -struct MulInner { - static void run(const Tensor& a, const Tensor& b, Tensor& out) { - apply_binary_elementwise_fn( - // NOLINTNEXTLINE(facebook-hte-ConstantArgumentPassByValue) - [](const CTYPE_A val_a, const CTYPE_B val_b) { - CTYPE_IN a_casted = static_cast(val_a); - CTYPE_IN b_casted = static_cast(val_b); - CTYPE_IN value = a_casted * b_casted; - - return static_cast(value); - }, - a, - b, - out); - } -}; - -struct ReportCanCastBug { - static void run(const Tensor&, const Tensor&, Tensor&) { - ET_DCHECK_MSG(false, "BUG: canCast should have been checked above"); - } -}; - -template < - typename CTYPE_A, - typename CTYPE_B, - typename CTYPE_IN, - typename CTYPE_OUT> -struct MulInner - : public ReportCanCastBug {}; - -} // namespace - Tensor& opt_mul_out( KernelRuntimeContext& ctx, const Tensor& a, @@ -159,52 +112,7 @@ Tensor& opt_mul_out( }); } } else { - ScalarType common_type = - promoteTypes(a_type, b_type, /*half_to_float*/ true); - ET_KERNEL_CHECK(ctx, canCast(common_type, out_type), InvalidArgument, out); - - ET_KERNEL_CHECK( - ctx, - resize_to_broadcast_target_size(a, b, out) == Error::Ok, - InvalidArgument, - out); - - if (executorch::runtime::isComplexType(a_type) || - executorch::runtime::isComplexType(b_type) || - executorch::runtime::isComplexType(out_type)) { - ET_KERNEL_CHECK( - ctx, a_type == b_type && a_type == out_type, InvalidArgument, out); - - ET_SWITCH_COMPLEXH_TYPES(out_type, ctx, "mul.out", CTYPE, [&]() { - apply_binary_elementwise_fn( - [](const CTYPE val_a, const CTYPE val_b) { return val_a * val_b; }, - a, - b, - out); - }); - } else { - ET_SWITCH_REALHBBF16_TYPES(a_type, ctx, "mul.out", CTYPE_A, [&]() { - ET_SWITCH_REALHBBF16_TYPES(b_type, ctx, "mul.out", CTYPE_B, [&]() { - using CTYPE_IN = typename torch::executor:: - promote_types::type; - ET_DCHECK(CppTypeToScalarType::value == common_type); - ET_SWITCH_REALHBBF16_TYPES( - out_type, ctx, "mul.out", CTYPE_OUT, [&]() { - apply_binary_elementwise_fn( - [](const CTYPE_A val_a, const CTYPE_B val_b) { - CTYPE_IN a_casted = static_cast(val_a); - CTYPE_IN b_casted = static_cast(val_b); - CTYPE_IN value = a_casted * b_casted; - - return static_cast(value); - }, - a, - b, - out); - }); - }); - }); - } + utils::mul_out(ctx, a, b, out); } return out; @@ -245,23 +153,7 @@ Tensor& opt_mul_scalar_out( out.numel()); }); } else { - ET_SWITCH_REALHBBF16_TYPES(a_type, ctx, "mul.Scalar_out", CTYPE_A, [&]() { - ET_SWITCH_REALB_TYPES( - common_type, ctx, "mul.Scalar_out", CTYPE_IN, [&]() { - ET_SWITCH_REALHBBF16_TYPES( - out_type, ctx, "mul.Scalar_out", CTYPE_OUT, [&]() { - CTYPE_IN b_casted = utils::scalar_to(b); - - const size_t n = a.numel(); - const CTYPE_A* a_data = a.const_data_ptr(); - CTYPE_OUT* out_data = out.mutable_data_ptr(); - for (auto i = 0; i < n; ++i) { - out_data[i] = static_cast( - static_cast(a_data[i]) * b_casted); - } - }); - }); - }); + utils::mul_scalar_out(ctx, a, b, out); } return out; diff --git a/kernels/optimized/cpu/op_sub.cpp b/kernels/optimized/cpu/op_sub.cpp index 58f8d2a7fdf..8caf31655c3 100644 --- a/kernels/optimized/cpu/op_sub.cpp +++ b/kernels/optimized/cpu/op_sub.cpp @@ -9,6 +9,7 @@ #include #include #include +#include #include #include #include @@ -190,27 +191,7 @@ Tensor& opt_sub_scalar_out( out.numel()); }); } else { - ET_SWITCH_REALHBF16_TYPES(a_type, ctx, "sub.Scalar_out", CTYPE_A, [&]() { - ET_SWITCH_REAL_TYPES(common_type, ctx, "sub.Scalar_out", CTYPE_IN, [&]() { - ET_SWITCH_REALHBF16_TYPES( - out_type, ctx, "sub.Scalar_out", CTYPE_OUT, [&]() { - CTYPE_IN b_casted = utils::scalar_to(b); - CTYPE_IN alpha_val; - ET_KERNEL_CHECK( - ctx, - utils::extract_scalar(alpha, &alpha_val), - InvalidArgument, ); - - const size_t n = a.numel(); - const CTYPE_A* a_data = a.const_data_ptr(); - CTYPE_OUT* out_data = out.mutable_data_ptr(); - for (auto i = 0; i < n; ++i) { - out_data[i] = static_cast( - static_cast(a_data[i]) - alpha_val * b_casted); - } - }); - }); - }); + utils::sub_scalar_out(ctx, a, b, alpha, out); } return out; diff --git a/kernels/optimized/cpu/targets.bzl b/kernels/optimized/cpu/targets.bzl index 3d9c5caf815..d923f60b057 100644 --- a/kernels/optimized/cpu/targets.bzl +++ b/kernels/optimized/cpu/targets.bzl @@ -28,6 +28,8 @@ def define_common_targets(): visibility = ["//executorch/kernels/optimized/cpu/...", "@EXECUTORCH_CLIENTS",], exported_deps = [ "//executorch/runtime/core:core", + "//executorch/kernels/portable/cpu:op_add_util", + "//executorch/kernels/portable/cpu:op_sub_util", "//executorch/kernels/portable/cpu/util:broadcast_indexes_range", ], ) diff --git a/shim_et/xplat/executorch/kernels/optimized/op_registration_util.bzl b/shim_et/xplat/executorch/kernels/optimized/op_registration_util.bzl index 7d9b1a0c317..a5451b65a43 100644 --- a/shim_et/xplat/executorch/kernels/optimized/op_registration_util.bzl +++ b/shim_et/xplat/executorch/kernels/optimized/op_registration_util.bzl @@ -151,6 +151,7 @@ OPTIMIZED_ATEN_OPS = ( deps = [ ":binary_ops", ":add_sub_impl", + "//executorch/kernels/portable/cpu:op_add_util", "//executorch/kernels/portable/cpu:scalar_utils", "//executorch/kernels/portable/cpu/util:broadcast_util", "//executorch/runtime/core/portable_type/c10/c10:aten_headers_for_executorch", @@ -178,6 +179,7 @@ OPTIMIZED_ATEN_OPS = ( }), deps = [ ":binary_ops", + "//executorch/kernels/portable/cpu:op_div_util", "//executorch/kernels/portable/cpu:scalar_utils", "//executorch/kernels/portable/cpu/util:broadcast_util", "//executorch/runtime/core/portable_type/c10/c10:aten_headers_for_executorch", @@ -224,9 +226,9 @@ OPTIMIZED_ATEN_OPS = ( name = "op_le", deps = [ ":binary_ops", + "//executorch/kernels/portable/cpu:op_le_util", "//executorch/kernels/portable/cpu:scalar_utils", "//executorch/kernels/portable/cpu/util:broadcast_util", - "//executorch/kernels/portable/cpu/pattern:comparison_op", "//executorch/kernels/portable/cpu/util:elementwise_util", "//executorch/runtime/core/portable_type/c10/c10:aten_headers_for_executorch", ], @@ -257,6 +259,7 @@ OPTIMIZED_ATEN_OPS = ( name = "op_mul", deps = [ ":binary_ops", + "//executorch/kernels/portable/cpu:op_mul_util", "//executorch/kernels/portable/cpu:scalar_utils", "//executorch/kernels/portable/cpu/util:broadcast_util", "//executorch/runtime/core/exec_aten/util:tensor_util", @@ -276,6 +279,7 @@ OPTIMIZED_ATEN_OPS = ( deps = [ ":binary_ops", ":add_sub_impl", + "//executorch/kernels/portable/cpu:op_sub_util", "//executorch/kernels/portable/cpu:scalar_utils", "//executorch/kernels/portable/cpu/util:broadcast_util", "//executorch/runtime/core/portable_type/c10/c10:aten_headers_for_executorch",