diff --git a/kernels/portable/cpu/op_clamp.cpp b/kernels/portable/cpu/op_clamp.cpp index 924780b29ab..069f6057fde 100644 --- a/kernels/portable/cpu/op_clamp.cpp +++ b/kernels/portable/cpu/op_clamp.cpp @@ -66,6 +66,15 @@ ET_NODISCARD bool check_bounds( return is_valid; } +template +To load_and_convert(const void* fromPtr) { + return static_cast(*reinterpret_cast(fromPtr)); +} + +template +void convert_and_store(From f, void* dst) { + *reinterpret_cast(dst) = static_cast(f); +} } // namespace Tensor& clamp_out( @@ -214,41 +223,46 @@ Tensor& clamp_tensor_out( constexpr auto name = "clamp.Tensor_out"; - ET_SWITCH_REALHB_TYPES(in_type, ctx, name, CTYPE_IN, [&]() { + ET_SWITCH_REALHB_TYPES(common_type, ctx, name, CTYPE_COMMON, [&]() { + using ToCtypeCommonFn = CTYPE_COMMON (*)(const void*); + ToCtypeCommonFn in_to_common; + ET_SWITCH_REALHB_TYPES(in_type, ctx, name, CTYPE_IN, [&]() { + in_to_common = load_and_convert; + }); + ToCtypeCommonFn min_to_common; ET_SWITCH_REALHB_TYPES(min_type, ctx, name, CTYPE_MIN, [&]() { - ET_SWITCH_REALHB_TYPES(max_type, ctx, name, CTYPE_MAX, [&]() { - ET_SWITCH_REALHB_TYPES(out_type, ctx, name, CTYPE_OUT, [&]() { - using CTYPE_MINMAX = typename torch::executor:: - promote_types::type; - using CTYPE = typename torch::executor:: - promote_types::type; - apply_ternary_elementwise_fn< - CTYPE_IN, - CTYPE_MIN, - CTYPE_MAX, - CTYPE_OUT>( - [has_min, has_max]( - const CTYPE_IN val_in, - const CTYPE_MIN val_min, - const CTYPE_MAX val_max) { - CTYPE val_out = static_cast(val_in); - if (has_min) { - val_out = - utils::max_override(val_out, static_cast(val_min)); - } - if (has_max) { - val_out = - utils::min_override(val_out, static_cast(val_max)); - } - return static_cast(val_out); - }, - in, - min, - max, - out); - }); - }); + min_to_common = load_and_convert; + }); + ToCtypeCommonFn max_to_common; + ET_SWITCH_REALHB_TYPES(max_type, ctx, name, CTYPE_MAX, [&]() { + max_to_common = load_and_convert; + }); + void (*common_to_out)(CTYPE_COMMON, void*); + ET_SWITCH_REALHB_TYPES(out_type, ctx, name, CTYPE_OUT, [&]() { + common_to_out = convert_and_store; }); + apply_ternary_elementwise_fn( + [has_min, has_max]( + const CTYPE_COMMON val_in, + const CTYPE_COMMON val_min, + const CTYPE_COMMON val_max) { + CTYPE_COMMON val_out = val_in; + if (has_min) { + val_out = utils::max_override(val_out, val_min); + } + if (has_max) { + val_out = utils::min_override(val_out, val_max); + } + return val_out; + }, + in, + min, + max, + out, + in_to_common, + min_to_common, + max_to_common, + common_to_out); }); return out; diff --git a/kernels/portable/cpu/op_where.cpp b/kernels/portable/cpu/op_where.cpp index a7736247597..57f5f2fc16e 100644 --- a/kernels/portable/cpu/op_where.cpp +++ b/kernels/portable/cpu/op_where.cpp @@ -48,16 +48,29 @@ Tensor& where_out( ET_SWITCH_REALHBBF16_TYPES(b_type, ctx, name, CTYPE_B, [&]() { using CTYPE_OUT = typename torch::executor::promote_types::type; - apply_ternary_elementwise_fn( - [](const CTYPE_A val_a, const CTYPE_B val_b, const uint8_t val_c) { - CTYPE_OUT a_casted = static_cast(val_a); - CTYPE_OUT b_casted = static_cast(val_b); - return val_c ? a_casted : b_casted; - }, + apply_ternary_elementwise_fn( + [](const CTYPE_OUT val_a, + const CTYPE_OUT val_b, + const CTYPE_OUT val_c) { return val_c ? val_a : val_b; }, a, b, cond, - out); + out, + [](const void* a_ptr) { + return static_cast( + *reinterpret_cast(a_ptr)); + }, + [](const void* b_ptr) { + return static_cast( + *reinterpret_cast(b_ptr)); + }, + [](const void* c_ptr) { + return static_cast( + *reinterpret_cast(c_ptr)); + }, + [](CTYPE_OUT result, void* out) { + *reinterpret_cast(out) = result; + }); }); }); diff --git a/kernels/portable/cpu/util/broadcast_util.h b/kernels/portable/cpu/util/broadcast_util.h index 92d35f322fb..f5b5f0fd2c1 100644 --- a/kernels/portable/cpu/util/broadcast_util.h +++ b/kernels/portable/cpu/util/broadcast_util.h @@ -313,29 +313,44 @@ inline void apply_binary_elementwise_fn( * Useful for ternary elementwise operators. For each element of the inputs, * perform a computation and write to the corresponding element of the output. * Tensor broadcasting is applied wherever it is required. + * + * In order to mitigate build time cost (straightforwardly |CTYPE_A| * + * |CTYPE_B| * |CTYPE_C| * |CTYPE_OUT|), all arguments to compute_fun + * are passed as CTYPE_COMMON. We require compute_fun to return + * CTYPE_COMMON, and we require loading conversion functions from each + * input type to CTYPE_COMMON and a storing conversion from + * CTYPE_COMMON to CTYPE_OUT be provided. Each conversion function + * must take a void* pointing to an element of the corresponding + * tensor, load that element, and convert it to CTYPE_COMMON. The + * storing conversion function must have the signature + * void(CTYPE_COMMON, void*), convert the given element to CTYPE_OUT, + * and store it to the given location. */ -template < - typename CTYPE_A, - typename CTYPE_B, - typename CTYPE_C, - typename CTYPE_OUT, - typename Op> +template inline void apply_ternary_elementwise_fn( const Op& compute_fun, const Tensor& a, const Tensor& b, const Tensor& c, - const Tensor& out) { + const Tensor& out, + CTYPE_COMMON (*load_a_to_common)(const void*), + CTYPE_COMMON (*load_b_to_common)(const void*), + CTYPE_COMMON (*load_c_to_common)(const void*), + void (*store_common_to_out)(CTYPE_COMMON, void*)) { const bool a_is_broadcasted = !out.sizes().equals(a.sizes()); const bool b_is_broadcasted = !out.sizes().equals(b.sizes()); const bool c_is_broadcasted = !out.sizes().equals(c.sizes()); const bool any_is_broadcasted = (a_is_broadcasted || b_is_broadcasted || c_is_broadcasted); - const CTYPE_A* const data_a = a.const_data_ptr(); - const CTYPE_B* const data_b = b.const_data_ptr(); - const CTYPE_C* const data_c = c.const_data_ptr(); - CTYPE_OUT* const data_out = out.mutable_data_ptr(); + const char* const data_a = reinterpret_cast(a.const_data_ptr()); + const char* const data_b = reinterpret_cast(b.const_data_ptr()); + const char* const data_c = reinterpret_cast(c.const_data_ptr()); + const auto a_element_size = a.element_size(); + const auto b_element_size = b.element_size(); + const auto c_element_size = c.element_size(); + const auto out_element_size = out.element_size(); + char* const data_out = reinterpret_cast(out.mutable_data_ptr()); for (size_t i = 0; i < out.numel(); ++i) { size_t a_linear_index = i; @@ -357,8 +372,11 @@ inline void apply_ternary_elementwise_fn( } } - data_out[i] = compute_fun( - data_a[a_linear_index], data_b[b_linear_index], data_c[c_linear_index]); + auto result = compute_fun( + load_a_to_common(&data_a[a_linear_index * a_element_size]), + load_b_to_common(&data_b[b_linear_index * b_element_size]), + load_c_to_common(&data_c[c_linear_index * c_element_size])); + store_common_to_out(result, &data_out[i * out_element_size]); } }