Skip to content

Commit abc08a3

Browse files
swolchokYIWENX14
authored andcommitted
Support BFloat16 in gather (#7823)
Partial fix for #7748.
1 parent 314baeb commit abc08a3

File tree

3 files changed

+3
-2
lines changed

3 files changed

+3
-2
lines changed

kernels/portable/cpu/op_gather.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,7 @@ Tensor& gather_out(
8686

8787
constexpr auto name = "gather.out";
8888

89-
ET_SWITCH_REALHB_TYPES(in.scalar_type(), ctx, name, CTYPE, [&]() {
89+
ET_SWITCH_REALHBBF16_TYPES(in.scalar_type(), ctx, name, CTYPE, [&]() {
9090
gather_helper<CTYPE>(in, index, out, dim);
9191
});
9292

kernels/test/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -139,6 +139,7 @@ set(all_test_sources
139139
"op_fmod_test.cpp"
140140
"op_full_like_test.cpp"
141141
"op_full_test.cpp"
142+
"op_gather_test.cpp"
142143
"op_ge_test.cpp"
143144
"op_gelu_test.cpp"
144145
"op_glu_test.cpp"

kernels/test/op_gather_test.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -194,7 +194,7 @@ class OpGatherOutTest : public OperatorTest {
194194

195195
TEST_F(OpGatherOutTest, AllValidInputOutputSupport) {
196196
#define TEST_ENTRY(CTYPE, DTYPE) test_gather_out<ScalarType::DTYPE>();
197-
ET_FORALL_REAL_TYPES(TEST_ENTRY);
197+
ET_FORALL_REALHBF16_TYPES(TEST_ENTRY);
198198
#undef TEST_ENTRY
199199
}
200200

0 commit comments

Comments
 (0)