Skip to content

Commit 3dac421

Browse files
Non-fatal error when ET_SWITCH encounters unsupported dtype
Differential Revision: D80141272 Pull Request resolved: #13359
1 parent 624b38e commit 3dac421

File tree

24 files changed

+348
-169
lines changed

24 files changed

+348
-169
lines changed

backends/cadence/fusion_g3/operators/op_clamp.cpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@ bool is_out_of_bounds(CTYPE_VAL val) {
4545
}
4646

4747
ET_NODISCARD bool check_bounds(
48+
KernelRuntimeContext& ctx,
4849
const Scalar& val_scalar,
4950
const ScalarType& val_type,
5051
const ScalarType& out_type,
@@ -107,14 +108,14 @@ Tensor& clamp_out(
107108
if (has_min) {
108109
ET_KERNEL_CHECK(
109110
ctx,
110-
check_bounds(min_opt.value(), min_type, out_type, "minimum"),
111+
check_bounds(ctx, min_opt.value(), min_type, out_type, "minimum"),
111112
InvalidArgument,
112113
out);
113114
}
114115
if (has_max) {
115116
ET_KERNEL_CHECK(
116117
ctx,
117-
check_bounds(max_opt.value(), max_type, out_type, "maximum"),
118+
check_bounds(ctx, max_opt.value(), max_type, out_type, "maximum"),
118119
InvalidArgument,
119120
out);
120121
}

extension/apple/ExecuTorch/Exported/ExecuTorchTensor.mm

Lines changed: 21 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -265,9 +265,15 @@ - (NSString *)description {
265265
auto const count = _tensor->numel();
266266
os << "\n count: " << count << ",";
267267
os << "\n scalars: [";
268+
// Create a minimal context for error handling in ET_SWITCH
269+
struct {
270+
[[noreturn]] void fail(torch::executor::Error /* error */) {
271+
ET_CHECK_MSG(false, "Unsupported dtype in description");
272+
}
273+
} ctx;
268274
ET_SWITCH_REALHBBF16_TYPES(
269275
static_cast<ScalarType>(_tensor->scalar_type()),
270-
nullptr,
276+
ctx,
271277
"description",
272278
CTYPE,
273279
[&] {
@@ -488,9 +494,15 @@ - (instancetype)initWithScalars:(NSArray<NSNumber *> *)scalars
488494
"Number of scalars does not match the shape");
489495
std::vector<uint8_t> data;
490496
data.resize(count * ExecuTorchSizeOfDataType(dataType));
497+
// Create a minimal context for error handling in ET_SWITCH
498+
struct {
499+
[[noreturn]] void fail(torch::executor::Error /* error */) {
500+
ET_CHECK_MSG(false, "Unsupported dtype in initWithScalars");
501+
}
502+
} ctx;
491503
for (NSUInteger index = 0; index < count; ++index) {
492504
ET_SWITCH_REALHBBF16_AND_UINT_TYPES(
493-
static_cast<ScalarType>(dataType), nil, "initWithScalars", CTYPE, [&] {
505+
static_cast<ScalarType>(dataType), ctx, "initWithScalars", CTYPE, [&] {
494506
reinterpret_cast<CTYPE *>(data.data())[index] = utils::toType<CTYPE>(scalars[index]);
495507
}
496508
);
@@ -801,8 +813,14 @@ + (instancetype)fullTensorWithShape:(NSArray<NSNumber *> *)shape
801813
dataType:(ExecuTorchDataType)dataType
802814
shapeDynamism:(ExecuTorchShapeDynamism)shapeDynamism {
803815
Scalar fillValue;
816+
// Create a minimal context for error handling in ET_SWITCH
817+
struct {
818+
[[noreturn]] void fail(torch::executor::Error /* error */) {
819+
ET_CHECK_MSG(false, "Unsupported dtype in fullTensor");
820+
}
821+
} ctx;
804822
ET_SWITCH_REALHBBF16_AND_UINT_TYPES(
805-
static_cast<ScalarType>(dataType), nil, "fullTensor", CTYPE, [&] {
823+
static_cast<ScalarType>(dataType), ctx, "fullTensor", CTYPE, [&] {
806824
fillValue = utils::toType<CTYPE>(scalar);
807825
}
808826
);

extension/llm/runner/text_decoder_runner.h

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -68,12 +68,20 @@ class ET_EXPERIMENTAL TextDecoderRunner {
6868
const executorch::aten::Tensor& logits_tensor,
6969
const float temperature = 0.0f) {
7070
int32_t result = 0;
71+
72+
// Create a minimal context for error handling in ET_SWITCH
73+
struct {
74+
[[noreturn]] void fail(torch::executor::Error /* error */) {
75+
ET_CHECK_MSG(false, "Unsupported dtype in logits_to_token");
76+
}
77+
} ctx;
78+
7179
ET_SWITCH_THREE_TYPES(
7280
Float,
7381
Half,
7482
BFloat16,
7583
logits_tensor.scalar_type(),
76-
unused,
84+
ctx,
7785
"logits_to_token",
7886
CTYPE,
7987
[&]() {

extension/tensor/tensor_ptr.h

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -111,7 +111,15 @@ inline TensorPtr make_tensor_ptr(
111111
runtime::canCast(deduced_type, type),
112112
"Cannot cast deduced type to specified type.");
113113
std::vector<uint8_t> casted_data(data.size() * runtime::elementSize(type));
114-
ET_SWITCH_REALHBBF16_TYPES(type, nullptr, "make_tensor_ptr", CTYPE, [&] {
114+
115+
// Create a minimal context for error handling in ET_SWITCH
116+
struct {
117+
[[noreturn]] void fail(torch::executor::Error /* error */) {
118+
ET_CHECK_MSG(false, "Unsupported dtype in make_tensor_ptr");
119+
}
120+
} ctx;
121+
122+
ET_SWITCH_REALHBBF16_TYPES(type, ctx, "make_tensor_ptr", CTYPE, [&] {
115123
std::transform(
116124
data.begin(),
117125
data.end(),

extension/tensor/tensor_ptr_maker.cpp

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,14 @@ TensorPtr random_strided(
8989
empty_strided(std::move(sizes), std::move(strides), type, dynamism);
9090
std::default_random_engine gen{std::random_device{}()};
9191

92-
ET_SWITCH_REALHBBF16_TYPES(type, nullptr, "random_strided", CTYPE, [&] {
92+
// Create a minimal context for error handling in ET_SWITCH
93+
struct {
94+
[[noreturn]] void fail(torch::executor::Error /* error */) {
95+
ET_CHECK_MSG(false, "Unsupported dtype in random_strided");
96+
}
97+
} ctx;
98+
99+
ET_SWITCH_REALHBBF16_TYPES(type, ctx, "random_strided", CTYPE, [&] {
93100
std::generate_n(tensor->mutable_data_ptr<CTYPE>(), tensor->numel(), [&]() {
94101
return static_cast<CTYPE>(distribution(gen));
95102
});
@@ -124,7 +131,14 @@ TensorPtr full_strided(
124131
executorch::aten::TensorShapeDynamism dynamism) {
125132
auto tensor =
126133
empty_strided(std::move(sizes), std::move(strides), type, dynamism);
127-
ET_SWITCH_REALHBBF16_TYPES(type, nullptr, "full_strided", CTYPE, [&] {
134+
// Create a minimal context for error handling in ET_SWITCH
135+
struct {
136+
[[noreturn]] void fail(torch::executor::Error /* error */) {
137+
ET_CHECK_MSG(false, "Unsupported data type in full_strided");
138+
}
139+
} ctx;
140+
141+
ET_SWITCH_REALHBBF16_TYPES(type, ctx, "full_strided", CTYPE, [&] {
128142
CTYPE value;
129143
ET_EXTRACT_SCALAR(fill_value, value);
130144
std::fill(

kernels/optimized/cpu/op_add_sub_impl.h

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -144,13 +144,13 @@ Tensor& opt_add_sub_out_impl(
144144
} else if (selected_optimized_path != ElementwiseOptimizedPath::kNone) {
145145
// Cannot apply the trick of -alpha here because alpha is Scalar without
146146
// support for - operator. At least not right now.
147-
ET_SWITCH_REALB_TYPES(out_type, ctx, op_name, CTYPE, [&]() {
147+
ET_SWITCH_REALB_TYPES(out_type, ctx, op_name, CTYPE, [&]() -> void {
148148
CTYPE alpha_val;
149149
ET_KERNEL_CHECK_MSG(
150150
ctx,
151151
torch::executor::native::utils::extract_scalar(alpha, &alpha_val),
152152
InvalidArgument,
153-
out,
153+
,
154154
"Failed to extract scalar alpha.");
155155
using Vec = at::vec::Vectorized<CTYPE>;
156156
Vec alpha_val_vec(alpha_val);
@@ -164,13 +164,13 @@ Tensor& opt_add_sub_out_impl(
164164
auto add_lambda = [&alpha_val_vec](auto x, auto y) {
165165
return y - alpha_val_vec * x;
166166
};
167-
return torch::executor::handle_broadcast_elementwise<CTYPE>(
167+
torch::executor::handle_broadcast_elementwise<CTYPE>(
168168
ctx, add_lambda, a, b, out, selected_optimized_path, alpha);
169169
} else {
170170
auto add_lambda = [&alpha_val_vec](auto x, auto y) {
171171
return x - alpha_val_vec * y;
172172
};
173-
return torch::executor::handle_broadcast_elementwise<CTYPE>(
173+
torch::executor::handle_broadcast_elementwise<CTYPE>(
174174
ctx, add_lambda, a, b, out, selected_optimized_path, alpha);
175175
}
176176
} else {
@@ -191,13 +191,13 @@ Tensor& opt_add_sub_out_impl(
191191
auto add_lambda = [&alpha_val_vec](auto x, auto y) {
192192
return y + alpha_val_vec * x;
193193
};
194-
return torch::executor::handle_broadcast_elementwise<CTYPE>(
194+
torch::executor::handle_broadcast_elementwise<CTYPE>(
195195
ctx, add_lambda, a, b, out, selected_optimized_path, alpha);
196196
} else {
197197
auto add_lambda = [&alpha_val_vec](auto x, auto y) {
198198
return x + alpha_val_vec * y;
199199
};
200-
return torch::executor::handle_broadcast_elementwise<CTYPE>(
200+
torch::executor::handle_broadcast_elementwise<CTYPE>(
201201
ctx, add_lambda, a, b, out, selected_optimized_path, alpha);
202202
}
203203
}

kernels/optimized/cpu/op_div.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -130,11 +130,11 @@ Tensor& opt_div_out(
130130
selected_optimized_path ==
131131
ElementwiseOptimizedPath::kBroadcastNdByNdReverseArguments) {
132132
auto div_lambda = [](auto x, auto y) { return y / x; };
133-
return torch::executor::handle_broadcast_elementwise<CTYPE>(
133+
torch::executor::handle_broadcast_elementwise<CTYPE>(
134134
ctx, div_lambda, a, b, out, selected_optimized_path);
135135
} else {
136136
auto div_lambda = [](auto x, auto y) { return x / y; };
137-
return torch::executor::handle_broadcast_elementwise<CTYPE>(
137+
torch::executor::handle_broadcast_elementwise<CTYPE>(
138138
ctx, div_lambda, a, b, out, selected_optimized_path);
139139
}
140140
});

kernels/optimized/cpu/op_le.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@ Tensor& opt_le_tensor_out(
5757
// Handle optimized broadcast cases
5858
ET_SWITCH_REALB_TYPES(out_type, ctx, "le.Tensor_out", CTYPE, [&]() {
5959
auto le_lambda = [](auto x, auto y) { return x.le(y); };
60-
return torch::executor::handle_broadcast_elementwise<CTYPE>(
60+
torch::executor::handle_broadcast_elementwise<CTYPE>(
6161
ctx, le_lambda, a, b, out, selected_optimized_path);
6262
});
6363
} else {

kernels/optimized/cpu/op_mul.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -148,13 +148,13 @@ Tensor& opt_mul_out(
148148

149149
ET_SWITCH_COMPLEXH_TYPES(out_type, ctx, "mul.out", CTYPE, [&]() {
150150
auto mul_lambda = [](auto x, auto y) { return x * y; };
151-
return torch::executor::handle_broadcast_elementwise<CTYPE>(
151+
torch::executor::handle_broadcast_elementwise<CTYPE>(
152152
ctx, mul_lambda, a, b, out, selected_optimized_path);
153153
});
154154
} else {
155155
ET_SWITCH_REALB_TYPES(out_type, ctx, "mul.out", CTYPE, [&]() {
156156
auto mul_lambda = [](auto x, auto y) { return x * y; };
157-
return torch::executor::handle_broadcast_elementwise<CTYPE>(
157+
torch::executor::handle_broadcast_elementwise<CTYPE>(
158158
ctx, mul_lambda, a, b, out, selected_optimized_path);
159159
});
160160
}

kernels/portable/cpu/op_clamp.cpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ bool is_out_of_bounds(CTYPE_CAST val_cast) {
3434
}
3535

3636
ET_NODISCARD bool check_bounds(
37+
KernelRuntimeContext& ctx,
3738
const Scalar& val_scalar,
3839
const torch::executor::native::ScalarType& val_type,
3940
const torch::executor::native::ScalarType& out_type,
@@ -107,14 +108,14 @@ Tensor& clamp_out(
107108
if (has_min) {
108109
ET_KERNEL_CHECK(
109110
ctx,
110-
check_bounds(min_opt.value(), min_type, out_type, "minimum"),
111+
check_bounds(ctx, min_opt.value(), min_type, out_type, "minimum"),
111112
InvalidArgument,
112113
out);
113114
}
114115
if (has_max) {
115116
ET_KERNEL_CHECK(
116117
ctx,
117-
check_bounds(max_opt.value(), max_type, out_type, "maximum"),
118+
check_bounds(ctx, max_opt.value(), max_type, out_type, "maximum"),
118119
InvalidArgument,
119120
out);
120121
}

0 commit comments

Comments
 (0)