diff --git a/datafusion/core/src/logical_plan/builder.rs b/datafusion/core/src/logical_plan/builder.rs index d13d542ef833..72787e192763 100644 --- a/datafusion/core/src/logical_plan/builder.rs +++ b/datafusion/core/src/logical_plan/builder.rs @@ -616,6 +616,134 @@ impl LogicalPlanBuilder { } } + /// Add missing sort aggregate expressions to all downstream projection and aggregate + #[allow(clippy::only_used_in_recursion)] + fn add_missing_aggr_exprs( + &self, + curr_plan: LogicalPlan, + missing_aggr_exprs: &[Expr], + alias_map: &mut HashMap, + ) -> Result { + match curr_plan { + LogicalPlan::Projection(Projection { + input, + mut expr, + schema, + alias, + }) => { + let input = self.add_missing_aggr_exprs( + input.as_ref().clone(), + missing_aggr_exprs, + alias_map, + )?; + let input_schema = input.schema(); + if missing_aggr_exprs.iter().any(|e| { + e.name(input_schema) + .and_then(|n| input_schema.field_with_unqualified_name(&n)) + .is_err() + }) && alias_map.is_empty() + { + return Ok(LogicalPlan::Projection(Projection { + expr, + input: Arc::new(input), + schema, + alias, + })); + } + + let mut missing_exprs: Vec = + Vec::with_capacity(missing_aggr_exprs.len()); + for missing_aggr_expr in missing_aggr_exprs { + let normalized_expr = resolve_exprs_to_aliases( + missing_aggr_expr, + alias_map, + input_schema, + )?; + let normalized_expr_name = normalized_expr.name(input_schema)?; + if input_schema + .field_with_unqualified_name(&normalized_expr_name) + .is_err() + { + continue; + }; + missing_exprs.push(normalized_expr); + } + + expr.extend(missing_exprs); + + let new_schema = DFSchema::new_with_metadata( + exprlist_to_fields(&expr, &input)?, + input_schema.metadata().clone(), + )?; + + Ok(LogicalPlan::Projection(Projection { + expr, + input: Arc::new(input), + schema: DFSchemaRef::new(new_schema), + alias, + })) + } + LogicalPlan::Aggregate(Aggregate { + input, + group_expr, + mut aggr_expr, + schema: _, + }) => { + let input = self.add_missing_aggr_exprs( + input.as_ref().clone(), + missing_aggr_exprs, + alias_map, + )?; + let input_schema = input.schema(); + + let mut missing_exprs = Vec::with_capacity(missing_aggr_exprs.len()); + for missing_aggr_expr in missing_aggr_exprs { + if aggr_expr.contains(missing_aggr_expr) { + continue; + } + let expr_name = missing_aggr_expr.name(input_schema)?; + alias_map.insert(expr_name.clone(), expr_name); + missing_exprs.push(missing_aggr_expr.clone()); + } + + aggr_expr.extend(missing_exprs); + + let all_expr = group_expr + .iter() + .chain(aggr_expr.iter()) + .cloned() + .collect::>(); + let new_schema = DFSchema::new_with_metadata( + exprlist_to_fields(&all_expr, &input)?, + input_schema.metadata().clone(), + )?; + + Ok(LogicalPlan::Aggregate(Aggregate { + input: Arc::new(input), + group_expr, + aggr_expr, + schema: Arc::new(new_schema), + })) + } + _ => { + let new_inputs = curr_plan + .inputs() + .into_iter() + .map(|input_plan| { + self.add_missing_aggr_exprs( + (*input_plan).clone(), + missing_aggr_exprs, + alias_map, + ) + }) + .collect::>>()?; + + let expr = curr_plan.expressions(); + utils::from_plan(&curr_plan, &expr, &new_inputs) + } + } + } + /// Apply a sort pub fn sort( &self, @@ -628,19 +756,31 @@ impl LogicalPlanBuilder { // Collect sort columns that are missing in the input plan's schema let mut missing_cols: Vec = vec![]; let mut columns: HashSet = HashSet::new(); + // Collect aggregate expressions that are missing in the input plan's schema + let mut missing_aggr_exprs: Vec = vec![]; + let mut aggr_exprs: HashSet = HashSet::new(); exprs .clone() .into_iter() .try_for_each::<_, Result<()>>(|expr| { - utils::expr_to_columns(&expr, &mut columns) + utils::expr_to_columns(&expr, &mut columns)?; + utils::expr_to_aggr_exprs(&expr, &mut aggr_exprs) })?; columns.into_iter().for_each(|c| { if schema.field_from_column(&c).is_err() { missing_cols.push(c); } }); + aggr_exprs.into_iter().for_each(|e| { + if e.name(schema) + .and_then(|n| schema.field_with_unqualified_name(&n)) + .is_err() + { + missing_aggr_exprs.push(e); + } + }); - if missing_cols.is_empty() { + if missing_cols.is_empty() && missing_aggr_exprs.is_empty() { return Ok(Self::from(LogicalPlan::Sort(Sort { expr: normalize_cols(exprs, &self.plan)?, input: Arc::new(self.plan.clone()), @@ -648,8 +788,14 @@ impl LogicalPlanBuilder { } let mut alias_map = HashMap::new(); - let plan = - self.add_missing_columns(self.plan.clone(), &missing_cols, &mut alias_map)?; + let mut plan = self.plan.clone(); + if !missing_cols.is_empty() { + plan = self.add_missing_columns(plan, &missing_cols, &mut alias_map)?; + } + if !missing_aggr_exprs.is_empty() { + plan = + self.add_missing_aggr_exprs(plan, &missing_aggr_exprs, &mut alias_map)?; + } let exprs = if alias_map.is_empty() { exprs } else { diff --git a/datafusion/core/src/optimizer/utils.rs b/datafusion/core/src/optimizer/utils.rs index 37f62925d350..bbdd3e702612 100644 --- a/datafusion/core/src/optimizer/utils.rs +++ b/datafusion/core/src/optimizer/utils.rs @@ -111,6 +111,32 @@ pub fn expr_to_columns(expr: &Expr, accum: &mut HashSet) -> Result<()> { Ok(()) } +/// Recursively walk an expression tree, collecting the unique set of aggregate expressions +/// referenced in the expression +struct AggregateExprVisitor<'a> { + accum: &'a mut HashSet, +} + +impl ExpressionVisitor for AggregateExprVisitor<'_> { + fn pre_visit(self, expr: &Expr) -> Result> { + match expr { + Expr::AggregateFunction { .. } | Expr::AggregateUDF { .. } => { + self.accum.insert(expr.clone()); + return Ok(Recursion::Stop(self)); + } + _ => {} + } + Ok(Recursion::Continue(self)) + } +} + +/// Recursively walk an expression tree, collecting the unique set of aggregate expressions +/// referenced in the expression +pub fn expr_to_aggr_exprs(expr: &Expr, accum: &mut HashSet) -> Result<()> { + expr.accept(AggregateExprVisitor { accum })?; + Ok(()) +} + /// Convenience rule for writing optimizers: recursively invoke /// optimize on plan's children and then return a node of the same /// type. Useful for optimizer rules which want to leave the type diff --git a/datafusion/core/src/sql/planner.rs b/datafusion/core/src/sql/planner.rs index cf23479d30b6..b39093daafbb 100644 --- a/datafusion/core/src/sql/planner.rs +++ b/datafusion/core/src/sql/planner.rs @@ -4349,6 +4349,17 @@ mod tests { ); } + #[test] + fn select_order_by_missing_aggr() { + let sql = "SELECT state FROM person GROUP BY state ORDER BY SUM(age)"; + let expected = "Projection: #person.state\ + \n Sort: #SUM(person.age) ASC NULLS LAST\ + \n Projection: #person.state, #SUM(person.age)\ + \n Aggregate: groupBy=[[#person.state]], aggr=[[SUM(#person.age)]]\ + \n TableScan: person projection=None"; + quick_test(sql, expected); + } + #[test] fn select_group_by() { let sql = "SELECT state FROM person GROUP BY state";