Skip to content

Commit 82e7249

Browse files
Test Half/Bfloat16: div, lt, ge, gt, eq, ne
Differential Revision: D80963873 Pull Request resolved: #13762
1 parent b8b2ecb commit 82e7249

File tree

6 files changed

+25
-26
lines changed

6 files changed

+25
-26
lines changed

kernels/test/op_div_test.cpp

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -522,9 +522,8 @@ TEST_F(OpDivOutTest, DynamicShapeUpperBoundLargerThanExpected) {
522522
TEST_F(OpDivOutTest, BroadcastNDTest) {
523523
// Test 3D tensors
524524
test_broadcast_3D<ScalarType::Float>();
525-
// half and bfloat16 are not supported for div quite yet
526-
// test_broadcast_3D<ScalarType::Half>();
527-
// test_broadcast_3D<ScalarType::BFloat16>();
525+
test_broadcast_3D<ScalarType::Half>();
526+
test_broadcast_3D<ScalarType::BFloat16>();
528527
}
529528

530529
TEST_F(OpDivOutTest, DynamicShapeUnbound) {

kernels/test/op_eq_test.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ class OpEqScalarOutTest : public OperatorTest {
6363

6464
TEST_F(OpEqScalarOutTest, AllRealInputBoolOutputSupport) {
6565
#define TEST_ENTRY(ctype, dtype) test_eq_scalar_out<ScalarType::dtype>();
66-
ET_FORALL_REAL_TYPES(TEST_ENTRY);
66+
ET_FORALL_REALHBF16_TYPES(TEST_ENTRY);
6767
#undef TEST_ENTRY
6868
}
6969

@@ -100,7 +100,7 @@ TEST_F(OpEqScalarOutTest, AllRealOutputDTypes) {
100100
GTEST_SKIP() << "ATen kernel can handle non-bool output dtype";
101101
}
102102
#define TEST_ENTRY(ctype, dtype) test_eq_all_output_dtypes<ScalarType::dtype>();
103-
ET_FORALL_REAL_TYPES(TEST_ENTRY);
103+
ET_FORALL_REALHBF16_TYPES(TEST_ENTRY);
104104
#undef TEST_ENTRY
105105
}
106106

kernels/test/op_ge_test.cpp

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -67,11 +67,11 @@ TEST_F(OpGeScalarOutTest, AllRealInputBoolOutputSupport) {
6767
#define TEST_ENTRY(ctype_in, dtype_in, ctype_out, dtype_out) \
6868
test_ge_scalar_out<ScalarType::dtype_in, ScalarType::dtype_out>();
6969

70-
#define TEST_FORALL_OUT_TYPES(ctype_in, dtype_in) \
71-
ET_FORALL_REAL_TYPES_WITH2(ctype_in, dtype_in, TEST_ENTRY) \
70+
#define TEST_FORALL_OUT_TYPES(ctype_in, dtype_in) \
71+
ET_FORALL_REALHBF16_TYPES_WITH2(ctype_in, dtype_in, TEST_ENTRY) \
7272
test_ge_scalar_out<ScalarType::dtype_in, ScalarType::Bool>();
7373

74-
ET_FORALL_REAL_TYPES(TEST_FORALL_OUT_TYPES)
74+
ET_FORALL_REALHBF16_TYPES(TEST_FORALL_OUT_TYPES)
7575

7676
#undef TEST_FORALL_OUT_TYPES
7777
#undef TEST_ENTRY
@@ -124,11 +124,11 @@ TEST_F(OpGeTensorOutTest, AllDtypesSupported) {
124124
#define TEST_ENTRY(ctype_in, dtype_in, ctype_out, dtype_out) \
125125
test_dtype<ScalarType::dtype_in, ScalarType::dtype_out>();
126126

127-
#define TEST_FORALL_OUT_TYPES(ctype_in, dtype_in) \
128-
ET_FORALL_REAL_TYPES_WITH2(ctype_in, dtype_in, TEST_ENTRY) \
127+
#define TEST_FORALL_OUT_TYPES(ctype_in, dtype_in) \
128+
ET_FORALL_REALHBF16_TYPES_WITH2(ctype_in, dtype_in, TEST_ENTRY) \
129129
test_dtype<ScalarType::dtype_in, ScalarType::Bool>();
130130

131-
ET_FORALL_REAL_TYPES(TEST_FORALL_OUT_TYPES);
131+
ET_FORALL_REALHBF16_TYPES(TEST_FORALL_OUT_TYPES);
132132

133133
#undef TEST_FORALL_OUT_TYPES
134134
#undef TEST_ENTRY

kernels/test/op_gt_test.cpp

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -67,11 +67,11 @@ TEST_F(OpGtScalarOutTest, AllRealInputBoolOutputSupport) {
6767
#define TEST_ENTRY(ctype_in, dtype_in, ctype_out, dtype_out) \
6868
test_gt_scalar_out<ScalarType::dtype_in, ScalarType::dtype_out>();
6969

70-
#define TEST_FORALL_OUT_TYPES(ctype_in, dtype_in) \
71-
ET_FORALL_REAL_TYPES_WITH2(ctype_in, dtype_in, TEST_ENTRY) \
70+
#define TEST_FORALL_OUT_TYPES(ctype_in, dtype_in) \
71+
ET_FORALL_REALHBF16_TYPES_WITH2(ctype_in, dtype_in, TEST_ENTRY) \
7272
test_gt_scalar_out<ScalarType::dtype_in, ScalarType::Bool>();
7373

74-
ET_FORALL_REAL_TYPES(TEST_FORALL_OUT_TYPES)
74+
ET_FORALL_REALHBF16_TYPES(TEST_FORALL_OUT_TYPES)
7575

7676
#undef TEST_FORALL_OUT_TYPES
7777
#undef TEST_ENTRY
@@ -124,11 +124,11 @@ TEST_F(OpGtTensorOutTest, AllDtypesSupported) {
124124
#define TEST_ENTRY(ctype_in, dtype_in, ctype_out, dtype_out) \
125125
test_dtype<ScalarType::dtype_in, ScalarType::dtype_out>();
126126

127-
#define TEST_FORALL_OUT_TYPES(ctype_in, dtype_in) \
128-
ET_FORALL_REAL_TYPES_WITH2(ctype_in, dtype_in, TEST_ENTRY) \
127+
#define TEST_FORALL_OUT_TYPES(ctype_in, dtype_in) \
128+
ET_FORALL_REALHBF16_TYPES_WITH2(ctype_in, dtype_in, TEST_ENTRY) \
129129
test_dtype<ScalarType::dtype_in, ScalarType::Bool>();
130130

131-
ET_FORALL_REAL_TYPES(TEST_FORALL_OUT_TYPES);
131+
ET_FORALL_REALHBF16_TYPES(TEST_FORALL_OUT_TYPES);
132132

133133
#undef TEST_FORALL_OUT_TYPES
134134
#undef TEST_ENTRY

kernels/test/op_lt_test.cpp

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -67,11 +67,11 @@ TEST_F(OpLtScalarOutTest, AllRealInputBoolOutputSupport) {
6767
#define TEST_ENTRY(ctype_in, dtype_in, ctype_out, dtype_out) \
6868
test_lt_scalar_out<ScalarType::dtype_in, ScalarType::dtype_out>();
6969

70-
#define TEST_FORALL_OUT_TYPES(ctype_in, dtype_in) \
71-
ET_FORALL_REAL_TYPES_WITH2(ctype_in, dtype_in, TEST_ENTRY) \
70+
#define TEST_FORALL_OUT_TYPES(ctype_in, dtype_in) \
71+
ET_FORALL_REALHBF16_TYPES_WITH2(ctype_in, dtype_in, TEST_ENTRY) \
7272
test_lt_scalar_out<ScalarType::dtype_in, ScalarType::Bool>();
7373

74-
ET_FORALL_REAL_TYPES(TEST_FORALL_OUT_TYPES)
74+
ET_FORALL_REALHBF16_TYPES(TEST_FORALL_OUT_TYPES)
7575

7676
#undef TEST_FORALL_OUT_TYPES
7777
#undef TEST_ENTRY
@@ -124,11 +124,11 @@ TEST_F(OpLtTensorOutTest, AllDtypesSupported) {
124124
#define TEST_ENTRY(ctype_in, dtype_in, ctype_out, dtype_out) \
125125
test_dtype<ScalarType::dtype_in, ScalarType::dtype_out>();
126126

127-
#define TEST_FORALL_OUT_TYPES(ctype_in, dtype_in) \
128-
ET_FORALL_REAL_TYPES_WITH2(ctype_in, dtype_in, TEST_ENTRY) \
127+
#define TEST_FORALL_OUT_TYPES(ctype_in, dtype_in) \
128+
ET_FORALL_REALHBF16_TYPES_WITH2(ctype_in, dtype_in, TEST_ENTRY) \
129129
test_dtype<ScalarType::dtype_in, ScalarType::Bool>();
130130

131-
ET_FORALL_REAL_TYPES(TEST_FORALL_OUT_TYPES);
131+
ET_FORALL_REALHBF16_TYPES(TEST_FORALL_OUT_TYPES);
132132

133133
#undef TEST_FORALL_OUT_TYPES
134134
#undef TEST_ENTRY

kernels/test/op_ne_test.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,7 @@ class OpNeScalarOutTest : public OperatorTest {
8383

8484
TEST_F(OpNeScalarOutTest, AllRealInputBoolOutputSupport) {
8585
#define TEST_ENTRY(ctype, dtype) test_ne_scalar_out<ScalarType::dtype>();
86-
ET_FORALL_REAL_TYPES(TEST_ENTRY);
86+
ET_FORALL_REALHBF16_TYPES(TEST_ENTRY);
8787
#undef TEST_ENTRY
8888
}
8989

@@ -117,13 +117,13 @@ TEST_F(OpNeScalarOutTest, MismatchedShapesDies) {
117117

118118
TEST_F(OpNeScalarOutTest, AllRealOutputDTypesSupported) {
119119
#define TEST_ENTRY(ctype, dtype) test_ne_all_output_dtypes<ScalarType::dtype>();
120-
ET_FORALL_REAL_TYPES(TEST_ENTRY);
120+
ET_FORALL_REALHBF16_TYPES(TEST_ENTRY);
121121
#undef TEST_ENTRY
122122
}
123123

124124
TEST_F(OpNeTest, AllDtypesSupported) {
125125
#define TEST_ENTRY(ctype, dtype) test_dtype<ctype, ScalarType::dtype>();
126-
ET_FORALL_REAL_TYPES(TEST_ENTRY);
126+
ET_FORALL_REALHBF16_TYPES(TEST_ENTRY);
127127
#undef TEST_ENTRY
128128
}
129129

0 commit comments

Comments
 (0)