1818//! [`TreeNode`] for visiting and rewriting expression and plan trees
1919
2020use crate :: Result ;
21+ use std:: marker:: PhantomData ;
2122use std:: sync:: Arc ;
2223
2324/// These macros are used to determine continuation during transforming traversals.
@@ -912,104 +913,6 @@ macro_rules! map_until_stop_and_collect {
912913 } }
913914}
914915
915- macro_rules! rewrite_recursive {
916- ( $START: ident, $NAME: ident, $TRANSFORM_UP: expr, $TRANSFORM_DOWN: expr) => {
917- let mut queue = vec![ ProcessingState :: NotStarted ( $START) ] ;
918-
919- while let Some ( item) = queue. pop( ) {
920- match item {
921- ProcessingState :: NotStarted ( $NAME) => {
922- let node = $TRANSFORM_DOWN?;
923-
924- queue. push( match node. tnr {
925- TreeNodeRecursion :: Continue => {
926- ProcessingState :: ProcessingChildren {
927- non_processed_children: node
928- . data
929- . arc_children( )
930- . into_iter( )
931- . cloned( )
932- . rev( )
933- . collect( ) ,
934- item: node,
935- processed_children: vec![ ] ,
936- }
937- }
938- TreeNodeRecursion :: Jump => ProcessingState :: ProcessedAllChildren (
939- node. with_tnr( TreeNodeRecursion :: Continue ) ,
940- ) ,
941- TreeNodeRecursion :: Stop => {
942- ProcessingState :: ProcessedAllChildren ( node)
943- }
944- } )
945- }
946- ProcessingState :: ProcessingChildren {
947- mut item,
948- mut non_processed_children,
949- mut processed_children,
950- } => match item. tnr {
951- TreeNodeRecursion :: Continue | TreeNodeRecursion :: Jump => {
952- if let Some ( non_processed_item) = non_processed_children. pop( ) {
953- queue. push( ProcessingState :: ProcessingChildren {
954- item,
955- non_processed_children,
956- processed_children,
957- } ) ;
958- queue. push( ProcessingState :: NotStarted ( non_processed_item) ) ;
959- } else {
960- item. transformed |=
961- processed_children. iter( ) . any( |item| item. transformed) ;
962- item. data = item. data. with_new_arc_children(
963- processed_children. into_iter( ) . map( |c| c. data) . collect( ) ,
964- ) ?;
965- queue. push( ProcessingState :: ProcessedAllChildren ( item) )
966- }
967- }
968- TreeNodeRecursion :: Stop => {
969- processed_children. extend(
970- non_processed_children
971- . into_iter( )
972- . rev( )
973- . map( Transformed :: no) ,
974- ) ;
975- item. transformed |=
976- processed_children. iter( ) . any( |item| item. transformed) ;
977- item. data = item. data. with_new_arc_children(
978- processed_children. into_iter( ) . map( |c| c. data) . collect( ) ,
979- ) ?;
980- queue. push( ProcessingState :: ProcessedAllChildren ( item) ) ;
981- }
982- } ,
983- ProcessingState :: ProcessedAllChildren ( node) => {
984- let node = node. transform_parent( |$NAME| $TRANSFORM_UP) ?;
985-
986- if let Some ( ProcessingState :: ProcessingChildren {
987- item: mut parent_node,
988- non_processed_children,
989- mut processed_children,
990- ..
991- } ) = queue. pop( )
992- {
993- parent_node. tnr = node. tnr;
994- processed_children. push( node) ;
995-
996- queue. push( ProcessingState :: ProcessingChildren {
997- item: parent_node,
998- non_processed_children,
999- processed_children,
1000- } )
1001- } else {
1002- debug_assert_eq!( queue. len( ) , 0 ) ;
1003- return Ok ( node) ;
1004- }
1005- }
1006- }
1007- }
1008-
1009- unreachable!( ) ;
1010- } ;
1011- }
1012-
1013916/// Transformation helper to access [`Transformed`] fields in a [`Result`] easily.
1014917///
1015918/// # Example
@@ -1063,6 +966,59 @@ pub trait DynTreeNode {
1063966 ) -> Result < Arc < Self > > ;
1064967}
1065968
969+ pub struct LegacyRewriter <
970+ FD : FnMut ( Node ) -> Result < Transformed < Node > > ,
971+ FU : FnMut ( Node ) -> Result < Transformed < Node > > ,
972+ Node : TreeNode ,
973+ > {
974+ f_down_func : FD ,
975+ f_up_func : FU ,
976+ _node : PhantomData < Node > ,
977+ }
978+
979+ impl <
980+ FD : FnMut ( Node ) -> Result < Transformed < Node > > ,
981+ FU : FnMut ( Node ) -> Result < Transformed < Node > > ,
982+ Node : TreeNode ,
983+ > LegacyRewriter < FD , FU , Node >
984+ {
985+ pub fn new ( f_down_func : FD , f_up_func : FU ) -> Self {
986+ Self {
987+ f_down_func,
988+ f_up_func,
989+ _node : PhantomData ,
990+ }
991+ }
992+ }
993+ impl <
994+ FD : FnMut ( Node ) -> Result < Transformed < Node > > ,
995+ FU : FnMut ( Node ) -> Result < Transformed < Node > > ,
996+ Node : TreeNode ,
997+ > TreeNodeRewriter for LegacyRewriter < FD , FU , Node >
998+ {
999+ type Node = Node ;
1000+
1001+ fn f_down ( & mut self , node : Self :: Node ) -> Result < Transformed < Self :: Node > > {
1002+ ( self . f_down_func ) ( node)
1003+ }
1004+
1005+ fn f_up ( & mut self , node : Self :: Node ) -> Result < Transformed < Self :: Node > > {
1006+ ( self . f_up_func ) ( node)
1007+ }
1008+ }
1009+
1010+ macro_rules! update_rec_node {
1011+ ( $NAME: ident, $CHILDREN: ident) => { {
1012+ $NAME. transformed |= $CHILDREN. iter( ) . any( |item| item. transformed) ;
1013+
1014+ $NAME. data = $NAME
1015+ . data
1016+ . with_new_arc_children( $CHILDREN. into_iter( ) . map( |c| c. data) . collect( ) ) ?;
1017+
1018+ $NAME
1019+ } } ;
1020+ }
1021+
10661022/// Blanket implementation for any `Arc<T>` where `T` implements [`DynTreeNode`]
10671023/// (such as [`Arc<dyn PhysicalExpr>`]).
10681024impl < T : DynTreeNode + ?Sized > TreeNode for Arc < T > {
@@ -1102,43 +1058,134 @@ impl<T: DynTreeNode + ?Sized> TreeNode for Arc<T> {
11021058 FU : FnMut ( Self ) -> Result < Transformed < Self > > ,
11031059 > (
11041060 self ,
1105- mut f_down : FD ,
1106- mut f_up : FU ,
1061+ f_down : FD ,
1062+ f_up : FU ,
11071063 ) -> Result < Transformed < Self > > {
1108- rewrite_recursive ! ( self , node , f_up( node ) , f_down ( node ) ) ;
1064+ self . rewrite ( & mut LegacyRewriter :: new ( f_down , f_up) )
11091065 }
11101066
11111067 fn transform_down < F : FnMut ( Self ) -> Result < Transformed < Self > > > (
11121068 self ,
11131069 f : F ,
11141070 ) -> Result < Transformed < Self > > {
1115- self . transform_down_up ( f, |node| Ok ( Transformed :: no ( node) ) )
1071+ self . rewrite ( & mut LegacyRewriter :: new ( f, |node| {
1072+ Ok ( Transformed :: no ( node) )
1073+ } ) )
11161074 }
11171075
11181076 fn transform_up < F : FnMut ( Self ) -> Result < Transformed < Self > > > (
11191077 self ,
11201078 f : F ,
11211079 ) -> Result < Transformed < Self > > {
1122- self . transform_down_up ( |node| Ok ( Transformed :: no ( node) ) , f)
1080+ self . rewrite ( & mut LegacyRewriter :: new (
1081+ |node| Ok ( Transformed :: no ( node) ) ,
1082+ f,
1083+ ) )
11231084 }
11241085 fn rewrite < R : TreeNodeRewriter < Node = Self > > (
11251086 self ,
11261087 rewriter : & mut R ,
11271088 ) -> Result < Transformed < Self > > {
1128- rewrite_recursive ! ( self , node, rewriter. f_up( node) , rewriter. f_down( node) ) ;
1089+ let mut stack = vec ! [ ProcessingState :: NotStarted ( self ) ] ;
1090+
1091+ while let Some ( item) = stack. pop ( ) {
1092+ match item {
1093+ ProcessingState :: NotStarted ( node) => {
1094+ let node = rewriter. f_down ( node) ?;
1095+
1096+ stack. push ( match node. tnr {
1097+ TreeNodeRecursion :: Continue => {
1098+ ProcessingState :: ProcessingChildren {
1099+ non_processed_children : node
1100+ . data
1101+ . arc_children ( )
1102+ . into_iter ( )
1103+ . cloned ( )
1104+ . rev ( )
1105+ . collect ( ) ,
1106+ item : node,
1107+ processed_children : vec ! [ ] ,
1108+ }
1109+ }
1110+ TreeNodeRecursion :: Jump => ProcessingState :: ProcessedAllChildren (
1111+ node. with_tnr ( TreeNodeRecursion :: Continue ) ,
1112+ ) ,
1113+ TreeNodeRecursion :: Stop => {
1114+ ProcessingState :: ProcessedAllChildren ( node)
1115+ }
1116+ } )
1117+ }
1118+ ProcessingState :: ProcessingChildren {
1119+ mut item,
1120+ mut non_processed_children,
1121+ mut processed_children,
1122+ } => match item. tnr {
1123+ TreeNodeRecursion :: Continue | TreeNodeRecursion :: Jump => {
1124+ if let Some ( non_processed_item) = non_processed_children. pop ( ) {
1125+ stack. push ( ProcessingState :: ProcessingChildren {
1126+ item,
1127+ non_processed_children,
1128+ processed_children,
1129+ } ) ;
1130+ stack. push ( ProcessingState :: NotStarted ( non_processed_item) ) ;
1131+ } else {
1132+ stack. push ( ProcessingState :: ProcessedAllChildren (
1133+ update_rec_node ! ( item, processed_children) ,
1134+ ) )
1135+ }
1136+ }
1137+ TreeNodeRecursion :: Stop => {
1138+ processed_children. extend (
1139+ non_processed_children
1140+ . into_iter ( )
1141+ . rev ( )
1142+ . map ( Transformed :: no) ,
1143+ ) ;
1144+ stack. push ( ProcessingState :: ProcessedAllChildren (
1145+ update_rec_node ! ( item, processed_children) ,
1146+ ) ) ;
1147+ }
1148+ } ,
1149+ ProcessingState :: ProcessedAllChildren ( node) => {
1150+ let node = node. transform_parent ( |n| rewriter. f_up ( n) ) ?;
1151+
1152+ if let Some ( ProcessingState :: ProcessingChildren {
1153+ item : mut parent_node,
1154+ non_processed_children,
1155+ mut processed_children,
1156+ ..
1157+ } ) = stack. pop ( )
1158+ {
1159+ parent_node. tnr = node. tnr ;
1160+ processed_children. push ( node) ;
1161+
1162+ stack. push ( ProcessingState :: ProcessingChildren {
1163+ item : parent_node,
1164+ non_processed_children,
1165+ processed_children,
1166+ } )
1167+ } else {
1168+ debug_assert_eq ! ( stack. len( ) , 0 ) ;
1169+ return Ok ( node) ;
1170+ }
1171+ }
1172+ }
1173+ }
1174+
1175+ unreachable ! ( ) ;
11291176 }
11301177
11311178 fn visit < ' n , V : TreeNodeVisitor < ' n , Node = Self > > (
11321179 & ' n self ,
11331180 visitor : & mut V ,
11341181 ) -> Result < TreeNodeRecursion > {
1135- let mut queue = vec ! [ VisitingState :: NotStarted ( self ) ] ;
1182+ let mut stack = vec ! [ VisitingState :: NotStarted ( self ) ] ;
11361183
1137- while let Some ( item) = queue . pop ( ) {
1184+ while let Some ( item) = stack . pop ( ) {
11381185 match item {
11391186 VisitingState :: NotStarted ( item) => {
11401187 let tnr = visitor. f_down ( item) ?;
1141- queue . push ( match tnr {
1188+ stack . push ( match tnr {
11421189 TreeNodeRecursion :: Continue => VisitingState :: VisitingChildren {
11431190 non_processed_children : item
11441191 . arc_children ( )
@@ -1165,14 +1212,14 @@ impl<T: DynTreeNode + ?Sized> TreeNode for Arc<T> {
11651212 } => match tnr {
11661213 TreeNodeRecursion :: Continue | TreeNodeRecursion :: Jump => {
11671214 if let Some ( non_processed_item) = non_processed_children. pop ( ) {
1168- queue . push ( VisitingState :: VisitingChildren {
1215+ stack . push ( VisitingState :: VisitingChildren {
11691216 item,
11701217 non_processed_children,
11711218 tnr,
11721219 } ) ;
1173- queue . push ( VisitingState :: NotStarted ( non_processed_item) ) ;
1220+ stack . push ( VisitingState :: NotStarted ( non_processed_item) ) ;
11741221 } else {
1175- queue . push ( VisitingState :: VisitedAllChildren { item, tnr } ) ;
1222+ stack . push ( VisitingState :: VisitedAllChildren { item, tnr } ) ;
11761223 }
11771224 }
11781225 TreeNodeRecursion :: Stop => {
@@ -1186,15 +1233,15 @@ impl<T: DynTreeNode + ?Sized> TreeNode for Arc<T> {
11861233 item,
11871234 non_processed_children,
11881235 ..
1189- } ) = queue . pop ( )
1236+ } ) = stack . pop ( )
11901237 {
1191- queue . push ( VisitingState :: VisitingChildren {
1238+ stack . push ( VisitingState :: VisitingChildren {
11921239 item,
11931240 non_processed_children,
11941241 tnr,
11951242 } ) ;
11961243 } else {
1197- debug_assert_eq ! ( queue . len( ) , 0 ) ;
1244+ debug_assert_eq ! ( stack . len( ) , 0 ) ;
11981245 return Ok ( tnr) ;
11991246 }
12001247 }
0 commit comments