From 4230738f9280ab3dcfd01ef9a886563e5bf108b5 Mon Sep 17 00:00:00 2001 From: Louis Mandel Date: Mon, 25 Nov 2024 10:56:57 -0500 Subject: [PATCH 1/2] Fix bug in `pdl_ast_utils.map_block_children` --- src/pdl/pdl_ast_utils.py | 10 ++++----- tests/test_ast_utils.py | 45 ++++++++++++++++++++++++++++++++++++++++ 2 files changed, 49 insertions(+), 6 deletions(-) create mode 100644 tests/test_ast_utils.py diff --git a/src/pdl/pdl_ast_utils.py b/src/pdl/pdl_ast_utils.py index a75983c7b..4df30d67f 100644 --- a/src/pdl/pdl_ast_utils.py +++ b/src/pdl/pdl_ast_utils.py @@ -126,13 +126,9 @@ def f_expr(self, expr: ExpressionType) -> ExpressionType: def map_block_children(f: MappedFunctions, block: BlockType) -> BlockType: if not isinstance(block, Block): - return f.f_block(block) + return block defs = {x: f.f_block(b) for x, b in block.defs.items()} - if block.fallback is not None: - fallback = f.f_block(block.fallback) - else: - fallback = None - block = block.model_copy(update={"defs": defs, "fallback": fallback}) + block = block.model_copy(update={"defs": defs}) match block: case FunctionBlock(): block.returns = f.f_block(block.returns) @@ -211,4 +207,6 @@ def map_block_children(f: MappedFunctions, block: BlockType) -> BlockType: pass case PdlParser(): block.parser.pdl = f.f_block(block.parser.pdl) + if block.fallback is not None: + block.fallback = f.f_block(block.fallback) return block diff --git a/tests/test_ast_utils.py b/tests/test_ast_utils.py new file mode 100644 index 000000000..33bc613f9 --- /dev/null +++ b/tests/test_ast_utils.py @@ -0,0 +1,45 @@ +import pathlib + +from pdl.pdl_ast_utils import MappedFunctions, iter_block_children, map_block_children +from pdl.pdl_parser import PDLParseError, parse_file + + +class Counter: + def __init__(self): + self.cpt = 0 + + def incr(self, *args): + self.cpt += 1 + +class IterCounter: + def __init__(self): + self.cpt = 0 + + def count(self, ast): + self.cpt += 1 + iter_block_children(self.count, ast) + + +class MapCounter: + def __init__(self): + self.cpt = 0 + + def count(map_self, ast): + map_self.cpt += 1 + class C(MappedFunctions): + def f_block(c_self, block): + return map_self.count(block) + _ = map_block_children(C(), ast) + return ast + +def test_ast_iterators() -> None: + for yaml_file_name in pathlib.Path(".").glob("**/*.pdl"): + try: + ast, _ = parse_file(yaml_file_name) + iter_cpt = IterCounter() + iter_cpt.count(ast.root) + map_cpt = MapCounter() + map_cpt.count(ast.root) + assert iter_cpt.cpt == map_cpt.cpt, yaml_file_name + except PDLParseError: + pass From 2370e50eb262521712f3b8f81ac6a609c510bf1e Mon Sep 17 00:00:00 2001 From: Louis Mandel Date: Mon, 25 Nov 2024 11:05:03 -0500 Subject: [PATCH 2/2] Formatting --- tests/test_ast_utils.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/tests/test_ast_utils.py b/tests/test_ast_utils.py index 33bc613f9..36b6b1796 100644 --- a/tests/test_ast_utils.py +++ b/tests/test_ast_utils.py @@ -7,10 +7,11 @@ class Counter: def __init__(self): self.cpt = 0 - + def incr(self, *args): self.cpt += 1 + class IterCounter: def __init__(self): self.cpt = 0 @@ -24,14 +25,17 @@ class MapCounter: def __init__(self): self.cpt = 0 - def count(map_self, ast): + def count(map_self, ast): # pylint: disable=no-self-argument map_self.cpt += 1 + class C(MappedFunctions): - def f_block(c_self, block): + def f_block(_, block): # pylint: disable=no-self-argument return map_self.count(block) + _ = map_block_children(C(), ast) return ast + def test_ast_iterators() -> None: for yaml_file_name in pathlib.Path(".").glob("**/*.pdl"): try: