Skip to content

Commit 9f3cc7b

Browse files
authored
move min_batch/max_batch to functions-aggregate-common (#16593)
1 parent 8d772e5 commit 9f3cc7b

File tree

6 files changed

+372
-343
lines changed

6 files changed

+372
-343
lines changed

Cargo.lock

Lines changed: 1 addition & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

datafusion/functions-aggregate-common/src/lib.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
pub mod accumulator;
3535
pub mod aggregate;
3636
pub mod merge_arrays;
37+
pub mod min_max;
3738
pub mod order;
3839
pub mod stats;
3940
pub mod tdigest;
Lines changed: 353 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,353 @@
1+
// Licensed to the Apache Software Foundation (ASF) under one
2+
// or more contributor license agreements. See the NOTICE file
3+
// distributed with this work for additional information
4+
// regarding copyright ownership. The ASF licenses this file
5+
// to you under the Apache License, Version 2.0 (the
6+
// "License"); you may not use this file except in compliance
7+
// with the License. You may obtain a copy of the License at
8+
//
9+
// http://www.apache.org/licenses/LICENSE-2.0
10+
//
11+
// Unless required by applicable law or agreed to in writing,
12+
// software distributed under the License is distributed on an
13+
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
// KIND, either express or implied. See the License for the
15+
// specific language governing permissions and limitations
16+
// under the License.
17+
18+
//! Basic min/max functionality shared across DataFusion aggregate functions
19+
20+
use arrow::array::{
21+
ArrayRef, AsArray as _, BinaryArray, BinaryViewArray, BooleanArray, Date32Array,
22+
Date64Array, Decimal128Array, Decimal256Array, DurationMicrosecondArray,
23+
DurationMillisecondArray, DurationNanosecondArray, DurationSecondArray, Float16Array,
24+
Float32Array, Float64Array, Int16Array, Int32Array, Int64Array, Int8Array,
25+
IntervalDayTimeArray, IntervalMonthDayNanoArray, IntervalYearMonthArray,
26+
LargeBinaryArray, LargeStringArray, StringArray, StringViewArray,
27+
Time32MillisecondArray, Time32SecondArray, Time64MicrosecondArray,
28+
Time64NanosecondArray, TimestampMicrosecondArray, TimestampMillisecondArray,
29+
TimestampNanosecondArray, TimestampSecondArray, UInt16Array, UInt32Array,
30+
UInt64Array, UInt8Array,
31+
};
32+
use arrow::compute;
33+
use arrow::datatypes::{DataType, IntervalUnit, TimeUnit};
34+
use datafusion_common::{downcast_value, Result, ScalarValue};
35+
use std::cmp::Ordering;
36+
37+
// Statically-typed version of min/max(array) -> ScalarValue for string types
38+
macro_rules! typed_min_max_batch_string {
39+
($VALUES:expr, $ARRAYTYPE:ident, $SCALAR:ident, $OP:ident) => {{
40+
let array = downcast_value!($VALUES, $ARRAYTYPE);
41+
let value = compute::$OP(array);
42+
let value = value.and_then(|e| Some(e.to_string()));
43+
ScalarValue::$SCALAR(value)
44+
}};
45+
}
46+
47+
// Statically-typed version of min/max(array) -> ScalarValue for binary types.
48+
macro_rules! typed_min_max_batch_binary {
49+
($VALUES:expr, $ARRAYTYPE:ident, $SCALAR:ident, $OP:ident) => {{
50+
let array = downcast_value!($VALUES, $ARRAYTYPE);
51+
let value = compute::$OP(array);
52+
let value = value.and_then(|e| Some(e.to_vec()));
53+
ScalarValue::$SCALAR(value)
54+
}};
55+
}
56+
57+
// Statically-typed version of min/max(array) -> ScalarValue for non-string types.
58+
macro_rules! typed_min_max_batch {
59+
($VALUES:expr, $ARRAYTYPE:ident, $SCALAR:ident, $OP:ident $(, $EXTRA_ARGS:ident)*) => {{
60+
let array = downcast_value!($VALUES, $ARRAYTYPE);
61+
let value = compute::$OP(array);
62+
ScalarValue::$SCALAR(value, $($EXTRA_ARGS.clone()),*)
63+
}};
64+
}
65+
66+
// Statically-typed version of min/max(array) -> ScalarValue for non-string types.
67+
// this is a macro to support both operations (min and max).
68+
macro_rules! min_max_batch {
69+
($VALUES:expr, $OP:ident) => {{
70+
match $VALUES.data_type() {
71+
DataType::Null => ScalarValue::Null,
72+
DataType::Decimal128(precision, scale) => {
73+
typed_min_max_batch!(
74+
$VALUES,
75+
Decimal128Array,
76+
Decimal128,
77+
$OP,
78+
precision,
79+
scale
80+
)
81+
}
82+
DataType::Decimal256(precision, scale) => {
83+
typed_min_max_batch!(
84+
$VALUES,
85+
Decimal256Array,
86+
Decimal256,
87+
$OP,
88+
precision,
89+
scale
90+
)
91+
}
92+
// all types that have a natural order
93+
DataType::Float64 => {
94+
typed_min_max_batch!($VALUES, Float64Array, Float64, $OP)
95+
}
96+
DataType::Float32 => {
97+
typed_min_max_batch!($VALUES, Float32Array, Float32, $OP)
98+
}
99+
DataType::Float16 => {
100+
typed_min_max_batch!($VALUES, Float16Array, Float16, $OP)
101+
}
102+
DataType::Int64 => typed_min_max_batch!($VALUES, Int64Array, Int64, $OP),
103+
DataType::Int32 => typed_min_max_batch!($VALUES, Int32Array, Int32, $OP),
104+
DataType::Int16 => typed_min_max_batch!($VALUES, Int16Array, Int16, $OP),
105+
DataType::Int8 => typed_min_max_batch!($VALUES, Int8Array, Int8, $OP),
106+
DataType::UInt64 => typed_min_max_batch!($VALUES, UInt64Array, UInt64, $OP),
107+
DataType::UInt32 => typed_min_max_batch!($VALUES, UInt32Array, UInt32, $OP),
108+
DataType::UInt16 => typed_min_max_batch!($VALUES, UInt16Array, UInt16, $OP),
109+
DataType::UInt8 => typed_min_max_batch!($VALUES, UInt8Array, UInt8, $OP),
110+
DataType::Timestamp(TimeUnit::Second, tz_opt) => {
111+
typed_min_max_batch!(
112+
$VALUES,
113+
TimestampSecondArray,
114+
TimestampSecond,
115+
$OP,
116+
tz_opt
117+
)
118+
}
119+
DataType::Timestamp(TimeUnit::Millisecond, tz_opt) => typed_min_max_batch!(
120+
$VALUES,
121+
TimestampMillisecondArray,
122+
TimestampMillisecond,
123+
$OP,
124+
tz_opt
125+
),
126+
DataType::Timestamp(TimeUnit::Microsecond, tz_opt) => typed_min_max_batch!(
127+
$VALUES,
128+
TimestampMicrosecondArray,
129+
TimestampMicrosecond,
130+
$OP,
131+
tz_opt
132+
),
133+
DataType::Timestamp(TimeUnit::Nanosecond, tz_opt) => typed_min_max_batch!(
134+
$VALUES,
135+
TimestampNanosecondArray,
136+
TimestampNanosecond,
137+
$OP,
138+
tz_opt
139+
),
140+
DataType::Date32 => typed_min_max_batch!($VALUES, Date32Array, Date32, $OP),
141+
DataType::Date64 => typed_min_max_batch!($VALUES, Date64Array, Date64, $OP),
142+
DataType::Time32(TimeUnit::Second) => {
143+
typed_min_max_batch!($VALUES, Time32SecondArray, Time32Second, $OP)
144+
}
145+
DataType::Time32(TimeUnit::Millisecond) => {
146+
typed_min_max_batch!(
147+
$VALUES,
148+
Time32MillisecondArray,
149+
Time32Millisecond,
150+
$OP
151+
)
152+
}
153+
DataType::Time64(TimeUnit::Microsecond) => {
154+
typed_min_max_batch!(
155+
$VALUES,
156+
Time64MicrosecondArray,
157+
Time64Microsecond,
158+
$OP
159+
)
160+
}
161+
DataType::Time64(TimeUnit::Nanosecond) => {
162+
typed_min_max_batch!(
163+
$VALUES,
164+
Time64NanosecondArray,
165+
Time64Nanosecond,
166+
$OP
167+
)
168+
}
169+
DataType::Interval(IntervalUnit::YearMonth) => {
170+
typed_min_max_batch!(
171+
$VALUES,
172+
IntervalYearMonthArray,
173+
IntervalYearMonth,
174+
$OP
175+
)
176+
}
177+
DataType::Interval(IntervalUnit::DayTime) => {
178+
typed_min_max_batch!($VALUES, IntervalDayTimeArray, IntervalDayTime, $OP)
179+
}
180+
DataType::Interval(IntervalUnit::MonthDayNano) => {
181+
typed_min_max_batch!(
182+
$VALUES,
183+
IntervalMonthDayNanoArray,
184+
IntervalMonthDayNano,
185+
$OP
186+
)
187+
}
188+
DataType::Duration(TimeUnit::Second) => {
189+
typed_min_max_batch!($VALUES, DurationSecondArray, DurationSecond, $OP)
190+
}
191+
DataType::Duration(TimeUnit::Millisecond) => {
192+
typed_min_max_batch!(
193+
$VALUES,
194+
DurationMillisecondArray,
195+
DurationMillisecond,
196+
$OP
197+
)
198+
}
199+
DataType::Duration(TimeUnit::Microsecond) => {
200+
typed_min_max_batch!(
201+
$VALUES,
202+
DurationMicrosecondArray,
203+
DurationMicrosecond,
204+
$OP
205+
)
206+
}
207+
DataType::Duration(TimeUnit::Nanosecond) => {
208+
typed_min_max_batch!(
209+
$VALUES,
210+
DurationNanosecondArray,
211+
DurationNanosecond,
212+
$OP
213+
)
214+
}
215+
other => {
216+
// This should have been handled before
217+
return datafusion_common::internal_err!(
218+
"Min/Max accumulator not implemented for type {:?}",
219+
other
220+
);
221+
}
222+
}
223+
}};
224+
}
225+
226+
/// dynamically-typed min(array) -> ScalarValue
227+
pub fn min_batch(values: &ArrayRef) -> Result<ScalarValue> {
228+
Ok(match values.data_type() {
229+
DataType::Utf8 => {
230+
typed_min_max_batch_string!(values, StringArray, Utf8, min_string)
231+
}
232+
DataType::LargeUtf8 => {
233+
typed_min_max_batch_string!(values, LargeStringArray, LargeUtf8, min_string)
234+
}
235+
DataType::Utf8View => {
236+
typed_min_max_batch_string!(
237+
values,
238+
StringViewArray,
239+
Utf8View,
240+
min_string_view
241+
)
242+
}
243+
DataType::Boolean => {
244+
typed_min_max_batch!(values, BooleanArray, Boolean, min_boolean)
245+
}
246+
DataType::Binary => {
247+
typed_min_max_batch_binary!(&values, BinaryArray, Binary, min_binary)
248+
}
249+
DataType::LargeBinary => {
250+
typed_min_max_batch_binary!(
251+
&values,
252+
LargeBinaryArray,
253+
LargeBinary,
254+
min_binary
255+
)
256+
}
257+
DataType::BinaryView => {
258+
typed_min_max_batch_binary!(
259+
&values,
260+
BinaryViewArray,
261+
BinaryView,
262+
min_binary_view
263+
)
264+
}
265+
DataType::Struct(_) => min_max_batch_generic(values, Ordering::Greater)?,
266+
DataType::List(_) => min_max_batch_generic(values, Ordering::Greater)?,
267+
DataType::LargeList(_) => min_max_batch_generic(values, Ordering::Greater)?,
268+
DataType::FixedSizeList(_, _) => {
269+
min_max_batch_generic(values, Ordering::Greater)?
270+
}
271+
DataType::Dictionary(_, _) => {
272+
let values = values.as_any_dictionary().values();
273+
min_batch(values)?
274+
}
275+
_ => min_max_batch!(values, min),
276+
})
277+
}
278+
279+
/// Generic min/max implementation for complex types
280+
fn min_max_batch_generic(array: &ArrayRef, ordering: Ordering) -> Result<ScalarValue> {
281+
if array.len() == array.null_count() {
282+
return ScalarValue::try_from(array.data_type());
283+
}
284+
let mut extreme = ScalarValue::try_from_array(array, 0)?;
285+
for i in 1..array.len() {
286+
let current = ScalarValue::try_from_array(array, i)?;
287+
if current.is_null() {
288+
continue;
289+
}
290+
if extreme.is_null() {
291+
extreme = current;
292+
continue;
293+
}
294+
if let Some(cmp) = extreme.partial_cmp(&current) {
295+
if cmp == ordering {
296+
extreme = current;
297+
}
298+
}
299+
}
300+
301+
Ok(extreme)
302+
}
303+
304+
/// dynamically-typed max(array) -> ScalarValue
305+
pub fn max_batch(values: &ArrayRef) -> Result<ScalarValue> {
306+
Ok(match values.data_type() {
307+
DataType::Utf8 => {
308+
typed_min_max_batch_string!(values, StringArray, Utf8, max_string)
309+
}
310+
DataType::LargeUtf8 => {
311+
typed_min_max_batch_string!(values, LargeStringArray, LargeUtf8, max_string)
312+
}
313+
DataType::Utf8View => {
314+
typed_min_max_batch_string!(
315+
values,
316+
StringViewArray,
317+
Utf8View,
318+
max_string_view
319+
)
320+
}
321+
DataType::Boolean => {
322+
typed_min_max_batch!(values, BooleanArray, Boolean, max_boolean)
323+
}
324+
DataType::Binary => {
325+
typed_min_max_batch_binary!(&values, BinaryArray, Binary, max_binary)
326+
}
327+
DataType::BinaryView => {
328+
typed_min_max_batch_binary!(
329+
&values,
330+
BinaryViewArray,
331+
BinaryView,
332+
max_binary_view
333+
)
334+
}
335+
DataType::LargeBinary => {
336+
typed_min_max_batch_binary!(
337+
&values,
338+
LargeBinaryArray,
339+
LargeBinary,
340+
max_binary
341+
)
342+
}
343+
DataType::Struct(_) => min_max_batch_generic(values, Ordering::Less)?,
344+
DataType::List(_) => min_max_batch_generic(values, Ordering::Less)?,
345+
DataType::LargeList(_) => min_max_batch_generic(values, Ordering::Less)?,
346+
DataType::FixedSizeList(_, _) => min_max_batch_generic(values, Ordering::Less)?,
347+
DataType::Dictionary(_, _) => {
348+
let values = values.as_any_dictionary().values();
349+
max_batch(values)?
350+
}
351+
_ => min_max_batch!(values, max),
352+
})
353+
}

0 commit comments

Comments
 (0)