Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 20 additions & 23 deletions kernels/portable/cpu/op_clamp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,9 @@ using Tensor = executorch::aten::Tensor;

namespace {

template <typename CTYPE_VAL, typename CTYPE_OUT, typename CTYPE_CAST>
template <typename CTYPE_OUT, typename CTYPE_CAST>
/** Check if val, when cast to CTYPE_CAST, is not in the range of CTYPE_OUT */
bool is_out_of_bounds(CTYPE_VAL val) {
const CTYPE_CAST val_cast = static_cast<CTYPE_CAST>(val);
bool is_out_of_bounds(CTYPE_CAST val_cast) {
return val_cast < std::numeric_limits<CTYPE_OUT>::lowest() ||
val_cast > std::numeric_limits<CTYPE_OUT>::max();
}
Expand All @@ -41,26 +40,24 @@ ET_NODISCARD bool check_bounds(
const char* val_name) {
auto is_valid = true;

ET_SWITCH_SCALAR_OBJ_TYPES(val_type, ctx, "clamp.out", CTYPE_VAL, [&]() {
CTYPE_VAL val = 0;
utils::extract_scalar(val_scalar, &val);
if (isIntegralType(out_type, /*includeBool=*/false)) {
ET_SWITCH_INT_TYPES(out_type, ctx, "clamp.out", CTYPE_OUT, [&]() {
if (is_out_of_bounds<CTYPE_VAL, CTYPE_OUT, long>(val)) {
ET_LOG(Error, "%s value out of bounds", val_name);
is_valid = false;
}
});
} else if (isFloatingType(out_type)) {
ET_SWITCH_FLOATH_TYPES(out_type, ctx, "clamp", CTYPE_OUT, [&]() {
if (std::isfinite(val) &&
is_out_of_bounds<CTYPE_VAL, CTYPE_OUT, double>(val)) {
ET_LOG(Error, "%s value out of bounds", val_name);
is_valid = false;
}
});
}
});
if (isIntegralType(out_type, /*includeBool=*/false)) {
const long val_long = utils::scalar_to<long>(val_scalar);
ET_SWITCH_INT_TYPES(out_type, ctx, "clamp.out", CTYPE_OUT, [&]() {
if (is_out_of_bounds<CTYPE_OUT, long>(val_long)) {
ET_LOG(Error, "%s value out of bounds", val_name);
is_valid = false;
}
});
} else if (isFloatingType(out_type)) {
ET_SWITCH_FLOATHBF16_TYPES(out_type, ctx, "clamp.out", CTYPE_OUT, [&]() {
const double val_double = utils::scalar_to<double>(val_scalar);
if (std::isfinite(val_double) &&
is_out_of_bounds<CTYPE_OUT, double>(val_double)) {
ET_LOG(Error, "%s value out of bounds", val_name);
is_valid = false;
}
});
}

return is_valid;
}
Expand Down
11 changes: 2 additions & 9 deletions kernels/portable/cpu/op_constant_pad_nd.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -183,17 +183,10 @@ Tensor& constant_pad_nd_out(
"Failed to resize output tensor.");

ScalarType in_type = in.scalar_type();
ScalarType value_type = utils::get_scalar_dtype(value);

ET_SWITCH_REALHBBF16_TYPES(in_type, ctx, "constant_pad_nd.out", CTYPE, [&]() {
CTYPE value_v;
ET_SWITCH_SCALAR_OBJ_TYPES(
value_type, ctx, "constant_pad_nd.out", CTYPE_VALUE, [&]() {
CTYPE_VALUE val = 0;
utils::extract_scalar(value, &val);
value_v = static_cast<CTYPE>(val);
});
constant_pad_nd_out_impl<CTYPE>(in, pad, value_v, out);
const CTYPE value_casted = utils::scalar_to<CTYPE>(value);
constant_pad_nd_out_impl<CTYPE>(in, pad, value_casted, out);
});

return out;
Expand Down
8 changes: 1 addition & 7 deletions kernels/portable/cpu/op_fill.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@ Tensor& fill_scalar_out(
(void)ctx;

ScalarType a_type = a.scalar_type();
ScalarType b_type = utils::get_scalar_dtype(b);
ScalarType out_type = out.scalar_type();

ET_KERNEL_CHECK(ctx, a_type == out_type, InvalidArgument, out);
Expand All @@ -43,12 +42,7 @@ Tensor& fill_scalar_out(
"Failed to resize output tensor.");

ET_SWITCH_REALHBBF16_TYPES(a_type, ctx, "fill.Scalar_out", CTYPE_A, [&] {
CTYPE_A b_casted;
ET_SWITCH_SCALAR_OBJ_TYPES(b_type, ctx, "fill.Scalar_out", CTYPE_B, [&] {
CTYPE_B b_val = 0;
utils::extract_scalar(b, &b_val);
b_casted = static_cast<CTYPE_A>(b_val);
});
const CTYPE_A b_casted = utils::scalar_to<CTYPE_A>(b);

apply_unary_map_fn(
[b_casted](const CTYPE_A val_a) { return b_casted; },
Expand Down
19 changes: 6 additions & 13 deletions kernels/portable/cpu/op_full.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@ Tensor& full_out(
Tensor& out) {
(void)ctx;

ScalarType val_type = utils::get_scalar_dtype(fill_value);
ScalarType out_type = out.scalar_type();

// Resize for dynamic shape
Expand All @@ -37,18 +36,12 @@ Tensor& full_out(

constexpr auto name = "full.out";

ET_SWITCH_SCALAR_OBJ_TYPES(val_type, ctx, name, CTYPE_VAL, [&] {
CTYPE_VAL val;
ET_KERNEL_CHECK(
ctx, utils::extract_scalar(fill_value, &val), InvalidArgument, );

ET_SWITCH_REALHBBF16_TYPES(out_type, ctx, name, CTYPE_OUT, [&] {
CTYPE_OUT val_casted = static_cast<CTYPE_OUT>(val);
auto data_out = out.mutable_data_ptr<CTYPE_OUT>();
for (const auto i : c10::irange(out.numel())) {
data_out[i] = val_casted;
}
});
ET_SWITCH_REALHBBF16_TYPES(out_type, ctx, name, CTYPE_OUT, [&] {
CTYPE_OUT val_casted = utils::scalar_to<CTYPE_OUT>(fill_value);
auto data_out = out.mutable_data_ptr<CTYPE_OUT>();
for (const auto i : c10::irange(out.numel())) {
data_out[i] = val_casted;
}
});

return out;
Expand Down
17 changes: 2 additions & 15 deletions kernels/portable/cpu/op_hardtanh.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -40,26 +40,13 @@ Tensor& hardtanh_out(
ctx, tensors_have_same_dim_order(in, out), InvalidArgument, out);

ScalarType in_type = in.scalar_type();
ScalarType min_type = utils::get_scalar_dtype(min);
ScalarType max_type = utils::get_scalar_dtype(max);
ScalarType out_type = out.scalar_type();

ET_KERNEL_CHECK(ctx, in_type == out_type, InvalidArgument, out);

ET_SWITCH_REALHBF16_TYPES(in_type, ctx, "hardtanh.out", CTYPE, [&]() {
CTYPE min_casted;
ET_SWITCH_SCALAR_OBJ_TYPES(min_type, ctx, "hardtanh.out", CTYPE_MIN, [&]() {
CTYPE_MIN min_val = 0;
utils::extract_scalar(min, &min_val);
min_casted = static_cast<CTYPE>(min_val);
});

CTYPE max_casted;
ET_SWITCH_SCALAR_OBJ_TYPES(max_type, ctx, "hardtanh.out", CTYPE_MAX, [&]() {
CTYPE_MAX max_val = 0;
utils::extract_scalar(max, &max_val);
max_casted = static_cast<CTYPE>(max_val);
});
const CTYPE min_casted = utils::scalar_to<CTYPE>(min);
const CTYPE max_casted = utils::scalar_to<CTYPE>(max);

apply_unary_map_fn(
[min_casted, max_casted](const CTYPE val_in) {
Expand Down
9 changes: 1 addition & 8 deletions kernels/portable/cpu/op_leaky_relu.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -39,19 +39,12 @@ Tensor& leaky_relu_out(
ctx, tensors_have_same_dim_order(in, out), InvalidArgument, out);

ScalarType in_type = in.scalar_type();
ScalarType sc_type = utils::get_scalar_dtype(negative_slope);
ScalarType out_type = out.scalar_type();

ET_KERNEL_CHECK(ctx, in_type == out_type, InvalidArgument, out);

ET_SWITCH_FLOATHBF16_TYPES(in_type, ctx, "leaky_relu.out", CTYPE, [&]() {
CTYPE negative_slope_casted = 0;
ET_SWITCH_SCALAR_OBJ_TYPES(
sc_type, ctx, "leaky_relu.out", CTYPE_MIN, [&]() {
CTYPE_MIN negative_slope_val = 0;
utils::extract_scalar(negative_slope, &negative_slope_val);
negative_slope_casted = static_cast<CTYPE>(negative_slope_val);
});
const CTYPE negative_slope_casted = utils::scalar_to<CTYPE>(negative_slope);

apply_unary_map_fn(
[negative_slope_casted](const CTYPE val_in) {
Expand Down
20 changes: 11 additions & 9 deletions kernels/portable/cpu/op_scalar_tensor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,19 +20,21 @@ scalar_tensor_out(KernelRuntimeContext& ctx, const Scalar& s, Tensor& out) {
ET_KERNEL_CHECK(
ctx, resize_tensor(out, {}) == Error::Ok, InvalidArgument, out);

ScalarType s_type = utils::get_scalar_dtype(s);
ScalarType out_type = out.scalar_type();

constexpr auto name = "scalar_tensor.out";

ET_SWITCH_REAL_TYPES_AND3(
Half, Bool, BFloat16, out_type, ctx, name, CTYPE, [&]() {
ET_SWITCH_SCALAR_OBJ_TYPES(s_type, ctx, name, CTYPE_S, [&]() {
CTYPE_S val_s = 0;
utils::extract_scalar(s, &val_s);
out.mutable_data_ptr<CTYPE>()[0] = convert<CTYPE, CTYPE_S>(val_s);
});
});
if (s.isFloatingPoint() &&
executorch::runtime::isIntegralType(out_type, false)) {
ET_SWITCH_INT_TYPES(out_type, ctx, name, CTYPE, [&]() {
out.mutable_data_ptr<CTYPE>()[0] =
static_cast<CTYPE>(utils::scalar_to<int64_t>(s));
});
} else {
ET_SWITCH_REALHBBF16_TYPES(out_type, ctx, name, CTYPE, [&]() {
out.mutable_data_ptr<CTYPE>()[0] = utils::scalar_to<CTYPE>(s);
});
}

return out;
}
Expand Down
12 changes: 3 additions & 9 deletions kernels/portable/cpu/op_scatter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -151,17 +151,11 @@ Tensor& scatter_value_out(
ET_KERNEL_CHECK(
ctx, resize_tensor(out, in.sizes()) == Error::Ok, InvalidArgument, out);

ScalarType val_type = utils::get_scalar_dtype(value);

constexpr auto name = "scatter.value_out";

ET_SWITCH_SCALAR_OBJ_TYPES(val_type, ctx, name, CTYPE_VAL, [&] {
CTYPE_VAL val;
ET_KERNEL_CHECK(ctx, utils::extract_scalar(value, &val), InvalidArgument, );

ET_SWITCH_REALHBBF16_TYPES(in.scalar_type(), ctx, name, CTYPE, [&]() {
scatter_value_helper<CTYPE>(in, dim, index, val, out);
});
ET_SWITCH_REALHBBF16_TYPES(in.scalar_type(), ctx, name, CTYPE, [&]() {
const CTYPE val = utils::scalar_to<CTYPE>(value);
scatter_value_helper<CTYPE>(in, dim, index, val, out);
});

return out;
Expand Down
7 changes: 1 addition & 6 deletions kernels/portable/cpu/op_var.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -127,12 +127,7 @@ Tensor& var_correction_out(

double correction_val = 1;
if (correction.has_value()) {
ScalarType corr_type = utils::get_scalar_dtype(correction.value());
ET_SWITCH_SCALAR_OBJ_TYPES(corr_type, ctx, name, CTYPE_CORR, [&]() {
CTYPE_CORR corr_val = 0;
utils::extract_scalar(correction.value(), &corr_val);
correction_val = static_cast<double>(corr_val);
});
correction_val = utils::scalar_to<double>(correction.value());
}

const size_t num = get_reduced_dim_product(in, dim_list);
Expand Down
Loading