@@ -24,8 +24,44 @@ class OpAbsTest : public OperatorTest {
2424 Tensor& op_abs_out (const Tensor& self, Tensor& out) {
2525 return torch::executor::aten::abs_outf (context_, self, out);
2626 }
27+
28+ template <ScalarType DTYPE>
29+ void test_dtype () {
30+ TensorFactory<DTYPE> tf;
31+
32+ Tensor in = tf.make ({2 , 3 }, {-3 , -2 , -1 , 0 , 1 , 2 });
33+ Tensor out = tf.zeros ({2 , 3 });
34+ Tensor expected = tf.make ({2 , 3 }, {3 , 2 , 1 , 0 , 1 , 2 });
35+
36+ Tensor ret = op_abs_out (in, out);
37+
38+ EXPECT_TENSOR_EQ (out, ret);
39+ EXPECT_TENSOR_EQ (out, expected);
40+ }
41+
42+ template <>
43+ void test_dtype<ScalarType::Byte>() {
44+ TensorFactory<ScalarType::Byte> tf;
45+
46+ Tensor in = tf.make ({2 , 3 }, {253 , 254 , 255 , 0 , 1 , 2 });
47+ Tensor out = tf.zeros ({2 , 3 });
48+ Tensor expected = tf.make ({2 , 3 }, {253 , 254 , 255 , 0 , 1 , 2 });
49+
50+ Tensor ret = op_abs_out (in, out);
51+
52+ EXPECT_TENSOR_EQ (out, ret);
53+ EXPECT_TENSOR_EQ (out, expected);
54+ }
2755};
2856
57+ TEST_F (OpAbsTest, AllRealHBF16Input) {
58+ #define TEST_KERNEL (INPUT_CTYPE, INPUT_DTYPE ) \
59+ test_dtype<ScalarType::INPUT_DTYPE>();
60+
61+ ET_FORALL_REALHBF16_TYPES (TEST_KERNEL);
62+ #undef TEST_KERNEL
63+ }
64+
2965TEST_F (OpAbsTest, SanityCheck) {
3066 TensorFactory<ScalarType::Float> tf;
3167
0 commit comments