Skip to content

Commit f40d62e

Browse files
andishgarpitrou
andauthored
GH-48123 [C++][Float16] Reimplement arrow::WithinUlp and Enable it for float16 (#48224)
### Rationale for this change Refer to [this comment](#48123 (comment)). Additionally, this change enables `arrow::WithinUlp` for `float16`. ### What changes are included in this PR? Re-implement `arrow::WithinUlp` and enable it for `float16`, including relevant tests for corner cases around powers of two and `Float16`. ### Are these changes tested? Yes, I ran the relevant unit tests. ### Are there any user-facing changes? No. * GitHub Issue: #48123 Lead-authored-by: arash andishgar <[email protected]> Co-authored-by: Antoine Pitrou <[email protected]> Signed-off-by: Antoine Pitrou <[email protected]>
1 parent fa4e593 commit f40d62e

File tree

3 files changed

+161
-31
lines changed

3 files changed

+161
-31
lines changed

cpp/src/arrow/testing/gtest_util_test.cc

Lines changed: 82 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,9 @@
1616
// under the License.
1717

1818
#include <cmath>
19+
#include <memory>
20+
#include <type_traits>
21+
#include <vector>
1922

2023
#include <gtest/gtest-spi.h>
2124
#include <gtest/gtest.h>
@@ -32,8 +35,9 @@
3235
#include "arrow/type.h"
3336
#include "arrow/type_traits.h"
3437
#include "arrow/util/checked_cast.h"
38+
#include "arrow/util/float16.h"
3539

36-
namespace arrow {
40+
namespace arrow::util {
3741

3842
// Test basic cases for contains NaN.
3943
class TestAssertContainsNaN : public ::testing::Test {};
@@ -198,8 +202,15 @@ void CheckWithinUlp(Float x, Float y, int n_ulp) {
198202
CheckWithinUlpSingle(-y, -x, n_ulp);
199203

200204
for (int exp : {1, -1, 10, -10}) {
201-
Float x_scaled = std::ldexp(x, exp);
202-
Float y_scaled = std::ldexp(y, exp);
205+
Float x_scaled(0);
206+
Float y_scaled(0);
207+
if constexpr (std::is_same_v<Float, Float16>) {
208+
x_scaled = Float16(std::ldexp(x.ToFloat(), exp));
209+
y_scaled = Float16(std::ldexp(y.ToFloat(), exp));
210+
} else {
211+
x_scaled = std::ldexp(x, exp);
212+
y_scaled = std::ldexp(y, exp);
213+
}
203214
CheckWithinUlpSingle(x_scaled, y_scaled, n_ulp);
204215
CheckWithinUlpSingle(y_scaled, x_scaled, n_ulp);
205216
}
@@ -219,8 +230,15 @@ void CheckNotWithinUlp(Float x, Float y, int n_ulp) {
219230
}
220231

221232
for (int exp : {1, -1, 10, -10}) {
222-
Float x_scaled = std::ldexp(x, exp);
223-
Float y_scaled = std::ldexp(y, exp);
233+
Float x_scaled(0);
234+
Float y_scaled(0);
235+
if constexpr (std::is_same_v<Float, Float16>) {
236+
x_scaled = Float16(std::ldexp(x.ToFloat(), exp));
237+
y_scaled = Float16(std::ldexp(y.ToFloat(), exp));
238+
} else {
239+
x_scaled = std::ldexp(x, exp);
240+
y_scaled = std::ldexp(y, exp);
241+
}
224242
CheckNotWithinUlpSingle(x_scaled, y_scaled, n_ulp);
225243
CheckNotWithinUlpSingle(y_scaled, x_scaled, n_ulp);
226244
}
@@ -242,6 +260,10 @@ TEST(TestWithinUlp, Double) {
242260
CheckWithinUlp(1.0, 0.9999999999999999, 1);
243261
CheckWithinUlp(1.0, 0.9999999999999988, 11);
244262
CheckNotWithinUlp(1.0, 0.9999999999999988, 10);
263+
CheckWithinUlp(1.0000000000000002, 0.9999999999999999, 2);
264+
CheckNotWithinUlp(1.0000000000000002, 0.9999999999999999, 1);
265+
CheckWithinUlp(0.9999999999999988, 1.0000000000000007, 14);
266+
CheckNotWithinUlp(0.9999999999999988, 1.0000000000000007, 13);
245267

246268
CheckWithinUlp(123.4567, 123.45670000000015, 11);
247269
CheckNotWithinUlp(123.4567, 123.45670000000015, 10);
@@ -271,6 +293,10 @@ TEST(TestWithinUlp, Float) {
271293
CheckWithinUlp(1.0f, 0.99999994f, 1);
272294
CheckWithinUlp(1.0f, 0.99999934f, 11);
273295
CheckNotWithinUlp(1.0f, 0.99999934f, 10);
296+
CheckWithinUlp(1.0000001f, 0.99999994f, 2);
297+
CheckNotWithinUlp(1.0000001f, 0.99999994f, 1);
298+
CheckWithinUlp(1.0000013f, 0.99999934f, 22);
299+
CheckNotWithinUlp(1.0000013f, 0.99999934f, 21);
274300

275301
CheckWithinUlp(123.456f, 123.456085f, 11);
276302
CheckNotWithinUlp(123.456f, 123.456085f, 10);
@@ -284,15 +310,65 @@ TEST(TestWithinUlp, Float) {
284310
CheckNotWithinUlp(12.34f, -12.34f, 10);
285311
}
286312

313+
std::vector<Float16> ConvertToFloat16Vector(const std::vector<float>& float_values) {
314+
std::vector<Float16> float16_vector;
315+
float16_vector.reserve(float_values.size());
316+
for (auto& value : float_values) {
317+
float16_vector.emplace_back(value);
318+
}
319+
return float16_vector;
320+
}
321+
322+
TEST(TestWithinUlp, Float16) {
323+
for (Float16 f : ConvertToFloat16Vector({0.0f, 1e-8f, 1.0f, 123.456f})) {
324+
CheckWithinUlp(f, f, 0);
325+
CheckWithinUlp(f, f, 1);
326+
CheckWithinUlp(f, f, 42);
327+
}
328+
CheckWithinUlp(Float16(-0.0f), Float16(0.0f), 1);
329+
CheckWithinUlp(Float16(1.0f), Float16(1.00097656f), 1);
330+
CheckWithinUlp(Float16(1.0f), Float16(1.01074219f), 11);
331+
CheckNotWithinUlp(Float16(1.0f), Float16(1.00097656f), 0);
332+
CheckNotWithinUlp(Float16(1.0f), Float16(1.01074219f), 10);
333+
// left and right have a different exponent but are still very close
334+
CheckWithinUlp(Float16(1.0f), Float16(0.999511719f), 1);
335+
CheckWithinUlp(Float16(1.0f), Float16(0.994628906f), 11);
336+
CheckNotWithinUlp(Float16(1.0f), Float16(0.994628906f), 10);
337+
CheckWithinUlp(Float16(1.00097656), Float16(0.999511719f), 2);
338+
CheckNotWithinUlp(Float16(1.00097656), Float16(0.999511719f), 1);
339+
CheckWithinUlp(Float16(1.01074219f), Float16(0.994628906f), 22);
340+
CheckNotWithinUlp(Float16(1.01074219f), Float16(0.994628906f), 21);
341+
342+
CheckWithinUlp(Float16(123.456f), Float16(124.143501f), 11);
343+
// The assertion below does not work because ldexp(Float16(124.143501f), 10)
344+
// results in inf in Float16.
345+
// CheckNotWithinUlp(Float16(123.456f), Float16(124.143501f), 10);
346+
347+
CheckWithinUlp(std::numeric_limits<Float16>::infinity(),
348+
std::numeric_limits<Float16>::infinity(), 10);
349+
CheckWithinUlp(-std::numeric_limits<Float16>::infinity(),
350+
-std::numeric_limits<Float16>::infinity(), 10);
351+
CheckWithinUlp(std::numeric_limits<Float16>::quiet_NaN(),
352+
std::numeric_limits<Float16>::quiet_NaN(), 10);
353+
CheckNotWithinUlp(std::numeric_limits<Float16>::infinity(),
354+
-std::numeric_limits<Float16>::infinity(), 10);
355+
CheckNotWithinUlp(Float16(12.34f), -std::numeric_limits<Float16>::infinity(), 10);
356+
CheckNotWithinUlp(Float16(12.34f), std::numeric_limits<Float16>::quiet_NaN(), 10);
357+
CheckNotWithinUlp(Float16(12.34f), Float16(-12.34f), 10);
358+
}
359+
287360
TEST(AssertTestWithinUlp, Basics) {
288361
AssertWithinUlp(123.4567, 123.45670000000015, 11);
289362
AssertWithinUlp(123.456f, 123.456085f, 11);
363+
AssertWithinUlp(Float16(123.456f), Float16(124.143501f), 11);
290364
#ifndef _WIN32
291365
// GH-47442
292366
EXPECT_FATAL_FAILURE(AssertWithinUlp(123.4567, 123.45670000000015, 10),
293367
"not within 10 ulps");
294368
EXPECT_FATAL_FAILURE(AssertWithinUlp(123.456f, 123.456085f, 10), "not within 10 ulps");
369+
EXPECT_FATAL_FAILURE(AssertWithinUlp(Float16(123.456f), Float16(124.143501f), 10),
370+
"not within 10 ulps");
295371
#endif
296372
}
297373

298-
} // namespace arrow
374+
} // namespace arrow::util

cpp/src/arrow/testing/math.cc

Lines changed: 74 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -17,53 +17,94 @@
1717

1818
#include "arrow/testing/math.h"
1919

20+
#include <algorithm>
2021
#include <cmath>
2122
#include <limits>
23+
#include <type_traits>
2224

2325
#include <gtest/gtest.h>
2426

27+
#include "arrow/util/float16.h"
2528
#include "arrow/util/logging_internal.h"
29+
#include "arrow/util/ubsan.h"
2630

2731
namespace arrow {
2832
namespace {
2933

3034
template <typename Float>
31-
bool WithinUlpOneWay(Float left, Float right, int n_ulps) {
32-
// The delta between 1.0 and the FP value immediately before it.
33-
// We're using this value because `frexp` returns a mantissa between 0.5 and 1.0.
34-
static const Float kOneUlp = Float(1.0) - std::nextafter(Float(1.0), Float(0.0));
35+
struct FloatToUInt;
3536

36-
DCHECK_GE(n_ulps, 1);
37+
template <>
38+
struct FloatToUInt<double> {
39+
using Type = uint64_t;
40+
};
3741

38-
if (left == 0) {
39-
return left == right;
40-
}
41-
if (left < 0) {
42-
left = -left;
43-
right = -right;
42+
template <>
43+
struct FloatToUInt<float> {
44+
using Type = uint32_t;
45+
};
46+
47+
template <>
48+
struct FloatToUInt<util::Float16> {
49+
using Type = uint16_t;
50+
};
51+
52+
template <typename Float>
53+
struct UlpDistanceUtil {
54+
public:
55+
using UIntType = typename FloatToUInt<Float>::Type;
56+
static constexpr UIntType kNumberOfBits = sizeof(Float) * 8;
57+
static constexpr UIntType kSignMask = static_cast<UIntType>(1) << (kNumberOfBits - 1);
58+
59+
// This implementation is inspired by:
60+
// https://randomascii.wordpress.com/2012/02/25/comparing-floating-point-numbers-2012-edition/
61+
static UIntType UlpDistance(Float left, Float right) {
62+
auto unsigned_left = util::SafeCopy<UIntType>(left);
63+
auto unsigned_right = util::SafeCopy<UIntType>(right);
64+
auto biased_left = ConvertSignAndMagnitudeToBiased(unsigned_left);
65+
auto biased_right = ConvertSignAndMagnitudeToBiased(unsigned_right);
66+
if (biased_left > biased_right) {
67+
std::swap(biased_left, biased_right);
68+
}
69+
return biased_right - biased_left;
4470
}
4571

46-
int left_exp;
47-
Float left_mant = std::frexp(left, &left_exp);
48-
Float delta = static_cast<Float>(n_ulps) * kOneUlp;
49-
Float lower_bound = std::ldexp(left_mant - delta, left_exp);
50-
Float upper_bound = std::ldexp(left_mant + delta, left_exp);
51-
return right >= lower_bound && right <= upper_bound;
52-
}
72+
private:
73+
// Source reference (GoogleTest):
74+
// https://github.com/google/googletest/blob/1b96fa13f549387b7549cc89e1a785cf143a1a50/googletest/include/gtest/internal/gtest-internal.h#L345-L368
75+
static UIntType ConvertSignAndMagnitudeToBiased(UIntType value) {
76+
if (value & kSignMask) {
77+
return ~value + 1;
78+
} else {
79+
return value | kSignMask;
80+
}
81+
}
82+
};
5383

5484
template <typename Float>
5585
bool WithinUlpGeneric(Float left, Float right, int n_ulps) {
56-
if (std::isnan(left) || std::isnan(right)) {
57-
return std::isnan(left) == std::isnan(right);
58-
}
59-
if (!std::isfinite(left) || !std::isfinite(right)) {
60-
return left == right;
86+
if constexpr (std::is_same_v<Float, util::Float16>) {
87+
if (left.is_nan() || right.is_nan()) {
88+
return left.is_nan() == right.is_nan();
89+
} else if (left.is_infinity() || right.is_infinity()) {
90+
return left == right;
91+
}
92+
} else {
93+
if (std::isnan(left) || std::isnan(right)) {
94+
return std::isnan(left) == std::isnan(right);
95+
}
96+
if (!std::isfinite(left) || !std::isfinite(right)) {
97+
return left == right;
98+
}
6199
}
100+
62101
if (n_ulps == 0) {
63102
return left == right;
64103
}
65-
return (std::abs(left) <= std::abs(right)) ? WithinUlpOneWay(left, right, n_ulps)
66-
: WithinUlpOneWay(right, left, n_ulps);
104+
105+
DCHECK_GE(n_ulps, 1);
106+
return UlpDistanceUtil<Float>::UlpDistance(left, right) <=
107+
static_cast<uint64_t>(n_ulps);
67108
}
68109

69110
template <typename Float>
@@ -75,6 +116,10 @@ void AssertWithinUlpGeneric(Float left, Float right, int n_ulps) {
75116

76117
} // namespace
77118

119+
bool WithinUlp(util::Float16 left, util::Float16 right, int n_ulps) {
120+
return WithinUlpGeneric(left, right, n_ulps);
121+
}
122+
78123
bool WithinUlp(float left, float right, int n_ulps) {
79124
return WithinUlpGeneric(left, right, n_ulps);
80125
}
@@ -83,6 +128,10 @@ bool WithinUlp(double left, double right, int n_ulps) {
83128
return WithinUlpGeneric(left, right, n_ulps);
84129
}
85130

131+
void AssertWithinUlp(util::Float16 left, util::Float16 right, int n_ulps) {
132+
AssertWithinUlpGeneric(left, right, n_ulps);
133+
}
134+
86135
void AssertWithinUlp(float left, float right, int n_ulps) {
87136
AssertWithinUlpGeneric(left, right, n_ulps);
88137
}

cpp/src/arrow/testing/math.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,14 +18,19 @@
1818
#pragma once
1919

2020
#include "arrow/testing/visibility.h"
21+
#include "arrow/type_fwd.h"
2122

2223
namespace arrow {
2324

25+
ARROW_TESTING_EXPORT
26+
bool WithinUlp(util::Float16 left, util::Float16 right, int n_ulps);
2427
ARROW_TESTING_EXPORT
2528
bool WithinUlp(float left, float right, int n_ulps);
2629
ARROW_TESTING_EXPORT
2730
bool WithinUlp(double left, double right, int n_ulps);
2831

32+
ARROW_TESTING_EXPORT
33+
void AssertWithinUlp(util::Float16 left, util::Float16 right, int n_ulps);
2934
ARROW_TESTING_EXPORT
3035
void AssertWithinUlp(float left, float right, int n_ulps);
3136
ARROW_TESTING_EXPORT

0 commit comments

Comments
 (0)