@@ -25,22 +25,31 @@ class OpSignTest : public OperatorTest {
2525 Tensor& op_sign_out (const Tensor& self, Tensor& out) {
2626 return torch::executor::aten::sign_outf (context_, self, out);
2727 }
28+
29+ template <typename CTYPE, ScalarType DTYPE>
30+ void test_et_dtype () {
31+ TensorFactory<DTYPE> tf;
32+
33+ const auto infinity = std::numeric_limits<CTYPE>::infinity ();
34+ const auto nan = std::numeric_limits<CTYPE>::quiet_NaN ();
35+ Tensor in = tf.make ({1 , 7 }, {-infinity, -3 ., -1.5 , 0 ., 1.5 , nan, infinity});
36+ Tensor out = tf.zeros ({1 , 7 });
37+ Tensor expected = tf.make ({1 , 7 }, {-1 ., -1 ., -1 ., 0 ., 1 ., nan, 1 .});
38+
39+ Tensor ret = op_sign_out (in, out);
40+
41+ EXPECT_TENSOR_EQ (out, ret);
42+ EXPECT_TENSOR_CLOSE (out, expected);
43+ }
2844};
2945
3046TEST_F (OpSignTest, ETSanityCheckFloat) {
3147 if (torch::executor::testing::SupportedFeatures::get ()->is_aten ) {
3248 GTEST_SKIP () << " ATen returns 0 on NAN input" ;
3349 }
34- TensorFactory<ScalarType::Float> tf;
35-
36- Tensor in = tf.make ({1 , 7 }, {-INFINITY, -3 ., -1.5 , 0 ., 1.5 , NAN, INFINITY});
37- Tensor out = tf.zeros ({1 , 7 });
38- Tensor expected = tf.make ({1 , 7 }, {-1 ., -1 ., -1 ., 0 ., 1 ., NAN, 1 .});
39-
40- Tensor ret = op_sign_out (in, out);
41-
42- EXPECT_TENSOR_EQ (out, ret);
43- EXPECT_TENSOR_CLOSE (out, expected);
50+ #define TEST_ENTRY (ctype, dtype ) test_et_dtype<ctype, ScalarType::dtype>();
51+ ET_FORALL_FLOATHBF16_TYPES (TEST_ENTRY);
52+ #undef TEST_ENTRY
4453}
4554
4655TEST_F (OpSignTest, ATenSanityCheckFloat) {
0 commit comments