22
33from dataclasses import dataclass
44from enum import Enum
5+ from typing import Callable , TypeAlias , Iterable
56
67from ..common import HasRegion , StrRegion
78
1011 "AstRepeat" , "AstIf" , "AstWhile" , "AstAssign" , "AstAugAssign" , "AstDefine" ,
1112 "AstNumber" , "AstString" , "AstAnyName" , "AstIdent" , "AstAttrName" ,
1213 "AstListLiteral" , "AstAttribute" , "AstItem" , "AstCall" , "AstOp" , "AstBinOp" ,
13- "AstUnaryOp" ,
14+ "AstUnaryOp" , 'walk_ast' , 'WalkableT' , 'WalkerFnT' , 'WalkerCallType'
1415]
1516
1617
18+ class WalkerCallType (Enum ):
19+ PRE = 'pre'
20+ POST = 'post'
21+
22+
23+ WalkableL0 : TypeAlias = 'AstNode | list[AstNode] | tuple[AstNode, ...] | None'
24+ WalkableT : TypeAlias = 'WalkableL0 | list[WalkableL0] | tuple[WalkableL0, ...]'
25+ WalkerFnT : TypeAlias = Callable [[WalkableT , WalkerCallType ], bool | None ]
26+ """Returns True if skip"""
27+
28+
1729@dataclass
1830class AstNode (HasRegion ):
1931 region : StrRegion
2032 name = None # type: str
2133 del name # So we get better error msg if we forget to add it to a class
2234
35+ def walk (self , fn : WalkerFnT ):
36+ if fn (self , WalkerCallType .PRE ):
37+ return
38+ self ._walk_members (fn )
39+ fn (self , WalkerCallType .POST )
40+
41+ def _walk_members (self , fn : WalkerFnT ):
42+ """We have to define this manually on all subclasses with children.
43+ We don't try to do anything overcomplicated as it is hard to tell
44+ if a dataclass field is a child or not."""
45+
46+ @classmethod
47+ def _walk_obj_members (cls , o : WalkableT , fn : WalkerFnT ):
48+ if o is None :
49+ return
50+ if isinstance (o , AstNode ):
51+ # noinspection PyProtectedMember
52+ return o ._walk_members (fn )
53+ try :
54+ it = iter (o )
55+ except TypeError :
56+ raise TypeError ("Don't know how to walk object" )
57+ for i in it :
58+ cls .walk_obj (i , fn )
59+
60+ @classmethod
61+ def walk_obj (cls , o : WalkableT , fn : WalkerFnT ):
62+ if isinstance (o , AstNode ):
63+ return o .walk (fn ) # Delegate straight away (might have special functionality)
64+ if fn (o , WalkerCallType .PRE ):
65+ return
66+ cls ._walk_obj_members (o , fn )
67+ fn (o , WalkerCallType .POST )
68+
69+ @classmethod
70+ def walk_multiple_objects (cls , fn : WalkerFnT , objs : Iterable [WalkableT ]):
71+ for o in objs :
72+ cls .walk_obj (o , fn )
73+
74+
75+ walk_ast = AstNode .walk_obj
76+
2377
2478@dataclass
2579class AstProgramNode (AstNode ):
2680 name = 'program'
2781 statements : list [AstNode ]
2882
83+ def _walk_members (self , fn : WalkerFnT ):
84+ self .walk_multiple_objects (fn , (self .statements ,))
85+
2986
3087# region ---- <Statements> ----
3188class VarDeclScope (Enum ):
@@ -45,13 +102,19 @@ class AstDeclNode(AstNode):
45102 type : VarType
46103 decls : list [tuple [AstIdent , AstNode | None ]]
47104
105+ def _walk_members (self , fn : WalkerFnT ):
106+ self .walk_multiple_objects (fn , (self .decls ,))
107+
48108
49109@dataclass
50110class AstRepeat (AstNode ):
51111 name = 'repeat'
52112 count : AstNode
53113 body : list [AstNode ]
54114
115+ def _walk_members (self , fn : WalkerFnT ):
116+ self .walk_multiple_objects (fn , (self .count , self .body ))
117+
55118
56119@dataclass
57120class AstIf (AstNode ):
@@ -63,20 +126,29 @@ class AstIf(AstNode):
63126 # ^ Separate cases for no block and empty block (can be else {} to easily
64127 # add extra blocks in scratch interface)
65128
129+ def _walk_members (self , fn : WalkerFnT ):
130+ self .walk_multiple_objects (fn , (self .cond , self .if_body , self .else_body ))
131+
66132
67133@dataclass
68134class AstWhile (AstNode ):
69135 name = 'while'
70136 cond : AstNode
71137 body : list [AstNode ]
72138
139+ def _walk_members (self , fn : WalkerFnT ):
140+ self .walk_multiple_objects (fn , (self .cond , self .body ))
141+
73142
74143@dataclass
75144class AstAssign (AstNode ):
76145 name = '='
77146 target : AstNode
78147 source : AstNode
79148
149+ def _walk_members (self , fn : WalkerFnT ):
150+ self .walk_multiple_objects (fn , (self .target , self .source ))
151+
80152
81153@dataclass
82154class AstAugAssign (AstNode ):
@@ -88,6 +160,9 @@ class AstAugAssign(AstNode):
88160 def name (self ):
89161 return self .op
90162
163+ def _walk_members (self , fn : WalkerFnT ):
164+ self .walk_multiple_objects (fn , (self .target , self .source ))
165+
91166
92167@dataclass
93168class AstDefine (AstNode ):
@@ -96,6 +171,9 @@ class AstDefine(AstNode):
96171 ident : AstIdent
97172 params : list [tuple [AstIdent , AstIdent ]] # type, ident
98173 body : list [AstNode ]
174+
175+ def _walk_members (self , fn : WalkerFnT ):
176+ self .walk_multiple_objects (fn , (self .ident , self .params , self .body ))
99177# endregion ---- </Statements> ----
100178
101179
@@ -135,27 +213,39 @@ class AstListLiteral(AstNode):
135213 name = 'list'
136214 items : list [AstNode ]
137215
216+ def _walk_members (self , fn : WalkerFnT ):
217+ self .walk_multiple_objects (fn , (self .items ,))
218+
138219
139220@dataclass
140221class AstAttribute (AstNode ):
141222 name = '.'
142223 obj : AstNode
143224 attr : AstAttrName
144225
226+ def _walk_members (self , fn : WalkerFnT ):
227+ self .walk_multiple_objects (fn , (self .obj , self .attr ))
228+
145229
146230@dataclass
147231class AstItem (AstNode ):
148232 name = 'item'
149233 obj : AstNode
150234 index : AstNode
151235
236+ def _walk_members (self , fn : WalkerFnT ):
237+ self .walk_multiple_objects (fn , (self .obj , self .index ))
238+
152239
153240@dataclass
154241class AstCall (AstNode ):
155242 name = 'call'
156243 obj : AstNode
157244 args : list [AstNode ]
158245
246+ def _walk_members (self , fn : WalkerFnT ):
247+ self .walk_multiple_objects (fn , (self .obj , self .args ))
248+
159249
160250@dataclass
161251class AstOp (AstNode ):
@@ -178,6 +268,9 @@ def __post_init__(self):
178268 def name (self ):
179269 return self .op
180270
271+ def _walk_members (self , fn : WalkerFnT ):
272+ self .walk_multiple_objects (fn , (self .left , self .right ))
273+
181274
182275@dataclass
183276class AstUnaryOp (AstOp ):
@@ -191,4 +284,7 @@ def __post_init__(self):
191284 @property
192285 def name (self ):
193286 return self .op
287+
288+ def _walk_members (self , fn : WalkerFnT ):
289+ self .walk_multiple_objects (fn , (self .operand ,))
194290# endregion ---- </Expressions> ----
0 commit comments