@@ -186,7 +186,6 @@ impl DPhpy {
186186 . get_relation_set ( & ( 0 ..self . join_relations . len ( ) ) . collect ( ) ) ?;
187187 if optimized {
188188 if let Some ( final_plan) = self . dp_table . get ( & all_relations) {
189- dbg ! ( final_plan) ;
190189 self . join_reorder ( final_plan, s_expr)
191190 } else {
192191 // Maybe exist cross join, which make graph disconnected
@@ -209,6 +208,7 @@ impl DPhpy {
209208 children : vec ! [ ] ,
210209 join_conditions : vec ! [ ] ,
211210 cost : relation. cardinality ( ) ?,
211+ cardinality : Some ( relation. cardinality ( ) ?) ,
212212 } ;
213213 let _ = self . dp_table . insert ( nodes, join) ;
214214 }
@@ -284,16 +284,16 @@ impl DPhpy {
284284 . get_relation_set_by_index ( * neighbor) ?;
285285 let merged_relation_set = union ( nodes, & neighbor_relations) ;
286286 if self . dp_table . contains_key ( & merged_relation_set)
287+ && merged_relation_set. len ( ) > nodes. len ( )
287288 && !self . emit_csg ( & merged_relation_set) ?
288289 {
289290 return Ok ( false ) ;
290291 }
291292 merged_sets. push ( merged_relation_set) ;
292293 }
293294
294- let mut new_forbidden_nodes;
295+ let mut new_forbidden_nodes = forbidden_nodes . clone ( ) ;
295296 for ( idx, neighbor) in neighbors. iter ( ) . enumerate ( ) {
296- new_forbidden_nodes = forbidden_nodes. clone ( ) ;
297297 new_forbidden_nodes. insert ( * neighbor) ;
298298 if !self . enumerate_csg_rec ( & merged_sets[ idx] , & new_forbidden_nodes) ? {
299299 return Ok ( false ) ;
@@ -311,15 +311,26 @@ impl DPhpy {
311311 ) -> Result < bool > {
312312 debug_assert ! ( self . dp_table. contains_key( left) ) ;
313313 debug_assert ! ( self . dp_table. contains_key( right) ) ;
314- let mut left_join = self . dp_table . get ( left) . unwrap ( ) ;
315- let mut right_join = self . dp_table . get ( right) . unwrap ( ) ;
316314 let parent_set = union ( left, right) ;
317315
318- let left_cardinality = left_join. cardinality ( self ) ?;
319- let right_cardinality = right_join. cardinality ( self ) ?;
316+ let mut left_join = self . dp_table . get ( left) . unwrap ( ) . clone ( ) ;
317+ let mut right_join = self . dp_table . get ( right) . unwrap ( ) . clone ( ) ;
318+ let left_cardinality = match left_join. cardinality {
319+ Some ( cardinality) => cardinality,
320+ None => left_join. cardinality ( self ) ?,
321+ } ;
322+ let right_cardinality = match right_join. cardinality {
323+ Some ( cardinality) => cardinality,
324+ None => right_join. cardinality ( self ) ?,
325+ } ;
326+
327+ self . dp_table
328+ . entry ( left. to_vec ( ) )
329+ . and_modify ( |v| * v = left_join. clone ( ) ) ;
330+ self . dp_table
331+ . entry ( right. to_vec ( ) )
332+ . and_modify ( |v| * v = right_join. clone ( ) ) ;
320333 if left_cardinality < right_cardinality {
321- // swap left_join and right_join
322- std:: mem:: swap ( & mut left_join, & mut right_join) ;
323334 for join_condition in join_conditions. iter_mut ( ) {
324335 std:: mem:: swap ( & mut join_condition. 0 , & mut join_condition. 1 ) ;
325336 }
@@ -329,10 +340,15 @@ impl DPhpy {
329340 JoinNode {
330341 join_type : JoinType :: Inner ,
331342 leaves : parent_set. clone ( ) ,
332- children : vec ! [ left_join. clone( ) , right_join. clone( ) ] ,
343+ children : if left_cardinality < right_cardinality {
344+ vec ! [ right_join, left_join]
345+ } else {
346+ vec ! [ left_join, right_join]
347+ } ,
333348 cost : left_cardinality * COST_FACTOR_COMPUTE_PER_ROW
334349 + right_cardinality * COST_FACTOR_HASH_TABLE_PER_ROW ,
335350 join_conditions,
351+ cardinality : None ,
336352 }
337353 } else {
338354 JoinNode {
@@ -341,6 +357,7 @@ impl DPhpy {
341357 children : vec ! [ left_join. clone( ) , right_join. clone( ) ] ,
342358 cost : left_cardinality * right_cardinality,
343359 join_conditions : vec ! [ ] ,
360+ cardinality : None ,
344361 }
345362 } ;
346363
@@ -372,7 +389,8 @@ impl DPhpy {
372389 let merged_relation_set = union ( right, & neighbor_relations) ;
373390 let ( connected, join_conditions) =
374391 self . query_graph . is_connected ( left, & merged_relation_set) ?;
375- if self . dp_table . contains_key ( & merged_relation_set)
392+ if merged_relation_set. len ( ) > right. len ( )
393+ && self . dp_table . contains_key ( & merged_relation_set)
376394 && connected
377395 && !self . emit_csg_cmp ( left, & merged_relation_set, join_conditions) ?
378396 {
@@ -381,9 +399,8 @@ impl DPhpy {
381399 merged_sets. push ( merged_relation_set) ;
382400 }
383401 // Continue to enumerate cmp
384- let mut new_forbidden_nodes;
402+ let mut new_forbidden_nodes = forbidden_nodes . clone ( ) ;
385403 for ( idx, neighbor) in neighbor_set. iter ( ) . enumerate ( ) {
386- new_forbidden_nodes = forbidden_nodes. clone ( ) ;
387404 new_forbidden_nodes. insert ( * neighbor) ;
388405 if !self . enumerate_cmp_rec ( left, & merged_sets[ idx] , & new_forbidden_nodes) ? {
389406 return Ok ( false ) ;
@@ -395,7 +412,7 @@ impl DPhpy {
395412 // Map join order in `JoinNode` to `SExpr`
396413 fn join_reorder ( & self , final_plan : & JoinNode , s_expr : SExpr ) -> Result < ( SExpr , bool ) > {
397414 // Convert `final_plan` to `SExpr`
398- let join_expr = self . s_expr ( final_plan) ? ;
415+ let join_expr = self . s_expr ( final_plan) ;
399416 // Find first join node in `s_expr`, then replace it with `join_expr`
400417 let new_s_expr = self . replace_join_expr ( & join_expr, & s_expr) ?;
401418 Ok ( ( new_s_expr, true ) )
@@ -417,13 +434,13 @@ impl DPhpy {
417434 Ok ( new_s_expr)
418435 }
419436
420- pub fn s_expr ( & self , join_node : & JoinNode ) -> Result < SExpr > {
437+ pub fn s_expr ( & self , join_node : & JoinNode ) -> SExpr {
421438 // Traverse JoinNode
422439 if join_node. children . is_empty ( ) {
423440 // The node is leaf, get relation for `join_relations`
424441 let idx = join_node. leaves [ 0 ] ;
425442 let relation = self . join_relations [ idx] . s_expr ( ) ;
426- return Ok ( relation) ;
443+ return relation;
427444 }
428445
429446 // The node is join
@@ -452,8 +469,7 @@ impl DPhpy {
452469 . children
453470 . iter ( )
454471 . map ( |child| self . s_expr ( child) )
455- . collect :: < Result < Vec < _ > > > ( ) ?;
456- let join_expr = SExpr :: create ( rel_op, children, None , None ) ;
457- Ok ( join_expr)
472+ . collect :: < Vec < _ > > ( ) ;
473+ SExpr :: create ( rel_op, children, None , None )
458474 }
459475}
0 commit comments