Skip to content

Commit 4230738

Browse files
committed
Fix bug in pdl_ast_utils.map_block_children
1 parent b454ffa commit 4230738

File tree

2 files changed

+49
-6
lines changed

2 files changed

+49
-6
lines changed

src/pdl/pdl_ast_utils.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -126,13 +126,9 @@ def f_expr(self, expr: ExpressionType) -> ExpressionType:
126126

127127
def map_block_children(f: MappedFunctions, block: BlockType) -> BlockType:
128128
if not isinstance(block, Block):
129-
return f.f_block(block)
129+
return block
130130
defs = {x: f.f_block(b) for x, b in block.defs.items()}
131-
if block.fallback is not None:
132-
fallback = f.f_block(block.fallback)
133-
else:
134-
fallback = None
135-
block = block.model_copy(update={"defs": defs, "fallback": fallback})
131+
block = block.model_copy(update={"defs": defs})
136132
match block:
137133
case FunctionBlock():
138134
block.returns = f.f_block(block.returns)
@@ -211,4 +207,6 @@ def map_block_children(f: MappedFunctions, block: BlockType) -> BlockType:
211207
pass
212208
case PdlParser():
213209
block.parser.pdl = f.f_block(block.parser.pdl)
210+
if block.fallback is not None:
211+
block.fallback = f.f_block(block.fallback)
214212
return block

tests/test_ast_utils.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
import pathlib
2+
3+
from pdl.pdl_ast_utils import MappedFunctions, iter_block_children, map_block_children
4+
from pdl.pdl_parser import PDLParseError, parse_file
5+
6+
7+
class Counter:
8+
def __init__(self):
9+
self.cpt = 0
10+
11+
def incr(self, *args):
12+
self.cpt += 1
13+
14+
class IterCounter:
15+
def __init__(self):
16+
self.cpt = 0
17+
18+
def count(self, ast):
19+
self.cpt += 1
20+
iter_block_children(self.count, ast)
21+
22+
23+
class MapCounter:
24+
def __init__(self):
25+
self.cpt = 0
26+
27+
def count(map_self, ast):
28+
map_self.cpt += 1
29+
class C(MappedFunctions):
30+
def f_block(c_self, block):
31+
return map_self.count(block)
32+
_ = map_block_children(C(), ast)
33+
return ast
34+
35+
def test_ast_iterators() -> None:
36+
for yaml_file_name in pathlib.Path(".").glob("**/*.pdl"):
37+
try:
38+
ast, _ = parse_file(yaml_file_name)
39+
iter_cpt = IterCounter()
40+
iter_cpt.count(ast.root)
41+
map_cpt = MapCounter()
42+
map_cpt.count(ast.root)
43+
assert iter_cpt.cpt == map_cpt.cpt, yaml_file_name
44+
except PDLParseError:
45+
pass

0 commit comments

Comments
 (0)