Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
154 changes: 150 additions & 4 deletions datafusion/core/src/logical_plan/builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -616,6 +616,134 @@ impl LogicalPlanBuilder {
}
}

/// Add missing sort aggregate expressions to all downstream projection and aggregate
#[allow(clippy::only_used_in_recursion)]
fn add_missing_aggr_exprs(
&self,
curr_plan: LogicalPlan,
missing_aggr_exprs: &[Expr],
alias_map: &mut HashMap<String, String>,
) -> Result<LogicalPlan> {
match curr_plan {
LogicalPlan::Projection(Projection {
input,
mut expr,
schema,
alias,
}) => {
let input = self.add_missing_aggr_exprs(
input.as_ref().clone(),
missing_aggr_exprs,
alias_map,
)?;
let input_schema = input.schema();
if missing_aggr_exprs.iter().any(|e| {
e.name(input_schema)
.and_then(|n| input_schema.field_with_unqualified_name(&n))
.is_err()
}) && alias_map.is_empty()
{
return Ok(LogicalPlan::Projection(Projection {
expr,
input: Arc::new(input),
schema,
alias,
}));
}

let mut missing_exprs: Vec<Expr> =
Vec::with_capacity(missing_aggr_exprs.len());
for missing_aggr_expr in missing_aggr_exprs {
let normalized_expr = resolve_exprs_to_aliases(
missing_aggr_expr,
alias_map,
input_schema,
)?;
let normalized_expr_name = normalized_expr.name(input_schema)?;
if input_schema
.field_with_unqualified_name(&normalized_expr_name)
.is_err()
{
continue;
};
missing_exprs.push(normalized_expr);
}

expr.extend(missing_exprs);

let new_schema = DFSchema::new_with_metadata(
exprlist_to_fields(&expr, &input)?,
input_schema.metadata().clone(),
)?;

Ok(LogicalPlan::Projection(Projection {
expr,
input: Arc::new(input),
schema: DFSchemaRef::new(new_schema),
alias,
}))
}
LogicalPlan::Aggregate(Aggregate {
input,
group_expr,
mut aggr_expr,
schema: _,
}) => {
let input = self.add_missing_aggr_exprs(
input.as_ref().clone(),
missing_aggr_exprs,
alias_map,
)?;
let input_schema = input.schema();

let mut missing_exprs = Vec::with_capacity(missing_aggr_exprs.len());
for missing_aggr_expr in missing_aggr_exprs {
if aggr_expr.contains(missing_aggr_expr) {
continue;
}
let expr_name = missing_aggr_expr.name(input_schema)?;
alias_map.insert(expr_name.clone(), expr_name);
missing_exprs.push(missing_aggr_expr.clone());
}

aggr_expr.extend(missing_exprs);

let all_expr = group_expr
.iter()
.chain(aggr_expr.iter())
.cloned()
.collect::<Vec<_>>();
let new_schema = DFSchema::new_with_metadata(
exprlist_to_fields(&all_expr, &input)?,
input_schema.metadata().clone(),
)?;

Ok(LogicalPlan::Aggregate(Aggregate {
input: Arc::new(input),
group_expr,
aggr_expr,
schema: Arc::new(new_schema),
}))
}
_ => {
let new_inputs = curr_plan
.inputs()
.into_iter()
.map(|input_plan| {
self.add_missing_aggr_exprs(
(*input_plan).clone(),
missing_aggr_exprs,
alias_map,
)
})
.collect::<Result<Vec<_>>>()?;

let expr = curr_plan.expressions();
utils::from_plan(&curr_plan, &expr, &new_inputs)
}
}
}

/// Apply a sort
pub fn sort(
&self,
Expand All @@ -628,28 +756,46 @@ impl LogicalPlanBuilder {
// Collect sort columns that are missing in the input plan's schema
let mut missing_cols: Vec<Column> = vec![];
let mut columns: HashSet<Column> = HashSet::new();
// Collect aggregate expressions that are missing in the input plan's schema
let mut missing_aggr_exprs: Vec<Expr> = vec![];
let mut aggr_exprs: HashSet<Expr> = HashSet::new();
exprs
.clone()
.into_iter()
.try_for_each::<_, Result<()>>(|expr| {
utils::expr_to_columns(&expr, &mut columns)
utils::expr_to_columns(&expr, &mut columns)?;
utils::expr_to_aggr_exprs(&expr, &mut aggr_exprs)
})?;
columns.into_iter().for_each(|c| {
if schema.field_from_column(&c).is_err() {
missing_cols.push(c);
}
});
aggr_exprs.into_iter().for_each(|e| {
if e.name(schema)
.and_then(|n| schema.field_with_unqualified_name(&n))
.is_err()
{
missing_aggr_exprs.push(e);
}
});

if missing_cols.is_empty() {
if missing_cols.is_empty() && missing_aggr_exprs.is_empty() {
return Ok(Self::from(LogicalPlan::Sort(Sort {
expr: normalize_cols(exprs, &self.plan)?,
input: Arc::new(self.plan.clone()),
})));
}

let mut alias_map = HashMap::new();
let plan =
self.add_missing_columns(self.plan.clone(), &missing_cols, &mut alias_map)?;
let mut plan = self.plan.clone();
if !missing_cols.is_empty() {
plan = self.add_missing_columns(plan, &missing_cols, &mut alias_map)?;
}
if !missing_aggr_exprs.is_empty() {
plan =
self.add_missing_aggr_exprs(plan, &missing_aggr_exprs, &mut alias_map)?;
}
let exprs = if alias_map.is_empty() {
exprs
} else {
Expand Down
26 changes: 26 additions & 0 deletions datafusion/core/src/optimizer/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,32 @@ pub fn expr_to_columns(expr: &Expr, accum: &mut HashSet<Column>) -> Result<()> {
Ok(())
}

/// Recursively walk an expression tree, collecting the unique set of aggregate expressions
/// referenced in the expression
struct AggregateExprVisitor<'a> {
accum: &'a mut HashSet<Expr>,
}

impl ExpressionVisitor for AggregateExprVisitor<'_> {
fn pre_visit(self, expr: &Expr) -> Result<Recursion<Self>> {
match expr {
Expr::AggregateFunction { .. } | Expr::AggregateUDF { .. } => {
self.accum.insert(expr.clone());
return Ok(Recursion::Stop(self));
}
_ => {}
}
Ok(Recursion::Continue(self))
}
}

/// Recursively walk an expression tree, collecting the unique set of aggregate expressions
/// referenced in the expression
pub fn expr_to_aggr_exprs(expr: &Expr, accum: &mut HashSet<Expr>) -> Result<()> {
expr.accept(AggregateExprVisitor { accum })?;
Ok(())
}

/// Convenience rule for writing optimizers: recursively invoke
/// optimize on plan's children and then return a node of the same
/// type. Useful for optimizer rules which want to leave the type
Expand Down
11 changes: 11 additions & 0 deletions datafusion/core/src/sql/planner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4349,6 +4349,17 @@ mod tests {
);
}

#[test]
fn select_order_by_missing_aggr() {
let sql = "SELECT state FROM person GROUP BY state ORDER BY SUM(age)";
let expected = "Projection: #person.state\
\n Sort: #SUM(person.age) ASC NULLS LAST\
\n Projection: #person.state, #SUM(person.age)\
\n Aggregate: groupBy=[[#person.state]], aggr=[[SUM(#person.age)]]\
\n TableScan: person projection=None";
quick_test(sql, expected);
}

#[test]
fn select_group_by() {
let sql = "SELECT state FROM person GROUP BY state";
Expand Down
Loading