Skip to content

Commit b6b7a16

Browse files
add safety check to prim kernels
Differential Revision: D79124770 Pull Request resolved: #12987
1 parent d118a63 commit b6b7a16

File tree

4 files changed

+415
-20
lines changed

4 files changed

+415
-20
lines changed

kernels/prim_ops/et_copy_index.cpp

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,14 @@ constexpr size_t kTensorDimensionLimit = 16;
6565
// The output of each iteration (copy_from) is copied into the copy_to tensor at
6666
// the specified index. This operator is supported in both ATen and lean modes.
6767
void et_copy_index(KernelRuntimeContext& context, Span<EValue*> stack) {
68-
(void)context;
68+
ET_KERNEL_CHECK_MSG(
69+
context,
70+
stack.size() == 3,
71+
InvalidProgram,
72+
/* void */,
73+
"Expected %zu args, got %zu",
74+
(size_t)3,
75+
stack.size());
6976
SizesType expected_output_size[kTensorDimensionLimit];
7077

7178
auto copy_to = (*stack[0]).toTensor();

kernels/prim_ops/et_view.cpp

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,14 @@ bool get_view_target_size(
6666
} // namespace
6767

6868
void et_view(KernelRuntimeContext& context, Span<EValue*> stack) {
69-
(void)context;
69+
ET_KERNEL_CHECK_MSG(
70+
context,
71+
stack.size() == 3,
72+
InvalidProgram,
73+
/* void */,
74+
"Expected %zu args, got %zu",
75+
(size_t)3,
76+
stack.size());
7077

7178
auto self = (*stack[0]).toTensor();
7279
auto size = (*stack[1]).toIntList();

kernels/prim_ops/register_prim_ops.cpp

Lines changed: 133 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,6 @@ namespace {
3636
}
3737

3838
#define __NUMBER_ET_PRIM_OP_IMPL(operator, stack, context) \
39-
(void)context; \
4039
EValue& a = *stack[0]; \
4140
EValue& b = *stack[1]; \
4241
EValue& out = *stack[2]; \
@@ -50,11 +49,23 @@ namespace {
5049
out = EValue(a.toDouble() operator b.toInt()); \
5150
}
5251

52+
#define __ET_PRIM_OP_NUM_ARGS_CHECK_IMPL(stack, context) \
53+
ET_KERNEL_CHECK_MSG( \
54+
context, \
55+
stack.size() == 3, \
56+
InvalidProgram, \
57+
/* void */, \
58+
"Expected %zu args, got %zu", \
59+
(size_t)3, \
60+
stack.size());
61+
5362
#define ALGEBRA_ET_PRIM_OP(operator, stack, context) \
63+
__ET_PRIM_OP_NUM_ARGS_CHECK_IMPL(stack, context) \
5464
__NUMBER_ET_PRIM_OP_IMPL(operator, stack, context) \
5565
__ET_PRIM_OP_ERROR_IMPL(a, b, context)
5666

5767
#define BOOLEAN_ET_PRIM_OP(operator, stack, context) \
68+
__ET_PRIM_OP_NUM_ARGS_CHECK_IMPL(stack, context) \
5869
__NUMBER_ET_PRIM_OP_IMPL(operator, stack, context) \
5970
else if (a.isBool() && b.isBool()) { \
6071
out = EValue(a.toBool() operator b.toBool()); \
@@ -80,7 +91,14 @@ static Kernel prim_ops[] = {
8091
Kernel(
8192
"aten::sym_size.int",
8293
[](KernelRuntimeContext& context, Span<EValue*> stack) {
83-
(void)context;
94+
ET_KERNEL_CHECK_MSG(
95+
context,
96+
stack.size() == 3,
97+
InvalidProgram,
98+
/* void */,
99+
"Expected %zu args, got %zu",
100+
(size_t)3,
101+
stack.size());
84102
EValue& self = *stack[0];
85103
EValue& dim = *stack[1];
86104
EValue& out = *stack[2];
@@ -94,7 +112,14 @@ static Kernel prim_ops[] = {
94112
Kernel(
95113
"aten::_local_scalar_dense",
96114
[](KernelRuntimeContext& context, Span<EValue*> stack) {
97-
(void)context;
115+
ET_KERNEL_CHECK_MSG(
116+
context,
117+
stack.size() == 2,
118+
InvalidProgram,
119+
/* void */,
120+
"Expected %zu args, got %zu",
121+
(size_t)2,
122+
stack.size());
98123
EValue& self = *stack[0];
99124
EValue& out = *stack[1];
100125
executorch::aten::Tensor self_tensor =
@@ -113,7 +138,14 @@ static Kernel prim_ops[] = {
113138
Kernel(
114139
"aten::sym_numel",
115140
[](KernelRuntimeContext& context, Span<EValue*> stack) {
116-
(void)context;
141+
ET_KERNEL_CHECK_MSG(
142+
context,
143+
stack.size() == 2,
144+
InvalidProgram,
145+
/* void */,
146+
"Expected %zu args, got %zu",
147+
(size_t)2,
148+
stack.size());
117149
EValue& self = *stack[0];
118150
EValue& out = *stack[1];
119151
executorch::aten::Tensor self_tensor =
@@ -125,7 +157,15 @@ static Kernel prim_ops[] = {
125157
Kernel(
126158
"executorch_prim::sym_max.Scalar",
127159
[](KernelRuntimeContext& context, Span<EValue*> stack) {
128-
(void)context;
160+
ET_KERNEL_CHECK_MSG(
161+
context,
162+
stack.size() == 3,
163+
InvalidProgram,
164+
/* void */,
165+
"Expected %zu args, got %zu",
166+
(size_t)3,
167+
stack.size());
168+
129169
EValue& a = *stack[0];
130170
EValue& b = *stack[1];
131171
EValue& out = *stack[2];
@@ -146,7 +186,14 @@ static Kernel prim_ops[] = {
146186
Kernel(
147187
"executorch_prim::sym_min.Scalar",
148188
[](KernelRuntimeContext& context, Span<EValue*> stack) {
149-
(void)context;
189+
ET_KERNEL_CHECK_MSG(
190+
context,
191+
stack.size() == 3,
192+
InvalidProgram,
193+
/* void */,
194+
"Expected %zu args, got %zu",
195+
(size_t)3,
196+
stack.size());
150197
EValue& a = *stack[0];
151198
EValue& b = *stack[1];
152199
EValue& out = *stack[2];
@@ -167,7 +214,6 @@ static Kernel prim_ops[] = {
167214
Kernel(
168215
"executorch_prim::add.Scalar",
169216
[](KernelRuntimeContext& context, Span<EValue*> stack) {
170-
(void)context;
171217
ALGEBRA_ET_PRIM_OP(+, stack, context);
172218
}),
173219

@@ -197,7 +243,14 @@ static Kernel prim_ops[] = {
197243
Kernel(
198244
"executorch_prim::floordiv.Scalar",
199245
[](KernelRuntimeContext& context, Span<EValue*> stack) {
200-
(void)context;
246+
ET_KERNEL_CHECK_MSG(
247+
context,
248+
stack.size() == 3,
249+
InvalidProgram,
250+
/* void */,
251+
"Expected %zu args, got %zu",
252+
(size_t)3,
253+
stack.size());
201254
EValue& a = *stack[0];
202255
EValue& b = *stack[1];
203256
EValue& out = *stack[2];
@@ -233,7 +286,14 @@ static Kernel prim_ops[] = {
233286
"executorch_prim::truediv.Scalar",
234287
[](KernelRuntimeContext& context, Span<EValue*> stack) {
235288
// can't use macro because of custom casting behavior
236-
(void)context;
289+
ET_KERNEL_CHECK_MSG(
290+
context,
291+
stack.size() == 3,
292+
InvalidProgram,
293+
/* void */,
294+
"Expected %zu args, got %zu",
295+
(size_t)3,
296+
stack.size());
237297
EValue& a = *stack[0];
238298
EValue& b = *stack[1];
239299
EValue& out = *stack[2];
@@ -266,7 +326,14 @@ static Kernel prim_ops[] = {
266326
// can't use macro because of custom casting behavior
267327
// TODO: Now that we are reliably generating conversion operators,
268328
// we can remove the mixed type handling for other operators
269-
(void)context;
329+
ET_KERNEL_CHECK_MSG(
330+
context,
331+
stack.size() == 2,
332+
InvalidProgram,
333+
/* void */,
334+
"Expected %zu args, got %zu",
335+
(size_t)2,
336+
stack.size());
270337
EValue& a = *stack[0];
271338
EValue& out = *stack[1];
272339
if (a.isInt()) {
@@ -318,7 +385,14 @@ static Kernel prim_ops[] = {
318385
Kernel(
319386
"executorch_prim::neg.Scalar",
320387
[](KernelRuntimeContext& context, Span<EValue*> stack) {
321-
(void)context;
388+
ET_KERNEL_CHECK_MSG(
389+
context,
390+
stack.size() == 2,
391+
InvalidProgram,
392+
/* void */,
393+
"Expected %zu args, got %zu",
394+
(size_t)2,
395+
stack.size());
322396
EValue& a = *stack[0];
323397
EValue& out = *stack[1];
324398
if (a.isInt()) {
@@ -335,7 +409,14 @@ static Kernel prim_ops[] = {
335409
Kernel(
336410
"executorch_prim::floordiv.int",
337411
[](KernelRuntimeContext& context, Span<EValue*> stack) {
338-
(void)context;
412+
ET_KERNEL_CHECK_MSG(
413+
context,
414+
stack.size() == 3,
415+
InvalidProgram,
416+
/* void */,
417+
"Expected %zu args, got %zu",
418+
(size_t)3,
419+
stack.size());
339420
EValue& a = *stack[0];
340421
EValue& b = *stack[1];
341422
EValue& out = *stack[2];
@@ -346,7 +427,14 @@ static Kernel prim_ops[] = {
346427
Kernel(
347428
"executorch_prim::mod.int",
348429
[](KernelRuntimeContext& context, Span<EValue*> stack) {
349-
(void)context;
430+
ET_KERNEL_CHECK_MSG(
431+
context,
432+
stack.size() == 3,
433+
InvalidProgram,
434+
/* void */,
435+
"Expected %zu args, got %zu",
436+
(size_t)3,
437+
stack.size());
350438
EValue& a = *stack[0];
351439
EValue& b = *stack[1];
352440
EValue& out = *stack[2];
@@ -357,7 +445,14 @@ static Kernel prim_ops[] = {
357445
Kernel(
358446
"executorch_prim::mod.Scalar",
359447
[](KernelRuntimeContext& context, Span<EValue*> stack) {
360-
(void)context;
448+
ET_KERNEL_CHECK_MSG(
449+
context,
450+
stack.size() == 3,
451+
InvalidProgram,
452+
/* void */,
453+
"Expected %zu args, got %zu",
454+
(size_t)3,
455+
stack.size());
361456
EValue& a = *stack[0];
362457
EValue& b = *stack[1];
363458
EValue& out = *stack[2];
@@ -379,7 +474,14 @@ static Kernel prim_ops[] = {
379474
Kernel(
380475
"executorch_prim::ceil.Scalar",
381476
[](KernelRuntimeContext& context, Span<EValue*> stack) {
382-
(void)context;
477+
ET_KERNEL_CHECK_MSG(
478+
context,
479+
stack.size() == 2,
480+
InvalidProgram,
481+
/* void */,
482+
"Expected %zu args, got %zu",
483+
(size_t)2,
484+
stack.size());
383485
EValue& a = *stack[0];
384486
EValue& out = *stack[1];
385487
if (a.isDouble()) {
@@ -399,7 +501,14 @@ static Kernel prim_ops[] = {
399501
Kernel(
400502
"executorch_prim::round.Scalar",
401503
[](KernelRuntimeContext& context, Span<EValue*> stack) {
402-
(void)context;
504+
ET_KERNEL_CHECK_MSG(
505+
context,
506+
stack.size() == 2,
507+
InvalidProgram,
508+
/* void */,
509+
"Expected %zu args, got %zu",
510+
(size_t)2,
511+
stack.size());
403512
EValue& a = *stack[0];
404513
EValue& out = *stack[1];
405514
if (a.isDouble()) {
@@ -436,7 +545,14 @@ static Kernel prim_ops[] = {
436545
Kernel(
437546
"executorch_prim::trunc.Scalar",
438547
[](KernelRuntimeContext& context, Span<EValue*> stack) {
439-
(void)context;
548+
ET_KERNEL_CHECK_MSG(
549+
context,
550+
stack.size() == 2,
551+
InvalidProgram,
552+
/* void */,
553+
"Expected %zu args, got %zu",
554+
(size_t)2,
555+
stack.size());
440556
EValue& a = *stack[0];
441557
EValue& out = *stack[1];
442558
if (a.isDouble()) {

0 commit comments

Comments
 (0)