Skip to content

Commit 4d28038

Browse files
committed
Avoid the sum/count rewrite when not helpful
1 parent 68cff5e commit 4d28038

File tree

4 files changed

+131
-38
lines changed

4 files changed

+131
-38
lines changed

datafusion/expr/src/simplify.rs

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@ pub struct SimplifyContext {
3838
schema: DFSchemaRef,
3939
query_execution_start_time: Option<DateTime<Utc>>,
4040
config_options: Arc<ConfigOptions>,
41+
aggregate_exprs: Option<Arc<Vec<Expr>>>,
4142
}
4243

4344
impl Default for SimplifyContext {
@@ -46,6 +47,7 @@ impl Default for SimplifyContext {
4647
schema: Arc::new(DFSchema::empty()),
4748
query_execution_start_time: None,
4849
config_options: Arc::new(ConfigOptions::default()),
50+
aggregate_exprs: None,
4951
}
5052
}
5153
}
@@ -78,6 +80,12 @@ impl SimplifyContext {
7880
self
7981
}
8082

83+
/// Set aggregate expressions from the containing aggregate node, if any.
84+
pub fn with_aggregate_exprs(mut self, aggregate_exprs: Arc<Vec<Expr>>) -> Self {
85+
self.aggregate_exprs = Some(aggregate_exprs);
86+
self
87+
}
88+
8189
/// Returns the schema
8290
pub fn schema(&self) -> &DFSchemaRef {
8391
&self.schema
@@ -108,6 +116,11 @@ impl SimplifyContext {
108116
pub fn config_options(&self) -> &Arc<ConfigOptions> {
109117
&self.config_options
110118
}
119+
120+
/// Returns aggregate expressions from the containing aggregate node, if any.
121+
pub fn aggregate_exprs(&self) -> Option<&[Expr]> {
122+
self.aggregate_exprs.as_deref().map(Vec::as_slice)
123+
}
111124
}
112125

113126
/// Was the expression simplified?

datafusion/functions-aggregate/src/sum.rs

Lines changed: 60 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -362,7 +362,7 @@ impl AggregateUDFImpl for Sum {
362362
///
363363
/// Backstory: TODO
364364
///
365-
fn sum_simplifier(mut agg: AggregateFunction, _info: &SimplifyContext) -> Result<Expr> {
365+
fn sum_simplifier(mut agg: AggregateFunction, info: &SimplifyContext) -> Result<Expr> {
366366
// Explicitly destructure to ensure we check all relevant fields
367367
let AggregateFunctionParams {
368368
args,
@@ -381,23 +381,71 @@ fn sum_simplifier(mut agg: AggregateFunction, _info: &SimplifyContext) -> Result
381381
return Ok(Expr::AggregateFunction(agg));
382382
}
383383

384-
// otherwise check the arguments if they are <col> <op> scalar
385-
let (arg, lit) = match SplitResult::new(agg.params.args.swap_remove(0)) {
386-
SplitResult::Original(expr) => {
387-
agg.params.args.push(expr); // put it back
388-
return Ok(Expr::AggregateFunction(agg));
389-
}
384+
// otherwise check the arguments if they are <arg> + <literal>
385+
let (arg, lit) = match SplitResult::new(agg.params.args[0].clone()) {
386+
SplitResult::Original => return Ok(Expr::AggregateFunction(agg)),
390387
SplitResult::Split { arg, lit } => (arg, lit),
391388
};
392389

390+
if !has_common_rewrite_arg(&arg, info) {
391+
return Ok(Expr::AggregateFunction(agg));
392+
}
393+
393394
// Rewrite to SUM(arg)
394-
agg.params.args.push(arg.clone());
395+
agg.params.args = vec![arg.clone()];
395396
let sum_agg = Expr::AggregateFunction(agg);
396397

397398
// sum(arg) + scalar * COUNT(arg)
398399
Ok(sum_agg + (lit * crate::count::count(arg)))
399400
}
400401

402+
fn has_common_rewrite_arg(arg: &Expr, info: &SimplifyContext) -> bool {
403+
let Some(aggregate_exprs) = info.aggregate_exprs() else {
404+
// Only apply this rewrite in the context of an Aggregate node where
405+
// sibling aggregate expressions are known.
406+
return false;
407+
};
408+
409+
aggregate_exprs
410+
.iter()
411+
.filter_map(sum_rewrite_candidate_arg)
412+
.filter(|candidate_arg| candidate_arg == arg)
413+
.take(2)
414+
.count()
415+
> 1
416+
}
417+
418+
fn sum_rewrite_candidate_arg(expr: &Expr) -> Option<Expr> {
419+
let Expr::AggregateFunction(aggregate_fn) = expr.clone().unalias_nested().data else {
420+
return None;
421+
};
422+
if !aggregate_fn.func.name().eq_ignore_ascii_case("sum") {
423+
return None;
424+
}
425+
426+
let AggregateFunctionParams {
427+
args,
428+
distinct,
429+
filter,
430+
order_by,
431+
null_treatment,
432+
} = &aggregate_fn.params;
433+
434+
if *distinct
435+
|| filter.is_some()
436+
|| !order_by.is_empty()
437+
|| null_treatment.is_some()
438+
|| args.len() != 1
439+
{
440+
return None;
441+
}
442+
443+
match SplitResult::new(args[0].clone()) {
444+
SplitResult::Split { arg, .. } => Some(arg),
445+
SplitResult::Original => None,
446+
}
447+
}
448+
401449
/// Result of trying to split an expression into an arg and constant
402450
#[derive(Debug, Clone)]
403451
enum SplitResult {
@@ -408,16 +456,16 @@ enum SplitResult {
408456
/// When `op` is `+`
409457
Split { arg: Expr, lit: Expr },
410458
/// If the expression is something else
411-
Original(Expr),
459+
Original,
412460
}
413461

414462
impl SplitResult {
415463
fn new(expr: Expr) -> Self {
416464
let Expr::BinaryExpr(BinaryExpr { left, op, right }) = expr else {
417-
return Self::Original(expr);
465+
return Self::Original;
418466
};
419467
if op != Operator::Plus {
420-
return Self::Original(Expr::BinaryExpr(BinaryExpr { left, op, right }));
468+
return Self::Original;
421469
}
422470

423471
match (left.as_ref(), right.as_ref()) {
@@ -429,7 +477,7 @@ impl SplitResult {
429477
arg: *left,
430478
lit: *right,
431479
},
432-
_ => Self::Original(Expr::BinaryExpr(BinaryExpr { left, op, right })),
480+
_ => Self::Original,
433481
}
434482
}
435483
}

datafusion/optimizer/src/simplify_expressions/simplify_exprs.rs

Lines changed: 32 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,12 @@ impl SimplifyExpressions {
103103
.with_config_options(config.options())
104104
.with_query_execution_start_time(config.query_execution_start_time());
105105

106+
let info = if let LogicalPlan::Aggregate(Aggregate { aggr_expr, .. }) = &plan {
107+
info.with_aggregate_exprs(Arc::new(aggr_expr.clone()))
108+
} else {
109+
info
110+
};
111+
106112
// Inputs have already been rewritten (due to bottom-up traversal handled by Optimizer)
107113
// Just need to rewrite our own expressions
108114

@@ -141,9 +147,7 @@ impl SimplifyExpressions {
141147
}
142148
})?;
143149

144-
transformed.transform_data(|plan| {
145-
rewrite_aggregate_non_aggregate_aggr_expr(plan)
146-
})
150+
transformed.transform_data(rewrite_aggregate_non_aggregate_aggr_expr)
147151
}
148152
}
149153

@@ -345,7 +349,31 @@ mod tests {
345349
assert_optimized_plan_equal!(
346350
plan,
347351
@r"
348-
Projection: sum(test.a) + Int64(2) * count(test.a) AS sum(test.a + Int64(2))
352+
Aggregate: groupBy=[[]], aggr=[[sum(test.a + Int64(2))]]
353+
TableScan: test
354+
"
355+
)?;
356+
Ok(())
357+
}
358+
359+
#[test]
360+
fn test_simplify_udaf_to_non_aggregate_expr_with_common_sum_arg() -> Result<()> {
361+
let schema = Schema::new(vec![Field::new("a", DataType::Int64, false)]);
362+
let table_scan = table_scan(Some("test"), &schema, None)?
363+
.build()
364+
.expect("building scan");
365+
366+
let plan = LogicalPlanBuilder::from(table_scan)
367+
.aggregate(
368+
Vec::<Expr>::new(),
369+
vec![sum(col("a") + lit(2i64)), sum(col("a") + lit(3i64))],
370+
)?
371+
.build()?;
372+
373+
assert_optimized_plan_equal!(
374+
plan,
375+
@r"
376+
Projection: sum(test.a) + Int64(2) * count(test.a) AS sum(test.a + Int64(2)), sum(test.a) + Int64(3) * count(test.a) AS sum(test.a + Int64(3))
349377
Aggregate: groupBy=[[]], aggr=[[sum(test.a), count(test.a)]]
350378
TableScan: test
351379
"

datafusion/sqllogictest/test_files/aggregates_simplify.slt

Lines changed: 26 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -32,13 +32,11 @@ query TT
3232
EXPLAIN SELECT SUM(column1 + 1) FROM sum_simplify_t;
3333
----
3434
logical_plan
35-
01)Projection: sum(sum_simplify_t.column1) + count(sum_simplify_t.column1) AS sum(sum_simplify_t.column1 + Int64(1))
36-
02)--Aggregate: groupBy=[[]], aggr=[[sum(sum_simplify_t.column1), count(sum_simplify_t.column1)]]
37-
03)----TableScan: sum_simplify_t projection=[column1]
35+
01)Aggregate: groupBy=[[]], aggr=[[sum(sum_simplify_t.column1 + Int64(1))]]
36+
02)--TableScan: sum_simplify_t projection=[column1]
3837
physical_plan
39-
01)ProjectionExec: expr=[sum(sum_simplify_t.column1)@0 + count(sum_simplify_t.column1)@1 as sum(sum_simplify_t.column1 + Int64(1))]
40-
02)--AggregateExec: mode=Single, gby=[], aggr=[sum(sum_simplify_t.column1), count(sum_simplify_t.column1)]
41-
03)----DataSourceExec: partitions=1, partition_sizes=[1]
38+
01)AggregateExec: mode=Single, gby=[], aggr=[sum(sum_simplify_t.column1 + Int64(1))]
39+
02)--DataSourceExec: partitions=1, partition_sizes=[1]
4240

4341

4442
# Mixed aggregate expressions with type validation
@@ -51,12 +49,12 @@ query TT
5149
EXPLAIN SELECT arrow_typeof(SUM(column1)), SUM(column1), SUM(column1 + 1) FROM sum_simplify_t;
5250
----
5351
logical_plan
54-
01)Projection: arrow_typeof(sum(sum_simplify_t.column1)), sum(sum_simplify_t.column1), sum(sum_simplify_t.column1) + count(sum_simplify_t.column1) AS sum(sum_simplify_t.column1 + Int64(1))
55-
02)--Aggregate: groupBy=[[]], aggr=[[sum(sum_simplify_t.column1), count(sum_simplify_t.column1)]]
52+
01)Projection: arrow_typeof(sum(sum_simplify_t.column1)), sum(sum_simplify_t.column1), sum(sum_simplify_t.column1 + Int64(1))
53+
02)--Aggregate: groupBy=[[]], aggr=[[sum(sum_simplify_t.column1), sum(sum_simplify_t.column1 + Int64(1))]]
5654
03)----TableScan: sum_simplify_t projection=[column1]
5755
physical_plan
58-
01)ProjectionExec: expr=[arrow_typeof(sum(sum_simplify_t.column1)@0) as arrow_typeof(sum(sum_simplify_t.column1)), sum(sum_simplify_t.column1)@0 as sum(sum_simplify_t.column1), sum(sum_simplify_t.column1)@0 + count(sum_simplify_t.column1)@1 as sum(sum_simplify_t.column1 + Int64(1))]
59-
02)--AggregateExec: mode=Single, gby=[], aggr=[sum(sum_simplify_t.column1), count(sum_simplify_t.column1)]
56+
01)ProjectionExec: expr=[arrow_typeof(sum(sum_simplify_t.column1)@0) as arrow_typeof(sum(sum_simplify_t.column1)), sum(sum_simplify_t.column1)@0 as sum(sum_simplify_t.column1), sum(sum_simplify_t.column1 + Int64(1))@1 as sum(sum_simplify_t.column1 + Int64(1))]
57+
02)--AggregateExec: mode=Single, gby=[], aggr=[sum(sum_simplify_t.column1), sum(sum_simplify_t.column1 + Int64(1))]
6058
03)----DataSourceExec: partitions=1, partition_sizes=[1]
6159

6260
# Duplicate aggregate expressions
@@ -70,14 +68,12 @@ EXPLAIN SELECT SUM(column1 + 1) AS sum_plus_1_a, SUM(column1 + 1) AS sum_plus_1_
7068
----
7169
logical_plan
7270
01)Projection: sum(sum_simplify_t.column1 + Int64(1)) AS sum_plus_1_a, sum(sum_simplify_t.column1 + Int64(1)) AS sum_plus_1_b
73-
02)--Projection: sum(sum_simplify_t.column1) + count(sum_simplify_t.column1) AS sum(sum_simplify_t.column1 + Int64(1))
74-
03)----Aggregate: groupBy=[[]], aggr=[[sum(sum_simplify_t.column1), count(sum_simplify_t.column1)]]
75-
04)------TableScan: sum_simplify_t projection=[column1]
71+
02)--Aggregate: groupBy=[[]], aggr=[[sum(sum_simplify_t.column1 + Int64(1))]]
72+
03)----TableScan: sum_simplify_t projection=[column1]
7673
physical_plan
7774
01)ProjectionExec: expr=[sum(sum_simplify_t.column1 + Int64(1))@0 as sum_plus_1_a, sum(sum_simplify_t.column1 + Int64(1))@0 as sum_plus_1_b]
78-
02)--ProjectionExec: expr=[sum(sum_simplify_t.column1)@0 + count(sum_simplify_t.column1)@1 as sum(sum_simplify_t.column1 + Int64(1))]
79-
03)----AggregateExec: mode=Single, gby=[], aggr=[sum(sum_simplify_t.column1), count(sum_simplify_t.column1)]
80-
04)------DataSourceExec: partitions=1, partition_sizes=[1]
75+
02)--AggregateExec: mode=Single, gby=[], aggr=[sum(sum_simplify_t.column1 + Int64(1))]
76+
03)----DataSourceExec: partitions=1, partition_sizes=[1]
8177

8278

8379
# constant aggregate expressions
@@ -162,13 +158,21 @@ query TT
162158
EXPLAIN SELECT SUM(DISTINCT column1 + 1), SUM(column1 + 1) FROM sum_simplify_t;
163159
----
164160
logical_plan
165-
01)Projection: sum(DISTINCT sum_simplify_t.column1 + Int64(1)), sum(sum_simplify_t.column1) + count(sum_simplify_t.column1) AS sum(sum_simplify_t.column1 + Int64(1))
166-
02)--Aggregate: groupBy=[[]], aggr=[[sum(DISTINCT sum_simplify_t.column1 + Int64(1)), sum(sum_simplify_t.column1), count(sum_simplify_t.column1)]]
167-
03)----TableScan: sum_simplify_t projection=[column1]
161+
01)Projection: sum(alias1) AS sum(DISTINCT sum_simplify_t.column1 + Int64(1)), sum(alias2) AS sum(sum_simplify_t.column1 + Int64(1))
162+
02)--Aggregate: groupBy=[[]], aggr=[[sum(alias1), sum(alias2)]]
163+
03)----Aggregate: groupBy=[[__common_expr_1 AS alias1]], aggr=[[sum(__common_expr_1) AS alias2]]
164+
04)------Projection: sum_simplify_t.column1 + Int64(1) AS __common_expr_1
165+
05)--------TableScan: sum_simplify_t projection=[column1]
168166
physical_plan
169-
01)ProjectionExec: expr=[sum(DISTINCT sum_simplify_t.column1 + Int64(1))@0 as sum(DISTINCT sum_simplify_t.column1 + Int64(1)), sum(sum_simplify_t.column1)@1 + count(sum_simplify_t.column1)@2 as sum(sum_simplify_t.column1 + Int64(1))]
170-
02)--AggregateExec: mode=Single, gby=[], aggr=[sum(DISTINCT sum_simplify_t.column1 + Int64(1)), sum(sum_simplify_t.column1), count(sum_simplify_t.column1)]
171-
03)----DataSourceExec: partitions=1, partition_sizes=[1]
167+
01)ProjectionExec: expr=[sum(alias1)@0 as sum(DISTINCT sum_simplify_t.column1 + Int64(1)), sum(alias2)@1 as sum(sum_simplify_t.column1 + Int64(1))]
168+
02)--AggregateExec: mode=Final, gby=[], aggr=[sum(alias1), sum(alias2)]
169+
03)----CoalescePartitionsExec
170+
04)------AggregateExec: mode=Partial, gby=[], aggr=[sum(alias1), sum(alias2)]
171+
05)--------AggregateExec: mode=FinalPartitioned, gby=[alias1@0 as alias1], aggr=[alias2]
172+
06)----------RepartitionExec: partitioning=Hash([alias1@0], 4), input_partitions=1
173+
07)------------AggregateExec: mode=Partial, gby=[__common_expr_1@0 as alias1], aggr=[alias2]
174+
08)--------------ProjectionExec: expr=[column1@0 + 1 as __common_expr_1]
175+
09)----------------DataSourceExec: partitions=1, partition_sizes=[1]
172176

173177
# FILTER clauses with different aggregate arguments
174178
query II

0 commit comments

Comments
 (0)