1- use arrow:: compute:: kernels:: filter;
21use indexmap:: IndexSet ;
32
43use crate :: common:: join_type:: JoinType ;
54use crate :: common:: table_schema:: qualified_name;
6- use crate :: common:: transformed:: Transformed ;
7- use crate :: error:: { Result , Error } ;
5+ use crate :: common:: transformed:: { TransformNode , Transformed , TreeNodeRecursion } ;
6+ use crate :: error:: { Error , Result } ;
87use crate :: internal_err;
98use crate :: logical:: expr:: { Column , LogicalExpr } ;
10- use crate :: logical:: plan:: { Filter , Join , LogicalPlan , SubqueryAlias } ;
9+ use crate :: logical:: plan:: { Filter , Join , LogicalPlan , Projection , SubqueryAlias } ;
1110use crate :: logical:: LogicalPlanBuilder ;
1211use crate :: optimizer:: rule:: rule_optimizer:: OptimizerRule ;
1312use crate :: utils:: expr:: { conjunction, replace_col, replace_cols_by_name, split_conjunctive_predicates} ;
@@ -32,7 +31,31 @@ impl OptimizerRule for PushdownFilter {
3231 } ;
3332
3433 match * filter. input {
35- LogicalPlan :: Projection ( projection) => { }
34+ LogicalPlan :: Projection ( mut projection) => {
35+ let predicates = split_conjunctive_predicates ( filter. expr ) ;
36+ let ( push_expr, keep_expr) = rewrite_predicates ( predicates, & projection) ?;
37+ let is_transformed = push_expr. is_some ( ) ;
38+ let new_projection = if let Some ( push_expr) = push_expr {
39+ let new_filter = LogicalPlanBuilder :: filter ( * projection. input , push_expr) . map ( Box :: new) ?;
40+ projection. input = new_filter;
41+
42+ LogicalPlan :: Projection ( projection)
43+ } else {
44+ LogicalPlan :: Projection ( projection)
45+ } ;
46+
47+ if let Some ( keep_expr) = keep_expr {
48+ return Ok ( Transformed :: yes (
49+ LogicalPlanBuilder :: from ( new_projection) . add_filter ( keep_expr) ?. build ( ) ,
50+ ) ) ;
51+ }
52+
53+ if is_transformed {
54+ Ok ( Transformed :: yes ( new_projection) )
55+ } else {
56+ Ok ( Transformed :: no ( new_projection) )
57+ }
58+ }
3659 LogicalPlan :: Filter ( child_ilter) => {
3760 let parent_predicates = split_conjunctive_predicates ( filter. expr ) ;
3861 let child_predicates = split_conjunctive_predicates ( child_ilter. expr ) ;
@@ -173,14 +196,55 @@ impl PushdownFilter {
173196 }
174197}
175198
199+ fn rewrite_predicates (
200+ predicates : Vec < LogicalExpr > ,
201+ projection : & Projection ,
202+ ) -> Result < ( Option < LogicalExpr > , Option < LogicalExpr > ) > {
203+ let mut push_predicates = vec ! [ ] ;
204+ let mut keep_predicates = vec ! [ ] ;
205+
206+ let projection_map = projection
207+ . schema
208+ . iter ( )
209+ . zip ( projection. exprs . iter ( ) )
210+ . map ( |( ( qualifier, field) , expr) | ( qualified_name ( qualifier, field. name ( ) ) , expr) )
211+ . collect :: < HashMap < _ , _ > > ( ) ;
212+
213+ for expr in predicates {
214+ if contain ( & expr, & projection_map) ? {
215+ push_predicates. push ( expr) ;
216+ } else {
217+ keep_predicates. push ( expr) ;
218+ }
219+ }
220+
221+ Ok ( ( conjunction ( push_predicates) , conjunction ( keep_predicates) ) )
222+ }
223+
224+ fn contain ( expr : & LogicalExpr , projection_map : & HashMap < String , & LogicalExpr > ) -> Result < bool > {
225+ let mut is_contain = false ;
226+
227+ expr. apply ( |expr| {
228+ if let LogicalExpr :: Column ( col) = expr {
229+ if projection_map. contains_key ( & col. qualified_name ( ) ) {
230+ is_contain = true ;
231+ return Ok ( TreeNodeRecursion :: Stop ) ;
232+ }
233+ }
234+
235+ Ok ( TreeNodeRecursion :: Continue )
236+ } ) ?;
237+
238+ Ok ( is_contain)
239+ }
240+
176241#[ cfg( test) ]
177242mod tests {
178243 use crate :: {
179244 build_mem_datasource,
180- datatypes:: { operator:: Operator , scalar:: ScalarValue } ,
181245 error:: Result ,
182246 logical:: {
183- expr:: { col, literal, BinaryExpr , Column , LogicalExpr } ,
247+ expr:: { col, literal} ,
184248 plan:: { LogicalPlan , TableScan } ,
185249 LogicalPlanBuilder ,
186250 } ,
@@ -192,7 +256,7 @@ mod tests {
192256 use std:: collections:: HashMap ;
193257
194258 #[ test]
195- fn test_filter_before_projection ( ) {
259+ fn test_filter_before_table_scan ( ) {
196260 assert_after_optimizer (
197261 "SELECT id,name FROM users WHERE id = 1" ,
198262 vec ! [ Box :: new( PushdownFilter ) ] ,
@@ -243,7 +307,14 @@ mod tests {
243307 . add_filter ( col ( "name" ) . gt ( literal ( "test" ) ) ) ?
244308 . build ( ) ;
245309
246- assert_after_optimizer_with_plan ( plan, vec ! [ Box :: new( PushdownFilter ) ] , vec ! [ ] ) ;
310+ assert_after_optimizer_with_plan (
311+ plan,
312+ vec ! [ Box :: new( PushdownFilter ) ] ,
313+ vec ! [
314+ "Projection: (id, name, email)" ,
315+ " TableScan: users, full_filter=[name > Utf8('test') AND id = Int64(10)]" ,
316+ ] ,
317+ ) ;
247318
248319 Ok ( ( ) )
249320 }
0 commit comments