@@ -39,16 +39,23 @@ class OpDiagonalCopyOutTest : public ::testing::Test {
3939 // first.
4040 torch::executor::runtime_init ();
4141 }
42+
43+ template <ScalarType DTYPE>
44+ void test_2d_dtype () {
45+ TensorFactory<DTYPE> tf;
46+
47+ Tensor input = tf.make ({3 , 4 }, {1 , 2 , 3 , 4 , 5 , 6 , 7 , 8 , 9 , 10 , 11 , 12 });
48+ Tensor out = tf.zeros ({2 });
49+ Tensor out_expected = tf.make ({2 }, {5 , 10 });
50+ op_diagonal_copy_out (input, 1 , 1 , 0 , out);
51+ EXPECT_TENSOR_CLOSE (out, out_expected);
52+ }
4253};
4354
4455TEST_F (OpDiagonalCopyOutTest, SmokeTest2D) {
45- TensorFactory<ScalarType::Float> tfFloat;
46-
47- Tensor input = tfFloat.make ({3 , 4 }, {1 , 2 , 3 , 4 , 5 , 6 , 7 , 8 , 9 , 10 , 11 , 12 });
48- Tensor out = tfFloat.zeros ({2 });
49- Tensor out_expected = tfFloat.make ({2 }, {5 , 10 });
50- op_diagonal_copy_out (input, 1 , 1 , 0 , out);
51- EXPECT_TENSOR_CLOSE (out, out_expected);
56+ #define TEST_ENTRY (ctype, dtype ) test_2d_dtype<ScalarType::dtype>();
57+ ET_FORALL_REALHBF16_TYPES (TEST_ENTRY);
58+ #undef TEST_ENTRY
5259}
5360
5461TEST_F (OpDiagonalCopyOutTest, SmokeTest3D) {
0 commit comments