Skip to content

Commit c86a2f6

Browse files
authored
Support BFloat16 in diagonal_copy (#7808)
Partial fix for #7748.
1 parent f633347 commit c86a2f6

File tree

2 files changed

+15
-8
lines changed

2 files changed

+15
-8
lines changed

kernels/portable/cpu/op_diagonal_copy.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,7 @@ Tensor& diagonal_copy_out(
9898

9999
constexpr auto name = "diagonal_copy.out";
100100

101-
ET_SWITCH_REALHB_TYPES(in.scalar_type(), ctx, name, CTYPE, [&] {
101+
ET_SWITCH_REALHBBF16_TYPES(in.scalar_type(), ctx, name, CTYPE, [&] {
102102
diagonal_copy_impl<CTYPE>(in, offset, dim1, dim2, out);
103103
});
104104

kernels/test/op_diagonal_copy_test.cpp

Lines changed: 14 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -39,16 +39,23 @@ class OpDiagonalCopyOutTest : public ::testing::Test {
3939
// first.
4040
torch::executor::runtime_init();
4141
}
42+
43+
template <ScalarType DTYPE>
44+
void test_2d_dtype() {
45+
TensorFactory<DTYPE> tf;
46+
47+
Tensor input = tf.make({3, 4}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12});
48+
Tensor out = tf.zeros({2});
49+
Tensor out_expected = tf.make({2}, {5, 10});
50+
op_diagonal_copy_out(input, 1, 1, 0, out);
51+
EXPECT_TENSOR_CLOSE(out, out_expected);
52+
}
4253
};
4354

4455
TEST_F(OpDiagonalCopyOutTest, SmokeTest2D) {
45-
TensorFactory<ScalarType::Float> tfFloat;
46-
47-
Tensor input = tfFloat.make({3, 4}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12});
48-
Tensor out = tfFloat.zeros({2});
49-
Tensor out_expected = tfFloat.make({2}, {5, 10});
50-
op_diagonal_copy_out(input, 1, 1, 0, out);
51-
EXPECT_TENSOR_CLOSE(out, out_expected);
56+
#define TEST_ENTRY(ctype, dtype) test_2d_dtype<ScalarType::dtype>();
57+
ET_FORALL_REALHBF16_TYPES(TEST_ENTRY);
58+
#undef TEST_ENTRY
5259
}
5360

5461
TEST_F(OpDiagonalCopyOutTest, SmokeTest3D) {

0 commit comments

Comments
 (0)