1- use std:: { collections:: HashMap , fmt:: Debug , hash:: Hash , marker:: PhantomData , sync:: Arc } ;
1+ use std:: {
2+ collections:: HashMap , fmt:: Debug , hash:: Hash , marker:: PhantomData , mem:: take, sync:: Arc ,
3+ } ;
24
35use crate :: {
46 compile:: rewrite:: {
@@ -531,6 +533,9 @@ where
531533 egraph : & ' a EGraph < L , A > ,
532534 // Caches results. `None` for nodes in progress to prevent recursion
533535 extract_map : HashMap < IdWithState < L , S > , Option < ( usize , C ) > > ,
536+ has_deep_recursion : bool ,
537+ // Cache for second pass to calculate recursive nodes cost.
538+ extract_map_recursive_cache : Option < HashMap < IdWithState < L , S > , Option < ( usize , C ) > > > ,
534539 cost_fn : Arc < CF > ,
535540 root_state : Arc < S > ,
536541}
@@ -547,6 +552,8 @@ where
547552 Self {
548553 egraph,
549554 extract_map : HashMap :: new ( ) ,
555+ has_deep_recursion : false ,
556+ extract_map_recursive_cache : None ,
550557 cost_fn : Arc :: new ( cost_fn) ,
551558 root_state : Arc :: new ( root_state) ,
552559 }
@@ -555,8 +562,17 @@ where
555562 /// Returns cost and path for best plan for provided root eclass.
556563 ///
557564 /// If all nodes happen to be recursive, returns `None`.
565+ ///
566+ /// If there were any nodes with deep recursion, the cost is calculated in two passes;
567+ /// the second pass fetches the cost from the extract map obtained on the first pass
568+ /// for recursive nodes only.
558569 pub fn find_best ( & mut self , root : Id ) -> Option < ( C , RecExpr < L > ) > {
559- let cost = self . extract ( root, Arc :: clone ( & self . root_state ) ) ?;
570+ let mut cost = self . extract ( root, Arc :: clone ( & self . root_state ) ) ?;
571+ if self . has_deep_recursion {
572+ self . extract_map_recursive_cache = Some ( take ( & mut self . extract_map ) ) ;
573+ cost = self . extract ( root, Arc :: clone ( & self . root_state ) ) ?;
574+ }
575+
560576 let root_id_with_state = IdWithState :: new ( root, Arc :: clone ( & self . root_state ) ) ;
561577 let root_node = self . choose_node ( & root_id_with_state) ?;
562578 let recexpr =
@@ -569,14 +585,33 @@ where
569585 /// Recursively extracts the cost of each node in the eclass
570586 /// and returns cost of the node with least cost based on the passed state,
571587 /// caching the cost together with node index inside eclass in `extract_map`.
588+ /// If `extract_map_recursive_cache` is available, fetches the costs
589+ /// of deep recursion nodes from there.
572590 ///
573591 /// Yields `None` if eclass is already in progress
574592 /// or all its nodes happen to be recursive.
575593 fn extract ( & mut self , eclass : Id , state : Arc < S > ) -> Option < C > {
576594 let id_with_state = IdWithState :: new ( eclass, state) ;
577595 if let Some ( cached_index_and_cost) = self . extract_map . get ( & id_with_state) {
578- // TODO: avoid cloning here?
579- return cached_index_and_cost. as_ref ( ) . map ( |( _, cost) | cost. clone ( ) ) ;
596+ // If the cost has been computed, return it
597+ if let Some ( ( _, cached_cost) ) = cached_index_and_cost {
598+ // TODO: avoid cloning here?
599+ return Some ( cached_cost. clone ( ) ) ;
600+ }
601+
602+ // If the cost is recursive, fetch from recursive cache if available
603+ if let Some ( extract_map_recursive_cache) = & self . extract_map_recursive_cache {
604+ if let Some ( Some ( ( _, cached_cost) ) ) =
605+ extract_map_recursive_cache. get ( & id_with_state)
606+ {
607+ // TODO: avoid cloning here?
608+ return Some ( cached_cost. clone ( ) ) ;
609+ }
610+ }
611+
612+ // Otherwise, mark this extractor as having deep recursion
613+ self . has_deep_recursion = true ;
614+ return None ;
580615 }
581616
582617 // Mark this eclass as in progress
@@ -595,6 +630,12 @@ where
595630 // Recursively get children cost
596631 let mut total_node_cost = this_node_cost;
597632 for child in node. children ( ) {
633+ // If a child is recursive to self, skip this node, as it will never compute
634+ // the cost
635+ if child == & eclass {
636+ continue ' nodes;
637+ }
638+
598639 let Some ( child_cost) = self . extract ( * child, Arc :: clone ( & new_state) ) else {
599640 // This path is inevitably recursive, try the next node
600641 continue ' nodes;
0 commit comments