@@ -118,32 +118,39 @@ class OpTopkValuesTest : public ::testing::Test {
118118 // first.
119119 torch::executor::runtime_init ();
120120 }
121+
122+ template <ScalarType DTYPE>
123+ void run_smoke_test () {
124+ TensorFactory<DTYPE> tfDtype;
125+ TensorFactory<ScalarType::Long> tfLong;
126+
127+ Tensor input =
128+ tfDtype.make ({3 , 2 , 2 }, {1 , 2 , 3 , 4 , 5 , 6 , 7 , 8 , 9 , 10 , 11 , 12 });
129+ int64_t k = 2 ;
130+ int64_t dim = 0 ;
131+ bool largest = true ;
132+ bool sorted = true ;
133+ Tensor values = tfDtype.zeros ({2 , 2 , 2 });
134+ Tensor indices = tfLong.zeros ({2 , 2 , 2 });
135+ Tensor values_expected =
136+ tfDtype.make ({2 , 2 , 2 }, {9 , 10 , 11 , 12 , 5 , 6 , 7 , 8 });
137+ Tensor indices_expected = tfLong.make ({2 , 2 , 2 }, {2 , 2 , 2 , 2 , 1 , 1 , 1 , 1 });
138+ op_topk_values (input, k, dim, largest, sorted, values, indices);
139+ EXPECT_TENSOR_CLOSE (values, values_expected);
140+ EXPECT_TENSOR_EQ (indices, indices_expected);
141+
142+ largest = false ;
143+ values_expected = tfDtype.make ({2 , 2 , 2 }, {1 , 2 , 3 , 4 , 5 , 6 , 7 , 8 });
144+ indices_expected = tfLong.make ({2 , 2 , 2 }, {0 , 0 , 0 , 0 , 1 , 1 , 1 , 1 });
145+ op_topk_values (input, k, dim, largest, sorted, values, indices);
146+ EXPECT_TENSOR_CLOSE (values, values_expected);
147+ EXPECT_TENSOR_EQ (indices, indices_expected);
148+ }
121149};
122150
123151TEST_F (OpTopkValuesTest, SmokeTest) {
124- TensorFactory<ScalarType::Float> tfFloat;
125- TensorFactory<ScalarType::Long> tfLong;
126-
127- Tensor input =
128- tfFloat.make ({3 , 2 , 2 }, {1 , 2 , 3 , 4 , 5 , 6 , 7 , 8 , 9 , 10 , 11 , 12 });
129- int64_t k = 2 ;
130- int64_t dim = 0 ;
131- bool largest = true ;
132- bool sorted = true ;
133- Tensor values = tfFloat.zeros ({2 , 2 , 2 });
134- Tensor indices = tfLong.zeros ({2 , 2 , 2 });
135- Tensor values_expected = tfFloat.make ({2 , 2 , 2 }, {9 , 10 , 11 , 12 , 5 , 6 , 7 , 8 });
136- Tensor indices_expected = tfLong.make ({2 , 2 , 2 }, {2 , 2 , 2 , 2 , 1 , 1 , 1 , 1 });
137- op_topk_values (input, k, dim, largest, sorted, values, indices);
138- EXPECT_TENSOR_CLOSE (values, values_expected);
139- EXPECT_TENSOR_EQ (indices, indices_expected);
140-
141- largest = false ;
142- values_expected = tfFloat.make ({2 , 2 , 2 }, {1 , 2 , 3 , 4 , 5 , 6 , 7 , 8 });
143- indices_expected = tfLong.make ({2 , 2 , 2 }, {0 , 0 , 0 , 0 , 1 , 1 , 1 , 1 });
144- op_topk_values (input, k, dim, largest, sorted, values, indices);
145- EXPECT_TENSOR_CLOSE (values, values_expected);
146- EXPECT_TENSOR_EQ (indices, indices_expected);
152+ #define RUN_SMOKE_TEST (ctype, dtype ) run_smoke_test<ScalarType::dtype>();
153+ ET_FORALL_REALHBF16_TYPES (RUN_SMOKE_TEST);
147154}
148155
149156TEST_F (OpTopkValuesTest, NonPartialSort) {
0 commit comments