Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 6 additions & 6 deletions kernels/portable/cpu/op_clamp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
#include <limits>

#include <executorch/kernels/portable/cpu/scalar_utils.h>
#include <executorch/kernels/portable/cpu/util/broadcast_util.h>
#include <executorch/kernels/portable/cpu/util/elementwise_util.h>
#include <executorch/kernels/portable/cpu/util/functional_util.h>
#include <executorch/kernels/portable/cpu/util/math_util.h>
#include <executorch/runtime/kernel/kernel_includes.h>
Expand Down Expand Up @@ -215,7 +215,7 @@ Tensor& clamp_tensor_out(
static constexpr const char op_name[] = "clamp.Tensor_out";

ET_SWITCH_REALHB_TYPES(common_type, ctx, op_name, CTYPE_COMMON, [&]() {
apply_ternary_elementwise_fn<CTYPE_COMMON, op_name>(
utils::apply_tritensor_elementwise_fn<CTYPE_COMMON, op_name>(
[has_min, has_max](
const CTYPE_COMMON val_in,
const CTYPE_COMMON val_min,
Expand All @@ -230,13 +230,13 @@ Tensor& clamp_tensor_out(
return val_out;
},
in,
SupportedTensorDtypes::REALHBBF16,
utils::SupportedTensorDtypes::REALHBBF16,
min,
SupportedTensorDtypes::REALHBBF16,
utils::SupportedTensorDtypes::REALHBBF16,
max,
SupportedTensorDtypes::REALHBBF16,
utils::SupportedTensorDtypes::REALHBBF16,
out,
SupportedTensorDtypes::REALHBBF16);
utils::SupportedTensorDtypes::REALHBBF16);
});

return out;
Expand Down
14 changes: 7 additions & 7 deletions kernels/portable/cpu/op_where.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,7 @@
* LICENSE file in the root directory of this source tree.
*/

#include <executorch/kernels/portable/cpu/util/broadcast_util.h>
#include <executorch/kernels/portable/cpu/util/functional_util.h>
#include <executorch/kernels/portable/cpu/util/elementwise_util.h>
#include <executorch/runtime/kernel/kernel_includes.h>

namespace torch {
Expand Down Expand Up @@ -44,19 +43,20 @@ Tensor& where_out(
cond_type == ScalarType::Bool || cond_type == ScalarType::Byte,
"Unhandled dtype %s for where.self_out",
torch::executor::toString(cond_type));

ET_SWITCH_REALHBBF16_TYPES(common_type, ctx, op_name, CTYPE_COMMON, [&]() {
apply_ternary_elementwise_fn<CTYPE_COMMON, op_name>(
utils::apply_tritensor_elementwise_fn<CTYPE_COMMON, op_name>(
[](const CTYPE_COMMON val_a,
const CTYPE_COMMON val_b,
const CTYPE_COMMON val_c) { return val_c ? val_a : val_b; },
a,
SupportedTensorDtypes::REALHBBF16,
utils::SupportedTensorDtypes::REALHBBF16,
b,
SupportedTensorDtypes::REALHBBF16,
utils::SupportedTensorDtypes::REALHBBF16,
cond,
SupportedTensorDtypes::BOOL_OR_BYTE,
utils::SupportedTensorDtypes::BOOL_OR_BYTE,
out,
SupportedTensorDtypes::SAME_AS_COMMON);
utils::SupportedTensorDtypes::SAME_AS_COMMON);
});

return out;
Expand Down
175 changes: 13 additions & 162 deletions kernels/portable/cpu/util/broadcast_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -270,125 +270,6 @@ size_t linearize_access_indexes(
// Mapping with broadcasting
//

namespace internal {
template <typename To, typename From>
To load_and_convert(const void* fromPtr) {
return static_cast<To>(*reinterpret_cast<const From*>(fromPtr));
}

template <typename To, typename From>
void convert_and_store(From f, void* dst) {
*reinterpret_cast<To*>(dst) = static_cast<To>(f);
}

template <typename CTYPE_COMMON>
using load_to_common_fn = CTYPE_COMMON (*)(const void*);

template <typename CTYPE_COMMON, const char* op_name>
load_to_common_fn<CTYPE_COMMON> get_load_to_common_fn_realhbbf16(
const Tensor& t) {
CTYPE_COMMON (*result)(const void*) = nullptr;
ET_SWITCH_REALHBBF16_TYPES(
t.scalar_type(), unused, op_name, TENSOR_CTYPE, [&]() {
result = internal::load_and_convert<CTYPE_COMMON, TENSOR_CTYPE>;
});
return result;
}

template <typename CTYPE_COMMON, const char* op_name>
load_to_common_fn<CTYPE_COMMON> get_load_to_common_fn_bool_or_byte(
const Tensor& t) {
CTYPE_COMMON (*result)(const void*) = nullptr;
ET_SWITCH_TWO_TYPES(
Bool, Byte, t.scalar_type(), unused, op_name, TENSOR_CTYPE, [&]() {
result = internal::load_and_convert<CTYPE_COMMON, TENSOR_CTYPE>;
});
return result;
}

template <typename CTYPE_COMMON>
using store_common_to_tensor_fn = void (*)(CTYPE_COMMON, void*);

template <typename CTYPE_COMMON, const char* op_name>
store_common_to_tensor_fn<CTYPE_COMMON>
get_store_common_to_tensor_fn_realhbbf16(const Tensor& t) {
void (*result)(CTYPE_COMMON, void*) = nullptr;
ET_SWITCH_REALHBBF16_TYPES(
t.scalar_type(), unused, op_name, TENSOR_CTYPE, [&]() {
result = internal::convert_and_store<TENSOR_CTYPE, CTYPE_COMMON>;
});
return result;
}

template <typename CTYPE_COMMON, const char* op_name>
store_common_to_tensor_fn<CTYPE_COMMON>
get_store_common_to_tensor_fn_bool_or_byte(const Tensor& t) {
void (*result)(CTYPE_COMMON, void*) = nullptr;
ET_SWITCH_TWO_TYPES(
Bool, Byte, t.scalar_type(), unused, op_name, TENSOR_CTYPE, [&]() {
result = internal::convert_and_store<TENSOR_CTYPE, CTYPE_COMMON>;
});
return result;
}
} // namespace internal

enum class SupportedTensorDtypes {
REALHBBF16,
BOOL_OR_BYTE,
SAME_AS_COMMON,
};

namespace internal {
template <typename CTYPE_COMMON, const char* op_name>
load_to_common_fn<CTYPE_COMMON> get_load_to_common_fn(
const Tensor& t,
SupportedTensorDtypes dtypes) {
switch (dtypes) {
case SupportedTensorDtypes::REALHBBF16:
return get_load_to_common_fn_realhbbf16<CTYPE_COMMON, op_name>(t);
case SupportedTensorDtypes::BOOL_OR_BYTE:
return get_load_to_common_fn_bool_or_byte<CTYPE_COMMON, op_name>(t);
case SupportedTensorDtypes::SAME_AS_COMMON: {
constexpr auto common_scalar_type =
CppTypeToScalarType<CTYPE_COMMON>::value;
ET_CHECK_MSG(
t.scalar_type() == common_scalar_type,
"Unhandled dtype %s for %s",
::executorch::runtime::toString(common_scalar_type),
op_name);
return internal::load_and_convert<CTYPE_COMMON, CTYPE_COMMON>;
}
}
ET_CHECK(false);
return nullptr;
}

template <typename CTYPE_COMMON, const char* op_name>
store_common_to_tensor_fn<CTYPE_COMMON> get_store_common_to_tensor_fn(
const Tensor& t,
SupportedTensorDtypes dtypes) {
switch (dtypes) {
case SupportedTensorDtypes::REALHBBF16:
return get_store_common_to_tensor_fn_realhbbf16<CTYPE_COMMON, op_name>(t);
case SupportedTensorDtypes::BOOL_OR_BYTE:
return get_store_common_to_tensor_fn_bool_or_byte<CTYPE_COMMON, op_name>(
t);
case SupportedTensorDtypes::SAME_AS_COMMON: {
constexpr auto common_scalar_type =
CppTypeToScalarType<CTYPE_COMMON>::value;
ET_CHECK_MSG(
t.scalar_type() == common_scalar_type,
"Unhandled dtype %s for %s",
::executorch::runtime::toString(common_scalar_type),
op_name);
return internal::convert_and_store<CTYPE_COMMON, CTYPE_COMMON>;
}
}
ET_CHECK(false);
return nullptr;
}
} // namespace internal

/**
* Useful for binary elementwise operators. For each element of the inputs,
* perform a computation and write to the corresponding element of the output.
Expand Down Expand Up @@ -432,56 +313,29 @@ inline void apply_binary_elementwise_fn(
* Useful for ternary elementwise operators. For each element of the inputs,
* perform a computation and write to the corresponding element of the output.
* Tensor broadcasting is applied wherever it is required.
*
* In order to mitigate build time cost (straightforwardly |CTYPE_A| *
* |CTYPE_B| * |CTYPE_C| * |CTYPE_OUT|), all arguments to compute_fun
* are passed as CTYPE_COMMON.
*
* Each tensor's supported dtypes set must be provided. The tensor
* will be checked to ensure that its dtype falls into that set.
*
* op_name is used to support dtype selective build, as with the
* ET_SWITCH family of macros. Note: because of C++17 quirks, you
* can't pass a string literal for op_name. Instead, you should do the
* following:
*
* static constexpr const char op_name[] = "my_op";
* apply_ternary_elementwise_fn<CTYPE_COMMON, op_name>.
*/
template <typename CTYPE_COMMON, const char* op_name, typename Op>
template <
typename CTYPE_A,
typename CTYPE_B,
typename CTYPE_C,
typename CTYPE_OUT,
typename Op>
inline void apply_ternary_elementwise_fn(
const Op& compute_fun,
const Tensor& a,
SupportedTensorDtypes a_dtypes,
const Tensor& b,
SupportedTensorDtypes b_dtypes,
const Tensor& c,
SupportedTensorDtypes c_dtypes,
const Tensor& out,
SupportedTensorDtypes out_dtypes) {
const Tensor& out) {
const bool a_is_broadcasted = !out.sizes().equals(a.sizes());
const bool b_is_broadcasted = !out.sizes().equals(b.sizes());
const bool c_is_broadcasted = !out.sizes().equals(c.sizes());
const bool any_is_broadcasted =
(a_is_broadcasted || b_is_broadcasted || c_is_broadcasted);

const auto load_a_to_common =
internal::get_load_to_common_fn<CTYPE_COMMON, op_name>(a, a_dtypes);
const auto load_b_to_common =
internal::get_load_to_common_fn<CTYPE_COMMON, op_name>(b, b_dtypes);
const auto load_c_to_common =
internal::get_load_to_common_fn<CTYPE_COMMON, op_name>(c, c_dtypes);
const auto store_common_to_out =
internal::get_store_common_to_tensor_fn<CTYPE_COMMON, op_name>(
out, out_dtypes);
const char* const data_a = reinterpret_cast<const char*>(a.const_data_ptr());
const char* const data_b = reinterpret_cast<const char*>(b.const_data_ptr());
const char* const data_c = reinterpret_cast<const char*>(c.const_data_ptr());
const auto a_element_size = a.element_size();
const auto b_element_size = b.element_size();
const auto c_element_size = c.element_size();
const auto out_element_size = out.element_size();
char* const data_out = reinterpret_cast<char*>(out.mutable_data_ptr());
const CTYPE_A* const data_a = a.const_data_ptr<CTYPE_A>();
const CTYPE_B* const data_b = b.const_data_ptr<CTYPE_B>();
const CTYPE_C* const data_c = c.const_data_ptr<CTYPE_C>();
CTYPE_OUT* const data_out = out.mutable_data_ptr<CTYPE_OUT>();

for (size_t i = 0; i < out.numel(); ++i) {
size_t a_linear_index = i;
Expand All @@ -503,11 +357,8 @@ inline void apply_ternary_elementwise_fn(
}
}

auto result = compute_fun(
load_a_to_common(&data_a[a_linear_index * a_element_size]),
load_b_to_common(&data_b[b_linear_index * b_element_size]),
load_c_to_common(&data_c[c_linear_index * c_element_size]));
store_common_to_out(result, &data_out[i * out_element_size]);
data_out[i] = compute_fun(
data_a[a_linear_index], data_b[b_linear_index], data_c[c_linear_index]);
}
}

Expand Down
Loading
Loading