Skip to content

Commit 783ce95

Browse files
committed
feat(cube): Cube type_coercion hack for topk aggregate, with upper_expressions
1 parent fd6f74f commit 783ce95

File tree

2 files changed

+47
-5
lines changed

2 files changed

+47
-5
lines changed

datafusion/expr/src/logical_plan/extension.rs

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,20 @@ pub trait UserDefinedLogicalNode: fmt::Debug + Send + Sync {
6060
/// passes and rewrites. See [`LogicalPlan::expressions`] for more details.
6161
fn expressions(&self) -> Vec<Expr>;
6262

63+
/// Cube extension: Returns expressions defined on the output schema. This should be removed
64+
/// (to avoid diverging from upstream DF) and such logical nodes should be split into two nodes.
65+
fn upper_expressions(&self) -> Vec<Expr> {
66+
Vec::new()
67+
}
68+
69+
// Cube extension: Replaces upper_expressions(). Returns None if no replacement is needed (as a
70+
// hack allowing default implementation without making `UserDefinedLogicalNode` derive from
71+
// `Clone`).
72+
fn with_upper_expressions(&self, upper_exprs: Vec<Expr>) -> Result<Option<Arc<dyn UserDefinedLogicalNode>>> {
73+
assert_eq!(upper_exprs.len(), 0);
74+
Ok(None)
75+
}
76+
6377
/// A list of output columns (e.g. the names of columns in
6478
/// self.schema()) for which predicates can not be pushed below
6579
/// this node without changing the output.

datafusion/optimizer/src/analyzer/type_coercion.rs

Lines changed: 33 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ use arrow::datatypes::{DataType, Field, IntervalUnit, Schema};
2626
use crate::analyzer::AnalyzerRule;
2727
use crate::utils::NamePreserver;
2828
use datafusion_common::config::ConfigOptions;
29-
use datafusion_common::tree_node::{Transformed, TreeNode, TreeNodeRewriter};
29+
use datafusion_common::tree_node::{Transformed, TreeNode, TreeNodeIterator, TreeNodeRewriter};
3030
use datafusion_common::{
3131
exec_err, internal_err, not_impl_err, plan_datafusion_err, plan_err, Column,
3232
DFSchema, DFSchemaRef, DataFusionError, Result, ScalarValue, TableReference,
@@ -50,9 +50,7 @@ use datafusion_expr::type_coercion::other::{
5050
use datafusion_expr::type_coercion::{is_datetime, is_utf8_or_large_utf8};
5151
use datafusion_expr::utils::merge_schema;
5252
use datafusion_expr::{
53-
is_false, is_not_false, is_not_true, is_not_unknown, is_true, is_unknown, not,
54-
AggregateUDF, Expr, ExprFunctionExt, ExprSchemable, Join, LogicalPlan, Operator,
55-
Projection, ScalarUDF, Union, WindowFrame, WindowFrameBound, WindowFrameUnits,
53+
is_false, is_not_false, is_not_true, is_not_unknown, is_true, is_unknown, not, AggregateUDF, Expr, ExprFunctionExt, ExprSchemable, Extension, Join, LogicalPlan, Operator, Projection, ScalarUDF, Union, WindowFrame, WindowFrameBound, WindowFrameUnits
5654
};
5755

5856
/// Performs type coercion by determining the schema
@@ -146,7 +144,37 @@ fn analyze_internal(
146144
// some plans need extra coercion after their expressions are coerced
147145
.map_data(|plan| expr_rewrite.coerce_plan(plan))?
148146
// recompute the schema after the expressions have been rewritten as the types may have changed
149-
.map_data(|plan| plan.recompute_schema())
147+
.map_data(|plan| plan.recompute_schema())?
148+
// Cube extension: Map "upper" expressions (after this node's output schema has been recomputed)
149+
.transform_data(|plan| {
150+
match &plan {
151+
LogicalPlan::Extension(Extension { node }) => {
152+
let upper_expressions = node.upper_expressions();
153+
if upper_expressions.is_empty() {
154+
Ok(Transformed::no(plan))
155+
} else {
156+
let output_schema = plan.schema().clone();
157+
let mut upper_expr_rewrite = TypeCoercionRewriter::new(&output_schema);
158+
upper_expressions.into_iter().map_until_stop_and_collect(|expr| {
159+
// No need for name preserver on upper expressions. (Why? Because
160+
// upper_expressions cannot change the output schema -- they use the output
161+
// schema, already defined by the node, as input. (They are filter
162+
// expressions.))
163+
expr.rewrite(&mut upper_expr_rewrite)
164+
})?
165+
.map_data(|upper_expressions| {
166+
let new_node = node.with_upper_expressions(upper_expressions)?;
167+
if let Some(new_node) = new_node {
168+
Ok(LogicalPlan::Extension(Extension { node: new_node }))
169+
} else {
170+
internal_err!("with_upper_expressions must not return None when upper_expressions() was non-empty")
171+
}
172+
})
173+
}
174+
},
175+
_ => Ok(Transformed::no(plan))
176+
}
177+
})
150178
}
151179

152180
/// Rewrite expressions to apply type coercion.

0 commit comments

Comments
 (0)