Skip to content

Commit 8651d31

Browse files
Add Half / Bfloat16 Tests
Differential Revision: D79374276 Pull Request resolved: #13048
1 parent 4197fc1 commit 8651d31

File tree

4 files changed

+59
-13
lines changed

4 files changed

+59
-13
lines changed

kernels/optimized/cpu/op_sub.cpp

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,11 @@ Tensor& opt_sub_out(
8585
ScalarType b_type = b.scalar_type();
8686
ScalarType out_type = out.scalar_type();
8787

88-
ET_KERNEL_CHECK(ctx, tensor_is_realh_type(out), InvalidArgument, out);
88+
ET_KERNEL_CHECK(
89+
ctx,
90+
executorch::runtime::tensor_is_realhbf16_type(out),
91+
InvalidArgument,
92+
out);
8993
if (a.numel() == 1 || b.numel() == 1) {
9094
if (a_type == b_type && a_type == out_type && a_type != ScalarType::Half) {
9195
const Tensor* tensor;
@@ -169,7 +173,7 @@ Tensor& opt_sub_scalar_out(
169173
ET_CHECK_MSG(error == Error::Ok, "Failed to resize output tensor.");
170174

171175
if (a_type == common_type && a_type == out_type &&
172-
a_type != ScalarType::Half) {
176+
a_type != ScalarType::Half && a_type != ScalarType::BFloat16) {
173177
ET_SWITCH_REAL_TYPES(a_type, ctx, "sub.Scalar_out", CTYPE, [&]() {
174178
CTYPE b_casted = utils::scalar_to<CTYPE>(b);
175179
CTYPE alpha_val;
@@ -186,9 +190,9 @@ Tensor& opt_sub_scalar_out(
186190
out.numel());
187191
});
188192
} else {
189-
ET_SWITCH_REALH_TYPES(a_type, ctx, "sub.Scalar_out", CTYPE_A, [&]() {
193+
ET_SWITCH_REALHBF16_TYPES(a_type, ctx, "sub.Scalar_out", CTYPE_A, [&]() {
190194
ET_SWITCH_REAL_TYPES(common_type, ctx, "sub.Scalar_out", CTYPE_IN, [&]() {
191-
ET_SWITCH_REALH_TYPES(
195+
ET_SWITCH_REALHBF16_TYPES(
192196
out_type, ctx, "sub.Scalar_out", CTYPE_OUT, [&]() {
193197
CTYPE_IN b_casted = utils::scalar_to<CTYPE_IN>(b);
194198
CTYPE_IN alpha_val;

kernels/test/op_floor_divide_test.cpp

Lines changed: 24 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -57,10 +57,9 @@ class OpFloorDivideTest : public OperatorTest {
5757
Tensor out = tf.zeros(sizes);
5858

5959
// floor_divide two tensors.
60-
// std::floor(-0.5 / -0.1) == 5.0, but -0.5 // -0.1 yeilds 4.0
6160
op_floor_divide_out(
62-
tf.make(sizes, /*data=*/{-5.3, 1.1, 2.2, 4.4, 6.8, -0.5}),
63-
tf.make(sizes, /*data=*/{2.7, 2.0, 2.0, 2.0, 2.0, -0.1}),
61+
tf.make(sizes, /*data=*/{-5.3, 1.1, 2.2, 4.4, 6.8, -0.9}),
62+
tf.make(sizes, /*data=*/{2.7, 2.0, 2.0, 2.0, 2.0, -0.2}),
6463
out);
6564

6665
// Check that it matches the expected output.
@@ -113,6 +112,14 @@ TEST_F(OpFloorDivideTest, DoubleTensors) {
113112
test_floating_point_floor_divide<ScalarType::Double>();
114113
}
115114

115+
TEST_F(OpFloorDivideTest, HalfTensors) {
116+
test_floating_point_floor_divide<ScalarType::Half>();
117+
}
118+
119+
TEST_F(OpFloorDivideTest, BFloat16Tensors) {
120+
test_floating_point_floor_divide<ScalarType::BFloat16>();
121+
}
122+
116123
TEST_F(OpFloorDivideTest, UnhandledDtypeDies) {
117124
// floor_divide() doesn't handle Bool.
118125
TensorFactory<ScalarType::Bool> tf;
@@ -331,3 +338,17 @@ TEST_F(OpFloorDivideTest, DynamicShapeUnbound) {
331338
Tensor ret = op_floor_divide_out(x, y, out);
332339
EXPECT_TENSOR_CLOSE(out, expected_result);
333340
}
341+
342+
// std::floor(0.5 / 0.1) == 5.0, but 0.5 // 0.1 yeilds 4.0
343+
TEST_F(OpFloorDivideTest, FloatFloorDivideEdgeCase) {
344+
TensorFactory<ScalarType::Float> tf;
345+
346+
Tensor x = tf.make({1, 2}, {0.5, -0.5});
347+
Tensor y = tf.make({1, 2}, {0.1, -0.1});
348+
Tensor expected_result = tf.make({1, 2}, {4.0, 4.0});
349+
350+
Tensor out = tf.zeros({1, 2});
351+
Tensor ret = op_floor_divide_out(x, y, out);
352+
EXPECT_TENSOR_EQ(ret, out);
353+
EXPECT_TENSOR_CLOSE(out, expected_result);
354+
}

kernels/test/op_rsub_test.cpp

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -64,14 +64,17 @@ class OpRSubScalarOutTest : public OperatorTest {
6464
Tensor out = tf.zeros(sizes);
6565

6666
// Performs substraction of tensor from scalar.
67+
// Values selected to be exactly representable to avoid throwing off
68+
// half/bfloat16 tests.
6769
op_rsub_scalar_out(
68-
tf.make(sizes, /*data=*/{1.1, 2.2, 4.4, 8.8}),
69-
1.1,
70+
tf.make(sizes, /*data=*/{1.25, 2.25, 4.5, 8.875}),
71+
1.0,
7072
/*alpha=*/1,
7173
out);
7274

7375
// Check that it matches the expected output.
74-
EXPECT_TENSOR_CLOSE(out, tf.make(sizes, /*data=*/{0.0, -1.1, -3.3, -7.7}));
76+
EXPECT_TENSOR_CLOSE(
77+
out, tf.make(sizes, /*data=*/{-0.25, -1.25, -3.5, -7.875}));
7578
}
7679

7780
/* %python
@@ -168,6 +171,14 @@ TEST_F(OpRSubScalarOutTest, DoubleTensors) {
168171
test_floating_point_rsub_scalar_out<ScalarType::Double>();
169172
}
170173

174+
TEST_F(OpRSubScalarOutTest, HalfTensors) {
175+
test_floating_point_rsub_scalar_out<ScalarType::Half>();
176+
}
177+
178+
TEST_F(OpRSubScalarOutTest, BFloat16Tensors) {
179+
test_floating_point_rsub_scalar_out<ScalarType::BFloat16>();
180+
}
181+
171182
TEST_F(OpRSubScalarOutTest, UnhandledDtypeDies) {
172183
// op_rsub_scalar_out() doesn't handle Bool.
173184
TensorFactory<ScalarType::Bool> tf;

kernels/test/op_sub_test.cpp

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -90,13 +90,15 @@ class OpSubOutTest : public OperatorTest {
9090

9191
// Performs substraction on two tensors.
9292
op_sub_out(
93-
tf.make(sizes, /*data=*/{1.1, 2.2, 4.4, 8.8}),
93+
tf.make(sizes, /*data=*/{1.25, 2.25, 4.5, 8.875}),
9494
tf.ones(sizes),
9595
/*alpha=*/1,
9696
out);
9797

98-
// Check that it matches the expected output.
99-
EXPECT_TENSOR_CLOSE(out, tf.make(sizes, /*data=*/{0.1, 1.2, 3.4, 7.8}));
98+
// Check that it matches the expected output. Values selected to
99+
// be exactly representable to avoid throwing off half/bfloat16
100+
// tests.
101+
EXPECT_TENSOR_CLOSE(out, tf.make(sizes, /*data=*/{0.25, 1.25, 3.5, 7.875}));
100102
}
101103

102104
template <ScalarType DTYPE>
@@ -260,6 +262,14 @@ TEST_F(OpSubOutTest, DoubleTensors) {
260262
test_floating_point_sub_out<ScalarType::Double>();
261263
}
262264

265+
TEST_F(OpSubOutTest, HalfTensors) {
266+
test_floating_point_sub_out<ScalarType::Half>();
267+
}
268+
269+
TEST_F(OpSubOutTest, BFloat16Tensors) {
270+
test_floating_point_sub_out<ScalarType::BFloat16>();
271+
}
272+
263273
TEST_F(OpSubOutTest, BroadcastSupported) {
264274
TensorFactory<ScalarType::Float> tf;
265275

0 commit comments

Comments
 (0)