Skip to content

Commit 77a7ace

Browse files
[EE/BE][ET][Portable] Eliminate usage of ET_SWITCH_SCALAR_OBJ_TYPES in portable kernels (#12040)
This PR was created by the merge bot to help merge the original PR into the main branch. ghstack PR number: #12010 by @manuelcandales ^ Please use this as the source of truth for the PR details, comments, and reviews ghstack PR base: https://github.com/pytorch/executorch/tree/gh/manuelcandales/117/base ghstack PR head: https://github.com/pytorch/executorch/tree/gh/manuelcandales/117/head Merge bot PR base: https://github.com/pytorch/executorch/tree/gh/manuelcandales/116/orig Merge bot PR head: https://github.com/pytorch/executorch/tree/gh/manuelcandales/117/orig @diff-train-skip-merge --------- Co-authored-by: Manuel Candales <[email protected]> Co-authored-by: Manuel Candales <[email protected]>
1 parent e402655 commit 77a7ace

File tree

9 files changed

+47
-99
lines changed

9 files changed

+47
-99
lines changed

kernels/portable/cpu/op_clamp.cpp

Lines changed: 20 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -26,10 +26,9 @@ using Tensor = executorch::aten::Tensor;
2626

2727
namespace {
2828

29-
template <typename CTYPE_VAL, typename CTYPE_OUT, typename CTYPE_CAST>
29+
template <typename CTYPE_OUT, typename CTYPE_CAST>
3030
/** Check if val, when cast to CTYPE_CAST, is not in the range of CTYPE_OUT */
31-
bool is_out_of_bounds(CTYPE_VAL val) {
32-
const CTYPE_CAST val_cast = static_cast<CTYPE_CAST>(val);
31+
bool is_out_of_bounds(CTYPE_CAST val_cast) {
3332
return val_cast < std::numeric_limits<CTYPE_OUT>::lowest() ||
3433
val_cast > std::numeric_limits<CTYPE_OUT>::max();
3534
}
@@ -41,26 +40,24 @@ ET_NODISCARD bool check_bounds(
4140
const char* val_name) {
4241
auto is_valid = true;
4342

44-
ET_SWITCH_SCALAR_OBJ_TYPES(val_type, ctx, "clamp.out", CTYPE_VAL, [&]() {
45-
CTYPE_VAL val = 0;
46-
utils::extract_scalar(val_scalar, &val);
47-
if (isIntegralType(out_type, /*includeBool=*/false)) {
48-
ET_SWITCH_INT_TYPES(out_type, ctx, "clamp.out", CTYPE_OUT, [&]() {
49-
if (is_out_of_bounds<CTYPE_VAL, CTYPE_OUT, long>(val)) {
50-
ET_LOG(Error, "%s value out of bounds", val_name);
51-
is_valid = false;
52-
}
53-
});
54-
} else if (isFloatingType(out_type)) {
55-
ET_SWITCH_FLOATH_TYPES(out_type, ctx, "clamp", CTYPE_OUT, [&]() {
56-
if (std::isfinite(val) &&
57-
is_out_of_bounds<CTYPE_VAL, CTYPE_OUT, double>(val)) {
58-
ET_LOG(Error, "%s value out of bounds", val_name);
59-
is_valid = false;
60-
}
61-
});
62-
}
63-
});
43+
if (isIntegralType(out_type, /*includeBool=*/false)) {
44+
const long val_long = utils::scalar_to<long>(val_scalar);
45+
ET_SWITCH_INT_TYPES(out_type, ctx, "clamp.out", CTYPE_OUT, [&]() {
46+
if (is_out_of_bounds<CTYPE_OUT, long>(val_long)) {
47+
ET_LOG(Error, "%s value out of bounds", val_name);
48+
is_valid = false;
49+
}
50+
});
51+
} else if (isFloatingType(out_type)) {
52+
ET_SWITCH_FLOATHBF16_TYPES(out_type, ctx, "clamp.out", CTYPE_OUT, [&]() {
53+
const double val_double = utils::scalar_to<double>(val_scalar);
54+
if (std::isfinite(val_double) &&
55+
is_out_of_bounds<CTYPE_OUT, double>(val_double)) {
56+
ET_LOG(Error, "%s value out of bounds", val_name);
57+
is_valid = false;
58+
}
59+
});
60+
}
6461

6562
return is_valid;
6663
}

kernels/portable/cpu/op_constant_pad_nd.cpp

Lines changed: 2 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -183,17 +183,10 @@ Tensor& constant_pad_nd_out(
183183
"Failed to resize output tensor.");
184184

185185
ScalarType in_type = in.scalar_type();
186-
ScalarType value_type = utils::get_scalar_dtype(value);
187186

188187
ET_SWITCH_REALHBBF16_TYPES(in_type, ctx, "constant_pad_nd.out", CTYPE, [&]() {
189-
CTYPE value_v;
190-
ET_SWITCH_SCALAR_OBJ_TYPES(
191-
value_type, ctx, "constant_pad_nd.out", CTYPE_VALUE, [&]() {
192-
CTYPE_VALUE val = 0;
193-
utils::extract_scalar(value, &val);
194-
value_v = static_cast<CTYPE>(val);
195-
});
196-
constant_pad_nd_out_impl<CTYPE>(in, pad, value_v, out);
188+
const CTYPE value_casted = utils::scalar_to<CTYPE>(value);
189+
constant_pad_nd_out_impl<CTYPE>(in, pad, value_casted, out);
197190
});
198191

199192
return out;

kernels/portable/cpu/op_fill.cpp

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,6 @@ Tensor& fill_scalar_out(
2626
(void)ctx;
2727

2828
ScalarType a_type = a.scalar_type();
29-
ScalarType b_type = utils::get_scalar_dtype(b);
3029
ScalarType out_type = out.scalar_type();
3130

3231
ET_KERNEL_CHECK(ctx, a_type == out_type, InvalidArgument, out);
@@ -43,12 +42,7 @@ Tensor& fill_scalar_out(
4342
"Failed to resize output tensor.");
4443

4544
ET_SWITCH_REALHBBF16_TYPES(a_type, ctx, "fill.Scalar_out", CTYPE_A, [&] {
46-
CTYPE_A b_casted;
47-
ET_SWITCH_SCALAR_OBJ_TYPES(b_type, ctx, "fill.Scalar_out", CTYPE_B, [&] {
48-
CTYPE_B b_val = 0;
49-
utils::extract_scalar(b, &b_val);
50-
b_casted = static_cast<CTYPE_A>(b_val);
51-
});
45+
const CTYPE_A b_casted = utils::scalar_to<CTYPE_A>(b);
5246

5347
apply_unary_map_fn(
5448
[b_casted](const CTYPE_A val_a) { return b_casted; },

kernels/portable/cpu/op_full.cpp

Lines changed: 6 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,6 @@ Tensor& full_out(
2424
Tensor& out) {
2525
(void)ctx;
2626

27-
ScalarType val_type = utils::get_scalar_dtype(fill_value);
2827
ScalarType out_type = out.scalar_type();
2928

3029
// Resize for dynamic shape
@@ -37,18 +36,12 @@ Tensor& full_out(
3736

3837
constexpr auto name = "full.out";
3938

40-
ET_SWITCH_SCALAR_OBJ_TYPES(val_type, ctx, name, CTYPE_VAL, [&] {
41-
CTYPE_VAL val;
42-
ET_KERNEL_CHECK(
43-
ctx, utils::extract_scalar(fill_value, &val), InvalidArgument, );
44-
45-
ET_SWITCH_REALHBBF16_TYPES(out_type, ctx, name, CTYPE_OUT, [&] {
46-
CTYPE_OUT val_casted = static_cast<CTYPE_OUT>(val);
47-
auto data_out = out.mutable_data_ptr<CTYPE_OUT>();
48-
for (const auto i : c10::irange(out.numel())) {
49-
data_out[i] = val_casted;
50-
}
51-
});
39+
ET_SWITCH_REALHBBF16_TYPES(out_type, ctx, name, CTYPE_OUT, [&] {
40+
CTYPE_OUT val_casted = utils::scalar_to<CTYPE_OUT>(fill_value);
41+
auto data_out = out.mutable_data_ptr<CTYPE_OUT>();
42+
for (const auto i : c10::irange(out.numel())) {
43+
data_out[i] = val_casted;
44+
}
5245
});
5346

5447
return out;

kernels/portable/cpu/op_hardtanh.cpp

Lines changed: 2 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -40,26 +40,13 @@ Tensor& hardtanh_out(
4040
ctx, tensors_have_same_dim_order(in, out), InvalidArgument, out);
4141

4242
ScalarType in_type = in.scalar_type();
43-
ScalarType min_type = utils::get_scalar_dtype(min);
44-
ScalarType max_type = utils::get_scalar_dtype(max);
4543
ScalarType out_type = out.scalar_type();
4644

4745
ET_KERNEL_CHECK(ctx, in_type == out_type, InvalidArgument, out);
4846

4947
ET_SWITCH_REALHBF16_TYPES(in_type, ctx, "hardtanh.out", CTYPE, [&]() {
50-
CTYPE min_casted;
51-
ET_SWITCH_SCALAR_OBJ_TYPES(min_type, ctx, "hardtanh.out", CTYPE_MIN, [&]() {
52-
CTYPE_MIN min_val = 0;
53-
utils::extract_scalar(min, &min_val);
54-
min_casted = static_cast<CTYPE>(min_val);
55-
});
56-
57-
CTYPE max_casted;
58-
ET_SWITCH_SCALAR_OBJ_TYPES(max_type, ctx, "hardtanh.out", CTYPE_MAX, [&]() {
59-
CTYPE_MAX max_val = 0;
60-
utils::extract_scalar(max, &max_val);
61-
max_casted = static_cast<CTYPE>(max_val);
62-
});
48+
const CTYPE min_casted = utils::scalar_to<CTYPE>(min);
49+
const CTYPE max_casted = utils::scalar_to<CTYPE>(max);
6350

6451
apply_unary_map_fn(
6552
[min_casted, max_casted](const CTYPE val_in) {

kernels/portable/cpu/op_leaky_relu.cpp

Lines changed: 1 addition & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -39,19 +39,12 @@ Tensor& leaky_relu_out(
3939
ctx, tensors_have_same_dim_order(in, out), InvalidArgument, out);
4040

4141
ScalarType in_type = in.scalar_type();
42-
ScalarType sc_type = utils::get_scalar_dtype(negative_slope);
4342
ScalarType out_type = out.scalar_type();
4443

4544
ET_KERNEL_CHECK(ctx, in_type == out_type, InvalidArgument, out);
4645

4746
ET_SWITCH_FLOATHBF16_TYPES(in_type, ctx, "leaky_relu.out", CTYPE, [&]() {
48-
CTYPE negative_slope_casted = 0;
49-
ET_SWITCH_SCALAR_OBJ_TYPES(
50-
sc_type, ctx, "leaky_relu.out", CTYPE_MIN, [&]() {
51-
CTYPE_MIN negative_slope_val = 0;
52-
utils::extract_scalar(negative_slope, &negative_slope_val);
53-
negative_slope_casted = static_cast<CTYPE>(negative_slope_val);
54-
});
47+
const CTYPE negative_slope_casted = utils::scalar_to<CTYPE>(negative_slope);
5548

5649
apply_unary_map_fn(
5750
[negative_slope_casted](const CTYPE val_in) {

kernels/portable/cpu/op_scalar_tensor.cpp

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -20,19 +20,21 @@ scalar_tensor_out(KernelRuntimeContext& ctx, const Scalar& s, Tensor& out) {
2020
ET_KERNEL_CHECK(
2121
ctx, resize_tensor(out, {}) == Error::Ok, InvalidArgument, out);
2222

23-
ScalarType s_type = utils::get_scalar_dtype(s);
2423
ScalarType out_type = out.scalar_type();
2524

2625
constexpr auto name = "scalar_tensor.out";
2726

28-
ET_SWITCH_REAL_TYPES_AND3(
29-
Half, Bool, BFloat16, out_type, ctx, name, CTYPE, [&]() {
30-
ET_SWITCH_SCALAR_OBJ_TYPES(s_type, ctx, name, CTYPE_S, [&]() {
31-
CTYPE_S val_s = 0;
32-
utils::extract_scalar(s, &val_s);
33-
out.mutable_data_ptr<CTYPE>()[0] = convert<CTYPE, CTYPE_S>(val_s);
34-
});
35-
});
27+
if (s.isFloatingPoint() &&
28+
executorch::runtime::isIntegralType(out_type, false)) {
29+
ET_SWITCH_INT_TYPES(out_type, ctx, name, CTYPE, [&]() {
30+
out.mutable_data_ptr<CTYPE>()[0] =
31+
static_cast<CTYPE>(utils::scalar_to<int64_t>(s));
32+
});
33+
} else {
34+
ET_SWITCH_REALHBBF16_TYPES(out_type, ctx, name, CTYPE, [&]() {
35+
out.mutable_data_ptr<CTYPE>()[0] = utils::scalar_to<CTYPE>(s);
36+
});
37+
}
3638

3739
return out;
3840
}

kernels/portable/cpu/op_scatter.cpp

Lines changed: 3 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -151,17 +151,11 @@ Tensor& scatter_value_out(
151151
ET_KERNEL_CHECK(
152152
ctx, resize_tensor(out, in.sizes()) == Error::Ok, InvalidArgument, out);
153153

154-
ScalarType val_type = utils::get_scalar_dtype(value);
155-
156154
constexpr auto name = "scatter.value_out";
157155

158-
ET_SWITCH_SCALAR_OBJ_TYPES(val_type, ctx, name, CTYPE_VAL, [&] {
159-
CTYPE_VAL val;
160-
ET_KERNEL_CHECK(ctx, utils::extract_scalar(value, &val), InvalidArgument, );
161-
162-
ET_SWITCH_REALHBBF16_TYPES(in.scalar_type(), ctx, name, CTYPE, [&]() {
163-
scatter_value_helper<CTYPE>(in, dim, index, val, out);
164-
});
156+
ET_SWITCH_REALHBBF16_TYPES(in.scalar_type(), ctx, name, CTYPE, [&]() {
157+
const CTYPE val = utils::scalar_to<CTYPE>(value);
158+
scatter_value_helper<CTYPE>(in, dim, index, val, out);
165159
});
166160

167161
return out;

kernels/portable/cpu/op_var.cpp

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -127,12 +127,7 @@ Tensor& var_correction_out(
127127

128128
double correction_val = 1;
129129
if (correction.has_value()) {
130-
ScalarType corr_type = utils::get_scalar_dtype(correction.value());
131-
ET_SWITCH_SCALAR_OBJ_TYPES(corr_type, ctx, name, CTYPE_CORR, [&]() {
132-
CTYPE_CORR corr_val = 0;
133-
utils::extract_scalar(correction.value(), &corr_val);
134-
correction_val = static_cast<double>(corr_val);
135-
});
130+
correction_val = utils::scalar_to<double>(correction.value());
136131
}
137132

138133
const size_t num = get_reduced_dim_product(in, dim_list);

0 commit comments

Comments
 (0)