Skip to content

Commit abbcd53

Browse files
authored
GH-46063: [C++][Compute] Fix the issue that MinMax kernel emits -inf/inf for all-NaN input (#48459)
### Rationale for this change Our MinMax kernels emit -inf/inf for all-NaN input array, which doesn't make sense. ### What changes are included in this PR? Initialize the running min/max value from -inf/inf to NaN, so we can leverage the nice property that: `std::fmin/fmax(all-NaN) = NaN` `std::fmin/fmax(NaN, non-NaN) = non-NaN` ### Are these changes tested? Test included. ### Are there any user-facing changes? None. * GitHub Issue: #46063 Authored-by: Rossi Sun <[email protected]> Signed-off-by: Rossi Sun <[email protected]>
1 parent 5bda712 commit abbcd53

File tree

4 files changed

+134
-42
lines changed

4 files changed

+134
-42
lines changed

cpp/src/arrow/acero/hash_aggregate_test.cc

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2000,6 +2000,74 @@ TEST_P(GroupBy, MinMaxScalar) {
20002000
}
20012001
}
20022002

2003+
TEST_P(GroupBy, MinMaxWithNaN) {
2004+
auto in_schema = schema({
2005+
field("argument1", float32()),
2006+
field("argument2", float64()),
2007+
field("key", int64()),
2008+
});
2009+
for (bool use_threads : {true, false}) {
2010+
SCOPED_TRACE(use_threads ? "parallel/merged" : "serial");
2011+
2012+
auto table = TableFromJSON(in_schema, {R"([
2013+
[NaN, NaN, 1],
2014+
[NaN, NaN, 2],
2015+
[NaN, NaN, 3]
2016+
])",
2017+
R"([
2018+
[NaN, NaN, 1],
2019+
[-Inf, -Inf, 2],
2020+
[Inf, Inf, 3]
2021+
])"});
2022+
2023+
ASSERT_OK_AND_ASSIGN(Datum aggregated_and_grouped,
2024+
GroupByTest(
2025+
{
2026+
table->GetColumnByName("argument1"),
2027+
table->GetColumnByName("argument1"),
2028+
table->GetColumnByName("argument1"),
2029+
table->GetColumnByName("argument2"),
2030+
table->GetColumnByName("argument2"),
2031+
table->GetColumnByName("argument2"),
2032+
},
2033+
{table->GetColumnByName("key")},
2034+
{
2035+
{"hash_min", nullptr},
2036+
{"hash_max", nullptr},
2037+
{"hash_min_max", nullptr},
2038+
{"hash_min", nullptr},
2039+
{"hash_max", nullptr},
2040+
{"hash_min_max", nullptr},
2041+
},
2042+
use_threads));
2043+
ValidateOutput(aggregated_and_grouped);
2044+
SortBy({"key_0"}, &aggregated_and_grouped);
2045+
2046+
AssertDatumsEqual(ArrayFromJSON(struct_({
2047+
field("key_0", int64()),
2048+
field("hash_min", float32()),
2049+
field("hash_max", float32()),
2050+
field("hash_min_max", struct_({
2051+
field("min", float32()),
2052+
field("max", float32()),
2053+
})),
2054+
field("hash_min", float64()),
2055+
field("hash_max", float64()),
2056+
field("hash_min_max", struct_({
2057+
field("min", float64()),
2058+
field("max", float64()),
2059+
})),
2060+
}),
2061+
R"([
2062+
[1, NaN, NaN, {"min": NaN, "max": NaN}, NaN, NaN, {"min": NaN, "max": NaN}],
2063+
[2, -Inf, -Inf, {"min": -Inf, "max": -Inf}, -Inf, -Inf, {"min": -Inf, "max": -Inf}],
2064+
[3, Inf, Inf, {"min": Inf, "max": Inf}, Inf, Inf, {"min": Inf, "max": Inf}]
2065+
])"),
2066+
aggregated_and_grouped,
2067+
/*verbose=*/true);
2068+
}
2069+
}
2070+
20032071
TEST_P(GroupBy, AnyAndAll) {
20042072
for (bool use_threads : {true, false}) {
20052073
SCOPED_TRACE(use_threads ? "parallel/merged" : "serial");

cpp/src/arrow/compute/kernels/aggregate_basic.inc.cc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -694,8 +694,8 @@ struct MinMaxState<ArrowType, SimdLevel, enable_if_floating_point<ArrowType>> {
694694
this->max = std::fmax(this->max, value);
695695
}
696696

697-
T min = std::numeric_limits<T>::infinity();
698-
T max = -std::numeric_limits<T>::infinity();
697+
T min = std::numeric_limits<T>::quiet_NaN();
698+
T max = std::numeric_limits<T>::quiet_NaN();
699699
bool has_nulls = false;
700700
};
701701

cpp/src/arrow/compute/kernels/aggregate_test.cc

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1841,6 +1841,24 @@ class TestPrimitiveMinMaxKernel : public ::testing::Test {
18411841
AssertMinMaxIsNull(array, options);
18421842
}
18431843

1844+
void AssertMinMaxIsNaN(const Datum& array, const ScalarAggregateOptions& options) {
1845+
ASSERT_OK_AND_ASSIGN(Datum out, MinMax(array, options));
1846+
for (const auto& val : out.scalar_as<StructScalar>().value) {
1847+
ASSERT_TRUE(std::isnan(checked_cast<const ScalarType&>(*val).value));
1848+
}
1849+
}
1850+
1851+
void AssertMinMaxIsNaN(const std::string& json, const ScalarAggregateOptions& options) {
1852+
auto array = ArrayFromJSON(type_singleton(), json);
1853+
AssertMinMaxIsNaN(array, options);
1854+
}
1855+
1856+
void AssertMinMaxIsNaN(const std::vector<std::string>& json,
1857+
const ScalarAggregateOptions& options) {
1858+
auto array = ChunkedArrayFromJSON(type_singleton(), json);
1859+
AssertMinMaxIsNaN(array, options);
1860+
}
1861+
18441862
std::shared_ptr<DataType> type_singleton() {
18451863
return default_type_instance<ArrowType>();
18461864
}
@@ -1963,6 +1981,9 @@ TYPED_TEST(TestFloatingMinMaxKernel, Floats) {
19631981
this->AssertMinMaxIs("[5, Inf, 2, 3, 4]", 2.0, INFINITY, options);
19641982
this->AssertMinMaxIs("[5, NaN, 2, 3, 4]", 2, 5, options);
19651983
this->AssertMinMaxIs("[5, -Inf, 2, 3, 4]", -INFINITY, 5, options);
1984+
this->AssertMinMaxIs("[NaN, null, 42]", 42, 42, options);
1985+
this->AssertMinMaxIsNaN("[NaN, NaN]", options);
1986+
this->AssertMinMaxIsNaN("[NaN, null]", options);
19661987
this->AssertMinMaxIs(chunked_input1, 1, 9, options);
19671988
this->AssertMinMaxIs(chunked_input2, 1, 9, options);
19681989
this->AssertMinMaxIs(chunked_input3, 1, 9, options);
@@ -1980,6 +2001,7 @@ TYPED_TEST(TestFloatingMinMaxKernel, Floats) {
19802001
this->AssertMinMaxIs("[5, -Inf, 2, 3, 4]", -INFINITY, 5, options);
19812002
this->AssertMinMaxIsNull("[5, null, 2, 3, 4]", options);
19822003
this->AssertMinMaxIsNull("[5, -Inf, null, 3, 4]", options);
2004+
this->AssertMinMaxIsNull("[NaN, null]", options);
19832005
this->AssertMinMaxIsNull(chunked_input1, options);
19842006
this->AssertMinMaxIsNull(chunked_input2, options);
19852007
this->AssertMinMaxIsNull(chunked_input3, options);

cpp/src/arrow/compute/kernels/hash_aggregate.cc

Lines changed: 42 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
// under the License.
1717

1818
#include <cmath>
19+
#include <concepts>
1920
#include <functional>
2021
#include <memory>
2122
#include <string>
@@ -270,52 +271,53 @@ struct GroupedCountImpl : public GroupedAggregator {
270271
// ----------------------------------------------------------------------
271272
// MinMax implementation
272273

274+
// XXX: Consider making these concepts complete and moving to public header.
275+
276+
template <typename T>
277+
concept CBooleanConcept = std::same_as<T, bool>;
278+
279+
// XXX: Ideally we want to have std::floating_point<Float16> = true.
280+
template <typename T>
281+
concept CFloatingPointConcept = std::floating_point<T> || std::same_as<T, util::Float16>;
282+
283+
template <typename T>
284+
concept CDecimalConcept = std::same_as<T, Decimal32> || std::same_as<T, Decimal64> ||
285+
std::same_as<T, Decimal128> || std::same_as<T, Decimal256>;
286+
273287
template <typename CType>
274288
struct AntiExtrema {
275289
static constexpr CType anti_min() { return std::numeric_limits<CType>::max(); }
276290
static constexpr CType anti_max() { return std::numeric_limits<CType>::min(); }
277291
};
278292

279-
template <>
280-
struct AntiExtrema<bool> {
281-
static constexpr bool anti_min() { return true; }
282-
static constexpr bool anti_max() { return false; }
283-
};
284-
285-
template <>
286-
struct AntiExtrema<float> {
287-
static constexpr float anti_min() { return std::numeric_limits<float>::infinity(); }
288-
static constexpr float anti_max() { return -std::numeric_limits<float>::infinity(); }
293+
template <CBooleanConcept CType>
294+
struct AntiExtrema<CType> {
295+
static constexpr CType anti_min() { return true; }
296+
static constexpr CType anti_max() { return false; }
289297
};
290298

291-
template <>
292-
struct AntiExtrema<double> {
293-
static constexpr double anti_min() { return std::numeric_limits<double>::infinity(); }
294-
static constexpr double anti_max() { return -std::numeric_limits<double>::infinity(); }
299+
template <CFloatingPointConcept CType>
300+
struct AntiExtrema<CType> {
301+
static constexpr CType anti_min() { return std::numeric_limits<CType>::quiet_NaN(); }
302+
static constexpr CType anti_max() { return std::numeric_limits<CType>::quiet_NaN(); }
295303
};
296304

297-
template <>
298-
struct AntiExtrema<Decimal32> {
299-
static constexpr Decimal32 anti_min() { return BasicDecimal32::GetMaxSentinel(); }
300-
static constexpr Decimal32 anti_max() { return BasicDecimal32::GetMinSentinel(); }
305+
template <CDecimalConcept CType>
306+
struct AntiExtrema<CType> {
307+
static constexpr CType anti_min() { return CType::GetMaxSentinel(); }
308+
static constexpr CType anti_max() { return CType::GetMinSentinel(); }
301309
};
302310

303-
template <>
304-
struct AntiExtrema<Decimal64> {
305-
static constexpr Decimal64 anti_min() { return BasicDecimal64::GetMaxSentinel(); }
306-
static constexpr Decimal64 anti_max() { return BasicDecimal64::GetMinSentinel(); }
307-
};
308-
309-
template <>
310-
struct AntiExtrema<Decimal128> {
311-
static constexpr Decimal128 anti_min() { return BasicDecimal128::GetMaxSentinel(); }
312-
static constexpr Decimal128 anti_max() { return BasicDecimal128::GetMinSentinel(); }
311+
template <typename CType>
312+
struct MinMaxOp {
313+
static constexpr CType min(CType a, CType b) { return std::min(a, b); }
314+
static constexpr CType max(CType a, CType b) { return std::max(a, b); }
313315
};
314316

315-
template <>
316-
struct AntiExtrema<Decimal256> {
317-
static constexpr Decimal256 anti_min() { return BasicDecimal256::GetMaxSentinel(); }
318-
static constexpr Decimal256 anti_max() { return BasicDecimal256::GetMinSentinel(); }
317+
template <CFloatingPointConcept CType>
318+
struct MinMaxOp<CType> {
319+
static constexpr CType min(CType a, CType b) { return std::fmin(a, b); }
320+
static constexpr CType max(CType a, CType b) { return std::fmax(a, b); }
319321
};
320322

321323
template <typename Type, typename Enable = void>
@@ -352,8 +354,8 @@ struct GroupedMinMaxImpl final : public GroupedAggregator {
352354
VisitGroupedValues<Type>(
353355
batch,
354356
[&](uint32_t g, CType val) {
355-
GetSet::Set(raw_mins, g, std::min(GetSet::Get(raw_mins, g), val));
356-
GetSet::Set(raw_maxes, g, std::max(GetSet::Get(raw_maxes, g), val));
357+
GetSet::Set(raw_mins, g, MinMaxOp<CType>::min(GetSet::Get(raw_mins, g), val));
358+
GetSet::Set(raw_maxes, g, MinMaxOp<CType>::max(GetSet::Get(raw_maxes, g), val));
357359
bit_util::SetBit(has_values_.mutable_data(), g);
358360
},
359361
[&](uint32_t g) { bit_util::SetBit(has_nulls_.mutable_data(), g); });
@@ -373,12 +375,12 @@ struct GroupedMinMaxImpl final : public GroupedAggregator {
373375
auto g = group_id_mapping.GetValues<uint32_t>(1);
374376
for (uint32_t other_g = 0; static_cast<int64_t>(other_g) < group_id_mapping.length;
375377
++other_g, ++g) {
376-
GetSet::Set(
377-
raw_mins, *g,
378-
std::min(GetSet::Get(raw_mins, *g), GetSet::Get(other_raw_mins, other_g)));
379-
GetSet::Set(
380-
raw_maxes, *g,
381-
std::max(GetSet::Get(raw_maxes, *g), GetSet::Get(other_raw_maxes, other_g)));
378+
GetSet::Set(raw_mins, *g,
379+
MinMaxOp<CType>::min(GetSet::Get(raw_mins, *g),
380+
GetSet::Get(other_raw_mins, other_g)));
381+
GetSet::Set(raw_maxes, *g,
382+
MinMaxOp<CType>::max(GetSet::Get(raw_maxes, *g),
383+
GetSet::Get(other_raw_maxes, other_g)));
382384

383385
if (bit_util::GetBit(other->has_values_.data(), other_g)) {
384386
bit_util::SetBit(has_values_.mutable_data(), *g);

0 commit comments

Comments
 (0)