diff --git a/src/pdl/pdl_ast_utils.py b/src/pdl/pdl_ast_utils.py index a75983c7b..4df30d67f 100644 --- a/src/pdl/pdl_ast_utils.py +++ b/src/pdl/pdl_ast_utils.py @@ -126,13 +126,9 @@ def f_expr(self, expr: ExpressionType) -> ExpressionType: def map_block_children(f: MappedFunctions, block: BlockType) -> BlockType: if not isinstance(block, Block): - return f.f_block(block) + return block defs = {x: f.f_block(b) for x, b in block.defs.items()} - if block.fallback is not None: - fallback = f.f_block(block.fallback) - else: - fallback = None - block = block.model_copy(update={"defs": defs, "fallback": fallback}) + block = block.model_copy(update={"defs": defs}) match block: case FunctionBlock(): block.returns = f.f_block(block.returns) @@ -211,4 +207,6 @@ def map_block_children(f: MappedFunctions, block: BlockType) -> BlockType: pass case PdlParser(): block.parser.pdl = f.f_block(block.parser.pdl) + if block.fallback is not None: + block.fallback = f.f_block(block.fallback) return block diff --git a/tests/test_ast_utils.py b/tests/test_ast_utils.py new file mode 100644 index 000000000..36b6b1796 --- /dev/null +++ b/tests/test_ast_utils.py @@ -0,0 +1,49 @@ +import pathlib + +from pdl.pdl_ast_utils import MappedFunctions, iter_block_children, map_block_children +from pdl.pdl_parser import PDLParseError, parse_file + + +class Counter: + def __init__(self): + self.cpt = 0 + + def incr(self, *args): + self.cpt += 1 + + +class IterCounter: + def __init__(self): + self.cpt = 0 + + def count(self, ast): + self.cpt += 1 + iter_block_children(self.count, ast) + + +class MapCounter: + def __init__(self): + self.cpt = 0 + + def count(map_self, ast): # pylint: disable=no-self-argument + map_self.cpt += 1 + + class C(MappedFunctions): + def f_block(_, block): # pylint: disable=no-self-argument + return map_self.count(block) + + _ = map_block_children(C(), ast) + return ast + + +def test_ast_iterators() -> None: + for yaml_file_name in pathlib.Path(".").glob("**/*.pdl"): + try: + ast, _ = parse_file(yaml_file_name) + iter_cpt = IterCounter() + iter_cpt.count(ast.root) + map_cpt = MapCounter() + map_cpt.count(ast.root) + assert iter_cpt.cpt == map_cpt.cpt, yaml_file_name + except PDLParseError: + pass