Skip to content

Commit 916ddbb

Browse files
authored
fix(tools): handle AnnAssign nodes (#1813)
1 parent bff0f2e commit 916ddbb

File tree

2 files changed

+40
-1
lines changed

2 files changed

+40
-1
lines changed

src/ethereum_spec_tools/new_fork/codemod/constant.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -124,6 +124,45 @@ def leave_Assign( # noqa: D102
124124

125125
return updated_node.with_changes(value=self.value.deep_clone())
126126

127+
@override
128+
def visit_AnnAssign_target(self, node: cst.AnnAssign) -> None: # noqa: D102
129+
if self._in_assign_target:
130+
raise Exception("already in assign target")
131+
self._in_assign_target = True
132+
133+
@override
134+
def leave_AnnAssign_target(self, node: cst.AnnAssign) -> None: # noqa: D102
135+
if not self._in_assign_target:
136+
raise Exception("not in assign target")
137+
self._in_assign_target = False
138+
139+
@override
140+
def visit_AnnAssign(self, node: cst.AnnAssign) -> None: # noqa: D102
141+
if self._matches or self._in_assign_target:
142+
raise Exception("nested assign")
143+
144+
@override
145+
def leave_AnnAssign( # noqa: D102
146+
self, original_node: cst.AnnAssign, updated_node: cst.AnnAssign
147+
) -> cst.AnnAssign:
148+
if self._in_assign_target:
149+
raise Exception("still in assign target")
150+
151+
if not self._matches:
152+
return updated_node
153+
154+
self._matches = False
155+
156+
for module, identifier in self.imports:
157+
AddImportsVisitor.add_needed_import(
158+
self.context, module, identifier
159+
)
160+
RemoveImportsVisitor.remove_unused_import(
161+
self.context, module, identifier
162+
)
163+
164+
return updated_node.with_changes(value=self.value.deep_clone())
165+
127166
@override
128167
def visit_Name(self, node: cst.Name) -> None: # noqa: D102
129168
if not self._in_assign_target:

tests/json_infra/test_tools_new_fork.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ def test_end_to_end(template_fork: str) -> None:
6262
source = f.read()
6363

6464
assert '"""' not in source[:20]
65-
assert "FORK_CRITERIA = ByTimestamp(7)" in source
65+
assert "FORK_CRITERIA: ForkCriteria = ByTimestamp(7)" in source
6666
assert template_fork.capitalize() not in source
6767

6868
with (fork_dir / "utils" / "hexadecimal.py").open("r") as f:

0 commit comments

Comments
 (0)