Skip to content

Commit 9402202

Browse files
committed
Make Sum use PrimitiveGroupsAccumulator
1 parent 9cbd467 commit 9402202

File tree

10 files changed

+1449
-129
lines changed

10 files changed

+1449
-129
lines changed

datafusion/src/cube_ext/joinagg.rs

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -245,7 +245,8 @@ impl ExecutionPlan for CrossJoinAggExec {
245245
&AggregateMode::Full,
246246
self.group_expr.len(),
247247
)?;
248-
let mut accumulators = create_accumulation_state(&self.agg_expr)?;
248+
let mut accumulators: hash_aggregate::AccumulationState =
249+
create_accumulation_state(&self.agg_expr)?;
249250
for partition in 0..self.join.right.output_partitioning().partition_count() {
250251
let mut batches = self.join.right.execute(partition).await?;
251252
while let Some(right) = batches.next().await {
@@ -273,7 +274,7 @@ impl ExecutionPlan for CrossJoinAggExec {
273274
let out_schema = self.schema.clone();
274275
let r = hash_aggregate::create_batch_from_map(
275276
&AggregateMode::Full,
276-
&accumulators,
277+
accumulators,
277278
self.group_expr.len(),
278279
&out_schema,
279280
)?;

datafusion/src/physical_plan/aggregates.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -144,7 +144,7 @@ pub fn create_aggregate_expr(
144144
))
145145
}
146146
(AggregateFunction::Sum, false) => {
147-
Arc::new(expressions::Sum::new(arg, name, return_type))
147+
Arc::new(expressions::Sum::new(arg, name, return_type, &arg_types[0]))
148148
}
149149
(AggregateFunction::Sum, true) => {
150150
return Err(DataFusionError::NotImplemented(

datafusion/src/physical_plan/expressions/sum.rs

Lines changed: 83 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ use std::sync::Arc;
2424
use crate::error::{DataFusionError, Result};
2525
use crate::physical_plan::groups_accumulator::GroupsAccumulator;
2626
use crate::physical_plan::groups_accumulator_flat_adapter::GroupsAccumulatorFlatAdapter;
27+
use crate::physical_plan::groups_accumulator_prim_op::PrimitiveGroupsAccumulator;
2728
use crate::physical_plan::{Accumulator, AggregateExpr, PhysicalExpr};
2829
use crate::scalar::ScalarValue;
2930
use arrow::compute;
@@ -49,6 +50,7 @@ use smallvec::SmallVec;
4950
pub struct Sum {
5051
name: String,
5152
data_type: DataType,
53+
input_data_type: DataType,
5254
expr: Arc<dyn PhysicalExpr>,
5355
nullable: bool,
5456
}
@@ -80,11 +82,16 @@ impl Sum {
8082
expr: Arc<dyn PhysicalExpr>,
8183
name: impl Into<String>,
8284
data_type: DataType,
85+
input_data_type: &DataType,
8386
) -> Self {
87+
// Note: data_type = sum_return_type(input_data_type) in the actual caller, so we don't
88+
// really need two params. But, we keep the four params to break symmetry with other
89+
// accumulators and any code that might use 3 params, such as the generic_test_op macro.
8490
Self {
8591
name: name.into(),
8692
expr,
8793
data_type,
94+
input_data_type: input_data_type.clone(),
8895
nullable: true,
8996
}
9097
}
@@ -127,12 +134,64 @@ impl AggregateExpr for Sum {
127134
fn create_groups_accumulator(
128135
&self,
129136
) -> arrow::error::Result<Option<Box<dyn GroupsAccumulator>>> {
130-
let data_type = self.data_type.clone();
131-
Ok(Some(Box::new(
132-
GroupsAccumulatorFlatAdapter::<SumAccumulator>::new(move || {
133-
SumAccumulator::try_new(&data_type)
134-
}),
135-
)))
137+
use arrow::datatypes::ArrowPrimitiveType;
138+
139+
macro_rules! make_accumulator {
140+
($T:ty, $U:ty) => { Box::new(PrimitiveGroupsAccumulator::<
141+
$T,
142+
$U,
143+
_,
144+
_,
145+
>::new(& <$T as ArrowPrimitiveType>::DATA_TYPE, |x: &mut <$T as ArrowPrimitiveType>::Native, y: <$U as ArrowPrimitiveType>::Native| {
146+
*x = *x + (y as <$T as ArrowPrimitiveType>::Native);
147+
}, |x: &mut <$T as ArrowPrimitiveType>::Native, y: <$T as ArrowPrimitiveType>::Native| { *x = *x + y; })) };
148+
}
149+
150+
// Note that upstream uses x.add_wrapping(y) for the sum functions -- but here we just mimic
151+
// the current datafusion Sum accumulator implementation using native +. (That native +
152+
// specifically is the one in the expressions *x = *x + ... above.)
153+
Ok(Some(match (&self.data_type, &self.input_data_type) {
154+
(DataType::Int64, DataType::Int64) => make_accumulator!(arrow::datatypes::Int64Type, arrow::datatypes::Int64Type),
155+
(DataType::Int64, DataType::Int32) => make_accumulator!(arrow::datatypes::Int64Type, arrow::datatypes::Int32Type),
156+
(DataType::Int64, DataType::Int16) => make_accumulator!(arrow::datatypes::Int64Type, arrow::datatypes::Int16Type),
157+
(DataType::Int64, DataType::Int8) => make_accumulator!(arrow::datatypes::Int64Type, arrow::datatypes::Int8Type),
158+
159+
(DataType::Int96, DataType::Int96) => make_accumulator!(arrow::datatypes::Int96Type, arrow::datatypes::Int96Type),
160+
161+
(DataType::Int64Decimal(0), DataType::Int64Decimal(0)) => make_accumulator!(arrow::datatypes::Int64Decimal0Type, arrow::datatypes::Int64Decimal0Type),
162+
(DataType::Int64Decimal(1), DataType::Int64Decimal(1)) => make_accumulator!(arrow::datatypes::Int64Decimal1Type, arrow::datatypes::Int64Decimal1Type),
163+
(DataType::Int64Decimal(2), DataType::Int64Decimal(2)) => make_accumulator!(arrow::datatypes::Int64Decimal2Type, arrow::datatypes::Int64Decimal2Type),
164+
(DataType::Int64Decimal(3), DataType::Int64Decimal(3)) => make_accumulator!(arrow::datatypes::Int64Decimal3Type, arrow::datatypes::Int64Decimal3Type),
165+
(DataType::Int64Decimal(4), DataType::Int64Decimal(4)) => make_accumulator!(arrow::datatypes::Int64Decimal4Type, arrow::datatypes::Int64Decimal4Type),
166+
(DataType::Int64Decimal(5), DataType::Int64Decimal(5)) => make_accumulator!(arrow::datatypes::Int64Decimal5Type, arrow::datatypes::Int64Decimal5Type),
167+
(DataType::Int64Decimal(10), DataType::Int64Decimal(10)) => make_accumulator!(arrow::datatypes::Int64Decimal10Type, arrow::datatypes::Int64Decimal10Type),
168+
169+
(DataType::Int96Decimal(0), DataType::Int96Decimal(0)) => make_accumulator!(arrow::datatypes::Int96Decimal0Type, arrow::datatypes::Int96Decimal0Type),
170+
(DataType::Int96Decimal(1), DataType::Int96Decimal(1)) => make_accumulator!(arrow::datatypes::Int96Decimal1Type, arrow::datatypes::Int96Decimal1Type),
171+
(DataType::Int96Decimal(2), DataType::Int96Decimal(2)) => make_accumulator!(arrow::datatypes::Int96Decimal2Type, arrow::datatypes::Int96Decimal2Type),
172+
(DataType::Int96Decimal(3), DataType::Int96Decimal(3)) => make_accumulator!(arrow::datatypes::Int96Decimal3Type, arrow::datatypes::Int96Decimal3Type),
173+
(DataType::Int96Decimal(4), DataType::Int96Decimal(4)) => make_accumulator!(arrow::datatypes::Int96Decimal4Type, arrow::datatypes::Int96Decimal4Type),
174+
(DataType::Int96Decimal(5), DataType::Int96Decimal(5)) => make_accumulator!(arrow::datatypes::Int96Decimal5Type, arrow::datatypes::Int96Decimal5Type),
175+
(DataType::Int96Decimal(10), DataType::Int96Decimal(10)) => make_accumulator!(arrow::datatypes::Int96Decimal10Type, arrow::datatypes::Int96Decimal10Type),
176+
177+
(DataType::UInt64, DataType::UInt64) => make_accumulator!(arrow::datatypes::UInt64Type, arrow::datatypes::UInt64Type),
178+
(DataType::UInt64, DataType::UInt32) => make_accumulator!(arrow::datatypes::UInt64Type, arrow::datatypes::UInt32Type),
179+
(DataType::UInt64, DataType::UInt16) => make_accumulator!(arrow::datatypes::UInt64Type, arrow::datatypes::UInt16Type),
180+
(DataType::UInt64, DataType::UInt8) => make_accumulator!(arrow::datatypes::UInt64Type, arrow::datatypes::UInt8Type),
181+
182+
(DataType::Float32, DataType::Float32) => make_accumulator!(arrow::datatypes::Float32Type, arrow::datatypes::Float32Type),
183+
(DataType::Float64, DataType::Float64) => make_accumulator!(arrow::datatypes::Float64Type, arrow::datatypes::Float64Type),
184+
185+
_ => {
186+
// This case should never be reached because we've handled all sum_return_type
187+
// arg_type values. Nonetheless:
188+
let data_type = self.data_type.clone();
189+
190+
Box::new(GroupsAccumulatorFlatAdapter::<SumAccumulator>::new(
191+
move || SumAccumulator::try_new(&data_type),
192+
))
193+
}
194+
}))
136195
}
137196

138197
fn name(&self) -> &str {
@@ -416,13 +475,25 @@ mod tests {
416475
use arrow::datatypes::*;
417476
use arrow::record_batch::RecordBatch;
418477

478+
// A wrapper to make Sum::new, which now has an input_type argument, work with
479+
// generic_test_op!.
480+
struct SumTestStandin;
481+
impl SumTestStandin {
482+
fn new(expr: Arc<dyn PhysicalExpr>,
483+
name: impl Into<String>,
484+
data_type: DataType) -> Sum {
485+
Sum::new(expr, name, data_type.clone(), &data_type)
486+
}
487+
}
488+
419489
#[test]
420490
fn sum_i32() -> Result<()> {
421491
let a: ArrayRef = Arc::new(Int32Array::from(vec![1, 2, 3, 4, 5]));
492+
422493
generic_test_op!(
423494
a,
424495
DataType::Int32,
425-
Sum,
496+
SumTestStandin,
426497
ScalarValue::from(15i64),
427498
DataType::Int64
428499
)
@@ -440,7 +511,7 @@ mod tests {
440511
generic_test_op!(
441512
a,
442513
DataType::Int32,
443-
Sum,
514+
SumTestStandin,
444515
ScalarValue::from(13i64),
445516
DataType::Int64
446517
)
@@ -452,7 +523,7 @@ mod tests {
452523
generic_test_op!(
453524
a,
454525
DataType::Int32,
455-
Sum,
526+
SumTestStandin,
456527
ScalarValue::Int64(None),
457528
DataType::Int64
458529
)
@@ -465,7 +536,7 @@ mod tests {
465536
generic_test_op!(
466537
a,
467538
DataType::UInt32,
468-
Sum,
539+
SumTestStandin,
469540
ScalarValue::from(15u64),
470541
DataType::UInt64
471542
)
@@ -478,7 +549,7 @@ mod tests {
478549
generic_test_op!(
479550
a,
480551
DataType::Float32,
481-
Sum,
552+
SumTestStandin,
482553
ScalarValue::from(15_f32),
483554
DataType::Float32
484555
)
@@ -491,7 +562,7 @@ mod tests {
491562
generic_test_op!(
492563
a,
493564
DataType::Float64,
494-
Sum,
565+
SumTestStandin,
495566
ScalarValue::from(15_f64),
496567
DataType::Float64
497568
)

datafusion/src/physical_plan/groups_accumulator.rs

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -194,10 +194,6 @@ pub trait GroupsAccumulator: Send {
194194
/// `n`. See [`EmitTo::First`] for more details.
195195
fn evaluate(&mut self, emit_to: EmitTo) -> Result<ArrayRef>;
196196

197-
// TODO: Remove this?
198-
/// evaluate for a particular group index.
199-
fn peek_evaluate(&self, group_index: usize) -> Result<ScalarValue>;
200-
201197
/// Returns the intermediate aggregate state for this accumulator,
202198
/// used for multi-phase grouping, resetting its internal state.
203199
///
@@ -216,10 +212,6 @@ pub trait GroupsAccumulator: Send {
216212
/// [`Accumulator::state`]: crate::accumulator::Accumulator::state
217213
fn state(&mut self, emit_to: EmitTo) -> Result<Vec<ArrayRef>>;
218214

219-
// TODO: Remove this?
220-
/// Looks at the state for a particular group index.
221-
fn peek_state(&self, group_index: usize) -> Result<SmallVec<[ScalarValue; 2]>>;
222-
223215
/// Merges intermediate state (the output from [`Self::state`])
224216
/// into this accumulator's current state.
225217
///

datafusion/src/physical_plan/groups_accumulator_adapter.rs

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -345,10 +345,6 @@ impl GroupsAccumulator for GroupsAccumulatorAdapter {
345345
result
346346
}
347347

348-
fn peek_evaluate(&self, group_index: usize) -> Result<ScalarValue> {
349-
self.states[group_index].accumulator.evaluate()
350-
}
351-
352348
// filtered_null_mask(opt_filter, &values);
353349
fn state(&mut self, emit_to: EmitTo) -> Result<Vec<ArrayRef>> {
354350
let vec_size_pre = self.states.allocated_size();
@@ -385,10 +381,6 @@ impl GroupsAccumulator for GroupsAccumulatorAdapter {
385381
Ok(arrays)
386382
}
387383

388-
fn peek_state(&self, group_index: usize) -> Result<SmallVec<[ScalarValue; 2]>> {
389-
self.states[group_index].accumulator.state()
390-
}
391-
392384
fn merge_batch(
393385
&mut self,
394386
values: &[ArrayRef],

datafusion/src/physical_plan/groups_accumulator_flat_adapter.rs

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -387,10 +387,6 @@ impl<AccumulatorType: Accumulator> GroupsAccumulator
387387
result
388388
}
389389

390-
fn peek_evaluate(&self, group_index: usize) -> Result<ScalarValue> {
391-
self.accumulators[group_index].evaluate()
392-
}
393-
394390
// filtered_null_mask(opt_filter, &values);
395391
fn state(&mut self, emit_to: EmitTo) -> Result<Vec<ArrayRef>> {
396392
let vec_size_pre = self.accumulators.allocated_size();
@@ -428,10 +424,6 @@ impl<AccumulatorType: Accumulator> GroupsAccumulator
428424
Ok(arrays)
429425
}
430426

431-
fn peek_state(&self, group_index: usize) -> Result<SmallVec<[ScalarValue; 2]>> {
432-
self.accumulators[group_index].state()
433-
}
434-
435427
fn merge_batch(
436428
&mut self,
437429
values: &[ArrayRef],

0 commit comments

Comments
 (0)