Skip to content

Commit 07bcd7f

Browse files
swolchokfacebook-github-bot
authored andcommitted
UnaryUfuncRealHBToFloatHTest: test Half more widely (#5676)
Summary: Pull Request resolved: pytorch/executorch#5676 The tests were trying to avoid Half in ATen mode, but that should work fine for these ops. ghstack-source-id: 245578279 Reviewed By: mergennachin Differential Revision: D63435866 fbshipit-source-id: 934b4453fa13df13619092fde979c06d51925663
1 parent b63c68e commit 07bcd7f

File tree

4 files changed

+39
-20
lines changed

4 files changed

+39
-20
lines changed

kernels/test/UnaryUfuncRealHBToFloatHTest.cpp

Lines changed: 6 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -38,9 +38,6 @@ void UnaryUfuncRealHBToFloatHTest::test_mismatched_input_shapes_dies() {
3838

3939
void UnaryUfuncRealHBToFloatHTest::
4040
test_all_real_input_half_output_static_dynamism_support() {
41-
if (get_supported_features()->is_aten) {
42-
GTEST_SKIP() << "Test Half support only for ExecuTorch mode";
43-
}
4441
#define TEST_ENTRY(ctype, dtype) \
4542
test_floating_point_op_out< \
4643
exec_aten::ScalarType::dtype, \
@@ -55,7 +52,7 @@ void UnaryUfuncRealHBToFloatHTest::
5552
test_floating_point_op_out< \
5653
exec_aten::ScalarType::dtype, \
5754
exec_aten::ScalarType::Float>();
58-
ET_FORALL_REAL_TYPES(TEST_ENTRY);
55+
ET_FORALL_REALH_TYPES(TEST_ENTRY);
5956
#undef TEST_ENTRY
6057
}
6158

@@ -65,15 +62,12 @@ void UnaryUfuncRealHBToFloatHTest::
6562
test_floating_point_op_out< \
6663
exec_aten::ScalarType::dtype, \
6764
exec_aten::ScalarType::Double>();
68-
ET_FORALL_REAL_TYPES(TEST_ENTRY);
65+
ET_FORALL_REALH_TYPES(TEST_ENTRY);
6966
#undef TEST_ENTRY
7067
}
7168

7269
void UnaryUfuncRealHBToFloatHTest::
7370
test_all_real_input_half_output_bound_dynamism_support() {
74-
if (get_supported_features()->is_aten) {
75-
GTEST_SKIP() << "Test Half support only for ExecuTorch mode";
76-
}
7771
#define TEST_ENTRY(ctype, dtype) \
7872
test_floating_point_op_out< \
7973
exec_aten::ScalarType::dtype, \
@@ -90,7 +84,7 @@ void UnaryUfuncRealHBToFloatHTest::
9084
exec_aten::ScalarType::dtype, \
9185
exec_aten::ScalarType::Float>( \
9286
{10, 10}, exec_aten::TensorShapeDynamism::DYNAMIC_BOUND);
93-
ET_FORALL_REAL_TYPES(TEST_ENTRY);
87+
ET_FORALL_REALH_TYPES(TEST_ENTRY);
9488
#undef TEST_ENTRY
9589
}
9690

@@ -101,7 +95,7 @@ void UnaryUfuncRealHBToFloatHTest::
10195
exec_aten::ScalarType::dtype, \
10296
exec_aten::ScalarType::Double>( \
10397
{10, 10}, exec_aten::TensorShapeDynamism::DYNAMIC_BOUND);
104-
ET_FORALL_REAL_TYPES(TEST_ENTRY);
98+
ET_FORALL_REALH_TYPES(TEST_ENTRY);
10599
#undef TEST_ENTRY
106100
}
107101

@@ -115,7 +109,7 @@ void UnaryUfuncRealHBToFloatHTest::
115109
exec_aten::ScalarType::dtype, \
116110
exec_aten::ScalarType::Float>( \
117111
{1, 1}, exec_aten::TensorShapeDynamism::DYNAMIC_UNBOUND);
118-
ET_FORALL_REAL_TYPES(TEST_ENTRY);
112+
ET_FORALL_REALH_TYPES(TEST_ENTRY);
119113
#undef TEST_ENTRY
120114
}
121115

@@ -129,7 +123,7 @@ void UnaryUfuncRealHBToFloatHTest::
129123
exec_aten::ScalarType::dtype, \
130124
exec_aten::ScalarType::Double>( \
131125
{1, 1}, exec_aten::TensorShapeDynamism::DYNAMIC_UNBOUND);
132-
ET_FORALL_REAL_TYPES(TEST_ENTRY);
126+
ET_FORALL_REALH_TYPES(TEST_ENTRY);
133127
#undef TEST_ENTRY
134128
}
135129

kernels/test/UnaryUfuncRealHBToFloatHTest.h

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,17 @@ class UnaryUfuncRealHBToFloatHTest : public OperatorTest {
6969
op_out(tf_in.make({1, 6}, test_vector), out);
7070

7171
auto expected = tf_out.make({1, 6}, expected_vector);
72-
EXPECT_TENSOR_CLOSE(out, expected);
72+
if (IN_DTYPE == ScalarType::Half || OUT_DTYPE == ScalarType::Half) {
73+
double rtol = executorch::runtime::testing::internal::kDefaultRtol;
74+
// It appears we need a higher tolerance for at least some ATen
75+
// tests, like aten_op_acosh_test.
76+
if (get_supported_features()->is_aten) {
77+
rtol = 1e-3;
78+
}
79+
EXPECT_TENSOR_CLOSE_WITH_TOL(out, expected, rtol, executorch::runtime::testing::internal::kDefaultHalfAtol);
80+
} else {
81+
EXPECT_TENSOR_CLOSE(out, expected);
82+
}
7383
// clang-format on
7484
}
7585

runtime/core/exec_aten/testing_util/tensor_util.cpp

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -76,13 +76,19 @@ bool data_is_close(
7676
return true;
7777
}
7878

79+
double default_atol_for_type(ScalarType t) {
80+
if (t == ScalarType::Half) {
81+
return internal::kDefaultHalfAtol;
82+
}
83+
return internal::kDefaultAtol;
84+
}
7985
} // namespace
8086

8187
bool tensors_are_close(
8288
const Tensor& a,
8389
const Tensor& b,
8490
double rtol,
85-
double atol) {
91+
std::optional<double> opt_atol) {
8692
if (a.scalar_type() != b.scalar_type() || a.sizes() != b.sizes()) {
8793
return false;
8894
}
@@ -100,6 +106,8 @@ bool tensors_are_close(
100106
// So we can just compare the two underlying data sequentially to figure out
101107
// if the two tensors are same.
102108

109+
double atol = opt_atol.value_or(default_atol_for_type(a.scalar_type()));
110+
103111
if (a.nbytes() == 0) {
104112
// Note that this case is important. It's valid for a zero-size tensor to
105113
// have a null data pointer, but in some environments it's invalid to pass a
@@ -149,11 +157,12 @@ bool tensor_data_is_close(
149157
const Tensor& a,
150158
const Tensor& b,
151159
double rtol,
152-
double atol) {
160+
std::optional<double> opt_atol) {
153161
if (a.scalar_type() != b.scalar_type() || a.numel() != b.numel()) {
154162
return false;
155163
}
156164

165+
double atol = opt_atol.value_or(default_atol_for_type(a.scalar_type()));
157166
if (a.nbytes() == 0) {
158167
// Note that this case is important. It's valid for a zero-size tensor to
159168
// have a null data pointer, but in some environments it's invalid to pass a
@@ -185,12 +194,12 @@ bool tensor_lists_are_close(
185194
const exec_aten::Tensor* tensors_b,
186195
size_t num_tensors_b,
187196
double rtol,
188-
double atol) {
197+
std::optional<double> opt_atol) {
189198
if (num_tensors_a != num_tensors_b) {
190199
return false;
191200
}
192201
for (size_t i = 0; i < num_tensors_a; i++) {
193-
if (!tensors_are_close(tensors_a[i], tensors_b[i], rtol, atol)) {
202+
if (!tensors_are_close(tensors_a[i], tensors_b[i], rtol, opt_atol)) {
194203
return false;
195204
}
196205
}

runtime/core/exec_aten/testing_util/tensor_util.h

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,13 +11,19 @@
1111
#include <executorch/runtime/core/exec_aten/exec_aten.h>
1212
#include <gmock/gmock.h> // For MATCHER_P
1313

14+
#include <optional>
15+
1416
namespace executorch {
1517
namespace runtime {
1618
namespace testing {
1719

1820
namespace internal {
1921
constexpr double kDefaultRtol = 1e-5;
2022
constexpr double kDefaultAtol = 1e-8;
23+
// Per
24+
// https://en.wikipedia.org/wiki/Half-precision_floating-point_format,
25+
// float16 has about 3.3 digits of precision.
26+
constexpr double kDefaultHalfAtol = 1e-3;
2127
} // namespace internal
2228

2329
/**
@@ -61,7 +67,7 @@ bool tensors_are_close(
6167
const exec_aten::Tensor& a,
6268
const exec_aten::Tensor& b,
6369
double rtol = internal::kDefaultRtol,
64-
double atol = internal::kDefaultAtol);
70+
std::optional<double> opt_atol = std::nullopt);
6571

6672
/**
6773
* Returns true if the tensors are of the same numel and dtype, and if all
@@ -92,7 +98,7 @@ bool tensor_data_is_close(
9298
const exec_aten::Tensor& a,
9399
const exec_aten::Tensor& b,
94100
double rtol = internal::kDefaultRtol,
95-
double atol = internal::kDefaultAtol);
101+
std::optional<double> opt_atol = std::nullopt);
96102

97103
/**
98104
* Returns true if the two lists are of the same length, and
@@ -105,7 +111,7 @@ bool tensor_lists_are_close(
105111
const exec_aten::Tensor* tensors_b,
106112
size_t num_tensors_b,
107113
double rtol = internal::kDefaultRtol,
108-
double atol = internal::kDefaultAtol);
114+
std::optional<double> opt_atol = std::nullopt);
109115

110116
/**
111117
* Lets gtest users write `EXPECT_THAT(tensor1, IsCloseTo(tensor2))` or

0 commit comments

Comments
 (0)