@@ -33,20 +33,26 @@ class OpFlipOutTest : public ::testing::Test {
3333 // first.
3434 torch::executor::runtime_init ();
3535 }
36+
37+ template <ScalarType DTYPE>
38+ void test_1d_dtype () {
39+ TensorFactory<DTYPE> tf;
40+
41+ Tensor input = tf.make ({4 , 1 , 3 }, {1 , 2 , 3 , 4 , 5 , 6 , 7 , 8 , 9 , 10 , 11 , 12 });
42+ int64_t dims_data[1 ] = {-1 };
43+ IntArrayRef dims = IntArrayRef (dims_data, 1 );
44+ Tensor out = tf.zeros ({4 , 1 , 3 });
45+ Tensor out_expected =
46+ tf.make ({4 , 1 , 3 }, {3 , 2 , 1 , 6 , 5 , 4 , 9 , 8 , 7 , 12 , 11 , 10 });
47+ op_flip_out (input, dims, out);
48+ EXPECT_TENSOR_CLOSE (out, out_expected);
49+ }
3650};
3751
3852TEST_F (OpFlipOutTest, SmokeTest1Dim) {
39- TensorFactory<ScalarType::Float> tfFloat;
40-
41- Tensor input =
42- tfFloat.make ({4 , 1 , 3 }, {1 , 2 , 3 , 4 , 5 , 6 , 7 , 8 , 9 , 10 , 11 , 12 });
43- int64_t dims_data[1 ] = {-1 };
44- IntArrayRef dims = IntArrayRef (dims_data, 1 );
45- Tensor out = tfFloat.zeros ({4 , 1 , 3 });
46- Tensor out_expected =
47- tfFloat.make ({4 , 1 , 3 }, {3 , 2 , 1 , 6 , 5 , 4 , 9 , 8 , 7 , 12 , 11 , 10 });
48- op_flip_out (input, dims, out);
49- EXPECT_TENSOR_CLOSE (out, out_expected);
53+ #define TEST_ENTRY (ctype, dtype ) test_1d_dtype<ScalarType::dtype>();
54+ ET_FORALL_REALHBF16_TYPES (TEST_ENTRY);
55+ #undef TEST_ENTRY
5056}
5157
5258TEST_F (OpFlipOutTest, SmokeTest2Dims) {
0 commit comments