Skip to content

Commit e4dcf0c

Browse files
authored
Fix error planning aggregates with duplicated names in select list (#18831)
## Which issue does this PR close? - Closes #18683 ``` ./target/debug/datafusion-cli DataFusion CLI v51.0.0 > create table t(a int, b int) as values (1,3), (3,4), (4,5); 0 row(s) fetched. Elapsed 0.021 seconds. > select a from t group by a order by min(b); +---+ | a | +---+ | 1 | | 3 | | 4 | +---+ 3 row(s) fetched. Elapsed 0.017 seconds. > EXPLAIN select a from t group by a order by min(b); +---------------+-------------------------------+ | plan_type | plan | +---------------+-------------------------------+ | physical_plan | ┌───────────────────────────┐ | | | │ ProjectionExec │ | | | │ -------------------- │ | | | │ a: a │ | | | └─────────────┬─────────────┘ | | | ┌─────────────┴─────────────┐ | | | │ SortPreservingMergeExec │ | | | │ -------------------- │ | | | │ min(t.b) ASC NULLS LAST │ | | | └─────────────┬─────────────┘ | | | ┌─────────────┴─────────────┐ | | | │ SortExec │ | | | │ -------------------- │ | | | │ min(t.b)@1 ASC NULLS LAST │ | | | └─────────────┬─────────────┘ | | | ┌─────────────┴─────────────┐ | | | │ AggregateExec │ | | | │ -------------------- │ | | | │ aggr: min(t.b) │ | | | │ group_by: a │ | | | │ │ | | | │ mode: │ | | | │ FinalPartitioned │ | | | └─────────────┬─────────────┘ | | | ┌─────────────┴─────────────┐ | | | │ CoalesceBatchesExec │ | | | │ -------------------- │ | | | │ target_batch_size: │ | | | │ 8192 │ | | | └─────────────┬─────────────┘ | | | ┌─────────────┴─────────────┐ | | | │ RepartitionExec │ | | | │ -------------------- │ | | | │ partition_count(in->out): │ | | | │ 1 -> 8 │ | | | │ │ | | | │ partitioning_scheme: │ | | | │ Hash([a@0], 8) │ | | | └─────────────┬─────────────┘ | | | ┌─────────────┴─────────────┐ | | | │ AggregateExec │ | | | │ -------------------- │ | | | │ aggr: min(t.b) │ | | | │ group_by: a │ | | | │ mode: Partial │ | | | └─────────────┬─────────────┘ | | | ┌─────────────┴─────────────┐ | | | │ DataSourceExec │ | | | │ -------------------- │ | | | │ bytes: 224 │ | | | │ format: memory │ | | | │ rows: 1 │ | | | └───────────────────────────┘ | | | | +---------------+-------------------------------+ 1 row(s) fetched. Elapsed 0.019 seconds. ``` ## Rationale for this change This PR fixes a bug where if there's an aggregate expression in the order by that is not included in the projection, the query fails as the expression isn't propagated. ## What changes are included in this PR? Fixes the bug by making sure expressions in the order by are also passed to the aggregate function. ## Are these changes tested? Yes, tests are added to ensure the query compiles and produces the proper plan. ## Are there any user-facing changes? No
1 parent 4cfe20f commit e4dcf0c

File tree

6 files changed

+204
-39
lines changed

6 files changed

+204
-39
lines changed

datafusion/sql/src/select.rs

Lines changed: 114 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ use datafusion_expr::utils::{
4141
};
4242
use datafusion_expr::{
4343
Aggregate, Expr, Filter, GroupingSet, LogicalPlan, LogicalPlanBuilder,
44-
LogicalPlanBuilderOptions, Partitioning,
44+
LogicalPlanBuilderOptions, Partitioning, SortExpr,
4545
};
4646

4747
use indexmap::IndexMap;
@@ -51,6 +51,21 @@ use sqlparser::ast::{
5151
};
5252
use sqlparser::ast::{NamedWindowDefinition, Select, SelectItem, TableWithJoins};
5353

54+
/// Result of the `aggregate` function, containing the aggregate plan and
55+
/// rewritten expressions that reference the aggregate output columns.
56+
struct AggregatePlanResult {
57+
/// The aggregate logical plan
58+
plan: LogicalPlan,
59+
/// SELECT expressions rewritten to reference aggregate output columns
60+
select_exprs: Vec<Expr>,
61+
/// HAVING expression rewritten to reference aggregate output columns
62+
having_expr: Option<Expr>,
63+
/// QUALIFY expression rewritten to reference aggregate output columns
64+
qualify_expr: Option<Expr>,
65+
/// ORDER BY expressions rewritten to reference aggregate output columns
66+
order_by_exprs: Vec<SortExpr>,
67+
}
68+
5469
impl<S: ContextProvider> SqlToRel<'_, S> {
5570
/// Generate a logic plan from an SQL select
5671
pub(super) fn select_to_plan(
@@ -219,33 +234,53 @@ impl<S: ContextProvider> SqlToRel<'_, S> {
219234
.transpose()?;
220235

221236
// The outer expressions we will search through for aggregates.
222-
// Aggregates may be sourced from the SELECT list or from the HAVING expression.
223-
let aggr_expr_haystack = select_exprs
224-
.iter()
225-
.chain(having_expr_opt.iter())
226-
.chain(qualify_expr_opt.iter());
227-
// All of the aggregate expressions (deduplicated).
228-
let aggr_exprs = find_aggregate_exprs(aggr_expr_haystack);
237+
// First, find aggregates in SELECT, HAVING, and QUALIFY
238+
let select_having_qualify_aggrs = find_aggregate_exprs(
239+
select_exprs
240+
.iter()
241+
.chain(having_expr_opt.iter())
242+
.chain(qualify_expr_opt.iter()),
243+
);
244+
245+
// Find aggregates in ORDER BY
246+
let order_by_aggrs = find_aggregate_exprs(order_by_rex.iter().map(|s| &s.expr));
247+
248+
// Combine: all aggregates from SELECT/HAVING/QUALIFY, plus ORDER BY aggregates
249+
// that aren't already in SELECT/HAVING/QUALIFY
250+
let mut aggr_exprs = select_having_qualify_aggrs;
251+
for order_by_aggr in order_by_aggrs {
252+
if !aggr_exprs.iter().any(|e| e == &order_by_aggr) {
253+
aggr_exprs.push(order_by_aggr);
254+
}
255+
}
229256

230257
// Process group by, aggregation or having
231-
let (
258+
let AggregatePlanResult {
232259
plan,
233-
mut select_exprs_post_aggr,
234-
having_expr_post_aggr,
235-
qualify_expr_post_aggr,
236-
) = if !group_by_exprs.is_empty() || !aggr_exprs.is_empty() {
260+
select_exprs: mut select_exprs_post_aggr,
261+
having_expr: having_expr_post_aggr,
262+
qualify_expr: qualify_expr_post_aggr,
263+
order_by_exprs: order_by_rex,
264+
} = if !group_by_exprs.is_empty() || !aggr_exprs.is_empty() {
237265
self.aggregate(
238266
&base_plan,
239267
&select_exprs,
240268
having_expr_opt.as_ref(),
241269
qualify_expr_opt.as_ref(),
270+
&order_by_rex,
242271
&group_by_exprs,
243272
&aggr_exprs,
244273
)?
245274
} else {
246275
match having_expr_opt {
247276
Some(having_expr) => return plan_err!("HAVING clause references: {having_expr} must appear in the GROUP BY clause or be used in an aggregate function"),
248-
None => (base_plan.clone(), select_exprs.clone(), having_expr_opt, qualify_expr_opt)
277+
None => AggregatePlanResult {
278+
plan: base_plan.clone(),
279+
select_exprs: select_exprs.clone(),
280+
having_expr: having_expr_opt,
281+
qualify_expr: qualify_expr_opt,
282+
order_by_exprs: order_by_rex,
283+
}
249284
}
250285
};
251286

@@ -366,7 +401,8 @@ impl<S: ContextProvider> SqlToRel<'_, S> {
366401
plan
367402
};
368403

369-
self.order_by(plan, order_by_rex)
404+
let plan = self.order_by(plan, order_by_rex)?;
405+
Ok(plan)
370406
}
371407

372408
/// Try converting Expr(Unnest(Expr)) to Projection/Unnest/Projection
@@ -872,16 +908,19 @@ impl<S: ContextProvider> SqlToRel<'_, S> {
872908
/// the aggregate
873909
/// * `qualify_expr_post_aggr` - The "qualify" expression rewritten to reference a column from
874910
/// the aggregate
875-
#[allow(clippy::type_complexity)]
911+
/// * `order_by_post_aggr` - The ORDER BY expressions rewritten to reference columns from
912+
/// the aggregate
913+
#[allow(clippy::too_many_arguments)]
876914
fn aggregate(
877915
&self,
878916
input: &LogicalPlan,
879917
select_exprs: &[Expr],
880918
having_expr_opt: Option<&Expr>,
881919
qualify_expr_opt: Option<&Expr>,
920+
order_by_exprs: &[SortExpr],
882921
group_by_exprs: &[Expr],
883922
aggr_exprs: &[Expr],
884-
) -> Result<(LogicalPlan, Vec<Expr>, Option<Expr>, Option<Expr>)> {
923+
) -> Result<AggregatePlanResult> {
885924
// create the aggregate plan
886925
let options =
887926
LogicalPlanBuilderOptions::new().with_add_implicit_group_by_exprs(true);
@@ -988,12 +1027,65 @@ impl<S: ContextProvider> SqlToRel<'_, S> {
9881027
None
9891028
};
9901029

991-
Ok((
1030+
// Rewrite the ORDER BY expressions to use the columns produced by the
1031+
// aggregation. If an ORDER BY expression matches a SELECT expression
1032+
// (ignoring aliases), use the SELECT's output column name to avoid
1033+
// duplication when the SELECT expression has an alias.
1034+
let order_by_post_aggr = order_by_exprs
1035+
.iter()
1036+
.map(|sort_expr| {
1037+
let rewritten_expr =
1038+
rebase_expr(&sort_expr.expr, &aggr_projection_exprs, input)?;
1039+
1040+
// Check if this ORDER BY expression matches any aliased SELECT expression
1041+
// If so, use the SELECT's alias instead of the raw expression
1042+
let final_expr = select_exprs_post_aggr
1043+
.iter()
1044+
.find_map(|select_expr| {
1045+
// Only consider aliased expressions
1046+
if let Expr::Alias(alias) = select_expr {
1047+
if alias.expr.as_ref() == &rewritten_expr {
1048+
// Use the alias name
1049+
return Some(Expr::Column(alias.name.clone().into()));
1050+
}
1051+
}
1052+
None
1053+
})
1054+
.unwrap_or(rewritten_expr);
1055+
1056+
Ok(sort_expr.with_expr(final_expr))
1057+
})
1058+
.collect::<Result<Vec<SortExpr>>>()?;
1059+
1060+
let all_valid_exprs: Vec<Expr> = column_exprs_post_aggr
1061+
.iter()
1062+
.cloned()
1063+
.chain(select_exprs_post_aggr.iter().filter_map(|e| {
1064+
if let Expr::Alias(alias) = e {
1065+
Some(Expr::Column(alias.name.clone().into()))
1066+
} else {
1067+
None
1068+
}
1069+
}))
1070+
.collect();
1071+
1072+
let order_by_exprs_only: Vec<Expr> =
1073+
order_by_post_aggr.iter().map(|s| s.expr.clone()).collect();
1074+
check_columns_satisfy_exprs(
1075+
&all_valid_exprs,
1076+
&order_by_exprs_only,
1077+
CheckColumnsSatisfyExprsPurpose::Aggregate(
1078+
CheckColumnsMustReferenceAggregatePurpose::OrderBy,
1079+
),
1080+
)?;
1081+
1082+
Ok(AggregatePlanResult {
9921083
plan,
993-
select_exprs_post_aggr,
994-
having_expr_post_aggr,
995-
qualify_expr_post_aggr,
996-
))
1084+
select_exprs: select_exprs_post_aggr,
1085+
having_expr: having_expr_post_aggr,
1086+
qualify_expr: qualify_expr_post_aggr,
1087+
order_by_exprs: order_by_post_aggr,
1088+
})
9971089
}
9981090

9991091
// If the projection is done over a named window, that window

datafusion/sql/src/utils.rs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,7 @@ pub(crate) enum CheckColumnsMustReferenceAggregatePurpose {
9797
Projection,
9898
Having,
9999
Qualify,
100+
OrderBy,
100101
}
101102

102103
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
@@ -116,6 +117,9 @@ impl CheckColumnsSatisfyExprsPurpose {
116117
Self::Aggregate(CheckColumnsMustReferenceAggregatePurpose::Qualify) => {
117118
"Column in QUALIFY must be in GROUP BY or an aggregate function"
118119
}
120+
Self::Aggregate(CheckColumnsMustReferenceAggregatePurpose::OrderBy) => {
121+
"Column in ORDER BY must be in GROUP BY or an aggregate function"
122+
}
119123
}
120124
}
121125

datafusion/sql/tests/cases/plan_to_sql.rs

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -334,9 +334,7 @@ fn roundtrip_statement_with_dialect_1() -> Result<(), DataFusionError> {
334334
sql: "select min(ta.j1_id) as j1_min from j1 ta order by min(ta.j1_id) limit 10;",
335335
parser_dialect: MySqlDialect {},
336336
unparser_dialect: UnparserMySqlDialect {},
337-
// top projection sort gets derived into a subquery
338-
// for MySQL, this subquery needs an alias
339-
expected: @"SELECT `j1_min` FROM (SELECT min(`ta`.`j1_id`) AS `j1_min`, min(`ta`.`j1_id`) FROM `j1` AS `ta` ORDER BY min(`ta`.`j1_id`) ASC) AS `derived_sort` LIMIT 10",
337+
expected: @"SELECT min(`ta`.`j1_id`) AS `j1_min` FROM `j1` AS `ta` ORDER BY `j1_min` ASC LIMIT 10",
340338
);
341339
Ok(())
342340
}
@@ -347,9 +345,7 @@ fn roundtrip_statement_with_dialect_2() -> Result<(), DataFusionError> {
347345
sql: "select min(ta.j1_id) as j1_min from j1 ta order by min(ta.j1_id) limit 10;",
348346
parser_dialect: GenericDialect {},
349347
unparser_dialect: UnparserDefaultDialect {},
350-
// top projection sort still gets derived into a subquery in default dialect
351-
// except for the default dialect, the subquery is left non-aliased
352-
expected: @"SELECT j1_min FROM (SELECT min(ta.j1_id) AS j1_min, min(ta.j1_id) FROM j1 AS ta ORDER BY min(ta.j1_id) ASC NULLS LAST) LIMIT 10",
348+
expected: @"SELECT min(ta.j1_id) AS j1_min FROM j1 AS ta ORDER BY j1_min ASC NULLS LAST LIMIT 10",
353349
);
354350
Ok(())
355351
}
@@ -360,7 +356,7 @@ fn roundtrip_statement_with_dialect_3() -> Result<(), DataFusionError> {
360356
sql: "select min(ta.j1_id) as j1_min, max(tb.j1_max) from j1 ta, (select distinct max(ta.j1_id) as j1_max from j1 ta order by max(ta.j1_id)) tb order by min(ta.j1_id) limit 10;",
361357
parser_dialect: MySqlDialect {},
362358
unparser_dialect: UnparserMySqlDialect {},
363-
expected: @"SELECT `j1_min`, `max(tb.j1_max)` FROM (SELECT min(`ta`.`j1_id`) AS `j1_min`, max(`tb`.`j1_max`), min(`ta`.`j1_id`) FROM `j1` AS `ta` CROSS JOIN (SELECT `j1_max` FROM (SELECT DISTINCT max(`ta`.`j1_id`) AS `j1_max` FROM `j1` AS `ta`) AS `derived_distinct`) AS `tb` ORDER BY min(`ta`.`j1_id`) ASC) AS `derived_sort` LIMIT 10",
359+
expected: @"SELECT min(`ta`.`j1_id`) AS `j1_min`, max(`tb`.`j1_max`) FROM `j1` AS `ta` CROSS JOIN (SELECT DISTINCT max(`ta`.`j1_id`) AS `j1_max` FROM `j1` AS `ta`) AS `tb` ORDER BY `j1_min` ASC LIMIT 10",
364360
);
365361
Ok(())
366362
}
@@ -1909,7 +1905,7 @@ fn test_complex_order_by_with_grouping() -> Result<()> {
19091905
}, {
19101906
assert_snapshot!(
19111907
sql,
1912-
@r#"SELECT j1.j1_id, j1.j1_string, lochierarchy FROM (SELECT j1.j1_id, j1.j1_string, (grouping(j1.j1_id) + grouping(j1.j1_string)) AS lochierarchy, grouping(j1.j1_string), grouping(j1.j1_id) FROM j1 GROUP BY ROLLUP (j1.j1_id, j1.j1_string) ORDER BY (grouping(j1.j1_id) + grouping(j1.j1_string)) DESC NULLS FIRST, CASE WHEN ((grouping(j1.j1_id) + grouping(j1.j1_string)) = 0) THEN j1.j1_id END ASC NULLS LAST) LIMIT 100"#
1908+
@r#"SELECT j1.j1_id, j1.j1_string, lochierarchy FROM (SELECT j1.j1_id, j1.j1_string, (grouping(j1.j1_id) + grouping(j1.j1_string)) AS lochierarchy, grouping(j1.j1_string), grouping(j1.j1_id) FROM j1 GROUP BY ROLLUP (j1.j1_id, j1.j1_string) ORDER BY lochierarchy DESC NULLS FIRST, CASE WHEN ((grouping(j1.j1_id) + grouping(j1.j1_string)) = 0) THEN j1.j1_id END ASC NULLS LAST) LIMIT 100"#
19131909
);
19141910
});
19151911

datafusion/sql/tests/sql_integration.rs

Lines changed: 54 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3261,6 +3261,55 @@ Sort: avg(person.age) + avg(person.age) ASC NULLS LAST
32613261
);
32623262
}
32633263

3264+
#[test]
3265+
fn select_groupby_orderby_aggregate_on_non_selected_column() {
3266+
let sql = "SELECT state FROM person GROUP BY state ORDER BY MIN(age)";
3267+
let plan = logical_plan(sql).unwrap();
3268+
assert_snapshot!(
3269+
plan,
3270+
@r#"
3271+
Projection: person.state
3272+
Sort: min(person.age) ASC NULLS LAST
3273+
Projection: person.state, min(person.age)
3274+
Aggregate: groupBy=[[person.state]], aggr=[[min(person.age)]]
3275+
TableScan: person
3276+
"#
3277+
);
3278+
}
3279+
3280+
#[test]
3281+
fn select_groupby_orderby_multiple_aggregates_on_non_selected_columns() {
3282+
let sql =
3283+
"SELECT state FROM person GROUP BY state ORDER BY MIN(age), MAX(salary) DESC";
3284+
let plan = logical_plan(sql).unwrap();
3285+
assert_snapshot!(
3286+
plan,
3287+
@r#"
3288+
Projection: person.state
3289+
Sort: min(person.age) ASC NULLS LAST, max(person.salary) DESC NULLS FIRST
3290+
Projection: person.state, min(person.age), max(person.salary)
3291+
Aggregate: groupBy=[[person.state]], aggr=[[min(person.age), max(person.salary)]]
3292+
TableScan: person
3293+
"#
3294+
);
3295+
}
3296+
3297+
#[test]
3298+
fn select_groupby_orderby_aggregate_on_non_selected_column_original_issue() {
3299+
let sql = "SELECT id FROM person GROUP BY id ORDER BY min(age)";
3300+
let plan = logical_plan(sql).unwrap();
3301+
assert_snapshot!(
3302+
plan,
3303+
@r#"
3304+
Projection: person.id
3305+
Sort: min(person.age) ASC NULLS LAST
3306+
Projection: person.id, min(person.age)
3307+
Aggregate: groupBy=[[person.id]], aggr=[[min(person.age)]]
3308+
TableScan: person
3309+
"#
3310+
);
3311+
}
3312+
32643313
fn logical_plan(sql: &str) -> Result<LogicalPlan> {
32653314
logical_plan_with_options(sql, ParserOptions::default())
32663315
}
@@ -3830,12 +3879,11 @@ fn order_by_unaliased_name() {
38303879
assert_snapshot!(
38313880
plan,
38323881
@r#"
3833-
Projection: z, q
3834-
Sort: p.state ASC NULLS LAST
3835-
Projection: p.state AS z, sum(p.age) AS q, p.state
3836-
Aggregate: groupBy=[[p.state]], aggr=[[sum(p.age)]]
3837-
SubqueryAlias: p
3838-
TableScan: person
3882+
Sort: z ASC NULLS LAST
3883+
Projection: p.state AS z, sum(p.age) AS q
3884+
Aggregate: groupBy=[[p.state]], aggr=[[sum(p.age)]]
3885+
SubqueryAlias: p
3886+
TableScan: person
38393887
"#
38403888
);
38413889
}

datafusion/sqllogictest/test_files/join.slt.part

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1504,7 +1504,7 @@ drop table t0;
15041504
statement ok
15051505
create table t1(v1 int, v2 int);
15061506

1507-
query error DataFusion error: Schema error: No field named tt1.v2. Valid fields are tt1.v1.
1507+
query error DataFusion error: Error during planning: Column in ORDER BY must be in GROUP BY or an aggregate function
15081508
select v1 from t1 as tt1 natural join t1 as tt2 group by v1 order by v2;
15091509

15101510
statement ok

datafusion/sqllogictest/test_files/order.slt

Lines changed: 27 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -419,6 +419,31 @@ select column1 + column2 from foo group by column1, column2 ORDER BY column2 des
419419
7
420420
3
421421

422+
# Test ordering by aggregate on non-selected column (issue #18683)
423+
# Previously failed with "Schema error: No field named foo.column2"
424+
query I
425+
select column1 from foo group by column1 order by min(column2);
426+
----
427+
1
428+
3
429+
5
430+
431+
# Test ordering by aggregate expression on non-selected columns
432+
query I
433+
select column1 from foo group by column1 order by min(column2) + max(column2);
434+
----
435+
1
436+
3
437+
5
438+
439+
# Test ordering by multiple aggregates on non-selected columns
440+
query I
441+
select column1 from foo group by column1 order by min(column2), max(column2);
442+
----
443+
1
444+
3
445+
5
446+
422447
# Test issue: https://github.com/apache/datafusion/issues/11549
423448
query I
424449
select column1 from foo order by log(column2);
@@ -1055,8 +1080,8 @@ SELECT SUM(column1) FROM foo ORDER BY SUM(column1)
10551080
----
10561081
16
10571082

1058-
# Order by unprojected aggregate expressions is not supported
1059-
query error DataFusion error: This feature is not implemented: Physical plan does not support logical expression AggregateFunction
1083+
# Order by unprojected aggregate expressions requires GROUP BY
1084+
query error DataFusion error: Error during planning: Column in SELECT must be in GROUP BY or an aggregate function
10601085
SELECT column2 FROM foo ORDER BY SUM(column1)
10611086

10621087
statement ok

0 commit comments

Comments
 (0)