88
99#include < gtest/gtest.h>
1010
11+ #include < executorch/kernels/test/TestUtil.h>
1112#include < executorch/runtime/core/evalue.h>
1213#include < executorch/runtime/core/exec_aten/exec_aten.h>
1314#include < executorch/runtime/core/exec_aten/testing_util/tensor_factory.h>
1617#include < executorch/runtime/kernel/kernel_runtime_context.h>
1718#include < executorch/runtime/kernel/operator_registry.h>
1819#include < executorch/runtime/platform/runtime.h>
19- #include < executorch/test/utils/DeathTest.h>
2020#include < cstdint>
2121#include < cstdio>
2222
@@ -27,12 +27,10 @@ using torch::executor::resize_tensor;
2727namespace torch {
2828namespace executor {
2929
30- class RegisterPrimOpsTest : public ::testing::Test {
30+ class RegisterPrimOpsTest : public OperatorTest {
3131 protected:
32- KernelRuntimeContext context;
3332 void SetUp () override {
34- torch::executor::runtime_init ();
35- context = KernelRuntimeContext ();
33+ context_ = KernelRuntimeContext ();
3634 }
3735};
3836
@@ -57,7 +55,7 @@ TEST_F(RegisterPrimOpsTest, SymSizeReturnsCorrectValue) {
5755 stack[i] = &values[i];
5856 }
5957
60- getOpsFn (" aten::sym_size.int" )(context , stack);
58+ getOpsFn (" aten::sym_size.int" )(context_ , stack);
6159
6260 int64_t expected = 5 ;
6361 EXPECT_EQ (stack[2 ]->toInt (), expected);
@@ -77,7 +75,7 @@ TEST_F(RegisterPrimOpsTest, SymNumelReturnsCorrectValue) {
7775 stack[i] = &values[i];
7876 }
7977
80- getOpsFn (" aten::sym_numel" )(context , stack);
78+ getOpsFn (" aten::sym_numel" )(context_ , stack);
8179
8280 int64_t expected = 15 ;
8381 EXPECT_EQ (stack[1 ]->toInt (), expected);
@@ -97,28 +95,28 @@ TEST_F(RegisterPrimOpsTest, TestAlgebraOps) {
9795 stack[i] = &values[i];
9896 }
9997
100- getOpsFn (" executorch_prim::add.Scalar" )(context , stack);
98+ getOpsFn (" executorch_prim::add.Scalar" )(context_ , stack);
10199 EXPECT_EQ (stack[2 ]->toInt (), 7 );
102100
103- getOpsFn (" executorch_prim::sub.Scalar" )(context , stack);
101+ getOpsFn (" executorch_prim::sub.Scalar" )(context_ , stack);
104102 EXPECT_EQ (stack[2 ]->toInt (), -1 );
105103
106- getOpsFn (" executorch_prim::mul.Scalar" )(context , stack);
104+ getOpsFn (" executorch_prim::mul.Scalar" )(context_ , stack);
107105 EXPECT_EQ (stack[2 ]->toInt (), 12 );
108106
109- getOpsFn (" executorch_prim::floordiv.Scalar" )(context , stack);
107+ getOpsFn (" executorch_prim::floordiv.Scalar" )(context_ , stack);
110108 EXPECT_EQ (stack[2 ]->toInt (), 0 );
111109
112- getOpsFn (" executorch_prim::truediv.Scalar" )(context , stack);
110+ getOpsFn (" executorch_prim::truediv.Scalar" )(context_ , stack);
113111 EXPECT_FLOAT_EQ (stack[2 ]->toDouble (), 0.75 );
114112
115- getOpsFn (" executorch_prim::mod.int" )(context , stack);
113+ getOpsFn (" executorch_prim::mod.int" )(context_ , stack);
116114 EXPECT_EQ (stack[2 ]->toInt (), 3 );
117115
118- getOpsFn (" executorch_prim::mod.Scalar" )(context , stack);
116+ getOpsFn (" executorch_prim::mod.Scalar" )(context_ , stack);
119117 EXPECT_EQ (stack[2 ]->toInt (), 3 );
120118
121- getOpsFn (" executorch_prim::sym_float.Scalar" )(context , stack);
119+ getOpsFn (" executorch_prim::sym_float.Scalar" )(context_ , stack);
122120 EXPECT_FLOAT_EQ (stack[1 ]->toDouble (), 3.0 );
123121}
124122
@@ -155,7 +153,7 @@ TEST_F(RegisterPrimOpsTest, TestETCopyIndex) {
155153 stack[2 ] = &values[2 ];
156154
157155 // Simple test to copy to index 0.
158- getOpsFn (" executorch_prim::et_copy_index.tensor" )(context , stack);
156+ getOpsFn (" executorch_prim::et_copy_index.tensor" )(context_ , stack);
159157
160158 EXPECT_EQ (copy_to.sizes ()[0 ], 1 );
161159 EXPECT_EQ (copy_to.sizes ()[1 ], 2 );
@@ -164,7 +162,7 @@ TEST_F(RegisterPrimOpsTest, TestETCopyIndex) {
164162 values[1 ] = tf.make ({2 }, {5 , 6 });
165163 values[2 ] = EValue ((int64_t )1 );
166164 // Copy to the next index, 1.
167- getOpsFn (" executorch_prim::et_copy_index.tensor" )(context , stack);
165+ getOpsFn (" executorch_prim::et_copy_index.tensor" )(context_ , stack);
168166
169167 EXPECT_EQ (copy_to.sizes ()[0 ], 2 );
170168 EXPECT_EQ (copy_to.sizes ()[1 ], 2 );
@@ -193,7 +191,7 @@ TEST_F(RegisterPrimOpsTest, TestETCopyIndexMismatchShape) {
193191 // copy_to.sizes[1:] and to_copy.sizes[:] don't match each other
194192 // which is a pre-requisite for this operator.
195193 ET_EXPECT_DEATH (
196- getOpsFn (" executorch_prim::et_copy_index.tensor" )(context , stack), " " );
194+ getOpsFn (" executorch_prim::et_copy_index.tensor" )(context_ , stack), " " );
197195}
198196
199197TEST_F (RegisterPrimOpsTest, TestETCopyIndexStaticShape) {
@@ -217,7 +215,7 @@ TEST_F(RegisterPrimOpsTest, TestETCopyIndexStaticShape) {
217215 stack[2 ] = &values[2 ];
218216
219217 // Copy and replace at index 1.
220- getOpsFn (" executorch_prim::et_copy_index.tensor" )(context , stack);
218+ getOpsFn (" executorch_prim::et_copy_index.tensor" )(context_ , stack);
221219 EXPECT_EQ (copy_to.sizes ()[0 ], 2 );
222220 EXPECT_EQ (copy_to.sizes ()[1 ], 2 );
223221 EXPECT_TENSOR_EQ (copy_to, tf.make ({2 , 2 }, {1 , 2 , 5 , 6 }));
@@ -228,7 +226,7 @@ TEST_F(RegisterPrimOpsTest, TestETCopyIndexStaticShape) {
228226 index = 2 ;
229227 values[2 ] = EValue (index);
230228 ET_EXPECT_DEATH (
231- getOpsFn (" executorch_prim::et_copy_index.tensor" )(context , stack), " " );
229+ getOpsFn (" executorch_prim::et_copy_index.tensor" )(context_ , stack), " " );
232230#endif
233231}
234232
@@ -246,19 +244,19 @@ TEST_F(RegisterPrimOpsTest, TestBooleanOps) {
246244 stack[i] = &values[i];
247245 }
248246
249- getOpsFn (" executorch_prim::ge.Scalar" )(context , stack);
247+ getOpsFn (" executorch_prim::ge.Scalar" )(context_ , stack);
250248 EXPECT_EQ (stack[2 ]->toBool (), false );
251249
252- getOpsFn (" executorch_prim::gt.Scalar" )(context , stack);
250+ getOpsFn (" executorch_prim::gt.Scalar" )(context_ , stack);
253251 EXPECT_EQ (stack[2 ]->toBool (), false );
254252
255- getOpsFn (" executorch_prim::le.Scalar" )(context , stack);
253+ getOpsFn (" executorch_prim::le.Scalar" )(context_ , stack);
256254 EXPECT_EQ (stack[2 ]->toBool (), true );
257255
258- getOpsFn (" executorch_prim::lt.Scalar" )(context , stack);
256+ getOpsFn (" executorch_prim::lt.Scalar" )(context_ , stack);
259257 EXPECT_EQ (stack[2 ]->toBool (), true );
260258
261- getOpsFn (" executorch_prim::eq.Scalar" )(context , stack);
259+ getOpsFn (" executorch_prim::eq.Scalar" )(context_ , stack);
262260 EXPECT_EQ (stack[2 ]->toBool (), false );
263261}
264262
@@ -277,7 +275,7 @@ TEST_F(RegisterPrimOpsTest, LocalScalarDenseReturnsCorrectValue) {
277275 stack[i] = &values[i];
278276 }
279277
280- getOpsFn (" aten::_local_scalar_dense" )(context , stack);
278+ getOpsFn (" aten::_local_scalar_dense" )(context_ , stack);
281279
282280 int64_t expected = 1 ;
283281 EXPECT_EQ (stack[1 ]->toInt (), expected);
@@ -295,7 +293,7 @@ TEST_F(RegisterPrimOpsTest, NegScalarReturnsCorrectValue) {
295293 stack[i] = &values[i];
296294 }
297295
298- getOpsFn (" executorch_prim::neg.Scalar" )(context , stack);
296+ getOpsFn (" executorch_prim::neg.Scalar" )(context_ , stack);
299297
300298 EXPECT_EQ (stack[1 ]->toDouble (), -5 .0f );
301299
@@ -305,7 +303,7 @@ TEST_F(RegisterPrimOpsTest, NegScalarReturnsCorrectValue) {
305303 values[0 ] = EValue (a);
306304 values[1 ] = EValue (b);
307305
308- getOpsFn (" executorch_prim::neg.Scalar" )(context , stack);
306+ getOpsFn (" executorch_prim::neg.Scalar" )(context_ , stack);
309307
310308 EXPECT_EQ (stack[1 ]->toInt (), -5l );
311309}
@@ -327,7 +325,7 @@ TEST_F(RegisterPrimOpsTest, TestNegScalarWithTensorDies) {
327325 }
328326
329327 // Try to negate a tensor, which should cause a runtime error.
330- ET_EXPECT_DEATH (getOpsFn (" executorch_prim::neg.Scalar" )(context , stack), " " );
328+ ET_EXPECT_DEATH (getOpsFn (" executorch_prim::neg.Scalar" )(context_ , stack), " " );
331329}
332330
333331TEST_F (RegisterPrimOpsTest, TestETView) {
@@ -410,9 +408,9 @@ TEST_F(RegisterPrimOpsTest, TestETView) {
410408
411409 // Bad stacks expect death
412410 for (int i = 0 ; i < N_BAD_STACKS; i++) {
413- ET_EXPECT_DEATH (
414- getOpsFn ( " executorch_prim::et_view.default " )(context, bad_stacks[i]) ,
415- " " );
411+ ET_EXPECT_KERNEL_FAILURE (
412+ context_ ,
413+ getOpsFn ( " executorch_prim::et_view.default " )(context_, bad_stacks[i]) );
416414 }
417415
418416 constexpr int N_GOOD_STACKS = N_GOOD_OUTS;
@@ -422,7 +420,7 @@ TEST_F(RegisterPrimOpsTest, TestETView) {
422420
423421 // Good outs expect no death and correct output
424422 for (int i = 0 ; i < N_GOOD_STACKS; i++) {
425- getOpsFn (" executorch_prim::et_view.default" )(context , good_out_stacks[i]);
423+ getOpsFn (" executorch_prim::et_view.default" )(context_ , good_out_stacks[i]);
426424 EXPECT_TENSOR_EQ (good_outs[i], tf.make ({1 , 3 , 2 }, {1 , 2 , 3 , 4 , 5 , 6 }));
427425 EXPECT_EQ (good_outs[i].const_data_ptr (), self.const_data_ptr ());
428426 }
@@ -456,7 +454,7 @@ TEST_F(RegisterPrimOpsTest, TestETViewDynamic) {
456454
457455 EValue* stack[3 ] = {&self_evalue, &size_int_list_evalue, &out_evalue};
458456
459- getOpsFn (" executorch_prim::et_view.default" )(context , stack);
457+ getOpsFn (" executorch_prim::et_view.default" )(context_ , stack);
460458
461459 EXPECT_TENSOR_EQ (out, tf.make ({1 , 3 , 1 }, {1 , 2 , 3 }));
462460 EXPECT_EQ (out.const_data_ptr (), self.const_data_ptr ());
@@ -493,14 +491,15 @@ TEST_F(RegisterPrimOpsTest, TestETViewEmpty) {
493491
494492 // good size test
495493 EValue* stack[3 ] = {&self_evalue, &size_int_list_evalue, &out_evalue};
496- getOpsFn (" executorch_prim::et_view.default" )(context , stack);
494+ getOpsFn (" executorch_prim::et_view.default" )(context_ , stack);
497495 EXPECT_TENSOR_EQ (out, tf.make ({3 , 1 , 0 }, {}));
498496 EXPECT_EQ (out.const_data_ptr (), self.const_data_ptr ());
499497
500498 // bad size test
501499 EValue* bad_stack[3 ] = {&self_evalue, &bad_size_int_list_evalue, &out_evalue};
502- ET_EXPECT_DEATH (
503- getOpsFn (" executorch_prim::et_view.default" )(context, bad_stack), " " );
500+ ET_EXPECT_KERNEL_FAILURE (
501+ context_,
502+ getOpsFn (" executorch_prim::et_view.default" )(context_, bad_stack));
504503}
505504
506505TEST_F (RegisterPrimOpsTest, TestCeil) {
@@ -518,7 +517,7 @@ TEST_F(RegisterPrimOpsTest, TestCeil) {
518517 stack[j] = &values[j];
519518 }
520519
521- getOpsFn (" executorch_prim::ceil.Scalar" )(context , stack);
520+ getOpsFn (" executorch_prim::ceil.Scalar" )(context_ , stack);
522521 EXPECT_EQ (stack[1 ]->toInt (), expected[i]);
523522 }
524523}
@@ -539,7 +538,7 @@ TEST_F(RegisterPrimOpsTest, TestRound) {
539538 stack[j] = &values[j];
540539 }
541540
542- getOpsFn (" executorch_prim::round.Scalar" )(context , stack);
541+ getOpsFn (" executorch_prim::round.Scalar" )(context_ , stack);
543542 EXPECT_EQ (stack[1 ]->toInt (), expected[i]);
544543 }
545544}
@@ -559,7 +558,7 @@ TEST_F(RegisterPrimOpsTest, TestTrunc) {
559558 stack[j] = &values[j];
560559 }
561560
562- getOpsFn (" executorch_prim::trunc.Scalar" )(context , stack);
561+ getOpsFn (" executorch_prim::trunc.Scalar" )(context_ , stack);
563562 EXPECT_EQ (stack[1 ]->toInt (), expected[i]);
564563 }
565564}
0 commit comments