|
1 | 1 | from __future__ import annotations |
2 | 2 |
|
3 | 3 | import ast |
| 4 | +import re |
4 | 5 | from collections import defaultdict |
5 | 6 | from functools import lru_cache |
6 | | -from typing import TYPE_CHECKING, TypeVar |
| 7 | +from typing import TYPE_CHECKING, Optional, TypeVar |
7 | 8 |
|
8 | 9 | import libcst as cst |
9 | 10 |
|
@@ -338,25 +339,91 @@ def function_to_optimize_original_worktree_fqn( |
338 | 339 | ) |
339 | 340 |
|
340 | 341 |
|
| 342 | +class AssertCleanup: |
| 343 | + def transform_asserts(self, code: str) -> str: |
| 344 | + lines = code.splitlines() |
| 345 | + result_lines = [] |
| 346 | + |
| 347 | + for line in lines: |
| 348 | + transformed = self._transform_assert_line(line) |
| 349 | + if transformed is not None: |
| 350 | + result_lines.append(transformed) |
| 351 | + else: |
| 352 | + result_lines.append(line) |
| 353 | + |
| 354 | + return "\n".join(result_lines) |
| 355 | + |
| 356 | + def _transform_assert_line(self, line: str) -> Optional[str]: |
| 357 | + indent = line[: len(line) - len(line.lstrip())] |
| 358 | + |
| 359 | + assert_match = re.match(r"\s*assert\s+(.*?)(?:\s*==\s*.*)?$", line) |
| 360 | + if assert_match: |
| 361 | + expression = assert_match.group(1).strip() |
| 362 | + if expression.startswith("not "): |
| 363 | + return f"{indent}{expression}" |
| 364 | + |
| 365 | + expression = re.sub(r"[,;]\s*$", "", expression) |
| 366 | + return f"{indent}{expression}" |
| 367 | + |
| 368 | + unittest_match = re.match(r"(\s*)self\.assert([A-Za-z]+)\((.*)\)$", line) |
| 369 | + if unittest_match: |
| 370 | + indent, assert_method, args = unittest_match.groups() |
| 371 | + |
| 372 | + if args: |
| 373 | + arg_parts = self._split_top_level_args(args) |
| 374 | + if arg_parts and arg_parts[0]: |
| 375 | + return f"{indent}{arg_parts[0]}" |
| 376 | + |
| 377 | + return None |
| 378 | + |
| 379 | + def _split_top_level_args(self, args_str: str) -> list[str]: |
| 380 | + result = [] |
| 381 | + current = [] |
| 382 | + depth = 0 |
| 383 | + |
| 384 | + for char in args_str: |
| 385 | + if char in "([{": |
| 386 | + depth += 1 |
| 387 | + current.append(char) |
| 388 | + elif char in ")]}": |
| 389 | + depth -= 1 |
| 390 | + current.append(char) |
| 391 | + elif char == "," and depth == 0: |
| 392 | + result.append("".join(current).strip()) |
| 393 | + current = [] |
| 394 | + else: |
| 395 | + current.append(char) |
| 396 | + |
| 397 | + if current: |
| 398 | + result.append("".join(current).strip()) |
| 399 | + |
| 400 | + return result |
| 401 | + |
| 402 | + |
341 | 403 | def clean_concolic_tests(test_suite_code: str) -> str: |
342 | 404 | try: |
| 405 | + can_parse = True |
343 | 406 | tree = ast.parse(test_suite_code) |
| 407 | + except SyntaxError: |
| 408 | + can_parse = False |
| 409 | + |
| 410 | + if not can_parse: |
| 411 | + return AssertCleanup().transform_asserts(test_suite_code) |
344 | 412 |
|
345 | | - for node in ast.walk(tree): |
346 | | - if isinstance(node, ast.FunctionDef) and node.name.startswith("test_"): |
347 | | - new_body = [] |
348 | | - for stmt in node.body: |
349 | | - if isinstance(stmt, ast.Assert): |
350 | | - if isinstance(stmt.test, ast.Compare) and isinstance(stmt.test.left, ast.Call): |
351 | | - new_body.append(ast.Expr(value=stmt.test.left)) |
352 | | - else: |
353 | | - new_body.append(stmt) |
| 413 | + tree = ast.parse(test_suite_code) |
354 | 414 |
|
| 415 | + for node in ast.walk(tree): |
| 416 | + if isinstance(node, ast.FunctionDef) and node.name.startswith("test_"): |
| 417 | + new_body = [] |
| 418 | + for stmt in node.body: |
| 419 | + if isinstance(stmt, ast.Assert): |
| 420 | + if isinstance(stmt.test, ast.Compare) and isinstance(stmt.test.left, ast.Call): |
| 421 | + new_body.append(ast.Expr(value=stmt.test.left)) |
355 | 422 | else: |
356 | 423 | new_body.append(stmt) |
357 | | - node.body = new_body |
358 | 424 |
|
359 | | - return ast.unparse(tree).strip() |
360 | | - except SyntaxError: |
361 | | - logger.warning("Failed to parse and modify CrossHair generated tests. Using original output.") |
362 | | - return test_suite_code |
| 425 | + else: |
| 426 | + new_body.append(stmt) |
| 427 | + node.body = new_body |
| 428 | + |
| 429 | + return ast.unparse(tree).strip() |
0 commit comments