Skip to content

Commit b14f5d3

Browse files
committed
Fix order hints of inplace aggregations
They were incorrect when the order of grouping columns differed from the input columns.
1 parent 93f32cc commit b14f5d3

File tree

2 files changed

+62
-19
lines changed

2 files changed

+62
-19
lines changed

rust/datafusion/src/physical_plan/hash_aggregate.rs

Lines changed: 22 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -111,6 +111,7 @@ pub enum AggregateStrategy {
111111
#[derive(Debug)]
112112
pub struct HashAggregateExec {
113113
strategy: AggregateStrategy,
114+
output_sort_order: Option<Vec<usize>>,
114115
/// Aggregation mode (full, partial)
115116
mode: AggregateMode,
116117
/// Grouping expressions
@@ -166,6 +167,7 @@ impl HashAggregateExec {
166167
/// Create a new hash aggregate execution plan
167168
pub fn try_new(
168169
strategy: AggregateStrategy,
170+
output_sort_order: Option<Vec<usize>>,
169171
mode: AggregateMode,
170172
group_expr: Vec<(Arc<dyn PhysicalExpr>, String)>,
171173
aggr_expr: Vec<Arc<dyn AggregateExpr>>,
@@ -178,8 +180,24 @@ impl HashAggregateExec {
178180

179181
let output_rows = SQLMetric::counter("outputRows");
180182

183+
match strategy {
184+
AggregateStrategy::Hash => assert!(output_sort_order.is_none()),
185+
AggregateStrategy::InplaceSorted => {
186+
assert!(output_sort_order.is_some());
187+
assert!(
188+
output_sort_order
189+
.as_ref()
190+
.unwrap()
191+
.iter()
192+
.all(|i| *i < group_expr.len()),
193+
"sort_order mentions value columns"
194+
);
195+
}
196+
}
197+
181198
Ok(HashAggregateExec {
182199
strategy,
200+
output_sort_order,
183201
mode,
184202
group_expr,
185203
aggr_expr,
@@ -281,6 +299,7 @@ impl ExecutionPlan for HashAggregateExec {
281299
match children.len() {
282300
1 => Ok(Arc::new(HashAggregateExec::try_new(
283301
self.strategy,
302+
self.output_sort_order.clone(),
284303
self.mode,
285304
self.group_expr.clone(),
286305
self.aggr_expr.clone(),
@@ -296,9 +315,7 @@ impl ExecutionPlan for HashAggregateExec {
296315
fn output_hints(&self) -> OptimizerHints {
297316
let sort_order = match self.strategy {
298317
AggregateStrategy::Hash => None,
299-
AggregateStrategy::InplaceSorted => {
300-
Some((0..self.group_expr.len()).collect_vec())
301-
}
318+
AggregateStrategy::InplaceSorted => self.output_sort_order.clone(),
302319
};
303320
OptimizerHints {
304321
sort_order,
@@ -1814,6 +1831,7 @@ mod tests {
18141831
let input_schema = input.schema();
18151832
let partial_aggregate = Arc::new(HashAggregateExec::try_new(
18161833
AggregateStrategy::Hash,
1834+
None,
18171835
AggregateMode::Partial,
18181836
groups.clone(),
18191837
aggregates.clone(),
@@ -1841,6 +1859,7 @@ mod tests {
18411859

18421860
let merged_aggregate = Arc::new(HashAggregateExec::try_new(
18431861
AggregateStrategy::Hash,
1862+
None,
18441863
AggregateMode::Final,
18451864
final_group
18461865
.iter()

rust/datafusion/src/physical_plan/planner.rs

Lines changed: 40 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -197,12 +197,14 @@ impl DefaultPhysicalPlanner {
197197
})
198198
.collect::<Result<Vec<_>>>()?;
199199

200-
let strategy = compute_aggregation_strategy(input_exec.as_ref(), &groups);
200+
let (strategy, order) =
201+
compute_aggregation_strategy(input_exec.as_ref(), &groups);
201202
// TODO: fix cubestore planning and re-enable.
202203
if false && input_exec.output_partitioning().partition_count() == 1 {
203204
// A single pass is enough for 1 partition.
204205
return Ok(Arc::new(HashAggregateExec::try_new(
205206
strategy,
207+
order,
206208
AggregateMode::Full,
207209
groups,
208210
aggregates,
@@ -214,6 +216,7 @@ impl DefaultPhysicalPlanner {
214216
let mut initial_aggr: Arc<dyn ExecutionPlan> =
215217
Arc::new(HashAggregateExec::try_new(
216218
strategy,
219+
order.clone(),
217220
AggregateMode::Partial,
218221
groups.clone(),
219222
aggregates.clone(),
@@ -238,6 +241,7 @@ impl DefaultPhysicalPlanner {
238241
// and the expressions corresponding to the respective aggregate
239242
Ok(Arc::new(HashAggregateExec::try_new(
240243
strategy,
244+
order,
241245
AggregateMode::Final,
242246
final_group
243247
.iter()
@@ -957,19 +961,25 @@ pub fn evaluate_const(expr: Arc<dyn PhysicalExpr>) -> Result<Arc<dyn PhysicalExp
957961
pub fn compute_aggregation_strategy(
958962
input: &dyn ExecutionPlan,
959963
group_key: &[(Arc<dyn PhysicalExpr>, String)],
960-
) -> AggregateStrategy {
961-
if !group_key.is_empty() && input_sorted_by_group_key(input, &group_key) {
962-
AggregateStrategy::InplaceSorted
964+
) -> (AggregateStrategy, /*sort_order*/ Option<Vec<usize>>) {
965+
let mut sort_order = Vec::new();
966+
if !group_key.is_empty()
967+
&& input_sorted_by_group_key(input, &group_key, &mut sort_order)
968+
{
969+
(AggregateStrategy::InplaceSorted, Some(sort_order))
963970
} else {
964-
AggregateStrategy::Hash
971+
(AggregateStrategy::Hash, None)
965972
}
966973
}
967974

968975
fn input_sorted_by_group_key(
969976
input: &dyn ExecutionPlan,
970977
group_key: &[(Arc<dyn PhysicalExpr>, String)],
978+
sort_order: &mut Vec<usize>,
971979
) -> bool {
972980
assert!(!group_key.is_empty());
981+
sort_order.clear();
982+
973983
let hints = input.output_hints();
974984
// We check the group key is a prefix of the sort key.
975985
let sort_key = hints.sort_order;
@@ -979,7 +989,8 @@ fn input_sorted_by_group_key(
979989
let sort_key = sort_key.unwrap();
980990
// Tracks which elements of sort key are used in the group key or have a single value.
981991
let mut sort_key_hit = vec![false; sort_key.len()];
982-
for (g, _) in group_key {
992+
let mut sort_to_group = vec![usize::MAX; sort_key.len()];
993+
for (group_i, (g, _)) in group_key.iter().enumerate() {
983994
let col = g.as_any().downcast_ref::<Column>();
984995
if col.is_none() {
985996
return false;
@@ -989,11 +1000,15 @@ fn input_sorted_by_group_key(
9891000
return false;
9901001
}
9911002
let input_col = input_col.unwrap();
992-
let sort_key_pos = sort_key.iter().find_position(|i| **i == input_col);
993-
if sort_key_pos.is_none() {
994-
return false;
1003+
let sort_key_pos = match sort_key.iter().find_position(|i| **i == input_col) {
1004+
None => return false,
1005+
Some((p, _)) => p,
1006+
};
1007+
sort_key_hit[sort_key_pos] = true;
1008+
if sort_to_group[sort_key_pos] != usize::MAX {
1009+
return false; // Bail out to simplify code a bit. This should not happen in practice.
9951010
}
996-
sort_key_hit[sort_key_pos.unwrap().0] = true;
1011+
sort_to_group[sort_key_pos] = group_i;
9971012
}
9981013
for i in 0..sort_key.len() {
9991014
if hints.single_value_columns.contains(&sort_key[i]) {
@@ -1003,12 +1018,21 @@ fn input_sorted_by_group_key(
10031018

10041019
// At this point all elements of the group key mapped into some column of the sort key.
10051020
// This checks the group key is mapped into a prefix of the sort key.
1006-
sort_key_hit
1007-
.iter()
1008-
.skip_while(|present| **present)
1009-
.skip_while(|present| !**present)
1010-
.next()
1011-
.is_none()
1021+
let pref_len = sort_key_hit.iter().take_while(|present| **present).count();
1022+
if sort_key_hit[pref_len..].iter().any(|present| *present) {
1023+
return false;
1024+
}
1025+
1026+
assert!(sort_order.is_empty()); // Cleared at the beginning of the function.
1027+
1028+
// Note that single-value columns might not have a mapping to the grouping key.
1029+
sort_order.extend(
1030+
sort_to_group
1031+
.iter()
1032+
.take(pref_len)
1033+
.filter(|i| **i != usize::MAX),
1034+
);
1035+
true
10121036
}
10131037

10141038
fn tuple_err<T, R>(value: (Result<T>, Result<R>)) -> Result<(T, R)> {

0 commit comments

Comments
 (0)