@@ -25,33 +25,28 @@ class OpFloorTest : public OperatorTest {
2525 Tensor& op_floor_out (const Tensor& self, Tensor& out) {
2626 return torch::executor::aten::floor_outf (context_, self, out);
2727 }
28- };
2928
30- TEST_F (OpFloorTest, SanityCheck) {
31- TensorFactory<ScalarType::Float> tf;
29+ template <ScalarType DTYPE>
30+ void test_floor_float_dtype () {
31+ TensorFactory<DTYPE> tf;
3232
33- Tensor in = tf.make ({1 , 7 }, {-3.0 , -2.99 , -1.01 , 0.0 , 1.01 , 2.99 , 3.0 });
34- Tensor out = tf.zeros ({1 , 7 });
35- Tensor expected = tf.make ({1 , 7 }, {-3.0 , -3.0 , -2.0 , 0.0 , 1.0 , 2.0 , 3.0 });
33+ Tensor in = tf.make ({1 , 7 }, {-3.0 , -2.99 , -1.01 , 0.0 , 1.01 , 2.99 , 3.0 });
34+ Tensor out = tf.zeros ({1 , 7 });
35+ Tensor expected = tf.make ({1 , 7 }, {-3.0 , -3.0 , -2.0 , 0.0 , 1.0 , 2.0 , 3.0 });
3636
37- Tensor ret = op_floor_out (in, out);
37+ Tensor ret = op_floor_out (in, out);
3838
39- EXPECT_TENSOR_EQ (out, ret);
40- EXPECT_TENSOR_EQ (out, expected);
41- }
39+ EXPECT_TENSOR_EQ (out, ret);
40+ EXPECT_TENSOR_EQ (out, expected);
41+ }
42+ };
4243
43- TEST_F (OpFloorTest, HalfSupport) {
44+ TEST_F (OpFloorTest, AllFloatDtypeSupport) {
45+ #define TEST_ENTRY (ctype, dtype ) test_floor_float_dtype<ScalarType::dtype>();
4446 if (torch::executor::testing::SupportedFeatures::get ()->is_aten ) {
45- GTEST_SKIP () << " Test Half support only for ExecuTorch mode" ;
47+ ET_FORALL_FLOAT_TYPES (TEST_ENTRY);
48+ } else {
49+ ET_FORALL_FLOATHBF16_TYPES (TEST_ENTRY);
4650 }
47- TensorFactory<ScalarType::Half> tf;
48-
49- Tensor in = tf.make ({1 , 7 }, {-3.0 , -2.99 , -1.01 , 0.0 , 1.01 , 2.99 , 3.0 });
50- Tensor out = tf.zeros ({1 , 7 });
51- Tensor expected = tf.make ({1 , 7 }, {-3.0 , -3.0 , -2.0 , 0.0 , 1.0 , 2.0 , 3.0 });
52-
53- Tensor ret = op_floor_out (in, out);
54-
55- EXPECT_TENSOR_EQ (out, ret);
56- EXPECT_TENSOR_EQ (out, expected);
51+ #undef TEST_ENTRY
5752}
0 commit comments