Skip to content

Commit 6ac1062

Browse files
committed
fix: Add missing aggregate expressions for ORDER BY clause
Signed-off-by: Alex Qyoun-ae <[email protected]>
1 parent 0ba29ac commit 6ac1062

File tree

3 files changed

+187
-4
lines changed

3 files changed

+187
-4
lines changed

datafusion/core/src/logical_plan/builder.rs

Lines changed: 150 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -616,6 +616,134 @@ impl LogicalPlanBuilder {
616616
}
617617
}
618618

619+
/// Add missing sort aggregate expressions to all downstream projection and aggregate
620+
#[allow(clippy::only_used_in_recursion)]
621+
fn add_missing_aggr_exprs(
622+
&self,
623+
curr_plan: LogicalPlan,
624+
missing_aggr_exprs: &[Expr],
625+
alias_map: &mut HashMap<String, String>,
626+
) -> Result<LogicalPlan> {
627+
match curr_plan {
628+
LogicalPlan::Projection(Projection {
629+
input,
630+
mut expr,
631+
schema,
632+
alias,
633+
}) => {
634+
let input = self.add_missing_aggr_exprs(
635+
input.as_ref().clone(),
636+
missing_aggr_exprs,
637+
alias_map,
638+
)?;
639+
let input_schema = input.schema();
640+
if missing_aggr_exprs.iter().any(|e| {
641+
e.name(input_schema)
642+
.and_then(|n| input_schema.field_with_unqualified_name(&n))
643+
.is_err()
644+
}) && alias_map.is_empty()
645+
{
646+
return Ok(LogicalPlan::Projection(Projection {
647+
expr,
648+
input: Arc::new(input),
649+
schema,
650+
alias,
651+
}));
652+
}
653+
654+
let mut missing_exprs: Vec<Expr> =
655+
Vec::with_capacity(missing_aggr_exprs.len());
656+
for missing_aggr_expr in missing_aggr_exprs {
657+
let normalized_expr = resolve_exprs_to_aliases(
658+
missing_aggr_expr,
659+
alias_map,
660+
input_schema,
661+
)?;
662+
let normalized_expr_name = normalized_expr.name(input_schema)?;
663+
if input_schema
664+
.field_with_unqualified_name(&normalized_expr_name)
665+
.is_err()
666+
{
667+
continue;
668+
};
669+
missing_exprs.push(normalized_expr);
670+
}
671+
672+
expr.extend(missing_exprs);
673+
674+
let new_schema = DFSchema::new_with_metadata(
675+
exprlist_to_fields(&expr, &input)?,
676+
input_schema.metadata().clone(),
677+
)?;
678+
679+
Ok(LogicalPlan::Projection(Projection {
680+
expr,
681+
input: Arc::new(input),
682+
schema: DFSchemaRef::new(new_schema),
683+
alias,
684+
}))
685+
}
686+
LogicalPlan::Aggregate(Aggregate {
687+
input,
688+
group_expr,
689+
mut aggr_expr,
690+
schema: _,
691+
}) => {
692+
let input = self.add_missing_aggr_exprs(
693+
input.as_ref().clone(),
694+
missing_aggr_exprs,
695+
alias_map,
696+
)?;
697+
let input_schema = input.schema();
698+
699+
let mut missing_exprs = Vec::with_capacity(missing_aggr_exprs.len());
700+
for missing_aggr_expr in missing_aggr_exprs {
701+
if aggr_expr.contains(missing_aggr_expr) {
702+
continue;
703+
}
704+
let expr_name = missing_aggr_expr.name(input_schema)?;
705+
alias_map.insert(expr_name.clone(), expr_name);
706+
missing_exprs.push(missing_aggr_expr.clone());
707+
}
708+
709+
aggr_expr.extend(missing_exprs);
710+
711+
let all_expr = group_expr
712+
.iter()
713+
.chain(aggr_expr.iter())
714+
.cloned()
715+
.collect::<Vec<_>>();
716+
let new_schema = DFSchema::new_with_metadata(
717+
exprlist_to_fields(&all_expr, &input)?,
718+
input_schema.metadata().clone(),
719+
)?;
720+
721+
Ok(LogicalPlan::Aggregate(Aggregate {
722+
input: Arc::new(input),
723+
group_expr,
724+
aggr_expr,
725+
schema: Arc::new(new_schema),
726+
}))
727+
}
728+
_ => {
729+
let new_inputs = curr_plan
730+
.inputs()
731+
.into_iter()
732+
.map(|input_plan| {
733+
self.add_missing_aggr_exprs(
734+
(*input_plan).clone(),
735+
missing_aggr_exprs,
736+
alias_map,
737+
)
738+
})
739+
.collect::<Result<Vec<_>>>()?;
740+
741+
let expr = curr_plan.expressions();
742+
utils::from_plan(&curr_plan, &expr, &new_inputs)
743+
}
744+
}
745+
}
746+
619747
/// Apply a sort
620748
pub fn sort(
621749
&self,
@@ -628,28 +756,46 @@ impl LogicalPlanBuilder {
628756
// Collect sort columns that are missing in the input plan's schema
629757
let mut missing_cols: Vec<Column> = vec![];
630758
let mut columns: HashSet<Column> = HashSet::new();
759+
// Collect aggregate expressions that are missing in the input plan's schema
760+
let mut missing_aggr_exprs: Vec<Expr> = vec![];
761+
let mut aggr_exprs: HashSet<Expr> = HashSet::new();
631762
exprs
632763
.clone()
633764
.into_iter()
634765
.try_for_each::<_, Result<()>>(|expr| {
635-
utils::expr_to_columns(&expr, &mut columns)
766+
utils::expr_to_columns(&expr, &mut columns)?;
767+
utils::expr_to_aggr_exprs(&expr, &mut aggr_exprs)
636768
})?;
637769
columns.into_iter().for_each(|c| {
638770
if schema.field_from_column(&c).is_err() {
639771
missing_cols.push(c);
640772
}
641773
});
774+
aggr_exprs.into_iter().for_each(|e| {
775+
if e.name(schema)
776+
.and_then(|n| schema.field_with_unqualified_name(&n))
777+
.is_err()
778+
{
779+
missing_aggr_exprs.push(e);
780+
}
781+
});
642782

643-
if missing_cols.is_empty() {
783+
if missing_cols.is_empty() && missing_aggr_exprs.is_empty() {
644784
return Ok(Self::from(LogicalPlan::Sort(Sort {
645785
expr: normalize_cols(exprs, &self.plan)?,
646786
input: Arc::new(self.plan.clone()),
647787
})));
648788
}
649789

650790
let mut alias_map = HashMap::new();
651-
let plan =
652-
self.add_missing_columns(self.plan.clone(), &missing_cols, &mut alias_map)?;
791+
let mut plan = self.plan.clone();
792+
if !missing_cols.is_empty() {
793+
plan = self.add_missing_columns(plan, &missing_cols, &mut alias_map)?;
794+
}
795+
if !missing_aggr_exprs.is_empty() {
796+
plan =
797+
self.add_missing_aggr_exprs(plan, &missing_aggr_exprs, &mut alias_map)?;
798+
}
653799
let exprs = if alias_map.is_empty() {
654800
exprs
655801
} else {

datafusion/core/src/optimizer/utils.rs

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -111,6 +111,32 @@ pub fn expr_to_columns(expr: &Expr, accum: &mut HashSet<Column>) -> Result<()> {
111111
Ok(())
112112
}
113113

114+
/// Recursively walk an expression tree, collecting the unique set of aggregate expressions
115+
/// referenced in the expression
116+
struct AggregateExprVisitor<'a> {
117+
accum: &'a mut HashSet<Expr>,
118+
}
119+
120+
impl ExpressionVisitor for AggregateExprVisitor<'_> {
121+
fn pre_visit(self, expr: &Expr) -> Result<Recursion<Self>> {
122+
match expr {
123+
Expr::AggregateFunction { .. } | Expr::AggregateUDF { .. } => {
124+
self.accum.insert(expr.clone());
125+
return Ok(Recursion::Stop(self));
126+
}
127+
_ => {}
128+
}
129+
Ok(Recursion::Continue(self))
130+
}
131+
}
132+
133+
/// Recursively walk an expression tree, collecting the unique set of aggregate expressions
134+
/// referenced in the expression
135+
pub fn expr_to_aggr_exprs(expr: &Expr, accum: &mut HashSet<Expr>) -> Result<()> {
136+
expr.accept(AggregateExprVisitor { accum })?;
137+
Ok(())
138+
}
139+
114140
/// Convenience rule for writing optimizers: recursively invoke
115141
/// optimize on plan's children and then return a node of the same
116142
/// type. Useful for optimizer rules which want to leave the type

datafusion/core/src/sql/planner.rs

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4349,6 +4349,17 @@ mod tests {
43494349
);
43504350
}
43514351

4352+
#[test]
4353+
fn select_order_by_missing_aggr() {
4354+
let sql = "SELECT state FROM person GROUP BY state ORDER BY SUM(age)";
4355+
let expected = "Projection: #person.state\
4356+
\n Sort: #SUM(person.age) ASC NULLS LAST\
4357+
\n Projection: #person.state, #SUM(person.age)\
4358+
\n Aggregate: groupBy=[[#person.state]], aggr=[[SUM(#person.age)]]\
4359+
\n TableScan: person projection=None";
4360+
quick_test(sql, expected);
4361+
}
4362+
43524363
#[test]
43534364
fn select_group_by() {
43544365
let sql = "SELECT state FROM person GROUP BY state";

0 commit comments

Comments
 (0)