@@ -36,7 +36,6 @@ namespace {
3636 }
3737
3838#define __NUMBER_ET_PRIM_OP_IMPL (operator, stack, context ) \
39- (void )context; \
4039 EValue& a = *stack[0 ]; \
4140 EValue& b = *stack[1 ]; \
4241 EValue& out = *stack[2 ]; \
@@ -50,11 +49,23 @@ namespace {
5049 out = EValue (a.toDouble () operator b.toInt ()); \
5150 }
5251
52+ #define __ET_PRIM_OP_NUM_ARGS_CHECK_IMPL (stack, context ) \
53+ ET_KERNEL_CHECK_MSG ( \
54+ context, \
55+ stack.size() == 3 , \
56+ InvalidProgram, \
57+ /* void */ , \
58+ " Expected %zu args, got %zu" , \
59+ (size_t )3 , \
60+ stack.size());
61+
5362#define ALGEBRA_ET_PRIM_OP (operator, stack, context ) \
63+ __ET_PRIM_OP_NUM_ARGS_CHECK_IMPL (stack, context) \
5464 __NUMBER_ET_PRIM_OP_IMPL (operator , stack, context) \
5565 __ET_PRIM_OP_ERROR_IMPL (a, b, context)
5666
5767#define BOOLEAN_ET_PRIM_OP (operator, stack, context ) \
68+ __ET_PRIM_OP_NUM_ARGS_CHECK_IMPL (stack, context) \
5869 __NUMBER_ET_PRIM_OP_IMPL (operator , stack, context) \
5970 else if (a.isBool() && b.isBool()) { \
6071 out = EValue (a.toBool () operator b.toBool ()); \
@@ -80,7 +91,14 @@ static Kernel prim_ops[] = {
8091 Kernel (
8192 " aten::sym_size.int" ,
8293 [](KernelRuntimeContext& context, Span<EValue*> stack) {
83- (void )context;
94+ ET_KERNEL_CHECK_MSG (
95+ context,
96+ stack.size () == 3 ,
97+ InvalidProgram,
98+ /* void */ ,
99+ " Expected %zu args, got %zu" ,
100+ (size_t )3 ,
101+ stack.size ());
84102 EValue& self = *stack[0 ];
85103 EValue& dim = *stack[1 ];
86104 EValue& out = *stack[2 ];
@@ -94,7 +112,14 @@ static Kernel prim_ops[] = {
94112 Kernel (
95113 " aten::_local_scalar_dense" ,
96114 [](KernelRuntimeContext& context, Span<EValue*> stack) {
97- (void )context;
115+ ET_KERNEL_CHECK_MSG (
116+ context,
117+ stack.size () == 2 ,
118+ InvalidProgram,
119+ /* void */ ,
120+ " Expected %zu args, got %zu" ,
121+ (size_t )2 ,
122+ stack.size ());
98123 EValue& self = *stack[0 ];
99124 EValue& out = *stack[1 ];
100125 executorch::aten::Tensor self_tensor =
@@ -113,7 +138,14 @@ static Kernel prim_ops[] = {
113138 Kernel (
114139 " aten::sym_numel" ,
115140 [](KernelRuntimeContext& context, Span<EValue*> stack) {
116- (void )context;
141+ ET_KERNEL_CHECK_MSG (
142+ context,
143+ stack.size () == 2 ,
144+ InvalidProgram,
145+ /* void */ ,
146+ " Expected %zu args, got %zu" ,
147+ (size_t )2 ,
148+ stack.size ());
117149 EValue& self = *stack[0 ];
118150 EValue& out = *stack[1 ];
119151 executorch::aten::Tensor self_tensor =
@@ -125,7 +157,15 @@ static Kernel prim_ops[] = {
125157 Kernel (
126158 " executorch_prim::sym_max.Scalar" ,
127159 [](KernelRuntimeContext& context, Span<EValue*> stack) {
128- (void )context;
160+ ET_KERNEL_CHECK_MSG (
161+ context,
162+ stack.size () == 3 ,
163+ InvalidProgram,
164+ /* void */ ,
165+ " Expected %zu args, got %zu" ,
166+ (size_t )3 ,
167+ stack.size ());
168+
129169 EValue& a = *stack[0 ];
130170 EValue& b = *stack[1 ];
131171 EValue& out = *stack[2 ];
@@ -146,7 +186,14 @@ static Kernel prim_ops[] = {
146186 Kernel (
147187 " executorch_prim::sym_min.Scalar" ,
148188 [](KernelRuntimeContext& context, Span<EValue*> stack) {
149- (void )context;
189+ ET_KERNEL_CHECK_MSG (
190+ context,
191+ stack.size () == 3 ,
192+ InvalidProgram,
193+ /* void */ ,
194+ " Expected %zu args, got %zu" ,
195+ (size_t )3 ,
196+ stack.size ());
150197 EValue& a = *stack[0 ];
151198 EValue& b = *stack[1 ];
152199 EValue& out = *stack[2 ];
@@ -167,7 +214,6 @@ static Kernel prim_ops[] = {
167214 Kernel (
168215 " executorch_prim::add.Scalar" ,
169216 [](KernelRuntimeContext& context, Span<EValue*> stack) {
170- (void )context;
171217 ALGEBRA_ET_PRIM_OP (+, stack, context);
172218 }),
173219
@@ -197,7 +243,14 @@ static Kernel prim_ops[] = {
197243 Kernel (
198244 " executorch_prim::floordiv.Scalar" ,
199245 [](KernelRuntimeContext& context, Span<EValue*> stack) {
200- (void )context;
246+ ET_KERNEL_CHECK_MSG (
247+ context,
248+ stack.size () == 3 ,
249+ InvalidProgram,
250+ /* void */ ,
251+ " Expected %zu args, got %zu" ,
252+ (size_t )3 ,
253+ stack.size ());
201254 EValue& a = *stack[0 ];
202255 EValue& b = *stack[1 ];
203256 EValue& out = *stack[2 ];
@@ -233,7 +286,14 @@ static Kernel prim_ops[] = {
233286 " executorch_prim::truediv.Scalar" ,
234287 [](KernelRuntimeContext& context, Span<EValue*> stack) {
235288 // can't use macro because of custom casting behavior
236- (void )context;
289+ ET_KERNEL_CHECK_MSG (
290+ context,
291+ stack.size () == 3 ,
292+ InvalidProgram,
293+ /* void */ ,
294+ " Expected %zu args, got %zu" ,
295+ (size_t )3 ,
296+ stack.size ());
237297 EValue& a = *stack[0 ];
238298 EValue& b = *stack[1 ];
239299 EValue& out = *stack[2 ];
@@ -266,7 +326,14 @@ static Kernel prim_ops[] = {
266326 // can't use macro because of custom casting behavior
267327 // TODO: Now that we are reliably generating conversion operators,
268328 // we can remove the mixed type handling for other operators
269- (void )context;
329+ ET_KERNEL_CHECK_MSG (
330+ context,
331+ stack.size () == 2 ,
332+ InvalidProgram,
333+ /* void */ ,
334+ " Expected %zu args, got %zu" ,
335+ (size_t )2 ,
336+ stack.size ());
270337 EValue& a = *stack[0 ];
271338 EValue& out = *stack[1 ];
272339 if (a.isInt ()) {
@@ -318,7 +385,14 @@ static Kernel prim_ops[] = {
318385 Kernel (
319386 " executorch_prim::neg.Scalar" ,
320387 [](KernelRuntimeContext& context, Span<EValue*> stack) {
321- (void )context;
388+ ET_KERNEL_CHECK_MSG (
389+ context,
390+ stack.size () == 2 ,
391+ InvalidProgram,
392+ /* void */ ,
393+ " Expected %zu args, got %zu" ,
394+ (size_t )2 ,
395+ stack.size ());
322396 EValue& a = *stack[0 ];
323397 EValue& out = *stack[1 ];
324398 if (a.isInt ()) {
@@ -335,7 +409,14 @@ static Kernel prim_ops[] = {
335409 Kernel (
336410 " executorch_prim::floordiv.int" ,
337411 [](KernelRuntimeContext& context, Span<EValue*> stack) {
338- (void )context;
412+ ET_KERNEL_CHECK_MSG (
413+ context,
414+ stack.size () == 3 ,
415+ InvalidProgram,
416+ /* void */ ,
417+ " Expected %zu args, got %zu" ,
418+ (size_t )3 ,
419+ stack.size ());
339420 EValue& a = *stack[0 ];
340421 EValue& b = *stack[1 ];
341422 EValue& out = *stack[2 ];
@@ -346,7 +427,14 @@ static Kernel prim_ops[] = {
346427 Kernel (
347428 " executorch_prim::mod.int" ,
348429 [](KernelRuntimeContext& context, Span<EValue*> stack) {
349- (void )context;
430+ ET_KERNEL_CHECK_MSG (
431+ context,
432+ stack.size () == 3 ,
433+ InvalidProgram,
434+ /* void */ ,
435+ " Expected %zu args, got %zu" ,
436+ (size_t )3 ,
437+ stack.size ());
350438 EValue& a = *stack[0 ];
351439 EValue& b = *stack[1 ];
352440 EValue& out = *stack[2 ];
@@ -357,7 +445,14 @@ static Kernel prim_ops[] = {
357445 Kernel (
358446 " executorch_prim::mod.Scalar" ,
359447 [](KernelRuntimeContext& context, Span<EValue*> stack) {
360- (void )context;
448+ ET_KERNEL_CHECK_MSG (
449+ context,
450+ stack.size () == 3 ,
451+ InvalidProgram,
452+ /* void */ ,
453+ " Expected %zu args, got %zu" ,
454+ (size_t )3 ,
455+ stack.size ());
361456 EValue& a = *stack[0 ];
362457 EValue& b = *stack[1 ];
363458 EValue& out = *stack[2 ];
@@ -379,7 +474,14 @@ static Kernel prim_ops[] = {
379474 Kernel (
380475 " executorch_prim::ceil.Scalar" ,
381476 [](KernelRuntimeContext& context, Span<EValue*> stack) {
382- (void )context;
477+ ET_KERNEL_CHECK_MSG (
478+ context,
479+ stack.size () == 2 ,
480+ InvalidProgram,
481+ /* void */ ,
482+ " Expected %zu args, got %zu" ,
483+ (size_t )2 ,
484+ stack.size ());
383485 EValue& a = *stack[0 ];
384486 EValue& out = *stack[1 ];
385487 if (a.isDouble ()) {
@@ -399,7 +501,14 @@ static Kernel prim_ops[] = {
399501 Kernel (
400502 " executorch_prim::round.Scalar" ,
401503 [](KernelRuntimeContext& context, Span<EValue*> stack) {
402- (void )context;
504+ ET_KERNEL_CHECK_MSG (
505+ context,
506+ stack.size () == 2 ,
507+ InvalidProgram,
508+ /* void */ ,
509+ " Expected %zu args, got %zu" ,
510+ (size_t )2 ,
511+ stack.size ());
403512 EValue& a = *stack[0 ];
404513 EValue& out = *stack[1 ];
405514 if (a.isDouble ()) {
@@ -436,7 +545,14 @@ static Kernel prim_ops[] = {
436545 Kernel (
437546 " executorch_prim::trunc.Scalar" ,
438547 [](KernelRuntimeContext& context, Span<EValue*> stack) {
439- (void )context;
548+ ET_KERNEL_CHECK_MSG (
549+ context,
550+ stack.size () == 2 ,
551+ InvalidProgram,
552+ /* void */ ,
553+ " Expected %zu args, got %zu" ,
554+ (size_t )2 ,
555+ stack.size ());
440556 EValue& a = *stack[0 ];
441557 EValue& out = *stack[1 ];
442558 if (a.isDouble ()) {
0 commit comments