@@ -23,8 +23,8 @@ use datafusion_common::{
2323 Column , Result , ScalarValue ,
2424} ;
2525use datafusion_expr:: {
26- utils:: grouping_set_to_exprlist, Aggregate , Expr , LogicalPlan , Projection , SortExpr ,
27- Window ,
26+ expr , utils:: grouping_set_to_exprlist, Aggregate , Expr , LogicalPlan , Projection ,
27+ SortExpr , Unnest , Window ,
2828} ;
2929use sqlparser:: ast;
3030
@@ -62,6 +62,28 @@ pub(crate) fn find_agg_node_within_select(
6262 }
6363}
6464
65+ /// Recursively searches children of [LogicalPlan] to find Unnest node if exist
66+ pub ( crate ) fn find_unnest_node_within_select ( plan : & LogicalPlan ) -> Option < & Unnest > {
67+ // Note that none of the nodes that have a corresponding node can have more
68+ // than 1 input node. E.g. Projection / Filter always have 1 input node.
69+ let input = plan. inputs ( ) ;
70+ let input = if input. len ( ) > 1 {
71+ return None ;
72+ } else {
73+ input. first ( ) ?
74+ } ;
75+
76+ if let LogicalPlan :: Unnest ( unnest) = input {
77+ Some ( unnest)
78+ } else if let LogicalPlan :: TableScan ( _) = input {
79+ None
80+ } else if let LogicalPlan :: Projection ( _) = input {
81+ None
82+ } else {
83+ find_unnest_node_within_select ( input)
84+ }
85+ }
86+
6587/// Recursively searches children of [LogicalPlan] to find Window nodes if exist
6688/// prior to encountering a Join, TableScan, or a nested subquery (derived table factor).
6789/// If Window node is not found prior to this or at all before reaching the end
@@ -104,26 +126,54 @@ pub(crate) fn find_window_nodes_within_select<'a>(
104126 }
105127}
106128
129+ /// Recursively identify Column expressions and transform them into the appropriate unnest expression
130+ ///
131+ /// For example, if expr contains the column expr "unnest_placeholder(make_array(Int64(1),Int64(2),Int64(2),Int64(5),NULL),depth=1)"
132+ /// it will be transformed into an actual unnest expression UNNEST([1, 2, 2, 5, NULL])
133+ pub ( crate ) fn unproject_unnest_expr ( expr : Expr , unnest : & Unnest ) -> Result < Expr > {
134+ expr. transform ( |sub_expr| {
135+ if let Expr :: Column ( col_ref) = & sub_expr {
136+ // Check if the column is among the columns to run unnest on.
137+ // Currently, only List/Array columns (defined in `list_type_columns`) are supported for unnesting.
138+ if unnest. list_type_columns . iter ( ) . any ( |e| e. 1 . output_column . name == col_ref. name ) {
139+ if let Ok ( idx) = unnest. schema . index_of_column ( col_ref) {
140+ if let LogicalPlan :: Projection ( Projection { expr, .. } ) = unnest. input . as_ref ( ) {
141+ if let Some ( unprojected_expr) = expr. get ( idx) {
142+ let unnest_expr = Expr :: Unnest ( expr:: Unnest :: new ( unprojected_expr. clone ( ) ) ) ;
143+ return Ok ( Transformed :: yes ( unnest_expr) ) ;
144+ }
145+ }
146+ }
147+ return internal_err ! (
148+ "Tried to unproject unnest expr for column '{}' that was not found in the provided Unnest!" , & col_ref. name
149+ ) ;
150+ }
151+ }
152+
153+ Ok ( Transformed :: no ( sub_expr) )
154+
155+ } ) . map ( |e| e. data )
156+ }
157+
107158/// Recursively identify all Column expressions and transform them into the appropriate
108159/// aggregate expression contained in agg.
109160///
110161/// For example, if expr contains the column expr "COUNT(*)" it will be transformed
111162/// into an actual aggregate expression COUNT(*) as identified in the aggregate node.
112163pub ( crate ) fn unproject_agg_exprs (
113- expr : & Expr ,
164+ expr : Expr ,
114165 agg : & Aggregate ,
115166 windows : Option < & [ & Window ] > ,
116167) -> Result < Expr > {
117- expr. clone ( )
118- . transform ( |sub_expr| {
168+ expr. transform ( |sub_expr| {
119169 if let Expr :: Column ( c) = sub_expr {
120170 if let Some ( unprojected_expr) = find_agg_expr ( agg, & c) ? {
121171 Ok ( Transformed :: yes ( unprojected_expr. clone ( ) ) )
122172 } else if let Some ( unprojected_expr) =
123173 windows. and_then ( |w| find_window_expr ( w, & c. name ) . cloned ( ) )
124174 {
125175 // Window function can contain an aggregation columns, e.g., 'avg(sum(ss_sales_price)) over ...' that needs to be unprojected
126- return Ok ( Transformed :: yes ( unproject_agg_exprs ( & unprojected_expr, agg, None ) ?) ) ;
176+ return Ok ( Transformed :: yes ( unproject_agg_exprs ( unprojected_expr, agg, None ) ?) ) ;
127177 } else {
128178 internal_err ! (
129179 "Tried to unproject agg expr for column '{}' that was not found in the provided Aggregate!" , & c. name
@@ -141,20 +191,19 @@ pub(crate) fn unproject_agg_exprs(
141191///
142192/// For example, if expr contains the column expr "COUNT(*) PARTITION BY id" it will be transformed
143193/// into an actual window expression as identified in the window node.
144- pub ( crate ) fn unproject_window_exprs ( expr : & Expr , windows : & [ & Window ] ) -> Result < Expr > {
145- expr. clone ( )
146- . transform ( |sub_expr| {
147- if let Expr :: Column ( c) = sub_expr {
148- if let Some ( unproj) = find_window_expr ( windows, & c. name ) {
149- Ok ( Transformed :: yes ( unproj. clone ( ) ) )
150- } else {
151- Ok ( Transformed :: no ( Expr :: Column ( c) ) )
152- }
194+ pub ( crate ) fn unproject_window_exprs ( expr : Expr , windows : & [ & Window ] ) -> Result < Expr > {
195+ expr. transform ( |sub_expr| {
196+ if let Expr :: Column ( c) = sub_expr {
197+ if let Some ( unproj) = find_window_expr ( windows, & c. name ) {
198+ Ok ( Transformed :: yes ( unproj. clone ( ) ) )
153199 } else {
154- Ok ( Transformed :: no ( sub_expr ) )
200+ Ok ( Transformed :: no ( Expr :: Column ( c ) ) )
155201 }
156- } )
157- . map ( |e| e. data )
202+ } else {
203+ Ok ( Transformed :: no ( sub_expr) )
204+ }
205+ } )
206+ . map ( |e| e. data )
158207}
159208
160209fn find_agg_expr < ' a > ( agg : & ' a Aggregate , column : & Column ) -> Result < Option < & ' a Expr > > {
@@ -218,7 +267,7 @@ pub(crate) fn unproject_sort_expr(
218267 // In case of aggregation there could be columns containing aggregation functions we need to unproject
219268 if let Some ( agg) = agg {
220269 if agg. schema . is_column_from_schema ( col_ref) {
221- let new_expr = unproject_agg_exprs ( & sort_expr. expr , agg, None ) ?;
270+ let new_expr = unproject_agg_exprs ( sort_expr. expr , agg, None ) ?;
222271 sort_expr. expr = new_expr;
223272 return Ok ( sort_expr) ;
224273 }
0 commit comments