Skip to content

Commit 2b34a9a

Browse files
authored
Fix bug in pdl_ast_utils.map_block_children (#198)
1 parent c3bc892 commit 2b34a9a

File tree

2 files changed

+53
-6
lines changed

2 files changed

+53
-6
lines changed

src/pdl/pdl_ast_utils.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -126,13 +126,9 @@ def f_expr(self, expr: ExpressionType) -> ExpressionType:
126126

127127
def map_block_children(f: MappedFunctions, block: BlockType) -> BlockType:
128128
if not isinstance(block, Block):
129-
return f.f_block(block)
129+
return block
130130
defs = {x: f.f_block(b) for x, b in block.defs.items()}
131-
if block.fallback is not None:
132-
fallback = f.f_block(block.fallback)
133-
else:
134-
fallback = None
135-
block = block.model_copy(update={"defs": defs, "fallback": fallback})
131+
block = block.model_copy(update={"defs": defs})
136132
match block:
137133
case FunctionBlock():
138134
block.returns = f.f_block(block.returns)
@@ -211,4 +207,6 @@ def map_block_children(f: MappedFunctions, block: BlockType) -> BlockType:
211207
pass
212208
case PdlParser():
213209
block.parser.pdl = f.f_block(block.parser.pdl)
210+
if block.fallback is not None:
211+
block.fallback = f.f_block(block.fallback)
214212
return block

tests/test_ast_utils.py

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
import pathlib
2+
3+
from pdl.pdl_ast_utils import MappedFunctions, iter_block_children, map_block_children
4+
from pdl.pdl_parser import PDLParseError, parse_file
5+
6+
7+
class Counter:
8+
def __init__(self):
9+
self.cpt = 0
10+
11+
def incr(self, *args):
12+
self.cpt += 1
13+
14+
15+
class IterCounter:
16+
def __init__(self):
17+
self.cpt = 0
18+
19+
def count(self, ast):
20+
self.cpt += 1
21+
iter_block_children(self.count, ast)
22+
23+
24+
class MapCounter:
25+
def __init__(self):
26+
self.cpt = 0
27+
28+
def count(map_self, ast): # pylint: disable=no-self-argument
29+
map_self.cpt += 1
30+
31+
class C(MappedFunctions):
32+
def f_block(_, block): # pylint: disable=no-self-argument
33+
return map_self.count(block)
34+
35+
_ = map_block_children(C(), ast)
36+
return ast
37+
38+
39+
def test_ast_iterators() -> None:
40+
for yaml_file_name in pathlib.Path(".").glob("**/*.pdl"):
41+
try:
42+
ast, _ = parse_file(yaml_file_name)
43+
iter_cpt = IterCounter()
44+
iter_cpt.count(ast.root)
45+
map_cpt = MapCounter()
46+
map_cpt.count(ast.root)
47+
assert iter_cpt.cpt == map_cpt.cpt, yaml_file_name
48+
except PDLParseError:
49+
pass

0 commit comments

Comments
 (0)