@@ -27,6 +27,7 @@ use egg::{
2727} ;
2828use itertools:: Itertools ;
2929use std:: {
30+ borrow:: Cow ,
3031 fmt:: { self , Display , Formatter } ,
3132 ops:: Index ,
3233 slice:: Iter ,
@@ -737,7 +738,31 @@ where
737738 . unwrap ( )
738739}
739740
740- type ListMatches = Vec < Subst > ;
741+ struct ListMatches {
742+ len : usize ,
743+ substs : Vec < Subst > ,
744+ prevs : Vec < usize > ,
745+ start : usize ,
746+ }
747+ impl ListMatches {
748+ fn range ( & self ) -> std:: ops:: Range < usize > {
749+ self . start ..self . substs . len ( )
750+ }
751+ fn for_each ( & self , mut f : impl FnMut ( & [ & Subst ] ) ) {
752+ let mut substs = Vec :: with_capacity ( self . len ) ;
753+ for i in self . range ( ) {
754+ substs. clear ( ) ;
755+ let mut i = i;
756+ while i != usize:: MAX {
757+ substs. push ( & self . substs [ i] ) ;
758+ i = self . prevs [ i] ;
759+ }
760+ substs. reverse ( ) ;
761+ assert_eq ! ( substs. len( ) , self . len) ;
762+ f ( & substs) ;
763+ }
764+ }
765+ }
741766
742767#[ derive( Clone , PartialEq ) ]
743768pub enum ListType {
@@ -847,6 +872,68 @@ impl ListNodeSearcher {
847872 }
848873 }
849874 }
875+
876+ fn search_from_list_matches < ' a > (
877+ & ' a self ,
878+ egraph : & EGraph < LogicalPlanLanguage , LogicalPlanAnalysis > ,
879+ limit : usize ,
880+ list_subst : & Subst ,
881+ output : & mut Vec < Subst > ,
882+ ) {
883+ let list_id = list_subst[ self . list_var ] ;
884+ for node in egraph[ list_id] . iter ( ) {
885+ let list_children = node. children ( ) ;
886+ if !self . match_node ( node) || list_children. is_empty ( ) {
887+ continue ;
888+ }
889+
890+ let mut list_matches = ListMatches {
891+ len : list_children. len ( ) ,
892+ substs : vec ! [ ] ,
893+ prevs : vec ! [ ] ,
894+ start : 0 ,
895+ } ;
896+
897+ list_matches. substs = self
898+ . elem_pattern
899+ . search_eclass_with_limit ( egraph, list_children[ 0 ] , limit)
900+ . map_or ( vec ! [ ] , |ms| ms. substs ) ;
901+
902+ list_matches. prevs = vec ! [ usize :: MAX ; list_matches. substs. len( ) ] ;
903+
904+ let agree = |subst1 : & Subst , subst2 : & Subst | {
905+ self . top_level_elem_vars
906+ . iter ( )
907+ . all ( |& v| subst1. get ( v) == subst2. get ( v) )
908+ } ;
909+
910+ for & list_child in & list_children[ 1 ..] {
911+ debug_assert_eq ! ( list_matches. substs. len( ) , list_matches. prevs. len( ) ) ;
912+ let range = list_matches. range ( ) ;
913+ if range. is_empty ( ) {
914+ break ;
915+ }
916+ list_matches. start = list_matches. substs . len ( ) ;
917+ self . elem_pattern
918+ . search_eclass_with_fn ( egraph, list_child, |subst| {
919+ for i in range. clone ( ) {
920+ if agree ( & list_matches. substs [ i] , subst) {
921+ list_matches. substs . push ( subst. clone ( ) ) ;
922+ list_matches. prevs . push ( i) ;
923+ }
924+ }
925+ Ok ( ( ) )
926+ } )
927+ . unwrap_or_default ( ) ;
928+ }
929+
930+ if !list_matches. range ( ) . is_empty ( ) {
931+ let mut subst = list_subst. clone ( ) ;
932+ subst. data = Some ( Arc :: new ( list_matches) ) ;
933+ output. push ( subst) ;
934+ }
935+ }
936+ }
850937}
851938
852939impl Searcher < LogicalPlanLanguage , LogicalPlanAnalysis > for ListNodeSearcher {
@@ -856,55 +943,48 @@ impl Searcher<LogicalPlanLanguage, LogicalPlanAnalysis> for ListNodeSearcher {
856943 eclass : Id ,
857944 limit : usize ,
858945 ) -> Option < SearchMatches < LogicalPlanLanguage > > {
859- let mut matches = self
860- . list_pattern
861- . search_eclass_with_limit ( egraph, eclass, limit) ?;
862-
863- let mut new_substs: Vec < Subst > = vec ! [ ] ;
864- for subst in matches. substs {
865- let list_id = subst[ self . list_var ] ;
866- for node in egraph[ list_id] . iter ( ) {
867- let list_children = node. children ( ) ;
868- if !self . match_node ( node) || list_children. is_empty ( ) {
869- continue ;
870- }
946+ let mut matches = SearchMatches {
947+ substs : vec ! [ ] ,
948+ eclass,
949+ ast : Some ( Cow :: Borrowed ( & self . list_pattern . ast ) ) ,
950+ } ;
951+ self . list_pattern
952+ . search_eclass_with_fn ( egraph, eclass, |subst| {
953+ self . search_from_list_matches ( egraph, limit, subst, & mut matches. substs ) ;
954+ Ok ( ( ) )
955+ } )
956+ . unwrap_or_default ( ) ;
871957
872- let matches_product = list_children
873- . iter ( )
874- . map ( |child| {
875- self . elem_pattern
876- . search_eclass_with_limit ( egraph, * child, limit)
877- . map_or ( vec ! [ ] , |ms| ms. substs )
878- } )
879- . multi_cartesian_product ( ) ;
880-
881- // TODO(mwillsey) this could be optimized more by filtering the
882- // matches as you go
883- for list_matches in matches_product {
884- let subst0 = & list_matches[ 0 ] ;
885- let agree_with_top_level = list_matches. iter ( ) . all ( |m| {
886- self . top_level_elem_vars
887- . iter ( )
888- . all ( |& v| m. get ( v) == subst0. get ( v) )
889- } ) ;
890-
891- if agree_with_top_level {
892- let mut subst = subst. clone ( ) ;
893- assert_eq ! ( list_matches. len( ) , list_children. len( ) ) ;
894- for & var in & self . top_level_elem_vars {
895- if let Some ( id) = list_matches[ 0 ] . get ( var) {
896- subst. insert ( var, * id) ;
897- }
898- }
899- subst. data = Some ( Arc :: new ( list_matches) ) ;
900- new_substs. push ( subst) ;
958+ ( !matches. substs . is_empty ( ) ) . then ( || matches)
959+ }
960+
961+ fn search_with_limit (
962+ & self ,
963+ egraph : & EGraph < LogicalPlanLanguage , LogicalPlanAnalysis > ,
964+ limit : usize ,
965+ ) -> Vec < SearchMatches < LogicalPlanLanguage > > {
966+ let mut result: Vec < SearchMatches < _ > > = vec ! [ ] ;
967+ self . list_pattern
968+ . search_with_fn ( egraph, |id, list_subst| {
969+ let last = match result. last_mut ( ) {
970+ Some ( top) if top. eclass == id => top,
971+ _ => {
972+ result. push ( SearchMatches {
973+ substs : vec ! [ ] ,
974+ eclass : id,
975+ ast : Some ( Cow :: Borrowed ( & self . list_pattern . ast ) ) ,
976+ } ) ;
977+ result. last_mut ( ) . unwrap ( )
901978 }
902- }
903- }
904- }
979+ } ;
980+ debug_assert_eq ! ( last. eclass, id) ;
981+ self . search_from_list_matches ( egraph, limit, list_subst, & mut last. substs ) ;
982+ Ok ( ( ) )
983+ } )
984+ . unwrap_or_default ( ) ;
905985
906- matches. substs = new_substs ;
907- ( !matches . substs . is_empty ( ) ) . then ( || matches )
986+ result . retain ( | matches| !matches . substs . is_empty ( ) ) ;
987+ result
908988 }
909989
910990 fn vars ( & self ) -> Vec < Var > {
@@ -999,42 +1079,42 @@ impl Applier<LogicalPlanLanguage, LogicalPlanAnalysis> for ListNodeApplier {
9991079 fn apply_one (
10001080 & self ,
10011081 egraph : & mut EGraph < LogicalPlanLanguage , LogicalPlanAnalysis > ,
1002- eclass : Id ,
1082+ mut eclass : Id ,
10031083 subst : & Subst ,
10041084 _searcher_ast : Option < & PatternAst < LogicalPlanLanguage > > ,
10051085 _rule_name : Symbol ,
10061086 ) -> Vec < Id > {
1007- let mut subst = subst. clone ( ) ;
1008-
10091087 let data = subst
10101088 . data
10111089 . as_ref ( )
10121090 . expect ( "no data, did you use ListNodeSearcher?" ) ;
1013- let list_matches = data
1014- . downcast_ref :: < ListMatches > ( )
1015- . expect ( "wrong data type" )
1016- . clone ( ) ;
1017-
1018- for list in & self . lists {
1019- let new_list = list_matches
1020- . iter ( )
1021- . map ( |list_subst| {
1022- let mut subst = subst. clone ( ) ;
1023- subst. extend ( list_subst. iter ( ) ) ;
1024- egraph. add_instantiation ( & list. elem_pattern , & subst)
1025- } )
1026- . collect ( ) ;
1091+ let list_matches = data. downcast_ref :: < ListMatches > ( ) . expect ( "wrong data type" ) ;
10271092
1028- subst. insert ( list. new_list_var , egraph. add ( list. make_node ( new_list) ) ) ;
1029- }
1093+ let mut subst = subst. clone ( ) ;
1094+ let mut result_ids = vec ! [ ] ;
1095+ list_matches. for_each ( |list_substs| {
1096+ for list in & self . lists {
1097+ let new_list = list_substs
1098+ . iter ( )
1099+ . map ( |list_subst| {
1100+ let mut subst = subst. clone ( ) ;
1101+ subst. extend ( list_subst. iter ( ) ) ;
1102+ egraph. add_instantiation ( & list. elem_pattern , & subst)
1103+ } )
1104+ . collect ( ) ;
10301105
1031- let result_id = egraph. add_instantiation ( & self . list_pattern , & subst) ;
1106+ subst. insert ( list. new_list_var , egraph. add ( list. make_node ( new_list) ) ) ;
1107+ }
1108+ let mut subst = subst. clone ( ) ;
1109+ subst. extend ( list_substs[ 0 ] . iter ( ) ) ;
1110+ let new_id = egraph. add_instantiation ( & self . list_pattern , & subst) ;
1111+ if egraph. union ( eclass, new_id) {
1112+ result_ids. push ( new_id) ;
1113+ eclass = new_id;
1114+ }
1115+ } ) ;
10321116
1033- if egraph. union ( eclass, result_id) {
1034- vec ! [ result_id]
1035- } else {
1036- vec ! [ ]
1037- }
1117+ result_ids
10381118 }
10391119
10401120 fn vars ( & self ) -> Vec < Var > {
0 commit comments