Skip to content

Commit d62a528

Browse files
refactor: Flatten multiple declarations
1 parent 7339b1b commit d62a528

File tree

3 files changed

+12
-6
lines changed

3 files changed

+12
-6
lines changed

parser/astgen/ast_node.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -100,10 +100,11 @@ class AstDeclNode(AstNode):
100100
name = 'var_decl'
101101
scope: VarDeclScope
102102
type: VarType
103-
decls: list[tuple[AstIdent, AstNode | None]]
103+
ident: AstIdent
104+
value: AstNode | None
104105

105106
def _walk_members(self, fn: WalkerFnT):
106-
self.walk_multiple_objects(fn, (self.decls,))
107+
self.walk_multiple_objects(fn, (self.ident, self.value))
107108

108109

109110
@dataclass

parser/astgen/astgen.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -134,14 +134,17 @@ def _walk_smt(self, smt: AnyNode) -> list[AstNode]:
134134
f"the root level.", smt.region)
135135

136136
def _walk_var_decl(self, smt: DeclNode):
137-
decls = [(self._walk_ident(d.ident),
138-
None if d.value is None else self._walk_expr(d.value))
139-
for d in smt.decl_list.decls]
140137
scope = (VarDeclScope.LET if isinstance(smt.decl_scope, DeclScope_Let)
141138
else VarDeclScope.GLOBAL)
142139
tp = (VarType.LIST if isinstance(smt.decl_type, DeclType_List)
143140
else VarType.VARIABLE)
144-
return [AstDeclNode(smt.region, scope, tp, decls)]
141+
return [self._walk_single_decl(d, scope, tp) for d in smt.decl_list.decls]
142+
143+
def _walk_single_decl(self, d: DeclItemNode, scope: VarDeclScope, tp: VarType):
144+
return AstDeclNode(
145+
region_union(d.ident, d.value),
146+
scope, tp, self._walk_ident(d.ident),
147+
None if d.value is None else self._walk_expr(d.value))
145148

146149
def _walk_assign_left(self, lhs: AnyNode) -> AstNode:
147150
if isinstance(lhs, IdentNode):

parser/common/common.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@ class HasRegion:
1616
def region_union(*args: RegionUnionArgT):
1717
regs = []
1818
for loc in args:
19+
if loc is None:
20+
continue
1921
if getattr(loc, 'region', None) is not None: # Duck-type HasRegion
2022
loc = loc.region
2123
if isinstance(loc, StrRegion):

0 commit comments

Comments
 (0)