Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
140 changes: 34 additions & 106 deletions kernels/portable/cpu/op_add.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,64 +7,14 @@
*/

#include <executorch/kernels/portable/cpu/scalar_utils.h>
#include <executorch/kernels/portable/cpu/util/broadcast_util.h>
#include <executorch/kernels/portable/cpu/util/functional_util.h>
#include <executorch/kernels/portable/cpu/util/elementwise_util.h>
#include <executorch/kernels/portable/cpu/util/kernel_ops_util.h>
#include <executorch/runtime/kernel/kernel_includes.h>
#include <executorch/runtime/platform/assert.h>

namespace torch {
namespace executor {
namespace native {
namespace {

template <
bool can_cast,
typename CTYPE_A,
typename CTYPE_B,
typename CTYPE_IN,
typename CTYPE_OUT>
struct AddInner;

template <
typename CTYPE_A,
typename CTYPE_B,
typename CTYPE_IN,
typename CTYPE_OUT>
struct AddInner<true, CTYPE_A, CTYPE_B, CTYPE_IN, CTYPE_OUT> {
static void
run(const Tensor& a, const Tensor& b, CTYPE_IN alpha_val, Tensor& out) {
apply_binary_elementwise_fn<CTYPE_A, CTYPE_B, CTYPE_OUT>(
// NOLINTNEXTLINE(facebook-hte-ConstantArgumentPassByValue)
[alpha_val](const CTYPE_A val_a, const CTYPE_B val_b) {
CTYPE_IN a_casted = static_cast<CTYPE_IN>(val_a);
CTYPE_IN b_casted = static_cast<CTYPE_IN>(val_b);
CTYPE_IN value = a_casted + alpha_val * b_casted;

return static_cast<CTYPE_OUT>(value);
},
a,
b,
out);
}
};

template <typename CTYPE_IN>
struct ReportCanCastBug {
static void run(const Tensor&, const Tensor&, CTYPE_IN, Tensor&) {
ET_DCHECK_MSG(false, "BUG: canCast should have been checked above");
}
};

template <
typename CTYPE_A,
typename CTYPE_B,
typename CTYPE_IN,
typename CTYPE_OUT>
struct AddInner<false, CTYPE_A, CTYPE_B, CTYPE_IN, CTYPE_OUT>
: public ReportCanCastBug<CTYPE_IN> {};

} // namespace

Tensor& add_out(
KernelRuntimeContext& ctx,
Expand All @@ -80,7 +30,9 @@ Tensor& add_out(

ET_KERNEL_CHECK(
ctx,
executorch::runtime::tensor_is_realhbbf16_type(out),
(executorch::runtime::tensor_is_realhbbf16_type(a) &&
executorch::runtime::tensor_is_realhbbf16_type(b) &&
executorch::runtime::tensor_is_realhbbf16_type(out)),
InvalidArgument,
out);
ET_KERNEL_CHECK(
Expand All @@ -96,25 +48,20 @@ Tensor& add_out(
ET_KERNEL_CHECK(
ctx, check_alpha_type(alpha_type, common_type), InvalidArgument, out);

constexpr auto name = "add.out";

ET_SWITCH_REALHBBF16_TYPES(a_type, ctx, name, CTYPE_A, [&]() {
ET_SWITCH_REALHBBF16_TYPES(b_type, ctx, name, CTYPE_B, [&]() {
using CTYPE_IN = typename torch::executor::
promote_types<CTYPE_A, CTYPE_B, /*half_to_float*/ true>::type;
ET_DCHECK(CppTypeToScalarType<CTYPE_IN>::value == common_type);
CTYPE_IN alpha_val;
utils::extract_scalar(alpha, &alpha_val);

ET_SWITCH_REALHBBF16_TYPES(out_type, ctx, name, CTYPE_OUT, [&]() {
AddInner<
can_cast<CTYPE_IN, CTYPE_OUT>::value,
CTYPE_A,
CTYPE_B,
CTYPE_IN,
CTYPE_OUT>::run(a, b, alpha_val, out);
});
});
static constexpr const char op_name[] = "add.out";

ET_SWITCH_REALB_TYPES(common_type, ctx, op_name, CTYPE_COMMON, [&]() {
utils::apply_bitensor_elementwise_fn<CTYPE_COMMON, op_name>(
[alpha](const CTYPE_COMMON val_a, const CTYPE_COMMON val_b) {
CTYPE_COMMON val_alpha = utils::scalar_to<CTYPE_COMMON>(alpha);
return val_a + val_alpha * val_b;
},
a,
utils::SupportedTensorDtypes::REALHBBF16,
b,
utils::SupportedTensorDtypes::REALHBBF16,
out,
utils::SupportedTensorDtypes::REALHBBF16);
});

return out;
Expand All @@ -138,14 +85,14 @@ Tensor& add_scalar_out(

ET_KERNEL_CHECK(
ctx,
executorch::runtime::tensor_is_realhbbf16_type(out),
(executorch::runtime::tensor_is_realhbbf16_type(a) &&
executorch::runtime::tensor_is_realhbbf16_type(out)),
InvalidArgument,
out);
ET_KERNEL_CHECK(
ctx, tensors_have_same_dim_order(a, out), InvalidArgument, out);

ScalarType a_type = a.scalar_type();
ScalarType b_type = utils::get_scalar_dtype(b);
ScalarType alpha_type = utils::get_scalar_dtype(alpha);
ScalarType common_type =
utils::promote_type_with_scalar(a_type, b, /*half_to_float*/ false);
Expand All @@ -155,42 +102,23 @@ Tensor& add_scalar_out(
ET_KERNEL_CHECK(
ctx, check_alpha_type(alpha_type, common_type), InvalidArgument, out);

if (common_type == ScalarType::Half) {
if (common_type == ScalarType::Half || common_type == ScalarType::BFloat16) {
common_type = ScalarType::Float;
}

constexpr auto name = "add.Scalar_out";

ET_SWITCH_REALHBBF16_TYPES(a_type, ctx, name, CTYPE_A, [&]() {
ET_SWITCH_SCALAR_OBJ_TYPES(b_type, ctx, name, CTYPE_B, [&]() {
using CTYPE_IN = typename utils::promote_type_with_scalar_type<
CTYPE_A,
CTYPE_B,
/*half_to_float*/ true>::type;
ET_DCHECK(CppTypeToScalarType<CTYPE_IN>::value == common_type);

CTYPE_B b_val;
utils::extract_scalar(b, &b_val);
CTYPE_IN b_casted = static_cast<CTYPE_IN>(b_val);

CTYPE_IN alpha_val;
utils::extract_scalar(alpha, &alpha_val);

using CTYPE_OUT = typename std::conditional<
std::is_same<CTYPE_A, internal::F2>::value,
internal::F2,
CTYPE_IN>::type;

apply_unary_map_fn(
[b_casted, alpha_val](const CTYPE_A val_a) {
CTYPE_IN a_casted = static_cast<CTYPE_IN>(val_a);
CTYPE_IN value = a_casted + alpha_val * b_casted;
return static_cast<CTYPE_OUT>(value);
},
a.const_data_ptr<CTYPE_A>(),
out.mutable_data_ptr<CTYPE_OUT>(),
out.numel());
});
static constexpr const char op_name[] = "add.Scalar_out";

ET_SWITCH_REALB_TYPES(common_type, ctx, op_name, CTYPE_COMMON, [&]() {
utils::apply_unitensor_elementwise_fn<CTYPE_COMMON, op_name>(
[b, alpha](const CTYPE_COMMON val_a) {
CTYPE_COMMON val_b = utils::scalar_to<CTYPE_COMMON>(b);
CTYPE_COMMON val_alpha = utils::scalar_to<CTYPE_COMMON>(alpha);
return val_a + val_alpha * val_b;
},
a,
utils::SupportedTensorDtypes::REALHBBF16,
out,
utils::SupportedTensorDtypes::REALHBBF16);
});

return out;
Expand Down
56 changes: 19 additions & 37 deletions kernels/portable/cpu/op_clamp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@

#include <executorch/kernels/portable/cpu/scalar_utils.h>
#include <executorch/kernels/portable/cpu/util/elementwise_util.h>
#include <executorch/kernels/portable/cpu/util/functional_util.h>
#include <executorch/kernels/portable/cpu/util/math_util.h>
#include <executorch/runtime/kernel/kernel_includes.h>

Expand Down Expand Up @@ -122,43 +121,26 @@ Tensor& clamp_out(

ET_KERNEL_CHECK(ctx, common_type == out_type, InvalidArgument, out);

ET_SWITCH_REALH_TYPES(out_type, ctx, "clamp", CTYPE_OUT, [&]() {
// Extract optional min value
CTYPE_OUT min = 0;
if (has_min) {
ET_SWITCH_SCALAR_OBJ_TYPES(min_type, ctx, "clamp", CTYPE_MIN, [&]() {
CTYPE_MIN min_val = 0;
utils::extract_scalar(min_opt.value(), &min_val);
min = static_cast<CTYPE_OUT>(min_val);
});
}

// Extract optional max value
CTYPE_OUT max = 0;
if (has_max) {
ET_SWITCH_SCALAR_OBJ_TYPES(max_type, ctx, "clamp", CTYPE_MAX, [&]() {
CTYPE_MAX max_val = 0;
utils::extract_scalar(max_opt.value(), &max_val);
max = static_cast<CTYPE_OUT>(max_val);
});
}
static constexpr const char op_name[] = "clamp.out";

ET_SWITCH_REALHB_TYPES(in_type, ctx, "clamp", CTYPE_IN, [&]() {
apply_unary_map_fn(
[has_min, min, has_max, max](const CTYPE_IN val_in) {
CTYPE_OUT val_out = static_cast<CTYPE_OUT>(val_in);
if (has_min) {
val_out = utils::max_override(val_out, min);
}
if (has_max) {
val_out = utils::min_override(val_out, max);
}
return val_out;
},
in.const_data_ptr<CTYPE_IN>(),
out.mutable_data_ptr<CTYPE_OUT>(),
in.numel());
});
ET_SWITCH_REALHB_TYPES(common_type, ctx, op_name, CTYPE_COMMON, [&]() {
utils::apply_unitensor_elementwise_fn<CTYPE_COMMON, op_name>(
[has_min, min_opt, has_max, max_opt](const CTYPE_COMMON val_in) {
CTYPE_COMMON val_out = val_in;
if (has_min) {
val_out = utils::max_override(
val_out, utils::scalar_to<CTYPE_COMMON>(min_opt.value()));
}
if (has_max) {
val_out = utils::min_override(
val_out, utils::scalar_to<CTYPE_COMMON>(max_opt.value()));
}
return val_out;
},
in,
utils::SupportedTensorDtypes::REALHBBF16,
out,
utils::SupportedTensorDtypes::REALHBBF16);
});

return out;
Expand Down
110 changes: 109 additions & 1 deletion kernels/portable/cpu/util/elementwise_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,33 @@ namespace executor {
namespace native {
namespace utils {

/*
* Convert Scalar to C++ type
*/

template <typename T>
T scalar_to(const Scalar& s) {
if (s.isBoolean()) {
return static_cast<T>(s.to<bool>());
} else if (s.isFloatingPoint()) {
return static_cast<T>(s.to<double>());
} else {
return static_cast<T>(s.to<int64_t>());
}
}

template <>
inline double scalar_to<double>(const Scalar& s) {
return s.isFloatingPoint() ? s.to<double>()
: static_cast<double>(s.to<int64_t>());
}

template <>
inline int64_t scalar_to<int64_t>(const Scalar& s) {
return s.isFloatingPoint() ? static_cast<int64_t>(s.to<double>())
: s.to<int64_t>();
}

namespace internal {

template <typename To, typename From>
Expand Down Expand Up @@ -139,6 +166,86 @@ store_common_to_tensor_fn<CTYPE_COMMON> get_store_common_to_tensor_fn(

} // namespace internal

template <typename CTYPE_COMMON, const char* op_name, typename Op>
inline void apply_unitensor_elementwise_fn(
const Op& compute_fun,
const Tensor& a,
SupportedTensorDtypes a_dtypes,
const Tensor& out,
SupportedTensorDtypes out_dtypes) {
const auto load_a_to_common =
internal::get_load_to_common_fn<CTYPE_COMMON, op_name>(a, a_dtypes);
const auto store_common_to_out =
internal::get_store_common_to_tensor_fn<CTYPE_COMMON, op_name>(
out, out_dtypes);
const char* const data_a = reinterpret_cast<const char*>(a.const_data_ptr());
const auto a_element_size = a.element_size();
const auto out_element_size = out.element_size();
char* const data_out = reinterpret_cast<char*>(out.mutable_data_ptr());

auto out_numel = out.numel();
for (size_t i = 0; i < out_numel; ++i) {
auto result = compute_fun(load_a_to_common(&data_a[i * a_element_size]));
store_common_to_out(result, &data_out[i * out_element_size]);
}
}

/**
* Useful for bi-tensor elementwise operators. For each element of the inputs,
* perform a computation and write to the corresponding element of the output.
* Tensor broadcasting is applied wherever it is required.
*/
template <typename CTYPE_COMMON, const char* op_name, typename Op>
inline void apply_bitensor_elementwise_fn(
const Op& compute_fun,
const Tensor& a,
SupportedTensorDtypes a_dtypes,
const Tensor& b,
SupportedTensorDtypes b_dtypes,
const Tensor& out,
SupportedTensorDtypes out_dtypes) {
const bool a_is_broadcasted = !out.sizes().equals(a.sizes());
const bool b_is_broadcasted = !out.sizes().equals(b.sizes());
const bool any_is_broadcasted = (a_is_broadcasted || b_is_broadcasted);

const auto load_a_to_common =
internal::get_load_to_common_fn<CTYPE_COMMON, op_name>(a, a_dtypes);
const auto load_b_to_common =
internal::get_load_to_common_fn<CTYPE_COMMON, op_name>(b, b_dtypes);
const auto store_common_to_out =
internal::get_store_common_to_tensor_fn<CTYPE_COMMON, op_name>(
out, out_dtypes);
const char* const data_a = reinterpret_cast<const char*>(a.const_data_ptr());
const char* const data_b = reinterpret_cast<const char*>(b.const_data_ptr());
const auto a_element_size = a.element_size();
const auto b_element_size = b.element_size();
const auto out_element_size = out.element_size();
char* const data_out = reinterpret_cast<char*>(out.mutable_data_ptr());

auto out_numel = out.numel();
for (size_t i = 0; i < out_numel; ++i) {
size_t a_linear_index = i;
size_t b_linear_index = i;

if (any_is_broadcasted) {
size_t out_indexes[kTensorDimensionLimit];
delinearize_index(i, out, out_indexes, kTensorDimensionLimit);

if (a_is_broadcasted) {
a_linear_index = linearize_access_indexes(out_indexes, out.dim(), a);
}
if (b_is_broadcasted) {
b_linear_index = linearize_access_indexes(out_indexes, out.dim(), b);
}
}

auto result = compute_fun(
load_a_to_common(&data_a[a_linear_index * a_element_size]),
load_b_to_common(&data_b[b_linear_index * b_element_size]));
store_common_to_out(result, &data_out[i * out_element_size]);
}
}

/**
* Useful for tri-tensor elementwise operators. For each element of the inputs,
* perform a computation and write to the corresponding element of the output.
Expand Down Expand Up @@ -194,7 +301,8 @@ inline void apply_tritensor_elementwise_fn(
const auto out_element_size = out.element_size();
char* const data_out = reinterpret_cast<char*>(out.mutable_data_ptr());

for (size_t i = 0; i < out.numel(); ++i) {
auto out_numel = out.numel();
for (size_t i = 0; i < out_numel; ++i) {
size_t a_linear_index = i;
size_t b_linear_index = i;
size_t c_linear_index = i;
Expand Down
Loading
Loading