Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 20 additions & 23 deletions kernels/portable/cpu/op_clamp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,9 @@ using Tensor = executorch::aten::Tensor;

namespace {

template <typename CTYPE_VAL, typename CTYPE_OUT, typename CTYPE_CAST>
template <typename CTYPE_OUT, typename CTYPE_CAST>
/** Check if val, when cast to CTYPE_CAST, is not in the range of CTYPE_OUT */
bool is_out_of_bounds(CTYPE_VAL val) {
const CTYPE_CAST val_cast = static_cast<CTYPE_CAST>(val);
bool is_out_of_bounds(CTYPE_CAST val_cast) {
return val_cast < std::numeric_limits<CTYPE_OUT>::lowest() ||
val_cast > std::numeric_limits<CTYPE_OUT>::max();
}
Expand All @@ -41,26 +40,24 @@ ET_NODISCARD bool check_bounds(
const char* val_name) {
auto is_valid = true;

ET_SWITCH_SCALAR_OBJ_TYPES(val_type, ctx, "clamp.out", CTYPE_VAL, [&]() {
CTYPE_VAL val = 0;
utils::extract_scalar(val_scalar, &val);
if (isIntegralType(out_type, /*includeBool=*/false)) {
ET_SWITCH_INT_TYPES(out_type, ctx, "clamp.out", CTYPE_OUT, [&]() {
if (is_out_of_bounds<CTYPE_VAL, CTYPE_OUT, long>(val)) {
ET_LOG(Error, "%s value out of bounds", val_name);
is_valid = false;
}
});
} else if (isFloatingType(out_type)) {
ET_SWITCH_FLOATH_TYPES(out_type, ctx, "clamp", CTYPE_OUT, [&]() {
if (std::isfinite(val) &&
is_out_of_bounds<CTYPE_VAL, CTYPE_OUT, double>(val)) {
ET_LOG(Error, "%s value out of bounds", val_name);
is_valid = false;
}
});
}
});
if (isIntegralType(out_type, /*includeBool=*/false)) {
const long val_long = utils::scalar_to<long>(val_scalar);
ET_SWITCH_INT_TYPES(out_type, ctx, "clamp.out", CTYPE_OUT, [&]() {
if (is_out_of_bounds<CTYPE_OUT, long>(val_long)) {
ET_LOG(Error, "%s value out of bounds", val_name);
is_valid = false;
}
});
} else if (isFloatingType(out_type)) {
ET_SWITCH_FLOATHBF16_TYPES(out_type, ctx, "clamp.out", CTYPE_OUT, [&]() {
const double val_double = utils::scalar_to<double>(val_scalar);
if (std::isfinite(val_double) &&
is_out_of_bounds<CTYPE_OUT, double>(val_double)) {
ET_LOG(Error, "%s value out of bounds", val_name);
is_valid = false;
}
});
}

return is_valid;
}
Expand Down
11 changes: 2 additions & 9 deletions kernels/portable/cpu/op_constant_pad_nd.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -183,17 +183,10 @@ Tensor& constant_pad_nd_out(
"Failed to resize output tensor.");

ScalarType in_type = in.scalar_type();
ScalarType value_type = utils::get_scalar_dtype(value);

ET_SWITCH_REALHBBF16_TYPES(in_type, ctx, "constant_pad_nd.out", CTYPE, [&]() {
CTYPE value_v;
ET_SWITCH_SCALAR_OBJ_TYPES(
value_type, ctx, "constant_pad_nd.out", CTYPE_VALUE, [&]() {
CTYPE_VALUE val = 0;
utils::extract_scalar(value, &val);
value_v = static_cast<CTYPE>(val);
});
constant_pad_nd_out_impl<CTYPE>(in, pad, value_v, out);
const CTYPE value_casted = utils::scalar_to<CTYPE>(value);
constant_pad_nd_out_impl<CTYPE>(in, pad, value_casted, out);
});

return out;
Expand Down
10 changes: 3 additions & 7 deletions kernels/portable/cpu/op_fill.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@ Tensor& fill_scalar_out(
(void)ctx;

ScalarType a_type = a.scalar_type();
ScalarType b_type = utils::get_scalar_dtype(b);
ScalarType out_type = out.scalar_type();

ET_KERNEL_CHECK(ctx, a_type == out_type, InvalidArgument, out);
Expand All @@ -43,12 +42,9 @@ Tensor& fill_scalar_out(
"Failed to resize output tensor.");

ET_SWITCH_REALHBBF16_TYPES(a_type, ctx, "fill.Scalar_out", CTYPE_A, [&] {
CTYPE_A b_casted;
ET_SWITCH_SCALAR_OBJ_TYPES(b_type, ctx, "fill.Scalar_out", CTYPE_B, [&] {
CTYPE_B b_val = 0;
utils::extract_scalar(b, &b_val);
b_casted = static_cast<CTYPE_A>(b_val);
});
auto opt_b_casted = utils::internal::check_overflow_scalar_cast<CTYPE_A>(b);
ET_KERNEL_CHECK(ctx, opt_b_casted.has_value(), InvalidArgument, );
auto b_casted = opt_b_casted.value();

apply_unary_map_fn(
[b_casted](const CTYPE_A val_a) { return b_casted; },
Expand Down
19 changes: 6 additions & 13 deletions kernels/portable/cpu/op_full.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@ Tensor& full_out(
Tensor& out) {
(void)ctx;

ScalarType val_type = utils::get_scalar_dtype(fill_value);
ScalarType out_type = out.scalar_type();

// Resize for dynamic shape
Expand All @@ -37,18 +36,12 @@ Tensor& full_out(

constexpr auto name = "full.out";

ET_SWITCH_SCALAR_OBJ_TYPES(val_type, ctx, name, CTYPE_VAL, [&] {
CTYPE_VAL val;
ET_KERNEL_CHECK(
ctx, utils::extract_scalar(fill_value, &val), InvalidArgument, );

ET_SWITCH_REALHBBF16_TYPES(out_type, ctx, name, CTYPE_OUT, [&] {
CTYPE_OUT val_casted = static_cast<CTYPE_OUT>(val);
auto data_out = out.mutable_data_ptr<CTYPE_OUT>();
for (const auto i : c10::irange(out.numel())) {
data_out[i] = val_casted;
}
});
ET_SWITCH_REALHBBF16_TYPES(out_type, ctx, name, CTYPE_OUT, [&] {
CTYPE_OUT val_casted = utils::scalar_to<CTYPE_OUT>(fill_value);
auto data_out = out.mutable_data_ptr<CTYPE_OUT>();
for (const auto i : c10::irange(out.numel())) {
data_out[i] = val_casted;
}
});

return out;
Expand Down
17 changes: 2 additions & 15 deletions kernels/portable/cpu/op_hardtanh.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -40,26 +40,13 @@ Tensor& hardtanh_out(
ctx, tensors_have_same_dim_order(in, out), InvalidArgument, out);

ScalarType in_type = in.scalar_type();
ScalarType min_type = utils::get_scalar_dtype(min);
ScalarType max_type = utils::get_scalar_dtype(max);
ScalarType out_type = out.scalar_type();

ET_KERNEL_CHECK(ctx, in_type == out_type, InvalidArgument, out);

ET_SWITCH_REALHBF16_TYPES(in_type, ctx, "hardtanh.out", CTYPE, [&]() {
CTYPE min_casted;
ET_SWITCH_SCALAR_OBJ_TYPES(min_type, ctx, "hardtanh.out", CTYPE_MIN, [&]() {
CTYPE_MIN min_val = 0;
utils::extract_scalar(min, &min_val);
min_casted = static_cast<CTYPE>(min_val);
});

CTYPE max_casted;
ET_SWITCH_SCALAR_OBJ_TYPES(max_type, ctx, "hardtanh.out", CTYPE_MAX, [&]() {
CTYPE_MAX max_val = 0;
utils::extract_scalar(max, &max_val);
max_casted = static_cast<CTYPE>(max_val);
});
const CTYPE min_casted = utils::scalar_to<CTYPE>(min);
const CTYPE max_casted = utils::scalar_to<CTYPE>(max);

apply_unary_map_fn(
[min_casted, max_casted](const CTYPE val_in) {
Expand Down
9 changes: 1 addition & 8 deletions kernels/portable/cpu/op_leaky_relu.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -39,19 +39,12 @@ Tensor& leaky_relu_out(
ctx, tensors_have_same_dim_order(in, out), InvalidArgument, out);

ScalarType in_type = in.scalar_type();
ScalarType sc_type = utils::get_scalar_dtype(negative_slope);
ScalarType out_type = out.scalar_type();

ET_KERNEL_CHECK(ctx, in_type == out_type, InvalidArgument, out);

ET_SWITCH_FLOATHBF16_TYPES(in_type, ctx, "leaky_relu.out", CTYPE, [&]() {
CTYPE negative_slope_casted = 0;
ET_SWITCH_SCALAR_OBJ_TYPES(
sc_type, ctx, "leaky_relu.out", CTYPE_MIN, [&]() {
CTYPE_MIN negative_slope_val = 0;
utils::extract_scalar(negative_slope, &negative_slope_val);
negative_slope_casted = static_cast<CTYPE>(negative_slope_val);
});
const CTYPE negative_slope_casted = utils::scalar_to<CTYPE>(negative_slope);

apply_unary_map_fn(
[negative_slope_casted](const CTYPE val_in) {
Expand Down
20 changes: 11 additions & 9 deletions kernels/portable/cpu/op_scalar_tensor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,19 +20,21 @@ scalar_tensor_out(KernelRuntimeContext& ctx, const Scalar& s, Tensor& out) {
ET_KERNEL_CHECK(
ctx, resize_tensor(out, {}) == Error::Ok, InvalidArgument, out);

ScalarType s_type = utils::get_scalar_dtype(s);
ScalarType out_type = out.scalar_type();

constexpr auto name = "scalar_tensor.out";

ET_SWITCH_REAL_TYPES_AND3(
Half, Bool, BFloat16, out_type, ctx, name, CTYPE, [&]() {
ET_SWITCH_SCALAR_OBJ_TYPES(s_type, ctx, name, CTYPE_S, [&]() {
CTYPE_S val_s = 0;
utils::extract_scalar(s, &val_s);
out.mutable_data_ptr<CTYPE>()[0] = convert<CTYPE, CTYPE_S>(val_s);
});
});
if (s.isFloatingPoint() &&
executorch::runtime::isIntegralType(out_type, false)) {
ET_SWITCH_INT_TYPES(out_type, ctx, name, CTYPE, [&]() {
out.mutable_data_ptr<CTYPE>()[0] =
static_cast<CTYPE>(utils::scalar_to<int64_t>(s));
});
} else {
ET_SWITCH_REALHBBF16_TYPES(out_type, ctx, name, CTYPE, [&]() {
out.mutable_data_ptr<CTYPE>()[0] = utils::scalar_to<CTYPE>(s);
});
}

return out;
}
Expand Down
12 changes: 3 additions & 9 deletions kernels/portable/cpu/op_scatter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -151,17 +151,11 @@ Tensor& scatter_value_out(
ET_KERNEL_CHECK(
ctx, resize_tensor(out, in.sizes()) == Error::Ok, InvalidArgument, out);

ScalarType val_type = utils::get_scalar_dtype(value);

constexpr auto name = "scatter.value_out";

ET_SWITCH_SCALAR_OBJ_TYPES(val_type, ctx, name, CTYPE_VAL, [&] {
CTYPE_VAL val;
ET_KERNEL_CHECK(ctx, utils::extract_scalar(value, &val), InvalidArgument, );

ET_SWITCH_REALHBBF16_TYPES(in.scalar_type(), ctx, name, CTYPE, [&]() {
scatter_value_helper<CTYPE>(in, dim, index, val, out);
});
ET_SWITCH_REALHBBF16_TYPES(in.scalar_type(), ctx, name, CTYPE, [&]() {
const CTYPE val = utils::scalar_to<CTYPE>(value);
scatter_value_helper<CTYPE>(in, dim, index, val, out);
});

return out;
Expand Down
7 changes: 1 addition & 6 deletions kernels/portable/cpu/op_var.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -127,12 +127,7 @@ Tensor& var_correction_out(

double correction_val = 1;
if (correction.has_value()) {
ScalarType corr_type = utils::get_scalar_dtype(correction.value());
ET_SWITCH_SCALAR_OBJ_TYPES(corr_type, ctx, name, CTYPE_CORR, [&]() {
CTYPE_CORR corr_val = 0;
utils::extract_scalar(correction.value(), &corr_val);
correction_val = static_cast<double>(corr_val);
});
correction_val = utils::scalar_to<double>(correction.value());
}

const size_t num = get_reduced_dim_product(in, dim_list);
Expand Down
52 changes: 51 additions & 1 deletion kernels/portable/cpu/scalar_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,10 @@

#pragma once

#include <algorithm>
#include <cmath>
#include <limits>

#include <c10/util/overflows.h>
#include <executorch/kernels/portable/cpu/selective_build.h>
#include <executorch/runtime/core/exec_aten/exec_aten.h>
#include <executorch/runtime/core/exec_aten/util/scalar_type_util.h>
Expand Down Expand Up @@ -261,6 +261,56 @@ bool extract_scalar(Scalar scalar, BOOL_T* out_val) {
return false;
}

/*
* Convert Scalar to C++ type
*/

template <typename T>
T scalar_to(const Scalar& s) {
if (s.isBoolean()) {
return static_cast<T>(s.to<bool>());
} else if (s.isFloatingPoint()) {
return static_cast<T>(s.to<double>());
} else {
return static_cast<T>(s.to<int64_t>());
}
}

template <>
inline double scalar_to<double>(const Scalar& s) {
return s.isFloatingPoint() ? s.to<double>()
: static_cast<double>(s.to<int64_t>());
}

template <>
inline int64_t scalar_to<int64_t>(const Scalar& s) {
return s.isFloatingPoint() ? static_cast<int64_t>(s.to<double>())
: s.to<int64_t>();
}

namespace internal {

template <typename To, typename From>
std::optional<To> check_overflow_cast(From in) {
// Converting to bool can't overflow so we exclude this case from checking.
if (!std::is_same_v<To, bool> && c10::overflows<To, From>(in)) {
return std::nullopt;
}
return static_cast<To>(in);
}

template <typename To>
std::optional<To> check_overflow_scalar_cast(const Scalar& in) {
if (in.isBoolean()) {
return check_overflow_cast<To>(in.to<bool>());
} else if (in.isFloatingPoint()) {
return check_overflow_cast<To>(in.to<double>());
} else {
return check_overflow_cast<To>(in.to<int64_t>());
}
}

} // namespace internal
} // namespace utils
} // namespace native
} // namespace executor
Expand Down
29 changes: 1 addition & 28 deletions kernels/portable/cpu/util/elementwise_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
#pragma once

#include <c10/util/irange.h>
#include <executorch/kernels/portable/cpu/scalar_utils.h>
#include <executorch/kernels/portable/cpu/util/broadcast_indexes_range.h>
#include <executorch/kernels/portable/cpu/util/broadcast_util.h>
#include <executorch/kernels/portable/cpu/util/dtype_util.h>
Expand All @@ -27,34 +28,6 @@ namespace torch {
namespace executor {
namespace native {
namespace utils {

/*
* Convert Scalar to C++ type
*/

template <typename T>
T scalar_to(const Scalar& s) {
if (s.isBoolean()) {
return static_cast<T>(s.to<bool>());
} else if (s.isFloatingPoint()) {
return static_cast<T>(s.to<double>());
} else {
return static_cast<T>(s.to<int64_t>());
}
}

template <>
inline double scalar_to<double>(const Scalar& s) {
return s.isFloatingPoint() ? s.to<double>()
: static_cast<double>(s.to<int64_t>());
}

template <>
inline int64_t scalar_to<int64_t>(const Scalar& s) {
return s.isFloatingPoint() ? static_cast<int64_t>(s.to<double>())
: s.to<int64_t>();
}

namespace internal {
/**
* Causes these utility functions to make sure to respect Tensor
Expand Down
2 changes: 1 addition & 1 deletion kernels/portable/cpu/util/targets.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -116,9 +116,9 @@ def define_common_targets():
"//executorch/runtime/core/portable_type/c10/c10:aten_headers_for_executorch",
"//executorch/runtime/kernel:kernel_runtime_context",
"//executorch/extension/threadpool:threadpool",
"//executorch/kernels/portable/cpu:scalar_utils",
],
deps = [
"//executorch/kernels/portable/cpu:scalar_utils",
"//executorch/runtime/kernel:kernel_includes",
],
visibility = ["//executorch/kernels/portable/cpu/...", "//executorch/kernels/optimized/cpu/...", "@EXECUTORCH_CLIENTS"],
Expand Down
Loading
Loading