@@ -25,66 +25,29 @@ class OpIsNanTest : public OperatorTest {
2525 Tensor& op_isnan_out (const Tensor& self, Tensor& out) {
2626 return torch::executor::aten::isnan_outf (context_, self, out);
2727 }
28- };
29-
30- TEST_F (OpIsNanTest, SanityCheckFloat) {
31- TensorFactory<ScalarType::Float> tf;
32- TensorFactory<ScalarType::Bool> tfb;
3328
34- Tensor in = tf. make (
35- { 1 , 5 }, {- 1.0 , 0.0 , 1.0 , NAN, std::numeric_limits< float >:: infinity ()});
36- Tensor out = tfb. zeros ({ 1 , 5 }) ;
37- Tensor expected = tfb. make ({ 1 , 5 }, { false , false , false , true , false }) ;
29+ template <ScalarType DTYPE>
30+ void test_sanity_check () {
31+ TensorFactory<DTYPE> tf ;
32+ TensorFactory<ScalarType::Bool> tfb;
3833
39- Tensor ret = op_isnan_out (in, out);
34+ using CTYPE = typename TensorFactory<DTYPE>::ctype;
35+ Tensor in = tf.make (
36+ {1 , 5 }, {-1 , 0 , 1 , NAN, std::numeric_limits<CTYPE>::infinity ()});
37+ Tensor out = tfb.zeros ({1 , 5 });
38+ Tensor expected = tfb.make ({1 , 5 }, {false , false , false , true , false });
4039
41- EXPECT_TENSOR_EQ (out, ret);
42- EXPECT_TENSOR_EQ (out, expected);
43- }
40+ Tensor ret = op_isnan_out (in, out);
4441
45- TEST_F (OpIsNanTest, SanityCheckHalf) {
46- if (torch::executor::testing::SupportedFeatures::get ()->is_aten ) {
47- GTEST_SKIP () << " Test Half support only for ExecuTorch mode" ;
42+ EXPECT_TENSOR_EQ (out, ret);
43+ EXPECT_TENSOR_EQ (out, expected);
4844 }
49- TensorFactory<ScalarType::Float> tf;
50- TensorFactory<ScalarType::Bool> tfb;
51-
52- Tensor in = tf.make (
53- {1 , 5 }, {-1.0 , 0.0 , 1.0 , NAN, std::numeric_limits<float >::infinity ()});
54- Tensor out = tfb.zeros ({1 , 5 });
55- Tensor expected = tfb.make ({1 , 5 }, {false , false , false , true , false });
56-
57- Tensor ret = op_isnan_out (in, out);
58-
59- EXPECT_TENSOR_EQ (out, ret);
60- EXPECT_TENSOR_EQ (out, expected);
61- }
62-
63- TEST_F (OpIsNanTest, SanityCheckByte) {
64- TensorFactory<ScalarType::Byte> tf;
65- TensorFactory<ScalarType::Bool> tfb;
66-
67- Tensor in = tf.make ({1 , 5 }, {1 , 2 , 3 , 4 , 5 });
68- Tensor out = tfb.zeros ({1 , 5 });
69- Tensor expected = tfb.make ({1 , 5 }, {false , false , false , false , false });
70-
71- Tensor ret = op_isnan_out (in, out);
72-
73- EXPECT_TENSOR_EQ (out, ret);
74- EXPECT_TENSOR_EQ (out, expected);
75- }
76-
77- TEST_F (OpIsNanTest, SanityCheckBool) {
78- TensorFactory<ScalarType::Bool> tfb;
79-
80- Tensor in = tfb.make ({1 , 5 }, {true , false , true , true , false });
81- Tensor out = tfb.zeros ({1 , 5 });
82- Tensor expected = tfb.make ({1 , 5 }, {false , false , false , false , false });
83-
84- Tensor ret = op_isnan_out (in, out);
45+ };
8546
86- EXPECT_TENSOR_EQ (out, ret);
87- EXPECT_TENSOR_EQ (out, expected);
47+ TEST_F (OpIsNanTest, SanityCheck) {
48+ #define TEST_ENTRY (ctype, dtype ) test_sanity_check<ScalarType::dtype>();
49+ ET_FORALL_FLOATHBF16_TYPES (TEST_ENTRY);
50+ #undef TEST_ENTRY
8851}
8952
9053TEST_F (OpIsNanTest, SanityCheckOutDtype) {
0 commit comments