diff --git a/kernels/optimized/cpu/op_bmm.cpp b/kernels/optimized/cpu/op_bmm.cpp index 51e86d54e60..9cbd30cb6e1 100644 --- a/kernels/optimized/cpu/op_bmm.cpp +++ b/kernels/optimized/cpu/op_bmm.cpp @@ -158,7 +158,7 @@ Tensor& opt_bmm_out( bmm_kernel(self, mat2, out); }); } else { - ET_SWITCH_REALH_TYPES(self_type, ctx, name, CTYPE, [&]() { + ET_SWITCH_REALHBF16_TYPES(self_type, ctx, name, CTYPE, [&]() { bmm_kernel(self, mat2, out); }); } diff --git a/kernels/portable/cpu/op_bmm.cpp b/kernels/portable/cpu/op_bmm.cpp index 060b92a0da2..ea4c1b52740 100644 --- a/kernels/portable/cpu/op_bmm.cpp +++ b/kernels/portable/cpu/op_bmm.cpp @@ -46,7 +46,7 @@ Tensor& bmm_out( internal::bmm_out_impl(in, mat2, out); }); } else { - ET_SWITCH_REALH_TYPES(in_type, ctx, op_name, CTYPE, [&]() { + ET_SWITCH_REALHBF16_TYPES(in_type, ctx, op_name, CTYPE, [&]() { internal::bmm_out_impl(in, mat2, out); }); } diff --git a/kernels/portable/cpu/op_max.cpp b/kernels/portable/cpu/op_max.cpp index 3f4a1d27c0e..09eceb7eb2d 100644 --- a/kernels/portable/cpu/op_max.cpp +++ b/kernels/portable/cpu/op_max.cpp @@ -78,8 +78,8 @@ std::tuple max_out( dim = dim < 0 ? dim + in.dim() : dim; - ET_SWITCH_REAL_TYPES_AND( - Bool, in.scalar_type(), ctx, "max.dim_max", CTYPE, [&]() { + ET_SWITCH_REALHBBF16_TYPES( + in.scalar_type(), ctx, "max.dim_max", CTYPE, [&]() { CTYPE* max_data = max.mutable_data_ptr(); long* max_indices_data = max_indices.mutable_data_ptr(); diff --git a/kernels/portable/cpu/op_min.cpp b/kernels/portable/cpu/op_min.cpp index 8b70bcd40f5..a045898c81e 100644 --- a/kernels/portable/cpu/op_min.cpp +++ b/kernels/portable/cpu/op_min.cpp @@ -78,8 +78,8 @@ std::tuple min_out( dim = dim < 0 ? dim + in.dim() : dim; - ET_SWITCH_REAL_TYPES_AND( - Bool, in.scalar_type(), ctx, "min.dim_min", CTYPE, [&]() { + ET_SWITCH_REALHBBF16_TYPES( + in.scalar_type(), ctx, "min.dim_min", CTYPE, [&]() { CTYPE* min_data = min.mutable_data_ptr(); long* min_indices_data = min_indices.mutable_data_ptr(); diff --git a/kernels/portable/cpu/op_scatter_add.cpp b/kernels/portable/cpu/op_scatter_add.cpp index 22fb3d161a8..f9c1f7677b6 100644 --- a/kernels/portable/cpu/op_scatter_add.cpp +++ b/kernels/portable/cpu/op_scatter_add.cpp @@ -79,24 +79,24 @@ Tensor& scatter_add_out( ScalarType self_type = self.scalar_type(); - ET_SWITCH_REAL_TYPES_AND( - Bool, self_type, ctx, "scatter_add.out", CTYPE, [&]() { - const CTYPE* self_data = self.const_data_ptr(); - const long* index_data = index.const_data_ptr(); - const CTYPE* src_data = src.const_data_ptr(); - CTYPE* out_data = out.mutable_data_ptr(); - - memcpy(out_data, self_data, self.nbytes()); - - if (index.numel() != 0) { - if (self.dim() == 0) { - out_data[0] += nonempty_size(index, 0) * src_data[0]; - } else { - scatter_add_helper( - src_data, index_data, out_data, src, index, out, dim); - } - } - }); + ET_SWITCH_REALHBBF16_TYPES(self_type, ctx, "scatter_add.out", CTYPE, [&]() { + const CTYPE* self_data = self.const_data_ptr(); + const long* index_data = index.const_data_ptr(); + const CTYPE* src_data = src.const_data_ptr(); + CTYPE* out_data = out.mutable_data_ptr(); + + memcpy(out_data, self_data, self.nbytes()); + + if (index.numel() != 0) { + if (self.dim() == 0) { + out_data[0] += + static_cast(nonempty_size(index, 0)) * src_data[0]; + } else { + scatter_add_helper( + src_data, index_data, out_data, src, index, out, dim); + } + } + }); return out; } diff --git a/kernels/test/op_bmm_test.cpp b/kernels/test/op_bmm_test.cpp index 70a5f37946d..edf2703e393 100644 --- a/kernels/test/op_bmm_test.cpp +++ b/kernels/test/op_bmm_test.cpp @@ -189,7 +189,7 @@ TEST_F(OpBmmOutTest, OutputDimFloat) { /// zeros(). TEST_F(OpBmmOutTest, AllRealDtypesSupported) { #define TEST_ENTRY(ctype, dtype) test_dtype(); - ET_FORALL_REAL_TYPES(TEST_ENTRY); + ET_FORALL_REALHBF16_TYPES(TEST_ENTRY); #undef TEST_ENTRY // TODO: Also add tests for half, complex, quantized, and other types. Easiest // way to do that would be to make TensorFactory support zeros() and ones() diff --git a/kernels/test/op_max_test.cpp b/kernels/test/op_max_test.cpp index 72de22e60d6..53c90ae909c 100644 --- a/kernels/test/op_max_test.cpp +++ b/kernels/test/op_max_test.cpp @@ -316,7 +316,7 @@ TEST_F(OpMaxOutTest, MismatchedDTypesDies) { TEST_F(OpMaxOutTest, AllRealInputLongOutputPasses) { #define TEST_ENTRY(ctype, dtype) test_max_out_dtype(); - ET_FORALL_REAL_TYPES_AND(Bool, TEST_ENTRY); + ET_FORALL_REALHBBF16_TYPES(TEST_ENTRY); #undef TEST_ENTRY } diff --git a/kernels/test/op_min_test.cpp b/kernels/test/op_min_test.cpp index 3d5e01c44d9..ebbca989051 100644 --- a/kernels/test/op_min_test.cpp +++ b/kernels/test/op_min_test.cpp @@ -312,7 +312,7 @@ TEST_F(OpMinOutTest, MismatchedDTypesDies) { TEST_F(OpMinOutTest, AllRealInputLongOutputPasses) { #define TEST_ENTRY(ctype, dtype) test_min_out_dtype(); - ET_FORALL_REAL_TYPES_AND(Bool, TEST_ENTRY); + ET_FORALL_REALHBBF16_TYPES(TEST_ENTRY); #undef TEST_ENTRY } diff --git a/kernels/test/op_scatter_add_test.cpp b/kernels/test/op_scatter_add_test.cpp index 82a5e858136..d5511b72683 100644 --- a/kernels/test/op_scatter_add_test.cpp +++ b/kernels/test/op_scatter_add_test.cpp @@ -281,7 +281,7 @@ class OpScatterAddOutTest : public OperatorTest { TEST_F(OpScatterAddOutTest, AllValidInputOutputSupport) { #define TEST_ENTRY(CTYPE, DTYPE) test_scatter_add_out(); - ET_FORALL_REAL_TYPES(TEST_ENTRY); + ET_FORALL_REALHBF16_TYPES(TEST_ENTRY); #undef TEST_ENTRY }