Skip to content

Commit d0a9ebe

Browse files
swolchokYIWENX14
authored andcommitted
Support BFloat16 in cat (#7795)
Partial fix for #7748.
1 parent af99313 commit d0a9ebe

File tree

2 files changed

+26
-20
lines changed

2 files changed

+26
-20
lines changed

kernels/portable/cpu/op_cat.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -56,12 +56,12 @@ Tensor& cat_out(
5656
const size_t ninputs = tensors.size();
5757

5858
const auto out_type = out.scalar_type();
59-
ET_SWITCH_REALHB_TYPES(out_type, ctx, "cat.out", CTYPE_OUT, [&] {
59+
ET_SWITCH_REALHBBF16_TYPES(out_type, ctx, "cat.out", CTYPE_OUT, [&] {
6060
CTYPE_OUT* out_ptr = out.mutable_data_ptr<CTYPE_OUT>();
6161
for (size_t i = 0; i < outer; ++i) {
6262
for (size_t j = 0; j < ninputs; ++j) {
6363
const auto in_type = tensors[j].scalar_type();
64-
ET_SWITCH_REALHB_TYPES(in_type, ctx, "cat.out", CTYPE_IN, [&] {
64+
ET_SWITCH_REALHBBF16_TYPES(in_type, ctx, "cat.out", CTYPE_IN, [&] {
6565
if (tensors[j].numel() == 0) {
6666
return;
6767
}

kernels/test/op_cat_test.cpp

Lines changed: 24 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,26 @@ class OpCatOutTest : public OperatorTest {
5353

5454
EXPECT_TENSOR_EQ(out, expected);
5555
}
56+
57+
template <ScalarType DTYPE>
58+
void test_16bit_dtype() {
59+
TensorFactory<DTYPE> tf;
60+
61+
Tensor x = tf.make({2, 3}, {1.5, -2.0, 3.25, 4.0, -5.5, 6.5});
62+
Tensor y = tf.make({2, 1}, {10.0, 20.0});
63+
64+
std::vector<Tensor> inputs = {x, y};
65+
66+
Tensor out = tf.zeros({2, 4});
67+
68+
// Concatenate along dim[1].
69+
Tensor ret = op_cat_out(
70+
ArrayRef<Tensor>(inputs.data(), inputs.size()), /*dim=*/1, out);
71+
72+
Tensor expected =
73+
tf.make({2, 4}, {1.5, -2.0, 3.25, 10.0, 4.0, -5.5, 6.5, 20.0});
74+
EXPECT_TENSOR_EQ(out, expected);
75+
}
5676
};
5777

5878
TEST_F(OpCatOutTest, SmokeDim1) {
@@ -105,26 +125,12 @@ TEST_F(OpCatOutTest, SmokeDim1) {
105125
EXPECT_TENSOR_EQ(out, expected);
106126
}
107127

108-
TEST_F(OpCatOutTest, HalfSupport) {
128+
TEST_F(OpCatOutTest, SixteenBitFloatSupport) {
109129
if (torch::executor::testing::SupportedFeatures::get()->is_aten) {
110-
GTEST_SKIP() << "Test Half support only for ExecuTorch mode";
130+
GTEST_SKIP() << "Test Half/BF16 support only for ExecuTorch mode";
111131
}
112-
TensorFactory<ScalarType::Half> tf;
113-
114-
Tensor x = tf.make({2, 3}, {1.5, -2.0, 3.25, 4.0, -5.5, 6.5});
115-
Tensor y = tf.make({2, 1}, {10.0, 20.0});
116-
117-
std::vector<Tensor> inputs = {x, y};
118-
119-
Tensor out = tf.zeros({2, 4});
120-
121-
// Concatenate along dim[1].
122-
Tensor ret = op_cat_out(
123-
ArrayRef<Tensor>(inputs.data(), inputs.size()), /*dim=*/1, out);
124-
125-
Tensor expected =
126-
tf.make({2, 4}, {1.5, -2.0, 3.25, 10.0, 4.0, -5.5, 6.5, 20.0});
127-
EXPECT_TENSOR_EQ(out, expected);
132+
test_16bit_dtype<ScalarType::Half>();
133+
test_16bit_dtype<ScalarType::BFloat16>();
128134
}
129135

130136
TEST_F(OpCatOutTest, NegativeDims) {

0 commit comments

Comments
 (0)