Skip to content

Commit badd76e

Browse files
swolchokfacebook-github-bot
authored andcommitted
Support bfloat16 in op_index_put (pytorch#5500)
Summary: Pull Request resolved: pytorch#5500 ghstack-source-id: 243857969 Reviewed By: digantdesai, larryliu0820 Differential Revision: D63057744 fbshipit-source-id: 9e1fb6f6479adb1575c5aed61b9da3c774586ba3
1 parent 0a72cb0 commit badd76e

File tree

2 files changed

+3
-3
lines changed

2 files changed

+3
-3
lines changed

kernels/portable/cpu/op_index_put.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ Tensor& index_put_out(
5353
ET_KERNEL_CHECK(
5454
ctx, tensor_is_broadcastable_to(values, out), InvalidArgument, out);
5555

56-
ET_SWITCH_REALHB_TYPES(in_type, ctx, "index_put.out", CTYPE, [&]() {
56+
ET_SWITCH_REALHBBF16_TYPES(in_type, ctx, "index_put.out", CTYPE, [&]() {
5757
apply_binary_elementwise_fn<CTYPE, CTYPE, CTYPE>(
5858
[accumulate](const CTYPE val_in, const CTYPE val) {
5959
return accumulate ? val_in + val : val;
@@ -120,7 +120,7 @@ Tensor& index_put_out(
120120
x_numel *= x_sizes[i];
121121
}
122122

123-
ET_SWITCH_REALHB_TYPES(in_type, ctx, "index_put.out", CTYPE, [&]() {
123+
ET_SWITCH_REALHBBF16_TYPES(in_type, ctx, "index_put.out", CTYPE, [&]() {
124124
const CTYPE* const values_data = values.const_data_ptr<CTYPE>();
125125
CTYPE* const out_data = out.mutable_data_ptr<CTYPE>();
126126

kernels/test/op_index_put_test.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -707,7 +707,7 @@ TEST_F(OpIndexPutOutTest, AllDtypesSupportedForInput) {
707707
#define TEST_ENTRY(ctype, dtype) \
708708
test_dtype<ScalarType::dtype, ScalarType::Long>();
709709

710-
ET_FORALL_REAL_TYPES_AND(Bool, TEST_ENTRY);
710+
ET_FORALL_REALHBBF16_TYPES(TEST_ENTRY);
711711

712712
#undef TEST_ENTRY
713713
}

0 commit comments

Comments
 (0)