Skip to content

Commit 99e6349

Browse files
[ET][Kernels] Increase Half/Bfloat16 support (#13719)
Add Half/Bfloat16 dtype support for the following ops: - bmm.out - max.dim_max - min.dim_min - scatter_add.out Differential Revision: D80963875 @diff-train-skip-merge Co-authored-by: Manuel Candales <[email protected]> Co-authored-by: Manuel Candales <[email protected]>
1 parent 3eb7947 commit 99e6349

File tree

9 files changed

+28
-28
lines changed

9 files changed

+28
-28
lines changed

kernels/optimized/cpu/op_bmm.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -158,7 +158,7 @@ Tensor& opt_bmm_out(
158158
bmm_kernel<CTYPE>(self, mat2, out);
159159
});
160160
} else {
161-
ET_SWITCH_REALH_TYPES(self_type, ctx, name, CTYPE, [&]() {
161+
ET_SWITCH_REALHBF16_TYPES(self_type, ctx, name, CTYPE, [&]() {
162162
bmm_kernel<CTYPE>(self, mat2, out);
163163
});
164164
}

kernels/portable/cpu/op_bmm.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ Tensor& bmm_out(
4646
internal::bmm_out_impl<CTYPE>(in, mat2, out);
4747
});
4848
} else {
49-
ET_SWITCH_REALH_TYPES(in_type, ctx, op_name, CTYPE, [&]() {
49+
ET_SWITCH_REALHBF16_TYPES(in_type, ctx, op_name, CTYPE, [&]() {
5050
internal::bmm_out_impl<CTYPE>(in, mat2, out);
5151
});
5252
}

kernels/portable/cpu/op_max.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -79,8 +79,8 @@ std::tuple<Tensor&, Tensor&> max_out(
7979

8080
dim = dim < 0 ? dim + in.dim() : dim;
8181

82-
ET_SWITCH_REAL_TYPES_AND(
83-
Bool, in.scalar_type(), ctx, "max.dim_max", CTYPE, [&]() {
82+
ET_SWITCH_REALHBBF16_TYPES(
83+
in.scalar_type(), ctx, "max.dim_max", CTYPE, [&]() {
8484
CTYPE* max_data = max.mutable_data_ptr<CTYPE>();
8585
long* max_indices_data = max_indices.mutable_data_ptr<long>();
8686

kernels/portable/cpu/op_min.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -79,8 +79,8 @@ std::tuple<Tensor&, Tensor&> min_out(
7979

8080
dim = dim < 0 ? dim + in.dim() : dim;
8181

82-
ET_SWITCH_REAL_TYPES_AND(
83-
Bool, in.scalar_type(), ctx, "min.dim_min", CTYPE, [&]() {
82+
ET_SWITCH_REALHBBF16_TYPES(
83+
in.scalar_type(), ctx, "min.dim_min", CTYPE, [&]() {
8484
CTYPE* min_data = min.mutable_data_ptr<CTYPE>();
8585
long* min_indices_data = min_indices.mutable_data_ptr<long>();
8686

kernels/portable/cpu/op_scatter_add.cpp

Lines changed: 18 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -79,24 +79,24 @@ Tensor& scatter_add_out(
7979

8080
ScalarType self_type = self.scalar_type();
8181

82-
ET_SWITCH_REAL_TYPES_AND(
83-
Bool, self_type, ctx, "scatter_add.out", CTYPE, [&]() {
84-
const CTYPE* self_data = self.const_data_ptr<CTYPE>();
85-
const long* index_data = index.const_data_ptr<long>();
86-
const CTYPE* src_data = src.const_data_ptr<CTYPE>();
87-
CTYPE* out_data = out.mutable_data_ptr<CTYPE>();
88-
89-
memcpy(out_data, self_data, self.nbytes());
90-
91-
if (index.numel() != 0) {
92-
if (self.dim() == 0) {
93-
out_data[0] += nonempty_size(index, 0) * src_data[0];
94-
} else {
95-
scatter_add_helper<CTYPE>(
96-
src_data, index_data, out_data, src, index, out, dim);
97-
}
98-
}
99-
});
82+
ET_SWITCH_REALHBBF16_TYPES(self_type, ctx, "scatter_add.out", CTYPE, [&]() {
83+
const CTYPE* self_data = self.const_data_ptr<CTYPE>();
84+
const long* index_data = index.const_data_ptr<long>();
85+
const CTYPE* src_data = src.const_data_ptr<CTYPE>();
86+
CTYPE* out_data = out.mutable_data_ptr<CTYPE>();
87+
88+
memcpy(out_data, self_data, self.nbytes());
89+
90+
if (index.numel() != 0) {
91+
if (self.dim() == 0) {
92+
out_data[0] +=
93+
static_cast<CTYPE>(nonempty_size(index, 0)) * src_data[0];
94+
} else {
95+
scatter_add_helper<CTYPE>(
96+
src_data, index_data, out_data, src, index, out, dim);
97+
}
98+
}
99+
});
100100

101101
return out;
102102
}

kernels/test/op_bmm_test.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -189,7 +189,7 @@ TEST_F(OpBmmOutTest, OutputDimFloat) {
189189
/// zeros().
190190
TEST_F(OpBmmOutTest, AllRealDtypesSupported) {
191191
#define TEST_ENTRY(ctype, dtype) test_dtype<ctype, ScalarType::dtype>();
192-
ET_FORALL_REAL_TYPES(TEST_ENTRY);
192+
ET_FORALL_REALHBF16_TYPES(TEST_ENTRY);
193193
#undef TEST_ENTRY
194194
// TODO: Also add tests for half, complex, quantized, and other types. Easiest
195195
// way to do that would be to make TensorFactory support zeros() and ones()

kernels/test/op_max_test.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -316,7 +316,7 @@ TEST_F(OpMaxOutTest, MismatchedDTypesDies) {
316316

317317
TEST_F(OpMaxOutTest, AllRealInputLongOutputPasses) {
318318
#define TEST_ENTRY(ctype, dtype) test_max_out_dtype<ScalarType::dtype>();
319-
ET_FORALL_REAL_TYPES_AND(Bool, TEST_ENTRY);
319+
ET_FORALL_REALHBBF16_TYPES(TEST_ENTRY);
320320
#undef TEST_ENTRY
321321
}
322322

kernels/test/op_min_test.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -312,7 +312,7 @@ TEST_F(OpMinOutTest, MismatchedDTypesDies) {
312312

313313
TEST_F(OpMinOutTest, AllRealInputLongOutputPasses) {
314314
#define TEST_ENTRY(ctype, dtype) test_min_out_dtype<ScalarType::dtype>();
315-
ET_FORALL_REAL_TYPES_AND(Bool, TEST_ENTRY);
315+
ET_FORALL_REALHBBF16_TYPES(TEST_ENTRY);
316316
#undef TEST_ENTRY
317317
}
318318

kernels/test/op_scatter_add_test.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -281,7 +281,7 @@ class OpScatterAddOutTest : public OperatorTest {
281281

282282
TEST_F(OpScatterAddOutTest, AllValidInputOutputSupport) {
283283
#define TEST_ENTRY(CTYPE, DTYPE) test_scatter_add_out<ScalarType::DTYPE>();
284-
ET_FORALL_REAL_TYPES(TEST_ENTRY);
284+
ET_FORALL_REALHBF16_TYPES(TEST_ENTRY);
285285
#undef TEST_ENTRY
286286
}
287287

0 commit comments

Comments
 (0)