From b8380d5b4f02f45ac515d61ca84e74b133758389 Mon Sep 17 00:00:00 2001 From: Scott Wolchok Date: Wed, 2 Oct 2024 16:04:36 -0700 Subject: [PATCH] [ExecuTorch] Simplify function pointers for apply_ternary_elementwise_fn Cleaning up some of the required boilerplate. I updated op_clamp and op_where, but continued to not optimize op_where for size/build time. Ideal usage optimizing for size/build time would look like op_clamp. Differential Revision: [D63790004](https://our.internmc.facebook.com/intern/diff/D63790004/) [ghstack-poisoned] --- kernels/portable/cpu/op_clamp.cpp | 38 ++++---------------- kernels/portable/cpu/op_where.cpp | 19 +++------- kernels/portable/cpu/util/broadcast_util.h | 40 ++++++++++++++++++++++ 3 files changed, 50 insertions(+), 47 deletions(-) diff --git a/kernels/portable/cpu/op_clamp.cpp b/kernels/portable/cpu/op_clamp.cpp index 069f6057fde..c73b2909ac6 100644 --- a/kernels/portable/cpu/op_clamp.cpp +++ b/kernels/portable/cpu/op_clamp.cpp @@ -66,15 +66,6 @@ ET_NODISCARD bool check_bounds( return is_valid; } -template -To load_and_convert(const void* fromPtr) { - return static_cast(*reinterpret_cast(fromPtr)); -} - -template -void convert_and_store(From f, void* dst) { - *reinterpret_cast(dst) = static_cast(f); -} } // namespace Tensor& clamp_out( @@ -221,26 +212,9 @@ Tensor& clamp_tensor_out( ET_KERNEL_CHECK(ctx, canCast(common_type, out_type), InvalidArgument, out); - constexpr auto name = "clamp.Tensor_out"; + static constexpr const char op_name[] = "clamp.Tensor_out"; - ET_SWITCH_REALHB_TYPES(common_type, ctx, name, CTYPE_COMMON, [&]() { - using ToCtypeCommonFn = CTYPE_COMMON (*)(const void*); - ToCtypeCommonFn in_to_common; - ET_SWITCH_REALHB_TYPES(in_type, ctx, name, CTYPE_IN, [&]() { - in_to_common = load_and_convert; - }); - ToCtypeCommonFn min_to_common; - ET_SWITCH_REALHB_TYPES(min_type, ctx, name, CTYPE_MIN, [&]() { - min_to_common = load_and_convert; - }); - ToCtypeCommonFn max_to_common; - ET_SWITCH_REALHB_TYPES(max_type, ctx, name, CTYPE_MAX, [&]() { - max_to_common = load_and_convert; - }); - void (*common_to_out)(CTYPE_COMMON, void*); - ET_SWITCH_REALHB_TYPES(out_type, ctx, name, CTYPE_OUT, [&]() { - common_to_out = convert_and_store; - }); + ET_SWITCH_REALHB_TYPES(common_type, ctx, op_name, CTYPE_COMMON, [&]() { apply_ternary_elementwise_fn( [has_min, has_max]( const CTYPE_COMMON val_in, @@ -259,10 +233,10 @@ Tensor& clamp_tensor_out( min, max, out, - in_to_common, - min_to_common, - max_to_common, - common_to_out); + get_load_to_common_fn_realhbbf16(in), + get_load_to_common_fn_realhbbf16(min), + get_load_to_common_fn_realhbbf16(max), + get_store_common_to_tensor_fn_realhbbf16(out)); }); return out; diff --git a/kernels/portable/cpu/op_where.cpp b/kernels/portable/cpu/op_where.cpp index 57f5f2fc16e..d93efcf5398 100644 --- a/kernels/portable/cpu/op_where.cpp +++ b/kernels/portable/cpu/op_where.cpp @@ -56,21 +56,10 @@ Tensor& where_out( b, cond, out, - [](const void* a_ptr) { - return static_cast( - *reinterpret_cast(a_ptr)); - }, - [](const void* b_ptr) { - return static_cast( - *reinterpret_cast(b_ptr)); - }, - [](const void* c_ptr) { - return static_cast( - *reinterpret_cast(c_ptr)); - }, - [](CTYPE_OUT result, void* out) { - *reinterpret_cast(out) = result; - }); + internal::load_and_convert, + internal::load_and_convert, + internal::load_and_convert, + internal::convert_and_store); }); }); diff --git a/kernels/portable/cpu/util/broadcast_util.h b/kernels/portable/cpu/util/broadcast_util.h index f5b5f0fd2c1..c75883322d6 100644 --- a/kernels/portable/cpu/util/broadcast_util.h +++ b/kernels/portable/cpu/util/broadcast_util.h @@ -270,6 +270,46 @@ size_t linearize_access_indexes( // Mapping with broadcasting // +namespace internal { +template +To load_and_convert(const void* fromPtr) { + return static_cast(*reinterpret_cast(fromPtr)); +} + +template +void convert_and_store(From f, void* dst) { + *reinterpret_cast(dst) = static_cast(f); +} +} // namespace internal + +template +using load_to_common_fn = CTYPE_COMMON (*)(const void*); + +template +load_to_common_fn get_load_to_common_fn_realhbbf16( + const Tensor& t) { + CTYPE_COMMON (*result)(const void*) = nullptr; + ET_SWITCH_REALHBBF16_TYPES( + t.scalar_type(), unused, op_name, TENSOR_CTYPE, [&]() { + result = internal::load_and_convert; + }); + return result; +} + +template +using store_common_to_tensor_fn = void (*)(CTYPE_COMMON, void*); + +template +store_common_to_tensor_fn +get_store_common_to_tensor_fn_realhbbf16(const Tensor& t) { + void (*result)(CTYPE_COMMON, void*) = nullptr; + ET_SWITCH_REALHBBF16_TYPES( + t.scalar_type(), unused, op_name, TENSOR_CTYPE, [&]() { + result = internal::convert_and_store; + }); + return result; +} + /** * Useful for binary elementwise operators. For each element of the inputs, * perform a computation and write to the corresponding element of the output.