@@ -37,18 +37,26 @@ class OpRollOutTest : public ::testing::Test {
3737 // first.
3838 torch::executor::runtime_init ();
3939 }
40+
41+ template <ScalarType DTYPE>
42+ void test_dtype () {
43+ TensorFactory<DTYPE> tf;
44+
45+ Tensor input = tf.make ({4 , 2 }, {1 , 2 , 3 , 4 , 5 , 6 , 7 , 8 });
46+ int64_t shifts_data[2 ] = {2 , 1 };
47+ ArrayRef<int64_t > shifts = ArrayRef<int64_t >(shifts_data, 2 );
48+ int64_t dims_data[2 ] = {0 , 1 };
49+ ArrayRef<int64_t > dims = ArrayRef<int64_t >(dims_data, 2 );
50+ Tensor out = tf.zeros ({4 , 2 });
51+ Tensor out_expected = tf.make ({4 , 2 }, {6 , 5 , 8 , 7 , 2 , 1 , 4 , 3 });
52+ op_roll_out (input, shifts, dims, out);
53+ EXPECT_TENSOR_CLOSE (out, out_expected);
54+ }
4055};
4156
4257TEST_F (OpRollOutTest, SmokeTest) {
43- TensorFactory<ScalarType::Float> tfFloat;
44-
45- Tensor input = tfFloat.make ({4 , 2 }, {1 , 2 , 3 , 4 , 5 , 6 , 7 , 8 });
46- int64_t shifts_data[2 ] = {2 , 1 };
47- ArrayRef<int64_t > shifts = ArrayRef<int64_t >(shifts_data, 2 );
48- int64_t dims_data[2 ] = {0 , 1 };
49- ArrayRef<int64_t > dims = ArrayRef<int64_t >(dims_data, 2 );
50- Tensor out = tfFloat.zeros ({4 , 2 });
51- Tensor out_expected = tfFloat.make ({4 , 2 }, {6 , 5 , 8 , 7 , 2 , 1 , 4 , 3 });
52- op_roll_out (input, shifts, dims, out);
53- EXPECT_TENSOR_CLOSE (out, out_expected);
58+ #define TEST_ENTRY (ctype, dtype ) test_dtype<ScalarType::dtype>();
59+ // TODO: enable bool test after #7856 lands.
60+ ET_FORALL_REALHBF16_TYPES (TEST_ENTRY);
61+ #undef TEST_ENTRY
5462}
0 commit comments