diff --git a/kernels/portable/cpu/op_add.cpp b/kernels/portable/cpu/op_add.cpp index 8b6dbec1a45..05b5eab6f03 100644 --- a/kernels/portable/cpu/op_add.cpp +++ b/kernels/portable/cpu/op_add.cpp @@ -7,8 +7,7 @@ */ #include -#include -#include +#include #include #include #include @@ -16,55 +15,6 @@ namespace torch { namespace executor { namespace native { -namespace { - -template < - bool can_cast, - typename CTYPE_A, - typename CTYPE_B, - typename CTYPE_IN, - typename CTYPE_OUT> -struct AddInner; - -template < - typename CTYPE_A, - typename CTYPE_B, - typename CTYPE_IN, - typename CTYPE_OUT> -struct AddInner { - static void - run(const Tensor& a, const Tensor& b, CTYPE_IN alpha_val, Tensor& out) { - apply_binary_elementwise_fn( - // NOLINTNEXTLINE(facebook-hte-ConstantArgumentPassByValue) - [alpha_val](const CTYPE_A val_a, const CTYPE_B val_b) { - CTYPE_IN a_casted = static_cast(val_a); - CTYPE_IN b_casted = static_cast(val_b); - CTYPE_IN value = a_casted + alpha_val * b_casted; - - return static_cast(value); - }, - a, - b, - out); - } -}; - -template -struct ReportCanCastBug { - static void run(const Tensor&, const Tensor&, CTYPE_IN, Tensor&) { - ET_DCHECK_MSG(false, "BUG: canCast should have been checked above"); - } -}; - -template < - typename CTYPE_A, - typename CTYPE_B, - typename CTYPE_IN, - typename CTYPE_OUT> -struct AddInner - : public ReportCanCastBug {}; - -} // namespace Tensor& add_out( KernelRuntimeContext& ctx, @@ -80,7 +30,9 @@ Tensor& add_out( ET_KERNEL_CHECK( ctx, - executorch::runtime::tensor_is_realhbbf16_type(out), + (executorch::runtime::tensor_is_realhbbf16_type(a) && + executorch::runtime::tensor_is_realhbbf16_type(b) && + executorch::runtime::tensor_is_realhbbf16_type(out)), InvalidArgument, out); ET_KERNEL_CHECK( @@ -96,25 +48,20 @@ Tensor& add_out( ET_KERNEL_CHECK( ctx, check_alpha_type(alpha_type, common_type), InvalidArgument, out); - constexpr auto name = "add.out"; - - ET_SWITCH_REALHBBF16_TYPES(a_type, ctx, name, CTYPE_A, [&]() { - ET_SWITCH_REALHBBF16_TYPES(b_type, ctx, name, CTYPE_B, [&]() { - using CTYPE_IN = typename torch::executor:: - promote_types::type; - ET_DCHECK(CppTypeToScalarType::value == common_type); - CTYPE_IN alpha_val; - utils::extract_scalar(alpha, &alpha_val); - - ET_SWITCH_REALHBBF16_TYPES(out_type, ctx, name, CTYPE_OUT, [&]() { - AddInner< - can_cast::value, - CTYPE_A, - CTYPE_B, - CTYPE_IN, - CTYPE_OUT>::run(a, b, alpha_val, out); - }); - }); + static constexpr const char op_name[] = "add.out"; + + ET_SWITCH_REALB_TYPES(common_type, ctx, op_name, CTYPE_COMMON, [&]() { + utils::apply_bitensor_elementwise_fn( + [alpha](const CTYPE_COMMON val_a, const CTYPE_COMMON val_b) { + CTYPE_COMMON val_alpha = utils::scalar_to(alpha); + return val_a + val_alpha * val_b; + }, + a, + utils::SupportedTensorDtypes::REALHBBF16, + b, + utils::SupportedTensorDtypes::REALHBBF16, + out, + utils::SupportedTensorDtypes::REALHBBF16); }); return out; @@ -138,14 +85,14 @@ Tensor& add_scalar_out( ET_KERNEL_CHECK( ctx, - executorch::runtime::tensor_is_realhbbf16_type(out), + (executorch::runtime::tensor_is_realhbbf16_type(a) && + executorch::runtime::tensor_is_realhbbf16_type(out)), InvalidArgument, out); ET_KERNEL_CHECK( ctx, tensors_have_same_dim_order(a, out), InvalidArgument, out); ScalarType a_type = a.scalar_type(); - ScalarType b_type = utils::get_scalar_dtype(b); ScalarType alpha_type = utils::get_scalar_dtype(alpha); ScalarType common_type = utils::promote_type_with_scalar(a_type, b, /*half_to_float*/ false); @@ -155,42 +102,23 @@ Tensor& add_scalar_out( ET_KERNEL_CHECK( ctx, check_alpha_type(alpha_type, common_type), InvalidArgument, out); - if (common_type == ScalarType::Half) { + if (common_type == ScalarType::Half || common_type == ScalarType::BFloat16) { common_type = ScalarType::Float; } - constexpr auto name = "add.Scalar_out"; - - ET_SWITCH_REALHBBF16_TYPES(a_type, ctx, name, CTYPE_A, [&]() { - ET_SWITCH_SCALAR_OBJ_TYPES(b_type, ctx, name, CTYPE_B, [&]() { - using CTYPE_IN = typename utils::promote_type_with_scalar_type< - CTYPE_A, - CTYPE_B, - /*half_to_float*/ true>::type; - ET_DCHECK(CppTypeToScalarType::value == common_type); - - CTYPE_B b_val; - utils::extract_scalar(b, &b_val); - CTYPE_IN b_casted = static_cast(b_val); - - CTYPE_IN alpha_val; - utils::extract_scalar(alpha, &alpha_val); - - using CTYPE_OUT = typename std::conditional< - std::is_same::value, - internal::F2, - CTYPE_IN>::type; - - apply_unary_map_fn( - [b_casted, alpha_val](const CTYPE_A val_a) { - CTYPE_IN a_casted = static_cast(val_a); - CTYPE_IN value = a_casted + alpha_val * b_casted; - return static_cast(value); - }, - a.const_data_ptr(), - out.mutable_data_ptr(), - out.numel()); - }); + static constexpr const char op_name[] = "add.Scalar_out"; + + ET_SWITCH_REALB_TYPES(common_type, ctx, op_name, CTYPE_COMMON, [&]() { + utils::apply_unitensor_elementwise_fn( + [b, alpha](const CTYPE_COMMON val_a) { + CTYPE_COMMON val_b = utils::scalar_to(b); + CTYPE_COMMON val_alpha = utils::scalar_to(alpha); + return val_a + val_alpha * val_b; + }, + a, + utils::SupportedTensorDtypes::REALHBBF16, + out, + utils::SupportedTensorDtypes::REALHBBF16); }); return out; diff --git a/kernels/portable/cpu/op_clamp.cpp b/kernels/portable/cpu/op_clamp.cpp index 9f93caa40f8..c8bb9297a13 100644 --- a/kernels/portable/cpu/op_clamp.cpp +++ b/kernels/portable/cpu/op_clamp.cpp @@ -13,7 +13,6 @@ #include #include -#include #include #include @@ -122,43 +121,26 @@ Tensor& clamp_out( ET_KERNEL_CHECK(ctx, common_type == out_type, InvalidArgument, out); - ET_SWITCH_REALH_TYPES(out_type, ctx, "clamp", CTYPE_OUT, [&]() { - // Extract optional min value - CTYPE_OUT min = 0; - if (has_min) { - ET_SWITCH_SCALAR_OBJ_TYPES(min_type, ctx, "clamp", CTYPE_MIN, [&]() { - CTYPE_MIN min_val = 0; - utils::extract_scalar(min_opt.value(), &min_val); - min = static_cast(min_val); - }); - } - - // Extract optional max value - CTYPE_OUT max = 0; - if (has_max) { - ET_SWITCH_SCALAR_OBJ_TYPES(max_type, ctx, "clamp", CTYPE_MAX, [&]() { - CTYPE_MAX max_val = 0; - utils::extract_scalar(max_opt.value(), &max_val); - max = static_cast(max_val); - }); - } + static constexpr const char op_name[] = "clamp.out"; - ET_SWITCH_REALHB_TYPES(in_type, ctx, "clamp", CTYPE_IN, [&]() { - apply_unary_map_fn( - [has_min, min, has_max, max](const CTYPE_IN val_in) { - CTYPE_OUT val_out = static_cast(val_in); - if (has_min) { - val_out = utils::max_override(val_out, min); - } - if (has_max) { - val_out = utils::min_override(val_out, max); - } - return val_out; - }, - in.const_data_ptr(), - out.mutable_data_ptr(), - in.numel()); - }); + ET_SWITCH_REALHB_TYPES(common_type, ctx, op_name, CTYPE_COMMON, [&]() { + utils::apply_unitensor_elementwise_fn( + [has_min, min_opt, has_max, max_opt](const CTYPE_COMMON val_in) { + CTYPE_COMMON val_out = val_in; + if (has_min) { + val_out = utils::max_override( + val_out, utils::scalar_to(min_opt.value())); + } + if (has_max) { + val_out = utils::min_override( + val_out, utils::scalar_to(max_opt.value())); + } + return val_out; + }, + in, + utils::SupportedTensorDtypes::REALHBBF16, + out, + utils::SupportedTensorDtypes::REALHBBF16); }); return out; diff --git a/kernels/portable/cpu/util/elementwise_util.h b/kernels/portable/cpu/util/elementwise_util.h index c7871258681..19dcbd73ba0 100644 --- a/kernels/portable/cpu/util/elementwise_util.h +++ b/kernels/portable/cpu/util/elementwise_util.h @@ -16,6 +16,33 @@ namespace executor { namespace native { namespace utils { +/* + * Convert Scalar to C++ type + */ + +template +T scalar_to(const Scalar& s) { + if (s.isBoolean()) { + return static_cast(s.to()); + } else if (s.isFloatingPoint()) { + return static_cast(s.to()); + } else { + return static_cast(s.to()); + } +} + +template <> +inline double scalar_to(const Scalar& s) { + return s.isFloatingPoint() ? s.to() + : static_cast(s.to()); +} + +template <> +inline int64_t scalar_to(const Scalar& s) { + return s.isFloatingPoint() ? static_cast(s.to()) + : s.to(); +} + namespace internal { template @@ -139,6 +166,86 @@ store_common_to_tensor_fn get_store_common_to_tensor_fn( } // namespace internal +template +inline void apply_unitensor_elementwise_fn( + const Op& compute_fun, + const Tensor& a, + SupportedTensorDtypes a_dtypes, + const Tensor& out, + SupportedTensorDtypes out_dtypes) { + const auto load_a_to_common = + internal::get_load_to_common_fn(a, a_dtypes); + const auto store_common_to_out = + internal::get_store_common_to_tensor_fn( + out, out_dtypes); + const char* const data_a = reinterpret_cast(a.const_data_ptr()); + const auto a_element_size = a.element_size(); + const auto out_element_size = out.element_size(); + char* const data_out = reinterpret_cast(out.mutable_data_ptr()); + + auto out_numel = out.numel(); + for (size_t i = 0; i < out_numel; ++i) { + auto result = compute_fun(load_a_to_common(&data_a[i * a_element_size])); + store_common_to_out(result, &data_out[i * out_element_size]); + } +} + +/** + * Useful for bi-tensor elementwise operators. For each element of the inputs, + * perform a computation and write to the corresponding element of the output. + * Tensor broadcasting is applied wherever it is required. + */ +template +inline void apply_bitensor_elementwise_fn( + const Op& compute_fun, + const Tensor& a, + SupportedTensorDtypes a_dtypes, + const Tensor& b, + SupportedTensorDtypes b_dtypes, + const Tensor& out, + SupportedTensorDtypes out_dtypes) { + const bool a_is_broadcasted = !out.sizes().equals(a.sizes()); + const bool b_is_broadcasted = !out.sizes().equals(b.sizes()); + const bool any_is_broadcasted = (a_is_broadcasted || b_is_broadcasted); + + const auto load_a_to_common = + internal::get_load_to_common_fn(a, a_dtypes); + const auto load_b_to_common = + internal::get_load_to_common_fn(b, b_dtypes); + const auto store_common_to_out = + internal::get_store_common_to_tensor_fn( + out, out_dtypes); + const char* const data_a = reinterpret_cast(a.const_data_ptr()); + const char* const data_b = reinterpret_cast(b.const_data_ptr()); + const auto a_element_size = a.element_size(); + const auto b_element_size = b.element_size(); + const auto out_element_size = out.element_size(); + char* const data_out = reinterpret_cast(out.mutable_data_ptr()); + + auto out_numel = out.numel(); + for (size_t i = 0; i < out_numel; ++i) { + size_t a_linear_index = i; + size_t b_linear_index = i; + + if (any_is_broadcasted) { + size_t out_indexes[kTensorDimensionLimit]; + delinearize_index(i, out, out_indexes, kTensorDimensionLimit); + + if (a_is_broadcasted) { + a_linear_index = linearize_access_indexes(out_indexes, out.dim(), a); + } + if (b_is_broadcasted) { + b_linear_index = linearize_access_indexes(out_indexes, out.dim(), b); + } + } + + auto result = compute_fun( + load_a_to_common(&data_a[a_linear_index * a_element_size]), + load_b_to_common(&data_b[b_linear_index * b_element_size])); + store_common_to_out(result, &data_out[i * out_element_size]); + } +} + /** * Useful for tri-tensor elementwise operators. For each element of the inputs, * perform a computation and write to the corresponding element of the output. @@ -194,7 +301,8 @@ inline void apply_tritensor_elementwise_fn( const auto out_element_size = out.element_size(); char* const data_out = reinterpret_cast(out.mutable_data_ptr()); - for (size_t i = 0; i < out.numel(); ++i) { + auto out_numel = out.numel(); + for (size_t i = 0; i < out_numel; ++i) { size_t a_linear_index = i; size_t b_linear_index = i; size_t c_linear_index = i; diff --git a/shim/xplat/executorch/kernels/portable/op_registration_util.bzl b/shim/xplat/executorch/kernels/portable/op_registration_util.bzl index 6fa797f6126..42f4624d76d 100644 --- a/shim/xplat/executorch/kernels/portable/op_registration_util.bzl +++ b/shim/xplat/executorch/kernels/portable/op_registration_util.bzl @@ -214,7 +214,7 @@ ATEN_OPS = ( name = "op_add", deps = [ "//executorch/kernels/portable/cpu/util:broadcast_util", - "//executorch/kernels/portable/cpu/util:functional_util", + "//executorch/kernels/portable/cpu/util:elementwise_util", "//executorch/kernels/portable/cpu/util:kernel_ops_util", ":scalar_utils", ], @@ -392,7 +392,6 @@ ATEN_OPS = ( ":scalar_utils", "//executorch/kernels/portable/cpu/util:broadcast_util", "//executorch/kernels/portable/cpu/util:elementwise_util", - "//executorch/kernels/portable/cpu/util:functional_util", "//executorch/kernels/portable/cpu/util:math_util", ], ),