Skip to content

Commit ec657fa

Browse files
Merge pull request #62 from MarcellPerger1/add-typechecking
refactor: Add AST walking
2 parents 8e5f541 + 7339b1b commit ec657fa

File tree

1 file changed

+97
-1
lines changed

1 file changed

+97
-1
lines changed

parser/astgen/ast_node.py

Lines changed: 97 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
from dataclasses import dataclass
44
from enum import Enum
5+
from typing import Callable, TypeAlias, Iterable
56

67
from ..common import HasRegion, StrRegion
78

@@ -10,22 +11,78 @@
1011
"AstRepeat", "AstIf", "AstWhile", "AstAssign", "AstAugAssign", "AstDefine",
1112
"AstNumber", "AstString", "AstAnyName", "AstIdent", "AstAttrName",
1213
"AstListLiteral", "AstAttribute", "AstItem", "AstCall", "AstOp", "AstBinOp",
13-
"AstUnaryOp",
14+
"AstUnaryOp", 'walk_ast', 'WalkableT', 'WalkerFnT', 'WalkerCallType'
1415
]
1516

1617

18+
class WalkerCallType(Enum):
19+
PRE = 'pre'
20+
POST = 'post'
21+
22+
23+
WalkableL0: TypeAlias = 'AstNode | list[AstNode] | tuple[AstNode, ...] | None'
24+
WalkableT: TypeAlias = 'WalkableL0 | list[WalkableL0] | tuple[WalkableL0, ...]'
25+
WalkerFnT: TypeAlias = Callable[[WalkableT, WalkerCallType], bool | None]
26+
"""Returns True if skip"""
27+
28+
1729
@dataclass
1830
class AstNode(HasRegion):
1931
region: StrRegion
2032
name = None # type: str
2133
del name # So we get better error msg if we forget to add it to a class
2234

35+
def walk(self, fn: WalkerFnT):
36+
if fn(self, WalkerCallType.PRE):
37+
return
38+
self._walk_members(fn)
39+
fn(self, WalkerCallType.POST)
40+
41+
def _walk_members(self, fn: WalkerFnT):
42+
"""We have to define this manually on all subclasses with children.
43+
We don't try to do anything overcomplicated as it is hard to tell
44+
if a dataclass field is a child or not."""
45+
46+
@classmethod
47+
def _walk_obj_members(cls, o: WalkableT, fn: WalkerFnT):
48+
if o is None:
49+
return
50+
if isinstance(o, AstNode):
51+
# noinspection PyProtectedMember
52+
return o._walk_members(fn)
53+
try:
54+
it = iter(o)
55+
except TypeError:
56+
raise TypeError("Don't know how to walk object")
57+
for i in it:
58+
cls.walk_obj(i, fn)
59+
60+
@classmethod
61+
def walk_obj(cls, o: WalkableT, fn: WalkerFnT):
62+
if isinstance(o, AstNode):
63+
return o.walk(fn) # Delegate straight away (might have special functionality)
64+
if fn(o, WalkerCallType.PRE):
65+
return
66+
cls._walk_obj_members(o, fn)
67+
fn(o, WalkerCallType.POST)
68+
69+
@classmethod
70+
def walk_multiple_objects(cls, fn: WalkerFnT, objs: Iterable[WalkableT]):
71+
for o in objs:
72+
cls.walk_obj(o, fn)
73+
74+
75+
walk_ast = AstNode.walk_obj
76+
2377

2478
@dataclass
2579
class AstProgramNode(AstNode):
2680
name = 'program'
2781
statements: list[AstNode]
2882

83+
def _walk_members(self, fn: WalkerFnT):
84+
self.walk_multiple_objects(fn, (self.statements,))
85+
2986

3087
# region ---- <Statements> ----
3188
class VarDeclScope(Enum):
@@ -45,13 +102,19 @@ class AstDeclNode(AstNode):
45102
type: VarType
46103
decls: list[tuple[AstIdent, AstNode | None]]
47104

105+
def _walk_members(self, fn: WalkerFnT):
106+
self.walk_multiple_objects(fn, (self.decls,))
107+
48108

49109
@dataclass
50110
class AstRepeat(AstNode):
51111
name = 'repeat'
52112
count: AstNode
53113
body: list[AstNode]
54114

115+
def _walk_members(self, fn: WalkerFnT):
116+
self.walk_multiple_objects(fn, (self.count, self.body))
117+
55118

56119
@dataclass
57120
class AstIf(AstNode):
@@ -63,20 +126,29 @@ class AstIf(AstNode):
63126
# ^ Separate cases for no block and empty block (can be else {} to easily
64127
# add extra blocks in scratch interface)
65128

129+
def _walk_members(self, fn: WalkerFnT):
130+
self.walk_multiple_objects(fn, (self.cond, self.if_body, self.else_body))
131+
66132

67133
@dataclass
68134
class AstWhile(AstNode):
69135
name = 'while'
70136
cond: AstNode
71137
body: list[AstNode]
72138

139+
def _walk_members(self, fn: WalkerFnT):
140+
self.walk_multiple_objects(fn, (self.cond, self.body))
141+
73142

74143
@dataclass
75144
class AstAssign(AstNode):
76145
name = '='
77146
target: AstNode
78147
source: AstNode
79148

149+
def _walk_members(self, fn: WalkerFnT):
150+
self.walk_multiple_objects(fn, (self.target, self.source))
151+
80152

81153
@dataclass
82154
class AstAugAssign(AstNode):
@@ -88,6 +160,9 @@ class AstAugAssign(AstNode):
88160
def name(self):
89161
return self.op
90162

163+
def _walk_members(self, fn: WalkerFnT):
164+
self.walk_multiple_objects(fn, (self.target, self.source))
165+
91166

92167
@dataclass
93168
class AstDefine(AstNode):
@@ -96,6 +171,9 @@ class AstDefine(AstNode):
96171
ident: AstIdent
97172
params: list[tuple[AstIdent, AstIdent]] # type, ident
98173
body: list[AstNode]
174+
175+
def _walk_members(self, fn: WalkerFnT):
176+
self.walk_multiple_objects(fn, (self.ident, self.params, self.body))
99177
# endregion ---- </Statements> ----
100178

101179

@@ -135,27 +213,39 @@ class AstListLiteral(AstNode):
135213
name = 'list'
136214
items: list[AstNode]
137215

216+
def _walk_members(self, fn: WalkerFnT):
217+
self.walk_multiple_objects(fn, (self.items,))
218+
138219

139220
@dataclass
140221
class AstAttribute(AstNode):
141222
name = '.'
142223
obj: AstNode
143224
attr: AstAttrName
144225

226+
def _walk_members(self, fn: WalkerFnT):
227+
self.walk_multiple_objects(fn, (self.obj, self.attr))
228+
145229

146230
@dataclass
147231
class AstItem(AstNode):
148232
name = 'item'
149233
obj: AstNode
150234
index: AstNode
151235

236+
def _walk_members(self, fn: WalkerFnT):
237+
self.walk_multiple_objects(fn, (self.obj, self.index))
238+
152239

153240
@dataclass
154241
class AstCall(AstNode):
155242
name = 'call'
156243
obj: AstNode
157244
args: list[AstNode]
158245

246+
def _walk_members(self, fn: WalkerFnT):
247+
self.walk_multiple_objects(fn, (self.obj, self.args))
248+
159249

160250
@dataclass
161251
class AstOp(AstNode):
@@ -178,6 +268,9 @@ def __post_init__(self):
178268
def name(self):
179269
return self.op
180270

271+
def _walk_members(self, fn: WalkerFnT):
272+
self.walk_multiple_objects(fn, (self.left, self.right))
273+
181274

182275
@dataclass
183276
class AstUnaryOp(AstOp):
@@ -191,4 +284,7 @@ def __post_init__(self):
191284
@property
192285
def name(self):
193286
return self.op
287+
288+
def _walk_members(self, fn: WalkerFnT):
289+
self.walk_multiple_objects(fn, (self.operand,))
194290
# endregion ---- </Expressions> ----

0 commit comments

Comments
 (0)