Skip to content

Commit ca35dd9

Browse files
manuelcandalesfacebook-github-bot
authored andcommitted
Non-fatal error when ET_SWITCH encounters unsupported dtype
Summary: Modify `ET_INTERNAL_SWITCH` macro, so that we fail non-fatally in the default clause when we encounter an unsupported dtype. Differential Revision: D80141272
1 parent 310a05d commit ca35dd9

File tree

13 files changed

+181
-107
lines changed

13 files changed

+181
-107
lines changed

kernels/optimized/cpu/op_add_sub_impl.h

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -144,13 +144,13 @@ Tensor& opt_add_sub_out_impl(
144144
} else if (selected_optimized_path != ElementwiseOptimizedPath::kNone) {
145145
// Cannot apply the trick of -alpha here because alpha is Scalar without
146146
// support for - operator. At least not right now.
147-
ET_SWITCH_REALB_TYPES(out_type, ctx, op_name, CTYPE, [&]() {
147+
ET_SWITCH_REALB_TYPES(out_type, ctx, op_name, CTYPE, [&]() -> void {
148148
CTYPE alpha_val;
149149
ET_KERNEL_CHECK_MSG(
150150
ctx,
151151
torch::executor::native::utils::extract_scalar(alpha, &alpha_val),
152152
InvalidArgument,
153-
out,
153+
,
154154
"Failed to extract scalar alpha.");
155155
using Vec = at::vec::Vectorized<CTYPE>;
156156
Vec alpha_val_vec(alpha_val);
@@ -164,13 +164,13 @@ Tensor& opt_add_sub_out_impl(
164164
auto add_lambda = [&alpha_val_vec](auto x, auto y) {
165165
return y - alpha_val_vec * x;
166166
};
167-
return torch::executor::handle_broadcast_elementwise<CTYPE>(
167+
torch::executor::handle_broadcast_elementwise<CTYPE>(
168168
ctx, add_lambda, a, b, out, selected_optimized_path, alpha);
169169
} else {
170170
auto add_lambda = [&alpha_val_vec](auto x, auto y) {
171171
return x - alpha_val_vec * y;
172172
};
173-
return torch::executor::handle_broadcast_elementwise<CTYPE>(
173+
torch::executor::handle_broadcast_elementwise<CTYPE>(
174174
ctx, add_lambda, a, b, out, selected_optimized_path, alpha);
175175
}
176176
} else {
@@ -191,13 +191,13 @@ Tensor& opt_add_sub_out_impl(
191191
auto add_lambda = [&alpha_val_vec](auto x, auto y) {
192192
return y + alpha_val_vec * x;
193193
};
194-
return torch::executor::handle_broadcast_elementwise<CTYPE>(
194+
torch::executor::handle_broadcast_elementwise<CTYPE>(
195195
ctx, add_lambda, a, b, out, selected_optimized_path, alpha);
196196
} else {
197197
auto add_lambda = [&alpha_val_vec](auto x, auto y) {
198198
return x + alpha_val_vec * y;
199199
};
200-
return torch::executor::handle_broadcast_elementwise<CTYPE>(
200+
torch::executor::handle_broadcast_elementwise<CTYPE>(
201201
ctx, add_lambda, a, b, out, selected_optimized_path, alpha);
202202
}
203203
}

kernels/optimized/cpu/op_div.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -130,11 +130,11 @@ Tensor& opt_div_out(
130130
selected_optimized_path ==
131131
ElementwiseOptimizedPath::kBroadcastNdByNdReverseArguments) {
132132
auto div_lambda = [](auto x, auto y) { return y / x; };
133-
return torch::executor::handle_broadcast_elementwise<CTYPE>(
133+
torch::executor::handle_broadcast_elementwise<CTYPE>(
134134
ctx, div_lambda, a, b, out, selected_optimized_path);
135135
} else {
136136
auto div_lambda = [](auto x, auto y) { return x / y; };
137-
return torch::executor::handle_broadcast_elementwise<CTYPE>(
137+
torch::executor::handle_broadcast_elementwise<CTYPE>(
138138
ctx, div_lambda, a, b, out, selected_optimized_path);
139139
}
140140
});

kernels/optimized/cpu/op_le.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@ Tensor& opt_le_tensor_out(
5757
// Handle optimized broadcast cases
5858
ET_SWITCH_REALB_TYPES(out_type, ctx, "le.Tensor_out", CTYPE, [&]() {
5959
auto le_lambda = [](auto x, auto y) { return x.le(y); };
60-
return torch::executor::handle_broadcast_elementwise<CTYPE>(
60+
torch::executor::handle_broadcast_elementwise<CTYPE>(
6161
ctx, le_lambda, a, b, out, selected_optimized_path);
6262
});
6363
} else {

kernels/optimized/cpu/op_mul.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -148,13 +148,13 @@ Tensor& opt_mul_out(
148148

149149
ET_SWITCH_COMPLEXH_TYPES(out_type, ctx, "mul.out", CTYPE, [&]() {
150150
auto mul_lambda = [](auto x, auto y) { return x * y; };
151-
return torch::executor::handle_broadcast_elementwise<CTYPE>(
151+
torch::executor::handle_broadcast_elementwise<CTYPE>(
152152
ctx, mul_lambda, a, b, out, selected_optimized_path);
153153
});
154154
} else {
155155
ET_SWITCH_REALB_TYPES(out_type, ctx, "mul.out", CTYPE, [&]() {
156156
auto mul_lambda = [](auto x, auto y) { return x * y; };
157-
return torch::executor::handle_broadcast_elementwise<CTYPE>(
157+
torch::executor::handle_broadcast_elementwise<CTYPE>(
158158
ctx, mul_lambda, a, b, out, selected_optimized_path);
159159
});
160160
}

kernels/portable/cpu/op_clamp.cpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ bool is_out_of_bounds(CTYPE_CAST val_cast) {
3434
}
3535

3636
ET_NODISCARD bool check_bounds(
37+
KernelRuntimeContext& ctx,
3738
const Scalar& val_scalar,
3839
const torch::executor::native::ScalarType& val_type,
3940
const torch::executor::native::ScalarType& out_type,
@@ -104,14 +105,14 @@ Tensor& clamp_out(
104105
if (has_min) {
105106
ET_KERNEL_CHECK(
106107
ctx,
107-
check_bounds(min_opt.value(), min_type, out_type, "minimum"),
108+
check_bounds(ctx, min_opt.value(), min_type, out_type, "minimum"),
108109
InvalidArgument,
109110
out);
110111
}
111112
if (has_max) {
112113
ET_KERNEL_CHECK(
113114
ctx,
114-
check_bounds(max_opt.value(), max_type, out_type, "maximum"),
115+
check_bounds(ctx, max_opt.value(), max_type, out_type, "maximum"),
115116
InvalidArgument,
116117
out);
117118
}

kernels/portable/cpu/op_convolution.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -415,7 +415,7 @@ Tensor& convolution_out(
415415
ET_SWITCH_REALH_TYPES(in.scalar_type(), ctx, name, CTYPE, [&]() {
416416
const auto load_bias = bias.has_value()
417417
? utils::internal::get_load_to_compute_fn<CTYPE, name>(
418-
bias.value(), utils::SupportedTensorDtypes::REALHBF16)
418+
ctx, bias.value(), utils::SupportedTensorDtypes::REALHBF16)
419419
: nullptr;
420420
convolution_wrapper<CTYPE>(
421421
in,

kernels/portable/cpu/op_cumsum.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -111,10 +111,10 @@ Tensor& cumsum_out(
111111
// @lint-ignore CLANGTIDY facebook-hte-CArray
112112
static constexpr const char op_name[] = "cumsum.out";
113113

114-
ET_SWITCH_REALHBBF16_TYPES(out.scalar_type(), ctx, op_name, CTYPE_OUT, [&] {
114+
ET_SWITCH_REALHBBF16_TYPES(out.scalar_type(), ctx, op_name, CTYPE_OUT, [&]() {
115115
const auto load_self =
116116
utils::internal::get_load_to_compute_fn<CTYPE_OUT, op_name>(
117-
self, utils::SupportedTensorDtypes::REALHBBF16);
117+
ctx, self, utils::SupportedTensorDtypes::REALHBBF16);
118118
cumsum_tensors<CTYPE_OUT>(self, load_self, dim, out);
119119
});
120120

kernels/portable/cpu/op_index_put.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -160,6 +160,7 @@ Tensor& index_put_out(
160160
namespace {
161161

162162
bool check_special_case_in_place_args(
163+
KernelRuntimeContext& ctx,
163164
Tensor& in,
164165
TensorOptList indices,
165166
const Tensor& values,
@@ -285,7 +286,7 @@ Tensor& index_put_(
285286
size_t dim = 0;
286287
ET_KERNEL_CHECK(
287288
ctx,
288-
check_special_case_in_place_args(in, indices, values, accumulate, &dim),
289+
check_special_case_in_place_args(ctx, in, indices, values, accumulate, &dim),
289290
InvalidArgument,
290291
in);
291292

kernels/portable/cpu/op_scatter.cpp

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -104,22 +104,20 @@ void scatter_value_helper(
104104
} // namespace
105105

106106
Tensor& scatter_src_out(
107-
KernelRuntimeContext& context,
107+
KernelRuntimeContext& ctx,
108108
const Tensor& in,
109109
int64_t dim,
110110
const Tensor& index,
111111
const Tensor& src,
112112
Tensor& out) {
113-
(void)context;
114-
115113
ET_KERNEL_CHECK(
116-
context,
114+
ctx,
117115
check_scatter_src_args(in, dim, index, src, out),
118116
InvalidArgument,
119117
out);
120118

121119
ET_KERNEL_CHECK(
122-
context,
120+
ctx,
123121
resize_tensor(out, in.sizes()) == Error::Ok,
124122
InvalidArgument,
125123
out);

kernels/portable/cpu/op_scatter_add.cpp

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -52,35 +52,33 @@ void scatter_add_helper(
5252
} // namespace
5353

5454
Tensor& scatter_add_out(
55-
KernelRuntimeContext& context,
55+
KernelRuntimeContext& ctx,
5656
const Tensor& self,
5757
int64_t dim,
5858
const Tensor& index,
5959
const Tensor& src,
6060
Tensor& out) {
61-
(void)context;
62-
6361
ET_KERNEL_CHECK(
64-
context,
62+
ctx,
6563
check_scatter_add_args(self, dim, index, src, out),
6664
InvalidArgument,
6765
out);
6866

6967
ET_KERNEL_CHECK(
70-
context,
68+
ctx,
7169
tensors_have_same_dim_order(self, src, out),
7270
InvalidArgument,
7371
out);
7472

7573
ET_KERNEL_CHECK(
76-
context, tensor_is_default_dim_order(index), InvalidArgument, out);
74+
ctx, tensor_is_default_dim_order(index), InvalidArgument, out);
7775

7876
if (dim < 0) {
7977
dim += nonzero_dim(self);
8078
}
8179

8280
ET_KERNEL_CHECK(
83-
context,
81+
ctx,
8482
resize_tensor(out, self.sizes()) == Error::Ok,
8583
InvalidArgument,
8684
out);

0 commit comments

Comments
 (0)