|
8 | 8 | class AssertCleanup: |
9 | 9 | def transform_asserts(self, code: str) -> str: |
10 | 10 | lines = code.splitlines() |
11 | | - result_lines = [] |
| 11 | + transformed_lines = [] |
12 | 12 |
|
13 | 13 | for line in lines: |
14 | 14 | transformed = self._transform_assert_line(line) |
15 | 15 | if transformed is not None: |
16 | | - result_lines.append(transformed) |
| 16 | + transformed_lines.append(transformed) |
17 | 17 | else: |
18 | | - result_lines.append(line) |
| 18 | + transformed_lines.append(line) |
19 | 19 |
|
20 | | - return "\n".join(result_lines) |
| 20 | + return "\n".join(transformed_lines) |
21 | 21 |
|
22 | 22 | def _transform_assert_line(self, line: str) -> Optional[str]: |
23 | 23 | indent = line[: len(line) - len(line.lstrip())] |
@@ -68,26 +68,25 @@ def _split_top_level_args(self, args_str: str) -> list[str]: |
68 | 68 |
|
69 | 69 | def clean_concolic_tests(test_suite_code: str) -> str: |
70 | 70 | try: |
71 | | - can_parse = True |
72 | 71 | tree = ast.parse(test_suite_code) |
| 72 | + can_parse = True |
73 | 73 | except SyntaxError: |
74 | 74 | can_parse = False |
75 | 75 |
|
76 | 76 | if not can_parse: |
77 | 77 | return AssertCleanup().transform_asserts(test_suite_code) |
78 | 78 |
|
79 | | - for node in ast.walk(tree): |
80 | | - if isinstance(node, ast.FunctionDef) and node.name.startswith("test_"): |
81 | | - new_body = [] |
82 | | - for stmt in node.body: |
83 | | - if isinstance(stmt, ast.Assert): |
84 | | - if isinstance(stmt.test, ast.Compare) and isinstance(stmt.test.left, ast.Call): |
85 | | - new_body.append(ast.Expr(value=stmt.test.left)) |
86 | | - else: |
87 | | - new_body.append(stmt) |
| 79 | + class AssertTransform(ast.NodeTransformer): |
| 80 | + def visit_Assert(self, node): |
| 81 | + if isinstance(node.test, ast.Compare) and isinstance(node.test.left, ast.Call): |
| 82 | + return ast.Expr(value=node.test.left, lineno=node.lineno, col_offset=node.col_offset) |
| 83 | + return node |
88 | 84 |
|
89 | | - else: |
90 | | - new_body.append(stmt) |
91 | | - node.body = new_body |
| 85 | + def visit_FunctionDef(self, node): |
| 86 | + if node.name.startswith("test_"): |
| 87 | + node.body = [self.visit(stmt) for stmt in node.body] |
| 88 | + return node |
92 | 89 |
|
| 90 | + transformer = AssertTransform() |
| 91 | + transformer.visit(tree) |
93 | 92 | return ast.unparse(tree).strip() |
0 commit comments