|
18 | 18 | //! [`TreeNode`] for visiting and rewriting expression and plan trees |
19 | 19 |
|
20 | 20 | use crate::Result; |
| 21 | +use std::marker::PhantomData; |
21 | 22 | use std::sync::Arc; |
22 | 23 |
|
23 | 24 | /// These macros are used to determine continuation during transforming traversals. |
@@ -912,104 +913,6 @@ macro_rules! map_until_stop_and_collect { |
912 | 913 | }} |
913 | 914 | } |
914 | 915 |
|
915 | | -macro_rules! rewrite_recursive { |
916 | | - ($START:ident, $NAME:ident, $TRANSFORM_UP:expr, $TRANSFORM_DOWN:expr) => { |
917 | | - let mut queue = vec![ProcessingState::NotStarted($START)]; |
918 | | - |
919 | | - while let Some(item) = queue.pop() { |
920 | | - match item { |
921 | | - ProcessingState::NotStarted($NAME) => { |
922 | | - let node = $TRANSFORM_DOWN?; |
923 | | - |
924 | | - queue.push(match node.tnr { |
925 | | - TreeNodeRecursion::Continue => { |
926 | | - ProcessingState::ProcessingChildren { |
927 | | - non_processed_children: node |
928 | | - .data |
929 | | - .arc_children() |
930 | | - .into_iter() |
931 | | - .cloned() |
932 | | - .rev() |
933 | | - .collect(), |
934 | | - item: node, |
935 | | - processed_children: vec![], |
936 | | - } |
937 | | - } |
938 | | - TreeNodeRecursion::Jump => ProcessingState::ProcessedAllChildren( |
939 | | - node.with_tnr(TreeNodeRecursion::Continue), |
940 | | - ), |
941 | | - TreeNodeRecursion::Stop => { |
942 | | - ProcessingState::ProcessedAllChildren(node) |
943 | | - } |
944 | | - }) |
945 | | - } |
946 | | - ProcessingState::ProcessingChildren { |
947 | | - mut item, |
948 | | - mut non_processed_children, |
949 | | - mut processed_children, |
950 | | - } => match item.tnr { |
951 | | - TreeNodeRecursion::Continue | TreeNodeRecursion::Jump => { |
952 | | - if let Some(non_processed_item) = non_processed_children.pop() { |
953 | | - queue.push(ProcessingState::ProcessingChildren { |
954 | | - item, |
955 | | - non_processed_children, |
956 | | - processed_children, |
957 | | - }); |
958 | | - queue.push(ProcessingState::NotStarted(non_processed_item)); |
959 | | - } else { |
960 | | - item.transformed |= |
961 | | - processed_children.iter().any(|item| item.transformed); |
962 | | - item.data = item.data.with_new_arc_children( |
963 | | - processed_children.into_iter().map(|c| c.data).collect(), |
964 | | - )?; |
965 | | - queue.push(ProcessingState::ProcessedAllChildren(item)) |
966 | | - } |
967 | | - } |
968 | | - TreeNodeRecursion::Stop => { |
969 | | - processed_children.extend( |
970 | | - non_processed_children |
971 | | - .into_iter() |
972 | | - .rev() |
973 | | - .map(Transformed::no), |
974 | | - ); |
975 | | - item.transformed |= |
976 | | - processed_children.iter().any(|item| item.transformed); |
977 | | - item.data = item.data.with_new_arc_children( |
978 | | - processed_children.into_iter().map(|c| c.data).collect(), |
979 | | - )?; |
980 | | - queue.push(ProcessingState::ProcessedAllChildren(item)); |
981 | | - } |
982 | | - }, |
983 | | - ProcessingState::ProcessedAllChildren(node) => { |
984 | | - let node = node.transform_parent(|$NAME| $TRANSFORM_UP)?; |
985 | | - |
986 | | - if let Some(ProcessingState::ProcessingChildren { |
987 | | - item: mut parent_node, |
988 | | - non_processed_children, |
989 | | - mut processed_children, |
990 | | - .. |
991 | | - }) = queue.pop() |
992 | | - { |
993 | | - parent_node.tnr = node.tnr; |
994 | | - processed_children.push(node); |
995 | | - |
996 | | - queue.push(ProcessingState::ProcessingChildren { |
997 | | - item: parent_node, |
998 | | - non_processed_children, |
999 | | - processed_children, |
1000 | | - }) |
1001 | | - } else { |
1002 | | - debug_assert_eq!(queue.len(), 0); |
1003 | | - return Ok(node); |
1004 | | - } |
1005 | | - } |
1006 | | - } |
1007 | | - } |
1008 | | - |
1009 | | - unreachable!(); |
1010 | | - }; |
1011 | | -} |
1012 | | - |
1013 | 916 | /// Transformation helper to access [`Transformed`] fields in a [`Result`] easily. |
1014 | 917 | /// |
1015 | 918 | /// # Example |
@@ -1063,6 +966,59 @@ pub trait DynTreeNode { |
1063 | 966 | ) -> Result<Arc<Self>>; |
1064 | 967 | } |
1065 | 968 |
|
| 969 | +pub struct LegacyRewriter< |
| 970 | + FD: FnMut(Node) -> Result<Transformed<Node>>, |
| 971 | + FU: FnMut(Node) -> Result<Transformed<Node>>, |
| 972 | + Node: TreeNode, |
| 973 | +> { |
| 974 | + f_down_func: FD, |
| 975 | + f_up_func: FU, |
| 976 | + _node: PhantomData<Node>, |
| 977 | +} |
| 978 | + |
| 979 | +impl< |
| 980 | + FD: FnMut(Node) -> Result<Transformed<Node>>, |
| 981 | + FU: FnMut(Node) -> Result<Transformed<Node>>, |
| 982 | + Node: TreeNode, |
| 983 | + > LegacyRewriter<FD, FU, Node> |
| 984 | +{ |
| 985 | + pub fn new(f_down_func: FD, f_up_func: FU) -> Self { |
| 986 | + Self { |
| 987 | + f_down_func, |
| 988 | + f_up_func, |
| 989 | + _node: PhantomData, |
| 990 | + } |
| 991 | + } |
| 992 | +} |
| 993 | +impl< |
| 994 | + FD: FnMut(Node) -> Result<Transformed<Node>>, |
| 995 | + FU: FnMut(Node) -> Result<Transformed<Node>>, |
| 996 | + Node: TreeNode, |
| 997 | + > TreeNodeRewriter for LegacyRewriter<FD, FU, Node> |
| 998 | +{ |
| 999 | + type Node = Node; |
| 1000 | + |
| 1001 | + fn f_down(&mut self, node: Self::Node) -> Result<Transformed<Self::Node>> { |
| 1002 | + (self.f_down_func)(node) |
| 1003 | + } |
| 1004 | + |
| 1005 | + fn f_up(&mut self, node: Self::Node) -> Result<Transformed<Self::Node>> { |
| 1006 | + (self.f_up_func)(node) |
| 1007 | + } |
| 1008 | +} |
| 1009 | + |
| 1010 | +macro_rules! update_rec_node { |
| 1011 | + ($NAME:ident, $CHILDREN:ident) => {{ |
| 1012 | + $NAME.transformed |= $CHILDREN.iter().any(|item| item.transformed); |
| 1013 | + |
| 1014 | + $NAME.data = $NAME |
| 1015 | + .data |
| 1016 | + .with_new_arc_children($CHILDREN.into_iter().map(|c| c.data).collect())?; |
| 1017 | + |
| 1018 | + $NAME |
| 1019 | + }}; |
| 1020 | +} |
| 1021 | + |
1066 | 1022 | /// Blanket implementation for any `Arc<T>` where `T` implements [`DynTreeNode`] |
1067 | 1023 | /// (such as [`Arc<dyn PhysicalExpr>`]). |
1068 | 1024 | impl<T: DynTreeNode + ?Sized> TreeNode for Arc<T> { |
@@ -1102,30 +1058,121 @@ impl<T: DynTreeNode + ?Sized> TreeNode for Arc<T> { |
1102 | 1058 | FU: FnMut(Self) -> Result<Transformed<Self>>, |
1103 | 1059 | >( |
1104 | 1060 | self, |
1105 | | - mut f_down: FD, |
1106 | | - mut f_up: FU, |
| 1061 | + f_down: FD, |
| 1062 | + f_up: FU, |
1107 | 1063 | ) -> Result<Transformed<Self>> { |
1108 | | - rewrite_recursive!(self, node, f_up(node), f_down(node)); |
| 1064 | + self.rewrite(&mut LegacyRewriter::new(f_down, f_up)) |
1109 | 1065 | } |
1110 | 1066 |
|
1111 | 1067 | fn transform_down<F: FnMut(Self) -> Result<Transformed<Self>>>( |
1112 | 1068 | self, |
1113 | 1069 | f: F, |
1114 | 1070 | ) -> Result<Transformed<Self>> { |
1115 | | - self.transform_down_up(f, |node| Ok(Transformed::no(node))) |
| 1071 | + self.rewrite(&mut LegacyRewriter::new(f, |node| { |
| 1072 | + Ok(Transformed::no(node)) |
| 1073 | + })) |
1116 | 1074 | } |
1117 | 1075 |
|
1118 | 1076 | fn transform_up<F: FnMut(Self) -> Result<Transformed<Self>>>( |
1119 | 1077 | self, |
1120 | 1078 | f: F, |
1121 | 1079 | ) -> Result<Transformed<Self>> { |
1122 | | - self.transform_down_up(|node| Ok(Transformed::no(node)), f) |
| 1080 | + self.rewrite(&mut LegacyRewriter::new( |
| 1081 | + |node| Ok(Transformed::no(node)), |
| 1082 | + f, |
| 1083 | + )) |
1123 | 1084 | } |
1124 | 1085 | fn rewrite<R: TreeNodeRewriter<Node = Self>>( |
1125 | 1086 | self, |
1126 | 1087 | rewriter: &mut R, |
1127 | 1088 | ) -> Result<Transformed<Self>> { |
1128 | | - rewrite_recursive!(self, node, rewriter.f_up(node), rewriter.f_down(node)); |
| 1089 | + let mut queue = vec![ProcessingState::NotStarted(self)]; |
| 1090 | + |
| 1091 | + while let Some(item) = queue.pop() { |
| 1092 | + match item { |
| 1093 | + ProcessingState::NotStarted(node) => { |
| 1094 | + let node = rewriter.f_down(node)?; |
| 1095 | + |
| 1096 | + queue.push(match node.tnr { |
| 1097 | + TreeNodeRecursion::Continue => { |
| 1098 | + ProcessingState::ProcessingChildren { |
| 1099 | + non_processed_children: node |
| 1100 | + .data |
| 1101 | + .arc_children() |
| 1102 | + .into_iter() |
| 1103 | + .cloned() |
| 1104 | + .rev() |
| 1105 | + .collect(), |
| 1106 | + item: node, |
| 1107 | + processed_children: vec![], |
| 1108 | + } |
| 1109 | + } |
| 1110 | + TreeNodeRecursion::Jump => ProcessingState::ProcessedAllChildren( |
| 1111 | + node.with_tnr(TreeNodeRecursion::Continue), |
| 1112 | + ), |
| 1113 | + TreeNodeRecursion::Stop => { |
| 1114 | + ProcessingState::ProcessedAllChildren(node) |
| 1115 | + } |
| 1116 | + }) |
| 1117 | + } |
| 1118 | + ProcessingState::ProcessingChildren { |
| 1119 | + mut item, |
| 1120 | + mut non_processed_children, |
| 1121 | + mut processed_children, |
| 1122 | + } => match item.tnr { |
| 1123 | + TreeNodeRecursion::Continue | TreeNodeRecursion::Jump => { |
| 1124 | + if let Some(non_processed_item) = non_processed_children.pop() { |
| 1125 | + queue.push(ProcessingState::ProcessingChildren { |
| 1126 | + item, |
| 1127 | + non_processed_children, |
| 1128 | + processed_children, |
| 1129 | + }); |
| 1130 | + queue.push(ProcessingState::NotStarted(non_processed_item)); |
| 1131 | + } else { |
| 1132 | + queue.push(ProcessingState::ProcessedAllChildren( |
| 1133 | + update_rec_node!(item, processed_children), |
| 1134 | + )) |
| 1135 | + } |
| 1136 | + } |
| 1137 | + TreeNodeRecursion::Stop => { |
| 1138 | + processed_children.extend( |
| 1139 | + non_processed_children |
| 1140 | + .into_iter() |
| 1141 | + .rev() |
| 1142 | + .map(Transformed::no), |
| 1143 | + ); |
| 1144 | + queue.push(ProcessingState::ProcessedAllChildren( |
| 1145 | + update_rec_node!(item, processed_children), |
| 1146 | + )); |
| 1147 | + } |
| 1148 | + }, |
| 1149 | + ProcessingState::ProcessedAllChildren(node) => { |
| 1150 | + let node = node.transform_parent(|n| rewriter.f_up(n))?; |
| 1151 | + |
| 1152 | + if let Some(ProcessingState::ProcessingChildren { |
| 1153 | + item: mut parent_node, |
| 1154 | + non_processed_children, |
| 1155 | + mut processed_children, |
| 1156 | + .. |
| 1157 | + }) = queue.pop() |
| 1158 | + { |
| 1159 | + parent_node.tnr = node.tnr; |
| 1160 | + processed_children.push(node); |
| 1161 | + |
| 1162 | + queue.push(ProcessingState::ProcessingChildren { |
| 1163 | + item: parent_node, |
| 1164 | + non_processed_children, |
| 1165 | + processed_children, |
| 1166 | + }) |
| 1167 | + } else { |
| 1168 | + debug_assert_eq!(queue.len(), 0); |
| 1169 | + return Ok(node); |
| 1170 | + } |
| 1171 | + } |
| 1172 | + } |
| 1173 | + } |
| 1174 | + |
| 1175 | + unreachable!(); |
1129 | 1176 | } |
1130 | 1177 |
|
1131 | 1178 | fn visit<'n, V: TreeNodeVisitor<'n, Node = Self>>( |
|
0 commit comments