diff --git a/kernels/optimized/cpu/op_div.cpp b/kernels/optimized/cpu/op_div.cpp index c1ca946156a..11066ab2d44 100644 --- a/kernels/optimized/cpu/op_div.cpp +++ b/kernels/optimized/cpu/op_div.cpp @@ -144,7 +144,8 @@ Tensor& opt_div_scalar_out( auto error = resize_tensor(out, a.sizes()); ET_CHECK_MSG(error == Error::Ok, "Failed to resize output tensor."); - if (a_type == common_type && a_type == out_type) { + if (a_type == common_type && a_type == out_type && + a_type != ScalarType::Half && a_type != ScalarType::BFloat16) { ET_SWITCH_REAL_TYPES(a_type, ctx, "div.Scalar_out", CTYPE, [&]() { ET_SWITCH_REAL_TYPES_AND( Bool, b_type, ctx, "div.Scalar_out", CTYPE_B, [&]() { diff --git a/kernels/optimized/cpu/op_le.cpp b/kernels/optimized/cpu/op_le.cpp index 8a23c94e419..a559512d934 100644 --- a/kernels/optimized/cpu/op_le.cpp +++ b/kernels/optimized/cpu/op_le.cpp @@ -88,7 +88,8 @@ Tensor& opt_le_scalar_out( ScalarType common_type = promoteTypes(a_type, b_type); ScalarType out_type = out.scalar_type(); - if (a_type == common_type && a_type == out_type) { + if (a_type == common_type && a_type == out_type && + a_type != ScalarType::Half && a_type != ScalarType::BFloat16) { ET_SWITCH_REAL_TYPES_AND(Bool, a_type, ctx, "le.Scalar_out", CTYPE, [&]() { ET_SWITCH_REAL_TYPES_AND( Bool, b_type, ctx, "le.Scalar_out", CTYPE_B, [&]() { diff --git a/kernels/test/op_div_test.cpp b/kernels/test/op_div_test.cpp index 8f41419a8e0..722d7dd92e2 100644 --- a/kernels/test/op_div_test.cpp +++ b/kernels/test/op_div_test.cpp @@ -54,7 +54,7 @@ class OpDivOutTest : public OperatorTest { #define ENUMERATE_TEST_ENTRY(ctype, dtype) \ test_div(); - ET_FORALL_FLOAT_TYPES(ENUMERATE_TEST_ENTRY) + ET_FORALL_FLOATHBF16_TYPES(ENUMERATE_TEST_ENTRY) #undef ENUMERATE_TEST_ENTRY } @@ -64,7 +64,7 @@ class OpDivOutTest : public OperatorTest { #define ENUMERATE_TEST_ENTRY(ctype, dtype) \ test_div_enumerate_out_types(); - ET_FORALL_REAL_TYPES(ENUMERATE_TEST_ENTRY) + ET_FORALL_REALHBF16_TYPES(ENUMERATE_TEST_ENTRY) #undef ENUMERATE_TEST_ENTRY } @@ -183,7 +183,7 @@ void OpDivOutTest::test_div_enumerate_a_types() { #define ENUMERATE_TEST_ENTRY(ctype, dtype) \ test_div_enumerate_b_types(); - ET_FORALL_REAL_TYPES(ENUMERATE_TEST_ENTRY) + ET_FORALL_REALHBF16_TYPES(ENUMERATE_TEST_ENTRY) test_div(); @@ -506,9 +506,8 @@ TEST_F(OpDivOutTest, DynamicShapeUpperBoundLargerThanExpected) { TEST_F(OpDivOutTest, BroadcastNDTest) { // Test 3D tensors test_broadcast_3D(); - // half and bfloat16 are not supported for div quite yet - // test_broadcast_3D(); - // test_broadcast_3D(); + test_broadcast_3D(); + test_broadcast_3D(); } TEST_F(OpDivOutTest, DynamicShapeUnbound) { diff --git a/kernels/test/op_eq_test.cpp b/kernels/test/op_eq_test.cpp index 24cf9e6cf8d..539fb172f85 100644 --- a/kernels/test/op_eq_test.cpp +++ b/kernels/test/op_eq_test.cpp @@ -63,7 +63,7 @@ class OpEqScalarOutTest : public OperatorTest { TEST_F(OpEqScalarOutTest, AllRealInputBoolOutputSupport) { #define TEST_ENTRY(ctype, dtype) test_eq_scalar_out(); - ET_FORALL_REAL_TYPES(TEST_ENTRY); + ET_FORALL_REALHBF16_TYPES(TEST_ENTRY); #undef TEST_ENTRY } @@ -100,7 +100,7 @@ TEST_F(OpEqScalarOutTest, AllRealOutputDTypes) { GTEST_SKIP() << "ATen kernel can handle non-bool output dtype"; } #define TEST_ENTRY(ctype, dtype) test_eq_all_output_dtypes(); - ET_FORALL_REAL_TYPES(TEST_ENTRY); + ET_FORALL_REALHBF16_TYPES(TEST_ENTRY); #undef TEST_ENTRY } diff --git a/kernels/test/op_ge_test.cpp b/kernels/test/op_ge_test.cpp index a79502b266e..4fd0aa515b3 100644 --- a/kernels/test/op_ge_test.cpp +++ b/kernels/test/op_ge_test.cpp @@ -67,11 +67,11 @@ TEST_F(OpGeScalarOutTest, AllRealInputBoolOutputSupport) { #define TEST_ENTRY(ctype_in, dtype_in, ctype_out, dtype_out) \ test_ge_scalar_out(); -#define TEST_FORALL_OUT_TYPES(ctype_in, dtype_in) \ - ET_FORALL_REAL_TYPES_WITH2(ctype_in, dtype_in, TEST_ENTRY) \ +#define TEST_FORALL_OUT_TYPES(ctype_in, dtype_in) \ + ET_FORALL_REALHBF16_TYPES_WITH2(ctype_in, dtype_in, TEST_ENTRY) \ test_ge_scalar_out(); - ET_FORALL_REAL_TYPES(TEST_FORALL_OUT_TYPES) + ET_FORALL_REALHBF16_TYPES(TEST_FORALL_OUT_TYPES) #undef TEST_FORALL_OUT_TYPES #undef TEST_ENTRY @@ -124,11 +124,11 @@ TEST_F(OpGeTensorOutTest, AllDtypesSupported) { #define TEST_ENTRY(ctype_in, dtype_in, ctype_out, dtype_out) \ test_dtype(); -#define TEST_FORALL_OUT_TYPES(ctype_in, dtype_in) \ - ET_FORALL_REAL_TYPES_WITH2(ctype_in, dtype_in, TEST_ENTRY) \ +#define TEST_FORALL_OUT_TYPES(ctype_in, dtype_in) \ + ET_FORALL_REALHBF16_TYPES_WITH2(ctype_in, dtype_in, TEST_ENTRY) \ test_dtype(); - ET_FORALL_REAL_TYPES(TEST_FORALL_OUT_TYPES); + ET_FORALL_REALHBF16_TYPES(TEST_FORALL_OUT_TYPES); #undef TEST_FORALL_OUT_TYPES #undef TEST_ENTRY diff --git a/kernels/test/op_gt_test.cpp b/kernels/test/op_gt_test.cpp index 96c0e95f950..028e7d16878 100644 --- a/kernels/test/op_gt_test.cpp +++ b/kernels/test/op_gt_test.cpp @@ -67,11 +67,11 @@ TEST_F(OpGtScalarOutTest, AllRealInputBoolOutputSupport) { #define TEST_ENTRY(ctype_in, dtype_in, ctype_out, dtype_out) \ test_gt_scalar_out(); -#define TEST_FORALL_OUT_TYPES(ctype_in, dtype_in) \ - ET_FORALL_REAL_TYPES_WITH2(ctype_in, dtype_in, TEST_ENTRY) \ +#define TEST_FORALL_OUT_TYPES(ctype_in, dtype_in) \ + ET_FORALL_REALHBF16_TYPES_WITH2(ctype_in, dtype_in, TEST_ENTRY) \ test_gt_scalar_out(); - ET_FORALL_REAL_TYPES(TEST_FORALL_OUT_TYPES) + ET_FORALL_REALHBF16_TYPES(TEST_FORALL_OUT_TYPES) #undef TEST_FORALL_OUT_TYPES #undef TEST_ENTRY @@ -124,11 +124,11 @@ TEST_F(OpGtTensorOutTest, AllDtypesSupported) { #define TEST_ENTRY(ctype_in, dtype_in, ctype_out, dtype_out) \ test_dtype(); -#define TEST_FORALL_OUT_TYPES(ctype_in, dtype_in) \ - ET_FORALL_REAL_TYPES_WITH2(ctype_in, dtype_in, TEST_ENTRY) \ +#define TEST_FORALL_OUT_TYPES(ctype_in, dtype_in) \ + ET_FORALL_REALHBF16_TYPES_WITH2(ctype_in, dtype_in, TEST_ENTRY) \ test_dtype(); - ET_FORALL_REAL_TYPES(TEST_FORALL_OUT_TYPES); + ET_FORALL_REALHBF16_TYPES(TEST_FORALL_OUT_TYPES); #undef TEST_FORALL_OUT_TYPES #undef TEST_ENTRY diff --git a/kernels/test/op_le_test.cpp b/kernels/test/op_le_test.cpp index 4a9b97dfe8a..1baf098f9dd 100644 --- a/kernels/test/op_le_test.cpp +++ b/kernels/test/op_le_test.cpp @@ -67,11 +67,11 @@ TEST_F(OpLeScalarOutTest, AllRealInputBoolOutputSupport) { #define TEST_ENTRY(ctype_in, dtype_in, ctype_out, dtype_out) \ test_le_scalar_out(); -#define TEST_FORALL_OUT_TYPES(ctype_in, dtype_in) \ - ET_FORALL_REAL_TYPES_WITH2(ctype_in, dtype_in, TEST_ENTRY) \ +#define TEST_FORALL_OUT_TYPES(ctype_in, dtype_in) \ + ET_FORALL_REALHBF16_TYPES_WITH2(ctype_in, dtype_in, TEST_ENTRY) \ test_le_scalar_out(); - ET_FORALL_REAL_TYPES(TEST_FORALL_OUT_TYPES) + ET_FORALL_REALHBF16_TYPES(TEST_FORALL_OUT_TYPES) #undef TEST_FORALL_OUT_TYPES #undef TEST_ENTRY @@ -124,11 +124,11 @@ TEST_F(OpLeTensorOutTest, AllDtypesSupported) { #define TEST_ENTRY(ctype_in, dtype_in, ctype_out, dtype_out) \ test_dtype(); -#define TEST_FORALL_OUT_TYPES(ctype_in, dtype_in) \ - ET_FORALL_REAL_TYPES_WITH2(ctype_in, dtype_in, TEST_ENTRY) \ +#define TEST_FORALL_OUT_TYPES(ctype_in, dtype_in) \ + ET_FORALL_REALHBF16_TYPES_WITH2(ctype_in, dtype_in, TEST_ENTRY) \ test_dtype(); - ET_FORALL_REAL_TYPES(TEST_FORALL_OUT_TYPES); + ET_FORALL_REALHBF16_TYPES(TEST_FORALL_OUT_TYPES); #undef TEST_FORALL_OUT_TYPES #undef TEST_ENTRY diff --git a/kernels/test/op_lt_test.cpp b/kernels/test/op_lt_test.cpp index eee12c50521..c17d69c37da 100644 --- a/kernels/test/op_lt_test.cpp +++ b/kernels/test/op_lt_test.cpp @@ -67,11 +67,11 @@ TEST_F(OpLtScalarOutTest, AllRealInputBoolOutputSupport) { #define TEST_ENTRY(ctype_in, dtype_in, ctype_out, dtype_out) \ test_lt_scalar_out(); -#define TEST_FORALL_OUT_TYPES(ctype_in, dtype_in) \ - ET_FORALL_REAL_TYPES_WITH2(ctype_in, dtype_in, TEST_ENTRY) \ +#define TEST_FORALL_OUT_TYPES(ctype_in, dtype_in) \ + ET_FORALL_REALHBF16_TYPES_WITH2(ctype_in, dtype_in, TEST_ENTRY) \ test_lt_scalar_out(); - ET_FORALL_REAL_TYPES(TEST_FORALL_OUT_TYPES) + ET_FORALL_REALHBF16_TYPES(TEST_FORALL_OUT_TYPES) #undef TEST_FORALL_OUT_TYPES #undef TEST_ENTRY @@ -124,11 +124,11 @@ TEST_F(OpLtTensorOutTest, AllDtypesSupported) { #define TEST_ENTRY(ctype_in, dtype_in, ctype_out, dtype_out) \ test_dtype(); -#define TEST_FORALL_OUT_TYPES(ctype_in, dtype_in) \ - ET_FORALL_REAL_TYPES_WITH2(ctype_in, dtype_in, TEST_ENTRY) \ +#define TEST_FORALL_OUT_TYPES(ctype_in, dtype_in) \ + ET_FORALL_REALHBF16_TYPES_WITH2(ctype_in, dtype_in, TEST_ENTRY) \ test_dtype(); - ET_FORALL_REAL_TYPES(TEST_FORALL_OUT_TYPES); + ET_FORALL_REALHBF16_TYPES(TEST_FORALL_OUT_TYPES); #undef TEST_FORALL_OUT_TYPES #undef TEST_ENTRY diff --git a/kernels/test/op_ne_test.cpp b/kernels/test/op_ne_test.cpp index fe4e6c3621c..46681b02725 100644 --- a/kernels/test/op_ne_test.cpp +++ b/kernels/test/op_ne_test.cpp @@ -83,7 +83,7 @@ class OpNeScalarOutTest : public OperatorTest { TEST_F(OpNeScalarOutTest, AllRealInputBoolOutputSupport) { #define TEST_ENTRY(ctype, dtype) test_ne_scalar_out(); - ET_FORALL_REAL_TYPES(TEST_ENTRY); + ET_FORALL_REALHBF16_TYPES(TEST_ENTRY); #undef TEST_ENTRY } @@ -117,13 +117,13 @@ TEST_F(OpNeScalarOutTest, MismatchedShapesDies) { TEST_F(OpNeScalarOutTest, AllRealOutputDTypesSupported) { #define TEST_ENTRY(ctype, dtype) test_ne_all_output_dtypes(); - ET_FORALL_REAL_TYPES(TEST_ENTRY); + ET_FORALL_REALHBF16_TYPES(TEST_ENTRY); #undef TEST_ENTRY } TEST_F(OpNeTest, AllDtypesSupported) { #define TEST_ENTRY(ctype, dtype) test_dtype(); - ET_FORALL_REAL_TYPES(TEST_ENTRY); + ET_FORALL_REALHBF16_TYPES(TEST_ENTRY); #undef TEST_ENTRY } diff --git a/kernels/test/op_sub_test.cpp b/kernels/test/op_sub_test.cpp index aa7d4d51e4e..c8e7c69c443 100644 --- a/kernels/test/op_sub_test.cpp +++ b/kernels/test/op_sub_test.cpp @@ -73,7 +73,7 @@ class OpSubOutTest : public OperatorTest { #define ENUMERATE_TEST_ENTRY(ctype, dtype) \ test_sub_enumerate_out_types(); - ET_FORALL_REAL_TYPES_AND(Half, ENUMERATE_TEST_ENTRY) + ET_FORALL_REALHBF16_TYPES(ENUMERATE_TEST_ENTRY) #undef ENUMERATE_TEST_ENTRY } @@ -208,7 +208,7 @@ class OpSubOutTest : public OperatorTest { #define ENUMERATE_TEST_ENTRY(ctype, dtype) \ test_sub_enumerate_b_types(); - ET_FORALL_REAL_TYPES_AND(Half, ENUMERATE_TEST_ENTRY) + ET_FORALL_REALHBF16_TYPES(ENUMERATE_TEST_ENTRY) #undef ENUMERATE_TEST_ENTRY }