Skip to content

Commit 0a72cb0

Browse files
swolchokfacebook-github-bot
authored andcommitted
Support bfloat16 in op_index (pytorch#5499)
Summary: Pull Request resolved: pytorch#5499 Seems to block bfloat16 stories110M as exported by torchchat (and we should have op coverage for bfloat16 anyway). ghstack-source-id: 243857968 Reviewed By: larryliu0820 Differential Revision: D63054001 fbshipit-source-id: 530b479872643f878912592c7b260d71e6e05804
1 parent 2eae7a9 commit 0a72cb0

File tree

2 files changed

+2
-2
lines changed

2 files changed

+2
-2
lines changed

kernels/portable/cpu/op_index.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,7 @@ Tensor& index_Tensor_out(
8989
compute_dim_map(in, indices, dim_map, block_count == 1);
9090
compute_index_map(in, indices, ix_map);
9191

92-
ET_SWITCH_REALHB_TYPES(in_type, ctx, "index.Tensor_out", CTYPE, [&]() {
92+
ET_SWITCH_REALHBBF16_TYPES(in_type, ctx, "index.Tensor_out", CTYPE, [&]() {
9393
const CTYPE* const in_data = in.const_data_ptr<CTYPE>();
9494
CTYPE* const out_data = out.mutable_data_ptr<CTYPE>();
9595

kernels/test/op_index_test.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,7 @@ class OpIndexTensorOutTest : public OperatorTest {
107107
#define TEST_ENTRY(ctype, dtype) \
108108
test_dtype<ScalarType::dtype, ScalarType::Long, ScalarType::dtype>();
109109

110-
ET_FORALL_REAL_TYPES_AND(Bool, TEST_ENTRY);
110+
ET_FORALL_REALHBF16_TYPES(TEST_ENTRY);
111111

112112
#undef TEST_ENTRY
113113
}

0 commit comments

Comments
 (0)