Skip to content

Commit 1fd6116

Browse files
authored
Add basic support for unnest unparsing (#13129)
* Add basic support for `unnest` unparsing (#45) * Fix taplo cargo check
1 parent 132b232 commit 1fd6116

File tree

5 files changed

+163
-32
lines changed

5 files changed

+163
-32
lines changed

datafusion/sql/Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,7 @@ strum = { version = "0.26.1", features = ["derive"] }
5656
ctor = { workspace = true }
5757
datafusion-functions = { workspace = true, default-features = true }
5858
datafusion-functions-aggregate = { workspace = true }
59+
datafusion-functions-nested = { workspace = true }
5960
datafusion-functions-window = { workspace = true }
6061
env_logger = { workspace = true }
6162
paste = "^1.0"

datafusion/sql/src/unparser/expr.rs

Lines changed: 34 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
// specific language governing permissions and limitations
1616
// under the License.
1717

18+
use datafusion_expr::expr::Unnest;
1819
use sqlparser::ast::Value::SingleQuotedString;
1920
use sqlparser::ast::{
2021
self, BinaryOperator, Expr as AstExpr, Function, Ident, Interval, ObjectName,
@@ -466,7 +467,7 @@ impl Unparser<'_> {
466467
Ok(ast::Expr::Value(ast::Value::Placeholder(p.id.to_string())))
467468
}
468469
Expr::OuterReferenceColumn(_, col) => self.col_to_sql(col),
469-
Expr::Unnest(_) => not_impl_err!("Unsupported Expr conversion: {expr:?}"),
470+
Expr::Unnest(unnest) => self.unnest_to_sql(unnest),
470471
}
471472
}
472473

@@ -1340,6 +1341,29 @@ impl Unparser<'_> {
13401341
}
13411342
}
13421343

1344+
/// Converts an UNNEST operation to an AST expression by wrapping it as a function call,
1345+
/// since there is no direct representation for UNNEST in the AST.
1346+
fn unnest_to_sql(&self, unnest: &Unnest) -> Result<ast::Expr> {
1347+
let args = self.function_args_to_sql(std::slice::from_ref(&unnest.expr))?;
1348+
1349+
Ok(ast::Expr::Function(Function {
1350+
name: ast::ObjectName(vec![Ident {
1351+
value: "UNNEST".to_string(),
1352+
quote_style: None,
1353+
}]),
1354+
args: ast::FunctionArguments::List(ast::FunctionArgumentList {
1355+
duplicate_treatment: None,
1356+
args,
1357+
clauses: vec![],
1358+
}),
1359+
filter: None,
1360+
null_treatment: None,
1361+
over: None,
1362+
within_group: vec![],
1363+
parameters: ast::FunctionArguments::None,
1364+
}))
1365+
}
1366+
13431367
fn arrow_dtype_to_ast_dtype(&self, data_type: &DataType) -> Result<ast::DataType> {
13441368
match data_type {
13451369
DataType::Null => {
@@ -1855,6 +1879,15 @@ mod tests {
18551879
}),
18561880
r#"CAST(a AS DECIMAL(12,0))"#,
18571881
),
1882+
(
1883+
Expr::Unnest(Unnest {
1884+
expr: Box::new(Expr::Column(Column {
1885+
relation: Some(TableReference::partial("schema", "table")),
1886+
name: "array_col".to_string(),
1887+
})),
1888+
}),
1889+
r#"UNNEST("schema"."table".array_col)"#,
1890+
),
18581891
];
18591892

18601893
for (expr, expected) in tests {

datafusion/sql/src/unparser/plan.rs

Lines changed: 42 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -26,8 +26,9 @@ use super::{
2626
subquery_alias_inner_query_and_columns, TableAliasRewriter,
2727
},
2828
utils::{
29-
find_agg_node_within_select, find_window_nodes_within_select,
30-
unproject_sort_expr, unproject_window_exprs,
29+
find_agg_node_within_select, find_unnest_node_within_select,
30+
find_window_nodes_within_select, unproject_sort_expr, unproject_unnest_expr,
31+
unproject_window_exprs,
3132
},
3233
Unparser,
3334
};
@@ -173,15 +174,24 @@ impl Unparser<'_> {
173174
p: &Projection,
174175
select: &mut SelectBuilder,
175176
) -> Result<()> {
177+
let mut exprs = p.expr.clone();
178+
179+
// If an Unnest node is found within the select, find and unproject the unnest column
180+
if let Some(unnest) = find_unnest_node_within_select(plan) {
181+
exprs = exprs
182+
.into_iter()
183+
.map(|e| unproject_unnest_expr(e, unnest))
184+
.collect::<Result<Vec<_>>>()?;
185+
};
186+
176187
match (
177188
find_agg_node_within_select(plan, true),
178189
find_window_nodes_within_select(plan, None, true),
179190
) {
180191
(Some(agg), window) => {
181192
let window_option = window.as_deref();
182-
let items = p
183-
.expr
184-
.iter()
193+
let items = exprs
194+
.into_iter()
185195
.map(|proj_expr| {
186196
let unproj = unproject_agg_exprs(proj_expr, agg, window_option)?;
187197
self.select_item_to_sql(&unproj)
@@ -198,9 +208,8 @@ impl Unparser<'_> {
198208
));
199209
}
200210
(None, Some(window)) => {
201-
let items = p
202-
.expr
203-
.iter()
211+
let items = exprs
212+
.into_iter()
204213
.map(|proj_expr| {
205214
let unproj = unproject_window_exprs(proj_expr, &window)?;
206215
self.select_item_to_sql(&unproj)
@@ -210,8 +219,7 @@ impl Unparser<'_> {
210219
select.projection(items);
211220
}
212221
_ => {
213-
let items = p
214-
.expr
222+
let items = exprs
215223
.iter()
216224
.map(|e| self.select_item_to_sql(e))
217225
.collect::<Result<Vec<_>>>()?;
@@ -318,7 +326,8 @@ impl Unparser<'_> {
318326
if let Some(agg) =
319327
find_agg_node_within_select(plan, select.already_projected())
320328
{
321-
let unprojected = unproject_agg_exprs(&filter.predicate, agg, None)?;
329+
let unprojected =
330+
unproject_agg_exprs(filter.predicate.clone(), agg, None)?;
322331
let filter_expr = self.expr_to_sql(&unprojected)?;
323332
select.having(Some(filter_expr));
324333
} else {
@@ -596,6 +605,28 @@ impl Unparser<'_> {
596605
Ok(())
597606
}
598607
LogicalPlan::Extension(_) => not_impl_err!("Unsupported operator: {plan:?}"),
608+
LogicalPlan::Unnest(unnest) => {
609+
if !unnest.struct_type_columns.is_empty() {
610+
return internal_err!(
611+
"Struct type columns are not currently supported in UNNEST: {:?}",
612+
unnest.struct_type_columns
613+
);
614+
}
615+
616+
// In the case of UNNEST, the Unnest node is followed by a duplicate Projection node that we should skip.
617+
// Otherwise, there will be a duplicate SELECT clause.
618+
// | Projection: table.col1, UNNEST(table.col2)
619+
// | Unnest: UNNEST(table.col2)
620+
// | Projection: table.col1, table.col2 AS UNNEST(table.col2)
621+
// | Filter: table.col3 = Int64(3)
622+
// | TableScan: table projection=None
623+
if let LogicalPlan::Projection(p) = unnest.input.as_ref() {
624+
// continue with projection input
625+
self.select_to_sql_recursively(&p.input, query, select, relation)
626+
} else {
627+
internal_err!("Unnest input is not a Projection: {unnest:?}")
628+
}
629+
}
599630
_ => not_impl_err!("Unsupported operator: {plan:?}"),
600631
}
601632
}

datafusion/sql/src/unparser/utils.rs

Lines changed: 68 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,8 @@ use datafusion_common::{
2323
Column, Result, ScalarValue,
2424
};
2525
use datafusion_expr::{
26-
utils::grouping_set_to_exprlist, Aggregate, Expr, LogicalPlan, Projection, SortExpr,
27-
Window,
26+
expr, utils::grouping_set_to_exprlist, Aggregate, Expr, LogicalPlan, Projection,
27+
SortExpr, Unnest, Window,
2828
};
2929
use sqlparser::ast;
3030

@@ -62,6 +62,28 @@ pub(crate) fn find_agg_node_within_select(
6262
}
6363
}
6464

65+
/// Recursively searches children of [LogicalPlan] to find Unnest node if exist
66+
pub(crate) fn find_unnest_node_within_select(plan: &LogicalPlan) -> Option<&Unnest> {
67+
// Note that none of the nodes that have a corresponding node can have more
68+
// than 1 input node. E.g. Projection / Filter always have 1 input node.
69+
let input = plan.inputs();
70+
let input = if input.len() > 1 {
71+
return None;
72+
} else {
73+
input.first()?
74+
};
75+
76+
if let LogicalPlan::Unnest(unnest) = input {
77+
Some(unnest)
78+
} else if let LogicalPlan::TableScan(_) = input {
79+
None
80+
} else if let LogicalPlan::Projection(_) = input {
81+
None
82+
} else {
83+
find_unnest_node_within_select(input)
84+
}
85+
}
86+
6587
/// Recursively searches children of [LogicalPlan] to find Window nodes if exist
6688
/// prior to encountering a Join, TableScan, or a nested subquery (derived table factor).
6789
/// If Window node is not found prior to this or at all before reaching the end
@@ -104,26 +126,54 @@ pub(crate) fn find_window_nodes_within_select<'a>(
104126
}
105127
}
106128

129+
/// Recursively identify Column expressions and transform them into the appropriate unnest expression
130+
///
131+
/// For example, if expr contains the column expr "unnest_placeholder(make_array(Int64(1),Int64(2),Int64(2),Int64(5),NULL),depth=1)"
132+
/// it will be transformed into an actual unnest expression UNNEST([1, 2, 2, 5, NULL])
133+
pub(crate) fn unproject_unnest_expr(expr: Expr, unnest: &Unnest) -> Result<Expr> {
134+
expr.transform(|sub_expr| {
135+
if let Expr::Column(col_ref) = &sub_expr {
136+
// Check if the column is among the columns to run unnest on.
137+
// Currently, only List/Array columns (defined in `list_type_columns`) are supported for unnesting.
138+
if unnest.list_type_columns.iter().any(|e| e.1.output_column.name == col_ref.name) {
139+
if let Ok(idx) = unnest.schema.index_of_column(col_ref) {
140+
if let LogicalPlan::Projection(Projection { expr, .. }) = unnest.input.as_ref() {
141+
if let Some(unprojected_expr) = expr.get(idx) {
142+
let unnest_expr = Expr::Unnest(expr::Unnest::new(unprojected_expr.clone()));
143+
return Ok(Transformed::yes(unnest_expr));
144+
}
145+
}
146+
}
147+
return internal_err!(
148+
"Tried to unproject unnest expr for column '{}' that was not found in the provided Unnest!", &col_ref.name
149+
);
150+
}
151+
}
152+
153+
Ok(Transformed::no(sub_expr))
154+
155+
}).map(|e| e.data)
156+
}
157+
107158
/// Recursively identify all Column expressions and transform them into the appropriate
108159
/// aggregate expression contained in agg.
109160
///
110161
/// For example, if expr contains the column expr "COUNT(*)" it will be transformed
111162
/// into an actual aggregate expression COUNT(*) as identified in the aggregate node.
112163
pub(crate) fn unproject_agg_exprs(
113-
expr: &Expr,
164+
expr: Expr,
114165
agg: &Aggregate,
115166
windows: Option<&[&Window]>,
116167
) -> Result<Expr> {
117-
expr.clone()
118-
.transform(|sub_expr| {
168+
expr.transform(|sub_expr| {
119169
if let Expr::Column(c) = sub_expr {
120170
if let Some(unprojected_expr) = find_agg_expr(agg, &c)? {
121171
Ok(Transformed::yes(unprojected_expr.clone()))
122172
} else if let Some(unprojected_expr) =
123173
windows.and_then(|w| find_window_expr(w, &c.name).cloned())
124174
{
125175
// Window function can contain an aggregation columns, e.g., 'avg(sum(ss_sales_price)) over ...' that needs to be unprojected
126-
return Ok(Transformed::yes(unproject_agg_exprs(&unprojected_expr, agg, None)?));
176+
return Ok(Transformed::yes(unproject_agg_exprs(unprojected_expr, agg, None)?));
127177
} else {
128178
internal_err!(
129179
"Tried to unproject agg expr for column '{}' that was not found in the provided Aggregate!", &c.name
@@ -141,20 +191,19 @@ pub(crate) fn unproject_agg_exprs(
141191
///
142192
/// For example, if expr contains the column expr "COUNT(*) PARTITION BY id" it will be transformed
143193
/// into an actual window expression as identified in the window node.
144-
pub(crate) fn unproject_window_exprs(expr: &Expr, windows: &[&Window]) -> Result<Expr> {
145-
expr.clone()
146-
.transform(|sub_expr| {
147-
if let Expr::Column(c) = sub_expr {
148-
if let Some(unproj) = find_window_expr(windows, &c.name) {
149-
Ok(Transformed::yes(unproj.clone()))
150-
} else {
151-
Ok(Transformed::no(Expr::Column(c)))
152-
}
194+
pub(crate) fn unproject_window_exprs(expr: Expr, windows: &[&Window]) -> Result<Expr> {
195+
expr.transform(|sub_expr| {
196+
if let Expr::Column(c) = sub_expr {
197+
if let Some(unproj) = find_window_expr(windows, &c.name) {
198+
Ok(Transformed::yes(unproj.clone()))
153199
} else {
154-
Ok(Transformed::no(sub_expr))
200+
Ok(Transformed::no(Expr::Column(c)))
155201
}
156-
})
157-
.map(|e| e.data)
202+
} else {
203+
Ok(Transformed::no(sub_expr))
204+
}
205+
})
206+
.map(|e| e.data)
158207
}
159208

160209
fn find_agg_expr<'a>(agg: &'a Aggregate, column: &Column) -> Result<Option<&'a Expr>> {
@@ -218,7 +267,7 @@ pub(crate) fn unproject_sort_expr(
218267
// In case of aggregation there could be columns containing aggregation functions we need to unproject
219268
if let Some(agg) = agg {
220269
if agg.schema.is_column_from_schema(col_ref) {
221-
let new_expr = unproject_agg_exprs(&sort_expr.expr, agg, None)?;
270+
let new_expr = unproject_agg_exprs(sort_expr.expr, agg, None)?;
222271
sort_expr.expr = new_expr;
223272
return Ok(sort_expr);
224273
}

datafusion/sql/tests/cases/plan_to_sql.rs

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ use datafusion_expr::test::function_stub::{count_udaf, max_udaf, min_udaf, sum_u
2424
use datafusion_expr::{col, lit, table_scan, wildcard, LogicalPlanBuilder};
2525
use datafusion_functions::unicode;
2626
use datafusion_functions_aggregate::grouping::grouping_udaf;
27+
use datafusion_functions_nested::make_array::make_array_udf;
2728
use datafusion_functions_window::rank::rank_udwf;
2829
use datafusion_sql::planner::{ContextProvider, PlannerContext, SqlToRel};
2930
use datafusion_sql::unparser::dialect::{
@@ -711,7 +712,8 @@ where
711712
.with_aggregate_function(max_udaf())
712713
.with_aggregate_function(grouping_udaf())
713714
.with_window_function(rank_udwf())
714-
.with_scalar_function(Arc::new(unicode::substr().as_ref().clone())),
715+
.with_scalar_function(Arc::new(unicode::substr().as_ref().clone()))
716+
.with_scalar_function(make_array_udf()),
715717
};
716718
let sql_to_rel = SqlToRel::new(&context);
717719
let plan = sql_to_rel.sql_statement_to_plan(statement).unwrap();
@@ -1084,3 +1086,18 @@ FROM person
10841086
GROUP BY person.id, person.first_name"#.replace("\n", " ").as_str(),
10851087
);
10861088
}
1089+
1090+
#[test]
1091+
fn test_unnest_to_sql() {
1092+
sql_round_trip(
1093+
GenericDialect {},
1094+
r#"SELECT unnest(array_col) as u1, struct_col, array_col FROM unnest_table WHERE array_col != NULL ORDER BY struct_col, array_col"#,
1095+
r#"SELECT UNNEST(unnest_table.array_col) AS u1, unnest_table.struct_col, unnest_table.array_col FROM unnest_table WHERE (unnest_table.array_col <> NULL) ORDER BY unnest_table.struct_col ASC NULLS LAST, unnest_table.array_col ASC NULLS LAST"#,
1096+
);
1097+
1098+
sql_round_trip(
1099+
GenericDialect {},
1100+
r#"SELECT unnest(make_array(1, 2, 2, 5, NULL)) as u1"#,
1101+
r#"SELECT UNNEST(make_array(1, 2, 2, 5, NULL)) AS u1"#,
1102+
);
1103+
}

0 commit comments

Comments
 (0)