diff --git a/kernels/portable/cpu/op_clamp.cpp b/kernels/portable/cpu/op_clamp.cpp index 37c5d0f6c21..9f93caa40f8 100644 --- a/kernels/portable/cpu/op_clamp.cpp +++ b/kernels/portable/cpu/op_clamp.cpp @@ -12,7 +12,7 @@ #include #include -#include +#include #include #include #include @@ -215,7 +215,7 @@ Tensor& clamp_tensor_out( static constexpr const char op_name[] = "clamp.Tensor_out"; ET_SWITCH_REALHB_TYPES(common_type, ctx, op_name, CTYPE_COMMON, [&]() { - apply_ternary_elementwise_fn( + utils::apply_tritensor_elementwise_fn( [has_min, has_max]( const CTYPE_COMMON val_in, const CTYPE_COMMON val_min, @@ -230,13 +230,13 @@ Tensor& clamp_tensor_out( return val_out; }, in, - SupportedTensorDtypes::REALHBBF16, + utils::SupportedTensorDtypes::REALHBBF16, min, - SupportedTensorDtypes::REALHBBF16, + utils::SupportedTensorDtypes::REALHBBF16, max, - SupportedTensorDtypes::REALHBBF16, + utils::SupportedTensorDtypes::REALHBBF16, out, - SupportedTensorDtypes::REALHBBF16); + utils::SupportedTensorDtypes::REALHBBF16); }); return out; diff --git a/kernels/portable/cpu/op_where.cpp b/kernels/portable/cpu/op_where.cpp index 90cb2442a2a..3765590ee83 100644 --- a/kernels/portable/cpu/op_where.cpp +++ b/kernels/portable/cpu/op_where.cpp @@ -6,8 +6,7 @@ * LICENSE file in the root directory of this source tree. */ -#include -#include +#include #include namespace torch { @@ -44,19 +43,20 @@ Tensor& where_out( cond_type == ScalarType::Bool || cond_type == ScalarType::Byte, "Unhandled dtype %s for where.self_out", torch::executor::toString(cond_type)); + ET_SWITCH_REALHBBF16_TYPES(common_type, ctx, op_name, CTYPE_COMMON, [&]() { - apply_ternary_elementwise_fn( + utils::apply_tritensor_elementwise_fn( [](const CTYPE_COMMON val_a, const CTYPE_COMMON val_b, const CTYPE_COMMON val_c) { return val_c ? val_a : val_b; }, a, - SupportedTensorDtypes::REALHBBF16, + utils::SupportedTensorDtypes::REALHBBF16, b, - SupportedTensorDtypes::REALHBBF16, + utils::SupportedTensorDtypes::REALHBBF16, cond, - SupportedTensorDtypes::BOOL_OR_BYTE, + utils::SupportedTensorDtypes::BOOL_OR_BYTE, out, - SupportedTensorDtypes::SAME_AS_COMMON); + utils::SupportedTensorDtypes::SAME_AS_COMMON); }); return out; diff --git a/kernels/portable/cpu/util/broadcast_util.h b/kernels/portable/cpu/util/broadcast_util.h index beda475610f..92d35f322fb 100644 --- a/kernels/portable/cpu/util/broadcast_util.h +++ b/kernels/portable/cpu/util/broadcast_util.h @@ -270,125 +270,6 @@ size_t linearize_access_indexes( // Mapping with broadcasting // -namespace internal { -template -To load_and_convert(const void* fromPtr) { - return static_cast(*reinterpret_cast(fromPtr)); -} - -template -void convert_and_store(From f, void* dst) { - *reinterpret_cast(dst) = static_cast(f); -} - -template -using load_to_common_fn = CTYPE_COMMON (*)(const void*); - -template -load_to_common_fn get_load_to_common_fn_realhbbf16( - const Tensor& t) { - CTYPE_COMMON (*result)(const void*) = nullptr; - ET_SWITCH_REALHBBF16_TYPES( - t.scalar_type(), unused, op_name, TENSOR_CTYPE, [&]() { - result = internal::load_and_convert; - }); - return result; -} - -template -load_to_common_fn get_load_to_common_fn_bool_or_byte( - const Tensor& t) { - CTYPE_COMMON (*result)(const void*) = nullptr; - ET_SWITCH_TWO_TYPES( - Bool, Byte, t.scalar_type(), unused, op_name, TENSOR_CTYPE, [&]() { - result = internal::load_and_convert; - }); - return result; -} - -template -using store_common_to_tensor_fn = void (*)(CTYPE_COMMON, void*); - -template -store_common_to_tensor_fn -get_store_common_to_tensor_fn_realhbbf16(const Tensor& t) { - void (*result)(CTYPE_COMMON, void*) = nullptr; - ET_SWITCH_REALHBBF16_TYPES( - t.scalar_type(), unused, op_name, TENSOR_CTYPE, [&]() { - result = internal::convert_and_store; - }); - return result; -} - -template -store_common_to_tensor_fn -get_store_common_to_tensor_fn_bool_or_byte(const Tensor& t) { - void (*result)(CTYPE_COMMON, void*) = nullptr; - ET_SWITCH_TWO_TYPES( - Bool, Byte, t.scalar_type(), unused, op_name, TENSOR_CTYPE, [&]() { - result = internal::convert_and_store; - }); - return result; -} -} // namespace internal - -enum class SupportedTensorDtypes { - REALHBBF16, - BOOL_OR_BYTE, - SAME_AS_COMMON, -}; - -namespace internal { -template -load_to_common_fn get_load_to_common_fn( - const Tensor& t, - SupportedTensorDtypes dtypes) { - switch (dtypes) { - case SupportedTensorDtypes::REALHBBF16: - return get_load_to_common_fn_realhbbf16(t); - case SupportedTensorDtypes::BOOL_OR_BYTE: - return get_load_to_common_fn_bool_or_byte(t); - case SupportedTensorDtypes::SAME_AS_COMMON: { - constexpr auto common_scalar_type = - CppTypeToScalarType::value; - ET_CHECK_MSG( - t.scalar_type() == common_scalar_type, - "Unhandled dtype %s for %s", - ::executorch::runtime::toString(common_scalar_type), - op_name); - return internal::load_and_convert; - } - } - ET_CHECK(false); - return nullptr; -} - -template -store_common_to_tensor_fn get_store_common_to_tensor_fn( - const Tensor& t, - SupportedTensorDtypes dtypes) { - switch (dtypes) { - case SupportedTensorDtypes::REALHBBF16: - return get_store_common_to_tensor_fn_realhbbf16(t); - case SupportedTensorDtypes::BOOL_OR_BYTE: - return get_store_common_to_tensor_fn_bool_or_byte( - t); - case SupportedTensorDtypes::SAME_AS_COMMON: { - constexpr auto common_scalar_type = - CppTypeToScalarType::value; - ET_CHECK_MSG( - t.scalar_type() == common_scalar_type, - "Unhandled dtype %s for %s", - ::executorch::runtime::toString(common_scalar_type), - op_name); - return internal::convert_and_store; - } - } - ET_CHECK(false); - return nullptr; -} -} // namespace internal - /** * Useful for binary elementwise operators. For each element of the inputs, * perform a computation and write to the corresponding element of the output. @@ -432,56 +313,29 @@ inline void apply_binary_elementwise_fn( * Useful for ternary elementwise operators. For each element of the inputs, * perform a computation and write to the corresponding element of the output. * Tensor broadcasting is applied wherever it is required. - * - * In order to mitigate build time cost (straightforwardly |CTYPE_A| * - * |CTYPE_B| * |CTYPE_C| * |CTYPE_OUT|), all arguments to compute_fun - * are passed as CTYPE_COMMON. - * - * Each tensor's supported dtypes set must be provided. The tensor - * will be checked to ensure that its dtype falls into that set. - * - * op_name is used to support dtype selective build, as with the - * ET_SWITCH family of macros. Note: because of C++17 quirks, you - * can't pass a string literal for op_name. Instead, you should do the - * following: - * - * static constexpr const char op_name[] = "my_op"; - * apply_ternary_elementwise_fn. */ -template +template < + typename CTYPE_A, + typename CTYPE_B, + typename CTYPE_C, + typename CTYPE_OUT, + typename Op> inline void apply_ternary_elementwise_fn( const Op& compute_fun, const Tensor& a, - SupportedTensorDtypes a_dtypes, const Tensor& b, - SupportedTensorDtypes b_dtypes, const Tensor& c, - SupportedTensorDtypes c_dtypes, - const Tensor& out, - SupportedTensorDtypes out_dtypes) { + const Tensor& out) { const bool a_is_broadcasted = !out.sizes().equals(a.sizes()); const bool b_is_broadcasted = !out.sizes().equals(b.sizes()); const bool c_is_broadcasted = !out.sizes().equals(c.sizes()); const bool any_is_broadcasted = (a_is_broadcasted || b_is_broadcasted || c_is_broadcasted); - const auto load_a_to_common = - internal::get_load_to_common_fn(a, a_dtypes); - const auto load_b_to_common = - internal::get_load_to_common_fn(b, b_dtypes); - const auto load_c_to_common = - internal::get_load_to_common_fn(c, c_dtypes); - const auto store_common_to_out = - internal::get_store_common_to_tensor_fn( - out, out_dtypes); - const char* const data_a = reinterpret_cast(a.const_data_ptr()); - const char* const data_b = reinterpret_cast(b.const_data_ptr()); - const char* const data_c = reinterpret_cast(c.const_data_ptr()); - const auto a_element_size = a.element_size(); - const auto b_element_size = b.element_size(); - const auto c_element_size = c.element_size(); - const auto out_element_size = out.element_size(); - char* const data_out = reinterpret_cast(out.mutable_data_ptr()); + const CTYPE_A* const data_a = a.const_data_ptr(); + const CTYPE_B* const data_b = b.const_data_ptr(); + const CTYPE_C* const data_c = c.const_data_ptr(); + CTYPE_OUT* const data_out = out.mutable_data_ptr(); for (size_t i = 0; i < out.numel(); ++i) { size_t a_linear_index = i; @@ -503,11 +357,8 @@ inline void apply_ternary_elementwise_fn( } } - auto result = compute_fun( - load_a_to_common(&data_a[a_linear_index * a_element_size]), - load_b_to_common(&data_b[b_linear_index * b_element_size]), - load_c_to_common(&data_c[c_linear_index * c_element_size])); - store_common_to_out(result, &data_out[i * out_element_size]); + data_out[i] = compute_fun( + data_a[a_linear_index], data_b[b_linear_index], data_c[c_linear_index]); } } diff --git a/kernels/portable/cpu/util/elementwise_util.h b/kernels/portable/cpu/util/elementwise_util.h new file mode 100644 index 00000000000..c7871258681 --- /dev/null +++ b/kernels/portable/cpu/util/elementwise_util.h @@ -0,0 +1,228 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include +#include + +namespace torch { +namespace executor { +namespace native { +namespace utils { + +namespace internal { + +template +To load_and_convert(const void* fromPtr) { + return static_cast(*reinterpret_cast(fromPtr)); +} + +template +void convert_and_store(From f, void* dst) { + *reinterpret_cast(dst) = static_cast(f); +} + +template +using load_to_common_fn = CTYPE_COMMON (*)(const void*); + +template +load_to_common_fn get_load_to_common_fn_realhbbf16( + const Tensor& t) { + CTYPE_COMMON (*result)(const void*) = nullptr; + ET_SWITCH_REALHBBF16_TYPES( + t.scalar_type(), unused, op_name, TENSOR_CTYPE, [&]() { + result = internal::load_and_convert; + }); + return result; +} + +template +load_to_common_fn get_load_to_common_fn_bool_or_byte( + const Tensor& t) { + CTYPE_COMMON (*result)(const void*) = nullptr; + ET_SWITCH_TWO_TYPES( + Bool, Byte, t.scalar_type(), unused, op_name, TENSOR_CTYPE, [&]() { + result = internal::load_and_convert; + }); + return result; +} + +template +using store_common_to_tensor_fn = void (*)(CTYPE_COMMON, void*); + +template +store_common_to_tensor_fn +get_store_common_to_tensor_fn_realhbbf16(const Tensor& t) { + void (*result)(CTYPE_COMMON, void*) = nullptr; + ET_SWITCH_REALHBBF16_TYPES( + t.scalar_type(), unused, op_name, TENSOR_CTYPE, [&]() { + result = internal::convert_and_store; + }); + return result; +} + +template +store_common_to_tensor_fn +get_store_common_to_tensor_fn_bool_or_byte(const Tensor& t) { + void (*result)(CTYPE_COMMON, void*) = nullptr; + ET_SWITCH_TWO_TYPES( + Bool, Byte, t.scalar_type(), unused, op_name, TENSOR_CTYPE, [&]() { + result = internal::convert_and_store; + }); + return result; +} + +} // namespace internal + +enum class SupportedTensorDtypes { + REALHBBF16, + BOOL_OR_BYTE, + SAME_AS_COMMON, +}; + +namespace internal { + +template +load_to_common_fn get_load_to_common_fn( + const Tensor& t, + SupportedTensorDtypes dtypes) { + switch (dtypes) { + case SupportedTensorDtypes::REALHBBF16: + return get_load_to_common_fn_realhbbf16(t); + case SupportedTensorDtypes::BOOL_OR_BYTE: + return get_load_to_common_fn_bool_or_byte(t); + case SupportedTensorDtypes::SAME_AS_COMMON: { + constexpr auto common_scalar_type = + CppTypeToScalarType::value; + ET_CHECK_MSG( + t.scalar_type() == common_scalar_type, + "Unhandled dtype %s for %s", + ::executorch::runtime::toString(common_scalar_type), + op_name); + return internal::load_and_convert; + } + } + ET_CHECK(false); + return nullptr; +} + +template +store_common_to_tensor_fn get_store_common_to_tensor_fn( + const Tensor& t, + SupportedTensorDtypes dtypes) { + switch (dtypes) { + case SupportedTensorDtypes::REALHBBF16: + return get_store_common_to_tensor_fn_realhbbf16(t); + case SupportedTensorDtypes::BOOL_OR_BYTE: + return get_store_common_to_tensor_fn_bool_or_byte( + t); + case SupportedTensorDtypes::SAME_AS_COMMON: { + constexpr auto common_scalar_type = + CppTypeToScalarType::value; + ET_CHECK_MSG( + t.scalar_type() == common_scalar_type, + "Unhandled dtype %s for %s", + ::executorch::runtime::toString(common_scalar_type), + op_name); + return internal::convert_and_store; + } + } + ET_CHECK(false); + return nullptr; +} + +} // namespace internal + +/** + * Useful for tri-tensor elementwise operators. For each element of the inputs, + * perform a computation and write to the corresponding element of the output. + * Tensor broadcasting is applied wherever it is required. + * + * In order to mitigate build time cost (straightforwardly |CTYPE_A| * + * |CTYPE_B| * |CTYPE_C| * |CTYPE_OUT|), all arguments to compute_fun + * are passed as CTYPE_COMMON. + * + * Each tensor's supported dtypes set must be provided. The tensor + * will be checked to ensure that its dtype falls into that set. + * + * op_name is used to support dtype selective build, as with the + * ET_SWITCH family of macros. Note: because of C++17 quirks, you + * can't pass a string literal for op_name. Instead, you should do the + * following: + * + * static constexpr const char op_name[] = "my_op"; + * apply_ternary_elementwise_fn. + */ +template +inline void apply_tritensor_elementwise_fn( + const Op& compute_fun, + const Tensor& a, + SupportedTensorDtypes a_dtypes, + const Tensor& b, + SupportedTensorDtypes b_dtypes, + const Tensor& c, + SupportedTensorDtypes c_dtypes, + const Tensor& out, + SupportedTensorDtypes out_dtypes) { + const bool a_is_broadcasted = !out.sizes().equals(a.sizes()); + const bool b_is_broadcasted = !out.sizes().equals(b.sizes()); + const bool c_is_broadcasted = !out.sizes().equals(c.sizes()); + const bool any_is_broadcasted = + (a_is_broadcasted || b_is_broadcasted || c_is_broadcasted); + + const auto load_a_to_common = + internal::get_load_to_common_fn(a, a_dtypes); + const auto load_b_to_common = + internal::get_load_to_common_fn(b, b_dtypes); + const auto load_c_to_common = + internal::get_load_to_common_fn(c, c_dtypes); + const auto store_common_to_out = + internal::get_store_common_to_tensor_fn( + out, out_dtypes); + const char* const data_a = reinterpret_cast(a.const_data_ptr()); + const char* const data_b = reinterpret_cast(b.const_data_ptr()); + const char* const data_c = reinterpret_cast(c.const_data_ptr()); + const auto a_element_size = a.element_size(); + const auto b_element_size = b.element_size(); + const auto c_element_size = c.element_size(); + const auto out_element_size = out.element_size(); + char* const data_out = reinterpret_cast(out.mutable_data_ptr()); + + for (size_t i = 0; i < out.numel(); ++i) { + size_t a_linear_index = i; + size_t b_linear_index = i; + size_t c_linear_index = i; + + if (any_is_broadcasted) { + size_t out_indexes[kTensorDimensionLimit]; + delinearize_index(i, out, out_indexes, kTensorDimensionLimit); + + if (a_is_broadcasted) { + a_linear_index = linearize_access_indexes(out_indexes, out.dim(), a); + } + if (b_is_broadcasted) { + b_linear_index = linearize_access_indexes(out_indexes, out.dim(), b); + } + if (c_is_broadcasted) { + c_linear_index = linearize_access_indexes(out_indexes, out.dim(), c); + } + } + + auto result = compute_fun( + load_a_to_common(&data_a[a_linear_index * a_element_size]), + load_b_to_common(&data_b[b_linear_index * b_element_size]), + load_c_to_common(&data_c[c_linear_index * c_element_size])); + store_common_to_out(result, &data_out[i * out_element_size]); + } +} + +} // namespace utils +} // namespace native +} // namespace executor +} // namespace torch diff --git a/kernels/portable/cpu/util/targets.bzl b/kernels/portable/cpu/util/targets.bzl index 82d3d84fa23..2285206728f 100644 --- a/kernels/portable/cpu/util/targets.bzl +++ b/kernels/portable/cpu/util/targets.bzl @@ -30,6 +30,7 @@ def define_common_targets(): "//executorch/kernels/portable/cpu/util:select_copy_util", "//executorch/kernels/portable/cpu/util:advanced_index_util", "//executorch/kernels/portable/cpu/util:slice_util", + "//executorch/kernels/portable/cpu/util:elementwise_util", ], visibility = ["//executorch/...", "@EXECUTORCH_CLIENTS"], ) @@ -77,6 +78,20 @@ def define_common_targets(): visibility = ["//executorch/kernels/portable/cpu/...", "//executorch/kernels/optimized/cpu/...", "@EXECUTORCH_CLIENTS"], ) + runtime.cxx_library( + name = "elementwise_util", + exported_headers = [ + "elementwise_util.h", + ], + compiler_flags = ["-Wno-missing-prototypes"], + deps = [ + ":broadcast_util", + "//executorch/kernels/portable/cpu:scalar_utils", + "//executorch/runtime/kernel:kernel_includes", + ], + visibility = ["//executorch/kernels/portable/cpu/...", "//executorch/kernels/optimized/cpu/..."], + ) + runtime.cxx_library( name = "advanced_index_util", srcs = ["advanced_index_util.cpp"], diff --git a/shim/xplat/executorch/kernels/portable/op_registration_util.bzl b/shim/xplat/executorch/kernels/portable/op_registration_util.bzl index ef8f936571c..6fa797f6126 100644 --- a/shim/xplat/executorch/kernels/portable/op_registration_util.bzl +++ b/shim/xplat/executorch/kernels/portable/op_registration_util.bzl @@ -391,6 +391,7 @@ ATEN_OPS = ( deps = [ ":scalar_utils", "//executorch/kernels/portable/cpu/util:broadcast_util", + "//executorch/kernels/portable/cpu/util:elementwise_util", "//executorch/kernels/portable/cpu/util:functional_util", "//executorch/kernels/portable/cpu/util:math_util", ], @@ -1186,8 +1187,7 @@ ATEN_OPS = ( name = "op_where", deps = [ "//executorch/kernels/portable/cpu/util:broadcast_util", - "//executorch/kernels/portable/cpu/util:functional_util", - "//executorch/runtime/core/exec_aten:lib", + "//executorch/kernels/portable/cpu/util:elementwise_util", ], ), op_target(