Skip to content

Commit e6e50de

Browse files
authored
GH-36753: [C++] Properly pretty-print and diff HalfFloatArrays (#46857)
### Rationale for this change #36753 asked for this to be implemented now that a half-float library was available. ### What changes are included in this PR? Pretty printing and diffing of HalfFloatArrays now displays floating values instead of uint16. ### Are these changes tested? Yes, with tests in C++ and Python. ### Are there any user-facing changes? Pretty-printing and diffing float16 will display as floating point and not uint16. * GitHub Issue: #36753 Authored-by: Eric Dinse <293818+dinse@users.noreply.github.com> Signed-off-by: Antoine Pitrou <antoine@python.org>
1 parent 03520f1 commit e6e50de

File tree

6 files changed

+78
-10
lines changed

6 files changed

+78
-10
lines changed

cpp/src/arrow/array/diff.cc

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@
4343
#include "arrow/type_traits.h"
4444
#include "arrow/util/bit_util.h"
4545
#include "arrow/util/checked_cast.h"
46+
#include "arrow/util/float16.h"
4647
#include "arrow/util/logging_internal.h"
4748
#include "arrow/util/range.h"
4849
#include "arrow/util/ree_util.h"
@@ -627,6 +628,14 @@ class MakeFormatterImpl {
627628
return Status::OK();
628629
}
629630

631+
Status Visit(const HalfFloatType&) {
632+
impl_ = [](const Array& array, int64_t index, std::ostream* os) {
633+
const auto& float16_arr = checked_cast<const HalfFloatArray&>(array);
634+
*os << arrow::util::Float16::FromBits(float16_arr.Value(index));
635+
};
636+
return Status::OK();
637+
}
638+
630639
// format Numerics with std::ostream defaults
631640
template <typename T>
632641
enable_if_number<T, Status> Visit(const T&) {

cpp/src/arrow/array/diff_test.cc

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
#include "arrow/testing/random.h"
3636
#include "arrow/testing/util.h"
3737
#include "arrow/type.h"
38+
#include "arrow/util/float16.h"
3839
#include "arrow/util/logging.h"
3940

4041
namespace arrow {
@@ -815,4 +816,19 @@ TEST_F(DiffTest, CompareRandomStruct) {
815816
}
816817
}
817818

819+
TEST_F(DiffTest, CompareHalfFloat) {
820+
auto first = ArrayFromJSON(float16(), "[1.1, 2.0, 2.5, 3.3]");
821+
auto second = ArrayFromJSON(float16(), "[1.1, 4.0, 3.5, 3.3]");
822+
auto expected_diff = R"(
823+
@@ -1, +1 @@
824+
-2
825+
-2.5
826+
+4
827+
+3.5
828+
)";
829+
830+
auto diff = first->Diff(*second);
831+
ASSERT_EQ(diff, expected_diff);
832+
}
833+
818834
} // namespace arrow

cpp/src/arrow/pretty_print.cc

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -239,12 +239,6 @@ class ArrayPrinter : public PrettyPrinter {
239239
return WritePrimitiveValues(array);
240240
}
241241

242-
Status WriteDataValues(const HalfFloatArray& array) {
243-
// XXX do not know how to format half floats yet
244-
StringFormatter<Int16Type> formatter{array.type().get()};
245-
return WritePrimitiveValues(array, &formatter);
246-
}
247-
248242
template <typename ArrayType, typename T = typename ArrayType::TypeClass>
249243
enable_if_has_string_view<T, Status> WriteDataValues(const ArrayType& array) {
250244
return WriteValues(array, [&](int64_t i) {

cpp/src/arrow/pretty_print_test.cc

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919

2020
#include <gtest/gtest.h>
2121

22+
#include <cmath>
2223
#include <cstdint>
2324
#include <cstring>
2425
#include <limits>
@@ -33,10 +34,13 @@
3334
#include "arrow/testing/builder.h"
3435
#include "arrow/testing/gtest_util.h"
3536
#include "arrow/type.h"
37+
#include "arrow/util/float16.h"
3638
#include "arrow/util/key_value_metadata.h"
3739

3840
namespace arrow {
3941

42+
using util::Float16;
43+
4044
class TestPrettyPrint : public ::testing::Test {
4145
public:
4246
void SetUp() {}
@@ -330,6 +334,37 @@ TEST_F(TestPrettyPrint, UInt64) {
330334
expected);
331335
}
332336

337+
TEST_F(TestPrettyPrint, HalfFloat) {
338+
static const char* expected = R"expected([
339+
-inf,
340+
-1234,
341+
-0,
342+
0,
343+
1,
344+
1.2001953125,
345+
2.5,
346+
3.9921875,
347+
4.125,
348+
10000,
349+
12344,
350+
inf,
351+
nan,
352+
null
353+
])expected";
354+
355+
std::vector<uint16_t> values = {
356+
Float16(-1e10f).bits(), Float16(-1234.0f).bits(), Float16(-0.0f).bits(),
357+
Float16(0.0f).bits(), Float16(1.0f).bits(), Float16(1.2f).bits(),
358+
Float16(2.5f).bits(), Float16(3.9921875f).bits(), Float16(4.125f).bits(),
359+
Float16(1e4f).bits(), Float16(12345.0f).bits(), Float16(1e5f).bits(),
360+
Float16(NAN).bits(), Float16(6.10f).bits()};
361+
362+
std::vector<bool> is_valid(values.size(), true);
363+
is_valid.back() = false;
364+
365+
CheckPrimitive<HalfFloatType, uint16_t>({0, 10}, is_valid, values, expected);
366+
}
367+
333368
TEST_F(TestPrettyPrint, DateTimeTypes) {
334369
std::vector<bool> is_valid = {true, true, false, true, false};
335370

python/pyarrow/tests/test_array.py

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -568,6 +568,8 @@ def test_array_diff():
568568
arr2 = pa.array(['foo', 'bar', None], type=pa.utf8())
569569
arr3 = pa.array([1, 2, 3])
570570
arr4 = pa.array([[], [1], None], type=pa.list_(pa.int64()))
571+
arr5 = pa.array([1.5, 3, 6], type=pa.float16())
572+
arr6 = pa.array([1, 3], type=pa.float16())
571573

572574
assert arr1.diff(arr1) == ''
573575
assert arr1.diff(arr2) == '''
@@ -579,6 +581,14 @@ def test_array_diff():
579581
assert arr1.diff(arr3).strip() == '# Array types differed: string vs int64'
580582
assert arr1.diff(arr4).strip() == ('# Array types differed: string vs '
581583
'list<item: int64>')
584+
assert arr5.diff(arr5) == ''
585+
assert arr5.diff(arr6) == '''
586+
@@ -0, +0 @@
587+
-1.5
588+
+1
589+
@@ -2, +2 @@
590+
-6
591+
'''
582592

583593

584594
def test_array_iter():
@@ -1706,9 +1716,13 @@ def test_floating_point_truncate_unsafe():
17061716

17071717
def test_half_float_array_from_python():
17081718
# GH-46611
1709-
arr = pa.array([1.0, 2.0, 3, None, 12345.6789, 1.234567], type=pa.float16())
1719+
vals = [-5, 0, 1.0, 2.0, 3, None, 12345.6789, 1.234567, float('inf')]
1720+
arr = pa.array(vals, type=pa.float16())
17101721
assert arr.type == pa.float16()
1711-
assert arr.to_pylist() == [1.0, 2.0, 3.0, None, 12344.0, 1.234375]
1722+
assert arr.to_pylist() == [-5, 0, 1.0, 2.0, 3, None, 12344.0,
1723+
1.234375, float('inf')]
1724+
assert str(arr) == ("[\n -5,\n 0,\n 1,\n 2,\n 3,\n null,\n 12344,"
1725+
"\n 1.234375,\n inf\n]")
17121726
msg1 = "Could not convert 'a' with type str: tried to convert to float16"
17131727
with pytest.raises(pa.ArrowInvalid, match=msg1):
17141728
arr = pa.array(['a', 3, None], type=pa.float16())

python/pyarrow/types.pxi

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4459,8 +4459,8 @@ def float16():
44594459
>>> a
44604460
<pyarrow.lib.HalfFloatArray object at ...>
44614461
[
4462-
15872,
4463-
32256
4462+
1.5,
4463+
nan
44644464
]
44654465
44664466
Note that unlike other float types, if you convert this array

0 commit comments

Comments
 (0)