Skip to content

Commit cb5aee4

Browse files
committed
Actually accumulating (badly) into a groups accumulator
1 parent 72076f7 commit cb5aee4

File tree

8 files changed

+916
-45
lines changed

8 files changed

+916
-45
lines changed

datafusion/src/cube_ext/joinagg.rs

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,9 @@ use crate::execution::context::{ExecutionContextState, ExecutionProps};
2525
use crate::logical_plan::{DFSchemaRef, Expr, LogicalPlan, UserDefinedLogicalNode};
2626
use crate::optimizer::optimizer::OptimizerRule;
2727
use crate::optimizer::utils::from_plan;
28-
use crate::physical_plan::hash_aggregate::{Accumulators, AggregateMode};
28+
use crate::physical_plan::hash_aggregate::{
29+
create_accumulation_state, AccumulationState, Accumulators, AggregateMode,
30+
};
2931
use crate::physical_plan::planner::{physical_name, ExtensionPlanner};
3032
use crate::physical_plan::{hash_aggregate, PhysicalPlanner};
3133
use crate::physical_plan::{
@@ -245,7 +247,7 @@ impl ExecutionPlan for CrossJoinAggExec {
245247
&AggregateMode::Full,
246248
self.group_expr.len(),
247249
)?;
248-
let mut accumulators = Accumulators::new();
250+
let mut accumulators = create_accumulation_state(&self.agg_expr)?;
249251
for partition in 0..self.join.right.output_partitioning().partition_count() {
250252
let mut batches = self.join.right.execute(partition).await?;
251253
while let Some(right) = batches.next().await {

datafusion/src/physical_plan/expressions/average.rs

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,9 @@ use std::convert::TryFrom;
2222
use std::sync::Arc;
2323

2424
use crate::error::{DataFusionError, Result};
25+
use crate::physical_plan::groups_accumulator::GroupsAccumulator;
26+
use crate::physical_plan::groups_accumulator_adapter::GroupsAccumulatorAdapter;
27+
use crate::physical_plan::groups_accumulator_flat_adapter::GroupsAccumulatorFlatAdapter;
2528
use crate::physical_plan::{Accumulator, AggregateExpr, PhysicalExpr};
2629
use crate::scalar::ScalarValue;
2730
use arrow::compute;
@@ -112,6 +115,23 @@ impl AggregateExpr for Avg {
112115
)?))
113116
}
114117

118+
fn uses_groups_accumulator(&self) -> bool {
119+
return true;
120+
}
121+
122+
/// the groups accumulator used to accumulate values from the expression. If this returns None,
123+
/// create_accumulator must be used.
124+
fn create_groups_accumulator(
125+
&self,
126+
) -> arrow::error::Result<Option<Box<dyn GroupsAccumulator>>> {
127+
Ok(Some(Box::new(
128+
GroupsAccumulatorFlatAdapter::<AvgAccumulator>::new(|| {
129+
// avg is f64 (as in create_accumulator)
130+
AvgAccumulator::try_new(&DataType::Float64)
131+
}),
132+
)))
133+
}
134+
115135
fn expressions(&self) -> Vec<Arc<dyn PhysicalExpr>> {
116136
vec![self.expr.clone()]
117137
}

datafusion/src/physical_plan/expressions/sum.rs

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,9 @@ use std::convert::TryFrom;
2222
use std::sync::Arc;
2323

2424
use crate::error::{DataFusionError, Result};
25+
use crate::physical_plan::groups_accumulator::GroupsAccumulator;
26+
use crate::physical_plan::groups_accumulator_adapter::GroupsAccumulatorAdapter;
27+
use crate::physical_plan::groups_accumulator_flat_adapter::GroupsAccumulatorFlatAdapter;
2528
use crate::physical_plan::{Accumulator, AggregateExpr, PhysicalExpr};
2629
use crate::scalar::ScalarValue;
2730
use arrow::compute;
@@ -118,6 +121,23 @@ impl AggregateExpr for Sum {
118121
Ok(Box::new(SumAccumulator::try_new(&self.data_type)?))
119122
}
120123

124+
fn uses_groups_accumulator(&self) -> bool {
125+
return true;
126+
}
127+
128+
/// the groups accumulator used to accumulate values from the expression. If this returns None,
129+
/// create_accumulator must be used.
130+
fn create_groups_accumulator(
131+
&self,
132+
) -> arrow::error::Result<Option<Box<dyn GroupsAccumulator>>> {
133+
let data_type = self.data_type.clone();
134+
Ok(Some(Box::new(
135+
GroupsAccumulatorFlatAdapter::<SumAccumulator>::new(move || {
136+
SumAccumulator::try_new(&data_type)
137+
}),
138+
)))
139+
}
140+
121141
fn name(&self) -> &str {
122142
&self.name
123143
}

datafusion/src/physical_plan/groups_accumulator.rs

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,9 @@
1818
//! Vectorized [`GroupsAccumulator`]
1919
2020
use crate::error::{DataFusionError, Result};
21+
use crate::scalar::ScalarValue;
2122
use arrow::array::{ArrayRef, BooleanArray};
23+
use smallvec::SmallVec;
2224

2325
/// From upstream: This replaces a datafusion_common::{not_impl_err} import.
2426
macro_rules! not_impl_err {
@@ -27,6 +29,18 @@ macro_rules! not_impl_err {
2729
};
2830
}
2931

32+
// TODO: Probably drop the #[macro_export] that was copy/pasted in
33+
// From upstream datafusion. We don't pass the backtrace:
34+
/// Exposes a macro to create `DataFusionError::ArrowError` with optional backtrace
35+
#[macro_export]
36+
macro_rules! arrow_datafusion_err {
37+
($ERR:expr) => {
38+
DataFusionError::ArrowError(
39+
$ERR, /* , Some(DataFusionError::get_back_trace() */
40+
)
41+
};
42+
}
43+
3044
/// Describes how many rows should be emitted during grouping.
3145
#[derive(Debug, Clone, Copy)]
3246
pub enum EmitTo {
@@ -164,6 +178,10 @@ pub trait GroupsAccumulator: Send {
164178
/// `n`. See [`EmitTo::First`] for more details.
165179
fn evaluate(&mut self, emit_to: EmitTo) -> Result<ArrayRef>;
166180

181+
// TODO: Remove this?
182+
/// evaluate for a particular group index.
183+
fn peek_evaluate(&self, group_index: usize) -> Result<ScalarValue>;
184+
167185
/// Returns the intermediate aggregate state for this accumulator,
168186
/// used for multi-phase grouping, resetting its internal state.
169187
///
@@ -182,6 +200,10 @@ pub trait GroupsAccumulator: Send {
182200
/// [`Accumulator::state`]: crate::accumulator::Accumulator::state
183201
fn state(&mut self, emit_to: EmitTo) -> Result<Vec<ArrayRef>>;
184202

203+
// TODO: Remove this?
204+
/// Looks at the state for a particular group index.
205+
fn peek_state(&self, group_index: usize) -> Result<SmallVec<[ScalarValue; 2]>>;
206+
185207
/// Merges intermediate state (the output from [`Self::state`])
186208
/// into this accumulator's current state.
187209
///

datafusion/src/physical_plan/groups_accumulator_adapter.rs

Lines changed: 16 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
2121
use std::mem::{size_of, size_of_val};
2222

23+
use crate::arrow_datafusion_err;
2324
use crate::error::{DataFusionError, Result};
2425
use crate::physical_plan::groups_accumulator::{EmitTo, GroupsAccumulator};
2526
use crate::physical_plan::Accumulator;
@@ -31,17 +32,7 @@ use arrow::{
3132
compute,
3233
datatypes::UInt32Type,
3334
};
34-
35-
// From upstream datafusion. We don't pass the backtrace:
36-
/// Exposes a macro to create `DataFusionError::ArrowError` with optional backtrace
37-
#[macro_export]
38-
macro_rules! arrow_datafusion_err {
39-
($ERR:expr) => {
40-
DataFusionError::ArrowError(
41-
$ERR, /* , Some(DataFusionError::get_back_trace() */
42-
)
43-
};
44-
}
35+
use smallvec::SmallVec;
4536

4637
/// An adapter that implements [`GroupsAccumulator`] for any [`Accumulator`]
4738
///
@@ -102,6 +93,7 @@ pub struct GroupsAccumulatorAdapter {
10293
/// state for each group, stored in group_index order
10394
states: Vec<AccumulatorState>,
10495

96+
// TODO: Code maintaining this is commented.
10597
/// Current memory usage, in bytes.
10698
///
10799
/// Note this is incrementally updated with deltas to avoid the
@@ -209,7 +201,11 @@ impl GroupsAccumulatorAdapter {
209201
{
210202
self.make_accumulators_if_needed(total_num_groups)?;
211203

212-
assert_eq!(values[0].len(), group_indices.len());
204+
assert_eq!(
205+
values[0].len(),
206+
group_indices.len(),
207+
"asserting values[0].len() == group_indices.len()"
208+
);
213209

214210
// figure out which input rows correspond to which groups.
215211
// Note that self.state.indices starts empty for all groups
@@ -348,6 +344,10 @@ impl GroupsAccumulator for GroupsAccumulatorAdapter {
348344
result
349345
}
350346

347+
fn peek_evaluate(&self, group_index: usize) -> Result<ScalarValue> {
348+
self.states[group_index].accumulator.evaluate()
349+
}
350+
351351
// filtered_null_mask(opt_filter, &values);
352352
fn state(&mut self, emit_to: EmitTo) -> Result<Vec<ArrayRef>> {
353353
let vec_size_pre = self.states.allocated_size();
@@ -384,6 +384,10 @@ impl GroupsAccumulator for GroupsAccumulatorAdapter {
384384
Ok(arrays)
385385
}
386386

387+
fn peek_state(&self, group_index: usize) -> Result<SmallVec<[ScalarValue; 2]>> {
388+
self.states[group_index].accumulator.state()
389+
}
390+
387391
fn merge_batch(
388392
&mut self,
389393
values: &[ArrayRef],

0 commit comments

Comments
 (0)