Skip to content

Commit d3fa8ef

Browse files
committed
perf: min/max groups accumulator
1 parent 64ae03e commit d3fa8ef

File tree

1 file changed

+171
-13
lines changed
  • datafusion/src/physical_plan/expressions

1 file changed

+171
-13
lines changed

datafusion/src/physical_plan/expressions/min_max.rs

Lines changed: 171 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -24,10 +24,11 @@ use std::sync::Arc;
2424
use crate::error::{DataFusionError, Result};
2525
use crate::physical_plan::groups_accumulator::GroupsAccumulator;
2626
use crate::physical_plan::groups_accumulator_flat_adapter::GroupsAccumulatorFlatAdapter;
27+
use crate::physical_plan::groups_accumulator_prim_op::PrimitiveGroupsAccumulator;
2728
use crate::physical_plan::{Accumulator, AggregateExpr, PhysicalExpr};
2829
use crate::scalar::ScalarValue;
2930
use arrow::compute;
30-
use arrow::datatypes::{DataType, TimeUnit};
31+
use arrow::datatypes::{ArrowPrimitiveType, DataType, TimeUnit};
3132
use arrow::{
3233
array::{
3334
ArrayRef, Float32Array, Float64Array, Int16Array, Int32Array, Int64Array,
@@ -108,12 +109,90 @@ impl AggregateExpr for Max {
108109
fn create_groups_accumulator(
109110
&self,
110111
) -> arrow::error::Result<Option<Box<dyn GroupsAccumulator>>> {
111-
let data_type = self.data_type.clone();
112-
Ok(Some(Box::new(
113-
GroupsAccumulatorFlatAdapter::<MaxAccumulator>::new(move || {
114-
MaxAccumulator::try_new(&data_type)
115-
}),
116-
)))
112+
macro_rules! make_max_accumulator {
113+
($T:ty) => {
114+
Box::new(
115+
PrimitiveGroupsAccumulator::<$T, $T, _, _>::new(
116+
&<$T as ArrowPrimitiveType>::DATA_TYPE,
117+
|x: &mut <$T as ArrowPrimitiveType>::Native,
118+
y: <$T as ArrowPrimitiveType>::Native| {
119+
*x = (*x).max(y);
120+
},
121+
|x: &mut <$T as ArrowPrimitiveType>::Native,
122+
y: <$T as ArrowPrimitiveType>::Native| {
123+
*x = (*x).max(y);
124+
},
125+
)
126+
.with_starting_value(<$T as ArrowPrimitiveType>::Native::MIN),
127+
)
128+
};
129+
}
130+
let acc: Box<dyn GroupsAccumulator> = match &self.data_type {
131+
DataType::Float64 => make_max_accumulator!(arrow::datatypes::Float64Type),
132+
DataType::Float32 => make_max_accumulator!(arrow::datatypes::Float32Type),
133+
DataType::Int64 => make_max_accumulator!(arrow::datatypes::Int64Type),
134+
DataType::Int96 => make_max_accumulator!(arrow::datatypes::Int96Type),
135+
DataType::Int64Decimal(0) => {
136+
make_max_accumulator!(arrow::datatypes::Int64Decimal0Type)
137+
}
138+
DataType::Int64Decimal(1) => {
139+
make_max_accumulator!(arrow::datatypes::Int64Decimal1Type)
140+
}
141+
DataType::Int64Decimal(2) => {
142+
make_max_accumulator!(arrow::datatypes::Int64Decimal2Type)
143+
}
144+
DataType::Int64Decimal(3) => {
145+
make_max_accumulator!(arrow::datatypes::Int64Decimal3Type)
146+
}
147+
DataType::Int64Decimal(4) => {
148+
make_max_accumulator!(arrow::datatypes::Int64Decimal4Type)
149+
}
150+
DataType::Int64Decimal(5) => {
151+
make_max_accumulator!(arrow::datatypes::Int64Decimal5Type)
152+
}
153+
DataType::Int64Decimal(10) => {
154+
make_max_accumulator!(arrow::datatypes::Int64Decimal10Type)
155+
}
156+
DataType::Int96Decimal(0) => {
157+
make_max_accumulator!(arrow::datatypes::Int96Decimal0Type)
158+
}
159+
DataType::Int96Decimal(1) => {
160+
make_max_accumulator!(arrow::datatypes::Int96Decimal1Type)
161+
}
162+
DataType::Int96Decimal(2) => {
163+
make_max_accumulator!(arrow::datatypes::Int96Decimal2Type)
164+
}
165+
DataType::Int96Decimal(3) => {
166+
make_max_accumulator!(arrow::datatypes::Int96Decimal3Type)
167+
}
168+
DataType::Int96Decimal(4) => {
169+
make_max_accumulator!(arrow::datatypes::Int96Decimal4Type)
170+
}
171+
DataType::Int96Decimal(5) => {
172+
make_max_accumulator!(arrow::datatypes::Int96Decimal5Type)
173+
}
174+
DataType::Int96Decimal(10) => {
175+
make_max_accumulator!(arrow::datatypes::Int96Decimal10Type)
176+
}
177+
DataType::Int32 => make_max_accumulator!(arrow::datatypes::Int32Type),
178+
DataType::Int16 => make_max_accumulator!(arrow::datatypes::Int16Type),
179+
DataType::Int8 => make_max_accumulator!(arrow::datatypes::Int8Type),
180+
DataType::UInt64 => make_max_accumulator!(arrow::datatypes::UInt64Type),
181+
DataType::UInt32 => make_max_accumulator!(arrow::datatypes::UInt32Type),
182+
DataType::UInt16 => make_max_accumulator!(arrow::datatypes::UInt16Type),
183+
DataType::UInt8 => make_max_accumulator!(arrow::datatypes::UInt8Type),
184+
_ => {
185+
// Not all types (strings) can use primitive accumulators. And strings use
186+
// max_string as the $OP in typed_min_match_batch.
187+
188+
// Timestamps presently take this branch.
189+
let data_type = self.data_type.clone();
190+
Box::new(GroupsAccumulatorFlatAdapter::<MaxAccumulator>::new(
191+
move || MaxAccumulator::try_new(&data_type),
192+
))
193+
}
194+
};
195+
Ok(Some(acc))
117196
}
118197

119198
fn name(&self) -> &str {
@@ -547,12 +626,91 @@ impl AggregateExpr for Min {
547626
fn create_groups_accumulator(
548627
&self,
549628
) -> arrow::error::Result<Option<Box<dyn GroupsAccumulator>>> {
550-
let data_type = self.data_type.clone();
551-
Ok(Some(Box::new(
552-
GroupsAccumulatorFlatAdapter::<MinAccumulator>::new(move || {
553-
MinAccumulator::try_new(&data_type)
554-
}),
555-
)))
629+
macro_rules! make_min_accumulator {
630+
($T:ty) => {
631+
Box::new(
632+
PrimitiveGroupsAccumulator::<$T, $T, _, _>::new(
633+
&<$T as ArrowPrimitiveType>::DATA_TYPE,
634+
|x: &mut <$T as ArrowPrimitiveType>::Native,
635+
y: <$T as ArrowPrimitiveType>::Native| {
636+
*x = (*x).min(y);
637+
},
638+
|x: &mut <$T as ArrowPrimitiveType>::Native,
639+
y: <$T as ArrowPrimitiveType>::Native| {
640+
*x = (*x).min(y);
641+
},
642+
)
643+
.with_starting_value(<$T as ArrowPrimitiveType>::Native::MAX),
644+
)
645+
};
646+
}
647+
648+
let acc: Box<dyn GroupsAccumulator> = match &self.data_type {
649+
DataType::Float64 => make_min_accumulator!(arrow::datatypes::Float64Type),
650+
DataType::Float32 => make_min_accumulator!(arrow::datatypes::Float32Type),
651+
DataType::Int64 => make_min_accumulator!(arrow::datatypes::Int64Type),
652+
DataType::Int96 => make_min_accumulator!(arrow::datatypes::Int96Type),
653+
DataType::Int64Decimal(0) => {
654+
make_min_accumulator!(arrow::datatypes::Int64Decimal0Type)
655+
}
656+
DataType::Int64Decimal(1) => {
657+
make_min_accumulator!(arrow::datatypes::Int64Decimal1Type)
658+
}
659+
DataType::Int64Decimal(2) => {
660+
make_min_accumulator!(arrow::datatypes::Int64Decimal2Type)
661+
}
662+
DataType::Int64Decimal(3) => {
663+
make_min_accumulator!(arrow::datatypes::Int64Decimal3Type)
664+
}
665+
DataType::Int64Decimal(4) => {
666+
make_min_accumulator!(arrow::datatypes::Int64Decimal4Type)
667+
}
668+
DataType::Int64Decimal(5) => {
669+
make_min_accumulator!(arrow::datatypes::Int64Decimal5Type)
670+
}
671+
DataType::Int64Decimal(10) => {
672+
make_min_accumulator!(arrow::datatypes::Int64Decimal10Type)
673+
}
674+
DataType::Int96Decimal(0) => {
675+
make_min_accumulator!(arrow::datatypes::Int96Decimal0Type)
676+
}
677+
DataType::Int96Decimal(1) => {
678+
make_min_accumulator!(arrow::datatypes::Int96Decimal1Type)
679+
}
680+
DataType::Int96Decimal(2) => {
681+
make_min_accumulator!(arrow::datatypes::Int96Decimal2Type)
682+
}
683+
DataType::Int96Decimal(3) => {
684+
make_min_accumulator!(arrow::datatypes::Int96Decimal3Type)
685+
}
686+
DataType::Int96Decimal(4) => {
687+
make_min_accumulator!(arrow::datatypes::Int96Decimal4Type)
688+
}
689+
DataType::Int96Decimal(5) => {
690+
make_min_accumulator!(arrow::datatypes::Int96Decimal5Type)
691+
}
692+
DataType::Int96Decimal(10) => {
693+
make_min_accumulator!(arrow::datatypes::Int96Decimal10Type)
694+
}
695+
DataType::Int32 => make_min_accumulator!(arrow::datatypes::Int32Type),
696+
DataType::Int16 => make_min_accumulator!(arrow::datatypes::Int16Type),
697+
DataType::Int8 => make_min_accumulator!(arrow::datatypes::Int8Type),
698+
DataType::UInt64 => make_min_accumulator!(arrow::datatypes::UInt64Type),
699+
DataType::UInt32 => make_min_accumulator!(arrow::datatypes::UInt32Type),
700+
DataType::UInt16 => make_min_accumulator!(arrow::datatypes::UInt16Type),
701+
DataType::UInt8 => make_min_accumulator!(arrow::datatypes::UInt8Type),
702+
_ => {
703+
// Not all types (strings) can use primitive accumulators. And strings use
704+
// min_string as the $OP in typed_min_match_batch.
705+
706+
// Timestamps presently take this branch.
707+
let data_type = self.data_type.clone();
708+
Box::new(GroupsAccumulatorFlatAdapter::<MinAccumulator>::new(
709+
move || MinAccumulator::try_new(&data_type),
710+
))
711+
}
712+
};
713+
Ok(Some(acc))
556714
}
557715

558716
fn name(&self) -> &str {

0 commit comments

Comments
 (0)