@@ -53,6 +53,26 @@ class OpCatOutTest : public OperatorTest {
5353
5454 EXPECT_TENSOR_EQ (out, expected);
5555 }
56+
57+ template <ScalarType DTYPE>
58+ void test_16bit_dtype () {
59+ TensorFactory<DTYPE> tf;
60+
61+ Tensor x = tf.make ({2 , 3 }, {1.5 , -2.0 , 3.25 , 4.0 , -5.5 , 6.5 });
62+ Tensor y = tf.make ({2 , 1 }, {10.0 , 20.0 });
63+
64+ std::vector<Tensor> inputs = {x, y};
65+
66+ Tensor out = tf.zeros ({2 , 4 });
67+
68+ // Concatenate along dim[1].
69+ Tensor ret = op_cat_out (
70+ ArrayRef<Tensor>(inputs.data (), inputs.size ()), /* dim=*/ 1 , out);
71+
72+ Tensor expected =
73+ tf.make ({2 , 4 }, {1.5 , -2.0 , 3.25 , 10.0 , 4.0 , -5.5 , 6.5 , 20.0 });
74+ EXPECT_TENSOR_EQ (out, expected);
75+ }
5676};
5777
5878TEST_F (OpCatOutTest, SmokeDim1) {
@@ -105,26 +125,12 @@ TEST_F(OpCatOutTest, SmokeDim1) {
105125 EXPECT_TENSOR_EQ (out, expected);
106126}
107127
108- TEST_F (OpCatOutTest, HalfSupport ) {
128+ TEST_F (OpCatOutTest, SixteenBitFloatSupport ) {
109129 if (torch::executor::testing::SupportedFeatures::get ()->is_aten ) {
110- GTEST_SKIP () << " Test Half support only for ExecuTorch mode" ;
130+ GTEST_SKIP () << " Test Half/BF16 support only for ExecuTorch mode" ;
111131 }
112- TensorFactory<ScalarType::Half> tf;
113-
114- Tensor x = tf.make ({2 , 3 }, {1.5 , -2.0 , 3.25 , 4.0 , -5.5 , 6.5 });
115- Tensor y = tf.make ({2 , 1 }, {10.0 , 20.0 });
116-
117- std::vector<Tensor> inputs = {x, y};
118-
119- Tensor out = tf.zeros ({2 , 4 });
120-
121- // Concatenate along dim[1].
122- Tensor ret = op_cat_out (
123- ArrayRef<Tensor>(inputs.data (), inputs.size ()), /* dim=*/ 1 , out);
124-
125- Tensor expected =
126- tf.make ({2 , 4 }, {1.5 , -2.0 , 3.25 , 10.0 , 4.0 , -5.5 , 6.5 , 20.0 });
127- EXPECT_TENSOR_EQ (out, expected);
132+ test_16bit_dtype<ScalarType::Half>();
133+ test_16bit_dtype<ScalarType::BFloat16>();
128134}
129135
130136TEST_F (OpCatOutTest, NegativeDims) {
0 commit comments