Skip to content

Commit caf4f70

Browse files
authored
GH-46739: [C++] Fix Float16 signed zero/NaN equality comparisons (#46973)
### Rationale for this change Equality comparisons between half-floats (used in their scalar/array `Equals` methods) do not properly handle `EqualOptions::nans_equal` and `EqualOptions::signed_zeros_equal`. ### What changes are included in this PR? - Internal fixes to the current comparison behavior and additional tests as needed - Prevents Float16 NaNs from being randomly generated by test utilities by default (matching behavior for float/double) ### Are these changes tested? Yes ### Are there any user-facing changes? No * GitHub Issue: #46739 Authored-by: Benjamin Harkins <[email protected]> Signed-off-by: Antoine Pitrou <[email protected]>
1 parent 2987165 commit caf4f70

File tree

9 files changed

+253
-83
lines changed

9 files changed

+253
-83
lines changed

cpp/src/arrow/array/array_test.cc

Lines changed: 30 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -2120,16 +2120,21 @@ void CheckSliceApproxEquals() {
21202120
ASSERT_TRUE(slice1->ApproxEquals(slice2));
21212121
}
21222122

2123+
template <typename ArrowType>
2124+
using NumericArgType = std::conditional_t<is_half_float_type<ArrowType>::value, Float16,
2125+
typename ArrowType::c_type>;
2126+
21232127
template <typename TYPE>
21242128
void CheckFloatingNanEquality() {
2129+
using V = NumericArgType<TYPE>;
21252130
std::shared_ptr<Array> a, b;
21262131
std::shared_ptr<DataType> type = TypeTraits<TYPE>::type_singleton();
21272132

2128-
const auto nan_value = static_cast<typename TYPE::c_type>(NAN);
2133+
const auto nan_value = std::numeric_limits<V>::quiet_NaN();
21292134

21302135
// NaN in a null entry
2131-
ArrayFromVector<TYPE>(type, {true, false}, {0.5, nan_value}, &a);
2132-
ArrayFromVector<TYPE>(type, {true, false}, {0.5, nan_value}, &b);
2136+
ArrayFromVector<TYPE, V>(type, {true, false}, {V(0.5), nan_value}, &a);
2137+
ArrayFromVector<TYPE, V>(type, {true, false}, {V(0.5), nan_value}, &b);
21332138
ASSERT_TRUE(a->Equals(b));
21342139
ASSERT_TRUE(b->Equals(a));
21352140
ASSERT_TRUE(a->ApproxEquals(b));
@@ -2140,8 +2145,8 @@ void CheckFloatingNanEquality() {
21402145
ASSERT_TRUE(b->RangeEquals(a, 1, 2, 1));
21412146

21422147
// NaN in a valid entry
2143-
ArrayFromVector<TYPE>(type, {false, true}, {0.5, nan_value}, &a);
2144-
ArrayFromVector<TYPE>(type, {false, true}, {0.5, nan_value}, &b);
2148+
ArrayFromVector<TYPE, V>(type, {false, true}, {V(0.5), nan_value}, &a);
2149+
ArrayFromVector<TYPE, V>(type, {false, true}, {V(0.5), nan_value}, &b);
21452150
ASSERT_FALSE(a->Equals(b));
21462151
ASSERT_FALSE(b->Equals(a));
21472152
ASSERT_TRUE(a->Equals(b, EqualOptions().nans_equal(true)));
@@ -2160,8 +2165,8 @@ void CheckFloatingNanEquality() {
21602165
ASSERT_TRUE(b->RangeEquals(a, 0, 1, 0));
21612166

21622167
// NaN != non-NaN
2163-
ArrayFromVector<TYPE>(type, {false, true}, {0.5, nan_value}, &a);
2164-
ArrayFromVector<TYPE>(type, {false, true}, {0.5, 0.0}, &b);
2168+
ArrayFromVector<TYPE, V>(type, {false, true}, {V(0.5), nan_value}, &a);
2169+
ArrayFromVector<TYPE, V>(type, {false, true}, {V(0.5), V(0.0)}, &b);
21652170
ASSERT_FALSE(a->Equals(b));
21662171
ASSERT_FALSE(b->Equals(a));
21672172
ASSERT_FALSE(a->Equals(b, EqualOptions().nans_equal(true)));
@@ -2182,15 +2187,16 @@ void CheckFloatingNanEquality() {
21822187

21832188
template <typename TYPE>
21842189
void CheckFloatingInfinityEquality() {
2190+
using V = NumericArgType<TYPE>;
21852191
std::shared_ptr<Array> a, b;
21862192
std::shared_ptr<DataType> type = TypeTraits<TYPE>::type_singleton();
21872193

2188-
const auto infinity = std::numeric_limits<typename TYPE::c_type>::infinity();
2194+
const auto infinity = std::numeric_limits<V>::infinity();
21892195

21902196
for (auto nans_equal : {false, true}) {
21912197
// Infinity in a null entry
2192-
ArrayFromVector<TYPE>(type, {true, false}, {0.5, infinity}, &a);
2193-
ArrayFromVector<TYPE>(type, {true, false}, {0.5, -infinity}, &b);
2198+
ArrayFromVector<TYPE, V>(type, {true, false}, {V(0.5), infinity}, &a);
2199+
ArrayFromVector<TYPE, V>(type, {true, false}, {V(0.5), -infinity}, &b);
21942200
ASSERT_TRUE(a->Equals(b));
21952201
ASSERT_TRUE(b->Equals(a));
21962202
ASSERT_TRUE(a->ApproxEquals(b, EqualOptions().atol(1e-5).nans_equal(nans_equal)));
@@ -2201,8 +2207,8 @@ void CheckFloatingInfinityEquality() {
22012207
ASSERT_TRUE(b->RangeEquals(a, 1, 2, 1));
22022208

22032209
// Infinity in a valid entry
2204-
ArrayFromVector<TYPE>(type, {false, true}, {0.5, infinity}, &a);
2205-
ArrayFromVector<TYPE>(type, {false, true}, {0.5, infinity}, &b);
2210+
ArrayFromVector<TYPE, V>(type, {false, true}, {V(0.5), infinity}, &a);
2211+
ArrayFromVector<TYPE, V>(type, {false, true}, {V(0.5), infinity}, &b);
22062212
ASSERT_TRUE(a->Equals(b));
22072213
ASSERT_TRUE(b->Equals(a));
22082214
ASSERT_TRUE(a->ApproxEquals(b, EqualOptions().atol(1e-5).nans_equal(nans_equal)));
@@ -2219,17 +2225,17 @@ void CheckFloatingInfinityEquality() {
22192225
ASSERT_TRUE(b->RangeEquals(a, 0, 1, 0));
22202226

22212227
// Infinity != non-infinity
2222-
ArrayFromVector<TYPE>(type, {false, true}, {0.5, -infinity}, &a);
2223-
ArrayFromVector<TYPE>(type, {false, true}, {0.5, 0.0}, &b);
2228+
ArrayFromVector<TYPE, V>(type, {false, true}, {V(0.5), -infinity}, &a);
2229+
ArrayFromVector<TYPE, V>(type, {false, true}, {V(0.5), V(0.0)}, &b);
22242230
ASSERT_FALSE(a->Equals(b));
22252231
ASSERT_FALSE(b->Equals(a));
22262232
ASSERT_FALSE(a->ApproxEquals(b, EqualOptions().atol(1e-5).nans_equal(nans_equal)));
22272233
ASSERT_FALSE(b->ApproxEquals(a));
22282234
ASSERT_FALSE(a->ApproxEquals(b, EqualOptions().atol(1e-5).nans_equal(nans_equal)));
22292235
ASSERT_FALSE(b->ApproxEquals(a, EqualOptions().atol(1e-5).nans_equal(nans_equal)));
22302236
// Infinity != Negative infinity
2231-
ArrayFromVector<TYPE>(type, {true, true}, {0.5, -infinity}, &a);
2232-
ArrayFromVector<TYPE>(type, {true, true}, {0.5, infinity}, &b);
2237+
ArrayFromVector<TYPE, V>(type, {true, true}, {V(0.5), -infinity}, &a);
2238+
ArrayFromVector<TYPE, V>(type, {true, true}, {V(0.5), infinity}, &b);
22332239
ASSERT_FALSE(a->Equals(b));
22342240
ASSERT_FALSE(b->Equals(a));
22352241
ASSERT_FALSE(a->ApproxEquals(b));
@@ -2249,11 +2255,12 @@ void CheckFloatingInfinityEquality() {
22492255

22502256
template <typename TYPE>
22512257
void CheckFloatingZeroEquality() {
2258+
using V = NumericArgType<TYPE>;
22522259
std::shared_ptr<Array> a, b;
22532260
std::shared_ptr<DataType> type = TypeTraits<TYPE>::type_singleton();
22542261

2255-
ArrayFromVector<TYPE>(type, {true, false}, {0.0, 1.0}, &a);
2256-
ArrayFromVector<TYPE>(type, {true, false}, {0.0, 1.0}, &b);
2262+
ArrayFromVector<TYPE, V>(type, {true, false}, {V(0.0), V(1.0)}, &a);
2263+
ArrayFromVector<TYPE, V>(type, {true, false}, {V(0.0), V(1.0)}, &b);
22572264
ASSERT_TRUE(a->Equals(b));
22582265
ASSERT_TRUE(b->Equals(a));
22592266
for (auto nans_equal : {false, true}) {
@@ -2269,8 +2276,8 @@ void CheckFloatingZeroEquality() {
22692276
}
22702277
}
22712278

2272-
ArrayFromVector<TYPE>(type, {true, false}, {0.0, 1.0}, &a);
2273-
ArrayFromVector<TYPE>(type, {true, false}, {-0.0, 1.0}, &b);
2279+
ArrayFromVector<TYPE, V>(type, {true, false}, {V(0.0), V(1.0)}, &a);
2280+
ArrayFromVector<TYPE, V>(type, {true, false}, {V(-0.0), V(1.0)}, &b);
22742281
for (auto nans_equal : {false, true}) {
22752282
auto opts = EqualOptions().nans_equal(nans_equal);
22762283
ASSERT_TRUE(a->Equals(b, opts));
@@ -2306,16 +2313,19 @@ TEST(TestPrimitiveAdHoc, FloatingSliceApproxEquals) {
23062313
TEST(TestPrimitiveAdHoc, FloatingNanEquality) {
23072314
CheckFloatingNanEquality<FloatType>();
23082315
CheckFloatingNanEquality<DoubleType>();
2316+
CheckFloatingNanEquality<HalfFloatType>();
23092317
}
23102318

23112319
TEST(TestPrimitiveAdHoc, FloatingInfinityEquality) {
23122320
CheckFloatingInfinityEquality<FloatType>();
23132321
CheckFloatingInfinityEquality<DoubleType>();
2322+
CheckFloatingInfinityEquality<HalfFloatType>();
23142323
}
23152324

23162325
TEST(TestPrimitiveAdHoc, FloatingZeroEquality) {
23172326
CheckFloatingZeroEquality<FloatType>();
23182327
CheckFloatingZeroEquality<DoubleType>();
2328+
CheckFloatingZeroEquality<HalfFloatType>();
23192329
}
23202330

23212331
// ----------------------------------------------------------------------

cpp/src/arrow/compare.cc

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -110,7 +110,7 @@ struct FloatingEquality<uint16_t, Flags> {
110110
bool operator()(uint16_t x, uint16_t y) const {
111111
Float16 f_x = Float16::FromBits(x);
112112
Float16 f_y = Float16::FromBits(y);
113-
if (x == y) {
113+
if (f_x == f_y) {
114114
return Flags::signed_zeros_equal || (f_x.signbit() == f_y.signbit());
115115
}
116116
if (Flags::nans_equal && f_x.is_nan() && f_y.is_nan()) {
@@ -171,7 +171,8 @@ void VisitFloatingEquality(const EqualOptions& options, bool floating_approximat
171171
}
172172

173173
inline bool IdentityImpliesEqualityNansNotEqual(const DataType& type) {
174-
if (type.id() == Type::FLOAT || type.id() == Type::DOUBLE) {
174+
if (type.id() == Type::FLOAT || type.id() == Type::DOUBLE ||
175+
type.id() == Type::HALF_FLOAT) {
175176
return false;
176177
}
177178
for (const auto& child : type.fields()) {

cpp/src/arrow/scalar.h

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@
3737
#include "arrow/type_traits.h"
3838
#include "arrow/util/compare.h"
3939
#include "arrow/util/decimal.h"
40+
#include "arrow/util/float16.h"
4041
#include "arrow/util/visibility.h"
4142
#include "arrow/visit_type_inline.h"
4243

@@ -245,6 +246,12 @@ struct ARROW_EXPORT UInt64Scalar : public NumericScalar<UInt64Type> {
245246

246247
struct ARROW_EXPORT HalfFloatScalar : public NumericScalar<HalfFloatType> {
247248
using NumericScalar<HalfFloatType>::NumericScalar;
249+
250+
explicit HalfFloatScalar(util::Float16 value)
251+
: NumericScalar(value.bits(), float16()) {}
252+
253+
HalfFloatScalar(util::Float16 value, std::shared_ptr<DataType> type)
254+
: NumericScalar(value.bits(), std::move(type)) {}
248255
};
249256

250257
struct ARROW_EXPORT FloatScalar : public NumericScalar<FloatType> {
@@ -969,6 +976,18 @@ struct MakeScalarImpl {
969976
return Status::OK();
970977
}
971978

979+
// This isn't captured by the generic case above because `util::Float16` isn't implicity
980+
// convertible to `uint16_t` (HalfFloat's ValueType)
981+
template <typename T>
982+
std::enable_if_t<std::is_same_v<std::decay_t<ValueRef>, util::Float16> &&
983+
is_half_float_type<T>::value,
984+
Status>
985+
Visit(const T& t) {
986+
out_ = std::make_shared<HalfFloatScalar>(static_cast<ValueRef>(value_),
987+
std::move(type_));
988+
return Status::OK();
989+
}
990+
972991
Status Visit(const ExtensionType& t) {
973992
ARROW_ASSIGN_OR_RAISE(auto storage,
974993
MakeScalar(t.storage_type(), static_cast<ValueRef>(value_)));

cpp/src/arrow/scalar_test.cc

Lines changed: 51 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -39,13 +39,15 @@
3939
#include "arrow/testing/random.h"
4040
#include "arrow/testing/util.h"
4141
#include "arrow/type_traits.h"
42+
#include "arrow/util/float16.h"
4243

4344
namespace arrow {
4445

4546
using compute::Cast;
4647
using compute::CastOptions;
4748
using internal::checked_cast;
4849
using internal::checked_pointer_cast;
50+
using util::Float16;
4951

5052
std::shared_ptr<Scalar> CheckMakeNullScalar(const std::shared_ptr<DataType>& type) {
5153
const auto scalar = MakeNullScalar(type);
@@ -201,22 +203,33 @@ TEST(TestScalar, IdentityCast) {
201203
*/
202204
}
203205

206+
template <typename ArrowType>
207+
using NumericArgType = std::conditional_t<is_half_float_type<ArrowType>::value, Float16,
208+
typename ArrowType::c_type>;
209+
204210
template <typename T>
205211
class TestNumericScalar : public ::testing::Test {
206212
public:
207213
TestNumericScalar() = default;
208214
};
209215

210-
TYPED_TEST_SUITE(TestNumericScalar, NumericArrowTypes);
216+
using NumericArrowTypesPlusHalfFloat =
217+
testing::Types<UInt8Type, UInt16Type, UInt32Type, UInt64Type, Int8Type, Int16Type,
218+
Int32Type, Int64Type, FloatType, DoubleType, HalfFloatType>;
219+
TYPED_TEST_SUITE(TestNumericScalar, NumericArrowTypesPlusHalfFloat);
211220

212221
TYPED_TEST(TestNumericScalar, Basics) {
213-
using T = typename TypeParam::c_type;
222+
using T = NumericArgType<TypeParam>;
214223
using ScalarType = typename TypeTraits<TypeParam>::ScalarType;
215224

216225
T value = static_cast<T>(1);
217226

218227
auto scalar_val = std::make_shared<ScalarType>(value);
219-
ASSERT_EQ(value, scalar_val->value);
228+
if constexpr (is_half_float_type<TypeParam>::value) {
229+
ASSERT_EQ(value, Float16::FromBits(scalar_val->value));
230+
} else {
231+
ASSERT_EQ(value, scalar_val->value);
232+
}
220233
ASSERT_TRUE(scalar_val->is_valid);
221234
ASSERT_OK(scalar_val->ValidateFull());
222235

@@ -227,8 +240,13 @@ TYPED_TEST(TestNumericScalar, Basics) {
227240
auto scalar_other = std::make_shared<ScalarType>(other_value);
228241
ASSERT_NE(*scalar_other, *scalar_val);
229242

230-
scalar_val->value = other_value;
231-
ASSERT_EQ(other_value, scalar_val->value);
243+
if constexpr (is_half_float_type<TypeParam>::value) {
244+
scalar_val->value = other_value.bits();
245+
ASSERT_EQ(other_value, Float16::FromBits(scalar_val->value));
246+
} else {
247+
scalar_val->value = other_value;
248+
ASSERT_EQ(other_value, scalar_val->value);
249+
}
232250
ASSERT_EQ(*scalar_other, *scalar_val);
233251

234252
ScalarType stack_val;
@@ -255,72 +273,72 @@ TYPED_TEST(TestNumericScalar, Basics) {
255273
ASSERT_OK(two->ValidateFull());
256274

257275
ASSERT_TRUE(null->Equals(*null_value));
258-
ASSERT_TRUE(one->Equals(ScalarType(1)));
259-
ASSERT_FALSE(one->Equals(ScalarType(2)));
260-
ASSERT_TRUE(two->Equals(ScalarType(2)));
261-
ASSERT_FALSE(two->Equals(ScalarType(3)));
276+
ASSERT_TRUE(one->Equals(ScalarType(static_cast<T>(1))));
277+
ASSERT_FALSE(one->Equals(ScalarType(static_cast<T>(2))));
278+
ASSERT_TRUE(two->Equals(ScalarType(static_cast<T>(2))));
279+
ASSERT_FALSE(two->Equals(ScalarType(static_cast<T>(3))));
262280

263281
ASSERT_TRUE(null->ApproxEquals(*null_value));
264-
ASSERT_TRUE(one->ApproxEquals(ScalarType(1)));
265-
ASSERT_FALSE(one->ApproxEquals(ScalarType(2)));
266-
ASSERT_TRUE(two->ApproxEquals(ScalarType(2)));
267-
ASSERT_FALSE(two->ApproxEquals(ScalarType(3)));
282+
ASSERT_TRUE(one->ApproxEquals(ScalarType(static_cast<T>(1))));
283+
ASSERT_FALSE(one->ApproxEquals(ScalarType(static_cast<T>(2))));
284+
ASSERT_TRUE(two->ApproxEquals(ScalarType(static_cast<T>(2))));
285+
ASSERT_FALSE(two->ApproxEquals(ScalarType(static_cast<T>(3))));
268286
}
269287

270288
TYPED_TEST(TestNumericScalar, Hashing) {
271-
using T = typename TypeParam::c_type;
289+
using T = NumericArgType<TypeParam>;
272290
using ScalarType = typename TypeTraits<TypeParam>::ScalarType;
273291

274292
std::unordered_set<std::shared_ptr<Scalar>, Scalar::Hash, Scalar::PtrsEqual> set;
275293
set.emplace(std::make_shared<ScalarType>());
276-
for (T i = 0; i < 10; ++i) {
277-
set.emplace(std::make_shared<ScalarType>(i));
294+
for (int i = 0; i < 10; ++i) {
295+
ASSERT_TRUE(set.emplace(std::make_shared<ScalarType>(static_cast<T>(i))).second);
278296
}
279297

280298
ASSERT_FALSE(set.emplace(std::make_shared<ScalarType>()).second);
281-
for (T i = 0; i < 10; ++i) {
282-
ASSERT_FALSE(set.emplace(std::make_shared<ScalarType>(i)).second);
299+
for (int i = 0; i < 10; ++i) {
300+
ASSERT_FALSE(set.emplace(std::make_shared<ScalarType>(static_cast<T>(i))).second);
283301
}
284302
}
285303

286304
TYPED_TEST(TestNumericScalar, MakeScalar) {
287-
using T = typename TypeParam::c_type;
305+
using T = NumericArgType<TypeParam>;
288306
using ScalarType = typename TypeTraits<TypeParam>::ScalarType;
289307
auto type = TypeTraits<TypeParam>::type_singleton();
290308

291309
std::shared_ptr<Scalar> three = MakeScalar(static_cast<T>(3));
292310
ASSERT_OK(three->ValidateFull());
293-
ASSERT_EQ(ScalarType(3), *three);
311+
ASSERT_EQ(ScalarType(static_cast<T>(3)), *three);
294312

295-
AssertMakeScalar(ScalarType(3), type, static_cast<T>(3));
313+
AssertMakeScalar(ScalarType(static_cast<T>(3)), type, static_cast<T>(3));
296314

297-
AssertParseScalar(type, "3", ScalarType(3));
315+
AssertParseScalar(type, "3", ScalarType(static_cast<T>(3)));
298316
}
299317

300318
template <typename T>
301319
class TestRealScalar : public ::testing::Test {
302320
public:
303-
using CType = typename T::c_type;
321+
using ValueType = NumericArgType<T>;
304322
using ScalarType = typename TypeTraits<T>::ScalarType;
305323

306324
void SetUp() {
307325
type_ = TypeTraits<T>::type_singleton();
308326

309-
scalar_val_ = std::make_shared<ScalarType>(static_cast<CType>(1));
327+
scalar_val_ = std::make_shared<ScalarType>(static_cast<ValueType>(1));
310328
ASSERT_TRUE(scalar_val_->is_valid);
311329

312-
scalar_other_ = std::make_shared<ScalarType>(static_cast<CType>(1.1));
330+
scalar_other_ = std::make_shared<ScalarType>(static_cast<ValueType>(1.1));
313331
ASSERT_TRUE(scalar_other_->is_valid);
314332

315-
scalar_zero_ = std::make_shared<ScalarType>(static_cast<CType>(0.0));
316-
scalar_other_zero_ = std::make_shared<ScalarType>(static_cast<CType>(0.0));
317-
scalar_neg_zero_ = std::make_shared<ScalarType>(static_cast<CType>(-0.0));
333+
scalar_zero_ = std::make_shared<ScalarType>(static_cast<ValueType>(0.0));
334+
scalar_other_zero_ = std::make_shared<ScalarType>(static_cast<ValueType>(0.0));
335+
scalar_neg_zero_ = std::make_shared<ScalarType>(static_cast<ValueType>(-0.0));
318336

319-
const CType nan_value = std::numeric_limits<CType>::quiet_NaN();
337+
const auto nan_value = std::numeric_limits<ValueType>::quiet_NaN();
320338
scalar_nan_ = std::make_shared<ScalarType>(nan_value);
321339
ASSERT_TRUE(scalar_nan_->is_valid);
322340

323-
const CType other_nan_value = std::numeric_limits<CType>::quiet_NaN();
341+
const auto other_nan_value = std::numeric_limits<ValueType>::quiet_NaN();
324342
scalar_other_nan_ = std::make_shared<ScalarType>(other_nan_value);
325343
ASSERT_TRUE(scalar_other_nan_->is_valid);
326344
}
@@ -522,7 +540,9 @@ class TestRealScalar : public ::testing::Test {
522540
scalar_zero_, scalar_other_zero_, scalar_neg_zero_;
523541
};
524542

525-
TYPED_TEST_SUITE(TestRealScalar, RealArrowTypes);
543+
using RealArrowTypesPlusHalfFloat =
544+
::testing::Types<FloatType, DoubleType, HalfFloatType>;
545+
TYPED_TEST_SUITE(TestRealScalar, RealArrowTypesPlusHalfFloat);
526546

527547
TYPED_TEST(TestRealScalar, NanEquals) { this->TestNanEquals(); }
528548

@@ -1181,8 +1201,6 @@ TEST(TestDayTimeIntervalScalars, Basics) {
11811201
ASSERT_TRUE(first->Equals(ts_val2));
11821202
}
11831203

1184-
// TODO test HalfFloatScalar
1185-
11861204
TYPED_TEST(TestNumericScalar, Cast) {
11871205
auto type = TypeTraits<TypeParam>::type_singleton();
11881206

0 commit comments

Comments
 (0)