1- use indexmap:: IndexSet ;
1+ use std:: { collections:: HashMap , sync:: Arc } ;
2+
3+ use indexmap:: { Equivalent , IndexSet } ;
24
35use crate :: {
4- common:: { join_type:: JoinType , table_schema:: TableSchemaRef , transformed:: Transformed } ,
6+ common:: {
7+ join_type:: JoinType ,
8+ table_schema:: TableSchemaRef ,
9+ transformed:: { TransformNode , Transformed , TransformedResult , TreeNodeRecursion } ,
10+ } ,
511 datatypes:: operator:: Operator ,
612 error:: Result ,
713 logical:: {
814 expr:: { BinaryExpr , LogicalExpr } ,
9- plan:: { Filter , Join , LogicalPlan } ,
15+ plan:: { Filter , LogicalPlan } ,
1016 LogicalPlanBuilder ,
1117 } ,
1218 optimizer:: rule:: OptimizerRule ,
1319} ;
1420
21+ /// Eliminate cross joins by rewriting them to inner joins when possible.
1522pub struct EliminateCrossJoin ;
1623
1724impl OptimizerRule for EliminateCrossJoin {
@@ -20,40 +27,74 @@ impl OptimizerRule for EliminateCrossJoin {
2027 }
2128
2229 fn rewrite ( & self , plan : LogicalPlan ) -> Result < Transformed < LogicalPlan > > {
23- match plan {
24- LogicalPlan :: Filter ( filter) if matches ! ( filter. input. as_ref( ) , LogicalPlan :: CrossJoin ( _) ) => {
25- let LogicalPlan :: CrossJoin ( cross_join) = * filter. input else {
26- return Ok ( Transformed :: no ( LogicalPlan :: Filter ( filter) ) ) ;
27- } ;
30+ let LogicalPlan :: Filter ( Filter { input, expr : predicate } ) = plan else {
31+ return Ok ( Transformed :: no ( plan) ) ;
32+ } ;
2833
29- let left_schema = cross_join. left . table_schema ( ) ;
30- let right_schema = cross_join. right . table_schema ( ) ;
34+ let mut all_corss_joins = vec ! [ ] ;
3135
32- let ( join_keys, remaining_predicate) = extract_join_pairs ( & filter. expr , & left_schema, & right_schema) ;
36+ // collect all cross joins and filter predicates in order
37+ input. apply ( |plan| {
38+ if let LogicalPlan :: CrossJoin ( cross_join) = plan {
39+ all_corss_joins. push ( cross_join) ;
40+ }
41+ Ok ( TreeNodeRecursion :: Continue )
42+ } ) ?;
3343
34- if join_keys. is_empty ( ) {
35- Ok ( Transformed :: no ( LogicalPlan :: Filter ( Filter {
36- input : Box :: new ( LogicalPlan :: CrossJoin ( cross_join) ) ,
37- expr : filter. expr ,
38- } ) ) )
39- } else {
40- let inner_join_plan = LogicalPlan :: Join ( Join {
41- left : cross_join. left ,
42- right : cross_join. right ,
43- join_type : JoinType :: Inner ,
44- on : join_keys. into_iter ( ) . map ( |( l, r) | ( l. clone ( ) , r. clone ( ) ) ) . collect ( ) ,
45- filter : None ,
46- schema : cross_join. schema ,
47- } ) ;
48-
49- if let Some ( predicate) = remaining_predicate {
50- LogicalPlanBuilder :: filter ( inner_join_plan, predicate) . map ( Transformed :: yes)
51- } else {
52- Ok ( Transformed :: yes ( inner_join_plan) )
53- }
54- }
44+ if all_corss_joins. is_empty ( ) {
45+ return Ok ( Transformed :: no ( LogicalPlan :: Filter ( Filter { input, expr : predicate } ) ) ) ;
46+ }
47+
48+ let mut all_join_keys = IndexSet :: new ( ) ;
49+ let mut replaced_cross_joins = HashMap :: new ( ) ;
50+ let len = all_corss_joins. len ( ) ;
51+ // iteratively rewrite cross joins to inner joins from bottom to top
52+ for ( index, cross_join) in all_corss_joins. into_iter ( ) . rev ( ) . enumerate ( ) {
53+ let left_schema = cross_join. left . table_schema ( ) ;
54+ let right_schema = cross_join. right . table_schema ( ) ;
55+
56+ let join_keys = extract_join_pairs ( & predicate, & left_schema, & right_schema) ;
57+
58+ all_join_keys. extend ( join_keys. clone ( ) ) ;
59+
60+ if !join_keys. is_empty ( ) {
61+ let inner_join_plan = LogicalPlanBuilder :: from ( Arc :: unwrap_or_clone ( cross_join. left . clone ( ) ) )
62+ . join (
63+ Arc :: unwrap_or_clone ( cross_join. right . clone ( ) ) ,
64+ JoinType :: Inner ,
65+ join_keys. into_iter ( ) . collect ( ) ,
66+ None ,
67+ ) ?
68+ . build ( ) ;
69+
70+ // this index should be from the top to the bottom
71+ replaced_cross_joins. insert ( ( len - 1 ) - index, inner_join_plan) ;
5572 }
56- _ => Ok ( Transformed :: no ( plan) ) ,
73+ }
74+
75+ if replaced_cross_joins. is_empty ( ) {
76+ return Ok ( Transformed :: no ( LogicalPlan :: Filter ( Filter { input, expr : predicate } ) ) ) ;
77+ }
78+
79+ // combine all predicates and replaced cross joins
80+ let mut index = 0 ;
81+ let new_input = input
82+ . transform ( |plan| {
83+ let result = if let Some ( replaced_join) = replaced_cross_joins. remove ( & index) {
84+ Ok ( Transformed :: yes ( replaced_join) )
85+ } else {
86+ Ok ( Transformed :: no ( plan) )
87+ } ;
88+ index += 1 ;
89+ result
90+ } )
91+ . data ( ) ?;
92+
93+ // remove all join keys from original predicates
94+ if let Some ( predicate) = remove_join_keys ( predicate, & all_join_keys) {
95+ LogicalPlanBuilder :: filter ( new_input, predicate) . map ( Transformed :: yes)
96+ } else {
97+ Ok ( Transformed :: yes ( new_input) )
5798 }
5899 }
59100}
@@ -62,7 +103,7 @@ fn extract_join_pairs<'a>(
62103 expr : & ' a LogicalExpr ,
63104 left_schema : & TableSchemaRef ,
64105 right_schema : & TableSchemaRef ,
65- ) -> ( IndexSet < ( & ' a LogicalExpr , & ' a LogicalExpr ) > , Option < LogicalExpr > ) {
106+ ) -> IndexSet < ( LogicalExpr , LogicalExpr ) > {
66107 let mut join_keys = IndexSet :: new ( ) ;
67108
68109 match expr {
@@ -75,69 +116,92 @@ fn extract_join_pairs<'a>(
75116 let right_col = right. try_as_column ( ) ;
76117
77118 if let ( Some ( left_col) , Some ( right_col) ) = ( left_col, right_col) {
78- if ( left_schema. has_column ( left_col) || right_schema. has_column ( left_col) )
79- && ( left_schema. has_column ( right_col) || right_schema. has_column ( right_col) )
80- {
81- join_keys. insert ( ( left. as_ref ( ) , right. as_ref ( ) ) ) ;
82-
83- return ( join_keys, None ) ;
119+ if left_schema. has_column ( left_col) && right_schema. has_column ( right_col) {
120+ join_keys. insert ( ( left. as_ref ( ) . clone ( ) , right. as_ref ( ) . clone ( ) ) ) ;
121+ } else if right_schema. has_column ( left_col) && left_schema. has_column ( right_col) {
122+ join_keys. insert ( ( right. as_ref ( ) . clone ( ) , left. as_ref ( ) . clone ( ) ) ) ;
84123 }
85124 }
86-
87- ( join_keys, Some ( expr. clone ( ) ) )
88125 }
89126 LogicalExpr :: BinaryExpr ( BinaryExpr {
90127 left,
91128 op : Operator :: And ,
92129 right,
93130 } ) => {
94- let ( left_join_keys, left_predicate ) = extract_join_pairs ( left, left_schema, right_schema) ;
95- let ( right_join_keys, right_predicate ) = extract_join_pairs ( right, left_schema, right_schema) ;
131+ let left_join_keys = extract_join_pairs ( left, left_schema, right_schema) ;
132+ let right_join_keys = extract_join_pairs ( right, left_schema, right_schema) ;
96133
97134 join_keys. extend ( left_join_keys) ;
98135 join_keys. extend ( right_join_keys) ;
99-
100- let predicate = match ( left_predicate, right_predicate) {
101- ( Some ( left_predicate) , Some ( right_predicate) ) => Some ( LogicalExpr :: BinaryExpr ( BinaryExpr {
102- left : Box :: new ( left_predicate) ,
103- op : Operator :: And ,
104- right : Box :: new ( right_predicate) ,
105- } ) ) ,
106- ( l, r) => l. or ( r) ,
107- } ;
108-
109- ( join_keys, predicate)
110136 }
111137 LogicalExpr :: BinaryExpr ( BinaryExpr {
112138 left,
113139 op : Operator :: Or ,
114140 right,
115141 } ) => {
116- let ( left_join_keys, left_predicate ) = extract_join_pairs ( left, left_schema, right_schema) ;
117- let ( right_join_keys, right_predicate ) = extract_join_pairs ( right, left_schema, right_schema) ;
142+ let left_join_keys = extract_join_pairs ( left, left_schema, right_schema) ;
143+ let right_join_keys = extract_join_pairs ( right, left_schema, right_schema) ;
118144
119145 for ( l, r) in left_join_keys {
120- if right_join_keys. contains ( & ( l, r) ) || right_join_keys. contains ( & ( r, l) ) {
146+ if right_join_keys. contains ( & ExprPair :: new ( & l, & r) ) || right_join_keys. contains ( & ExprPair :: new ( & r, & l) )
147+ {
121148 join_keys. insert ( ( l, r) ) ;
122149 }
123150 }
151+ }
152+ _ => { }
153+ }
124154
125- let predicate = match ( left_predicate, right_predicate) {
126- ( Some ( l) , Some ( r) ) => Some ( LogicalExpr :: BinaryExpr ( BinaryExpr {
127- left : Box :: new ( l) ,
128- op : Operator :: Or ,
129- right : Box :: new ( r) ,
130- } ) ) ,
131- ( l, r) => l. or ( r) ,
132- } ;
155+ join_keys
156+ }
133157
134- ( join_keys, predicate)
158+ fn remove_join_keys ( expr : LogicalExpr , join_keys : & IndexSet < ( LogicalExpr , LogicalExpr ) > ) -> Option < LogicalExpr > {
159+ match expr {
160+ LogicalExpr :: BinaryExpr ( BinaryExpr {
161+ left,
162+ op : Operator :: Eq ,
163+ right,
164+ } ) if join_keys. contains ( & ExprPair :: new ( left. as_ref ( ) , right. as_ref ( ) ) )
165+ || join_keys. contains ( & ExprPair :: new ( right. as_ref ( ) , left. as_ref ( ) ) ) =>
166+ {
167+ None
168+ }
169+ LogicalExpr :: BinaryExpr ( BinaryExpr { left, op, right } ) if op == Operator :: And => {
170+ let l = remove_join_keys ( * left, join_keys) ;
171+ let r = remove_join_keys ( * right, join_keys) ;
172+ match ( l, r) {
173+ ( Some ( ll) , Some ( rr) ) => Some ( LogicalExpr :: BinaryExpr ( BinaryExpr :: new ( ll, op, rr) ) ) ,
174+ ( Some ( ll) , _) => Some ( ll) ,
175+ ( _, Some ( rr) ) => Some ( rr) ,
176+ _ => None ,
177+ }
135178 }
179+ LogicalExpr :: BinaryExpr ( BinaryExpr { left, op, right } ) if op == Operator :: Or => {
180+ let l = remove_join_keys ( * left, join_keys) ;
181+ let r = remove_join_keys ( * right, join_keys) ;
182+ match ( l, r) {
183+ ( Some ( ll) , Some ( rr) ) => Some ( LogicalExpr :: BinaryExpr ( BinaryExpr :: new ( ll, op, rr) ) ) ,
184+ _ => None ,
185+ }
186+ }
187+ _ => Some ( expr) ,
188+ }
189+ }
190+
191+ #[ derive( Debug , Eq , PartialEq , Hash ) ]
192+ struct ExprPair < ' a > ( & ' a LogicalExpr , & ' a LogicalExpr ) ;
136193
137- _ => ( join_keys, Some ( expr. clone ( ) ) ) ,
194+ impl < ' a > ExprPair < ' a > {
195+ fn new ( left : & ' a LogicalExpr , right : & ' a LogicalExpr ) -> Self {
196+ Self ( left, right)
138197 }
139198}
140199
200+ impl Equivalent < ( LogicalExpr , LogicalExpr ) > for ExprPair < ' _ > {
201+ fn equivalent ( & self , other : & ( LogicalExpr , LogicalExpr ) ) -> bool {
202+ self . 0 == & other. 0 && self . 1 == & other. 1
203+ }
204+ }
141205#[ cfg( test) ]
142206mod tests {
143207 use crate :: { optimizer:: rule:: eliminate_cross_join:: EliminateCrossJoin , test_utils:: assert_after_optimizer} ;
@@ -231,4 +295,46 @@ mod tests {
231295 ] ,
232296 ) ;
233297 }
298+
299+ #[ test]
300+ fn test_tpch_03 ( ) {
301+ assert_after_optimizer (
302+ " select
303+ l_orderkey,
304+ sum(l_extendedprice * (1 - l_discount)) as revenue,
305+ o_orderdate,
306+ o_shippriority
307+ from
308+ customer,
309+ orders,
310+ lineitem
311+ where
312+ c_mktsegment = 'BUILDING'
313+ and c_custkey = o_custkey
314+ and l_orderkey = o_orderkey
315+ and o_orderdate < date '1995-03-15'
316+ and l_shipdate > date '1995-03-15'
317+ group by
318+ l_orderkey,
319+ o_orderdate,
320+ o_shippriority
321+ order by
322+ revenue desc,
323+ o_orderdate
324+ limit 10;" ,
325+ vec ! [ Box :: new( EliminateCrossJoin ) ] ,
326+ vec ! [
327+ "Limit: fetch=10, skip=0" ,
328+ " Sort: revenue DESC, orders.o_orderdate ASC" ,
329+ " Projection: (lineitem.l_orderkey, SUM(lineitem.l_extendedprice * Int64(1) - lineitem.l_discount) AS revenue, orders.o_orderdate, orders.o_shippriority)" ,
330+ " Aggregate: group_expr=[lineitem.l_orderkey,orders.o_orderdate,orders.o_shippriority], aggregat_expr=[SUM(lineitem.l_extendedprice * Int64(1) - lineitem.l_discount)]" ,
331+ " Filter: customer.c_mktsegment = Utf8('BUILDING') AND orders.o_orderdate < CAST(Utf8('1995-03-15') AS Date32) AND lineitem.l_shipdate > CAST(Utf8('1995-03-15') AS Date32)" ,
332+ " Inner Join: On: (orders.o_orderkey, lineitem.l_orderkey)" ,
333+ " Inner Join: On: (customer.c_custkey, orders.o_custkey)" ,
334+ " TableScan: customer" ,
335+ " TableScan: orders" ,
336+ " TableScan: lineitem" ,
337+ ] ,
338+ ) ;
339+ }
234340}
0 commit comments