Skip to content

Commit b79db77

Browse files
committed
cleanup
1 parent c6c6fff commit b79db77

File tree

3 files changed

+98
-18
lines changed

3 files changed

+98
-18
lines changed

codeflash/code_utils/code_replacer.py

Lines changed: 82 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
11
from __future__ import annotations
22

33
import ast
4+
import re
45
from collections import defaultdict
56
from functools import lru_cache
6-
from typing import TYPE_CHECKING, TypeVar
7+
from typing import TYPE_CHECKING, Optional, TypeVar
78

89
import libcst as cst
910

@@ -338,25 +339,91 @@ def function_to_optimize_original_worktree_fqn(
338339
)
339340

340341

342+
class AssertCleanup:
343+
def transform_asserts(self, code: str) -> str:
344+
lines = code.splitlines()
345+
result_lines = []
346+
347+
for line in lines:
348+
transformed = self._transform_assert_line(line)
349+
if transformed is not None:
350+
result_lines.append(transformed)
351+
else:
352+
result_lines.append(line)
353+
354+
return "\n".join(result_lines)
355+
356+
def _transform_assert_line(self, line: str) -> Optional[str]:
357+
indent = line[: len(line) - len(line.lstrip())]
358+
359+
assert_match = re.match(r"\s*assert\s+(.*?)(?:\s*==\s*.*)?$", line)
360+
if assert_match:
361+
expression = assert_match.group(1).strip()
362+
if expression.startswith("not "):
363+
return f"{indent}{expression}"
364+
365+
expression = re.sub(r"[,;]\s*$", "", expression)
366+
return f"{indent}{expression}"
367+
368+
unittest_match = re.match(r"(\s*)self\.assert([A-Za-z]+)\((.*)\)$", line)
369+
if unittest_match:
370+
indent, assert_method, args = unittest_match.groups()
371+
372+
if args:
373+
arg_parts = self._split_top_level_args(args)
374+
if arg_parts and arg_parts[0]:
375+
return f"{indent}{arg_parts[0]}"
376+
377+
return None
378+
379+
def _split_top_level_args(self, args_str: str) -> list[str]:
380+
result = []
381+
current = []
382+
depth = 0
383+
384+
for char in args_str:
385+
if char in "([{":
386+
depth += 1
387+
current.append(char)
388+
elif char in ")]}":
389+
depth -= 1
390+
current.append(char)
391+
elif char == "," and depth == 0:
392+
result.append("".join(current).strip())
393+
current = []
394+
else:
395+
current.append(char)
396+
397+
if current:
398+
result.append("".join(current).strip())
399+
400+
return result
401+
402+
341403
def clean_concolic_tests(test_suite_code: str) -> str:
342404
try:
405+
can_parse = True
343406
tree = ast.parse(test_suite_code)
407+
except SyntaxError:
408+
can_parse = False
409+
410+
if not can_parse:
411+
return AssertCleanup().transform_asserts(test_suite_code)
344412

345-
for node in ast.walk(tree):
346-
if isinstance(node, ast.FunctionDef) and node.name.startswith("test_"):
347-
new_body = []
348-
for stmt in node.body:
349-
if isinstance(stmt, ast.Assert):
350-
if isinstance(stmt.test, ast.Compare) and isinstance(stmt.test.left, ast.Call):
351-
new_body.append(ast.Expr(value=stmt.test.left))
352-
else:
353-
new_body.append(stmt)
413+
tree = ast.parse(test_suite_code)
354414

415+
for node in ast.walk(tree):
416+
if isinstance(node, ast.FunctionDef) and node.name.startswith("test_"):
417+
new_body = []
418+
for stmt in node.body:
419+
if isinstance(stmt, ast.Assert):
420+
if isinstance(stmt.test, ast.Compare) and isinstance(stmt.test.left, ast.Call):
421+
new_body.append(ast.Expr(value=stmt.test.left))
355422
else:
356423
new_body.append(stmt)
357-
node.body = new_body
358424

359-
return ast.unparse(tree).strip()
360-
except SyntaxError:
361-
logger.warning("Failed to parse and modify CrossHair generated tests. Using original output.")
362-
return test_suite_code
425+
else:
426+
new_body.append(stmt)
427+
node.body = new_body
428+
429+
return ast.unparse(tree).strip()

codeflash/verification/concolic_testing.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
from __future__ import annotations
22

33
import ast
4-
import difflib
54
import subprocess
65
import tempfile
76
from argparse import Namespace
@@ -60,8 +59,8 @@ def generate_concolic_tests(
6059
return function_to_concolic_tests, concolic_test_suite_code
6160

6261
if cover_result.returncode == 0:
63-
original_code: str = cover_result.stdout
64-
concolic_test_suite_code: str = clean_concolic_tests(original_code)
62+
generated_concolic_test: str = cover_result.stdout
63+
concolic_test_suite_code: str = clean_concolic_tests(generated_concolic_test)
6564
concolic_test_suite_dir = Path(tempfile.mkdtemp(dir=test_cfg.concolic_test_root_dir))
6665
concolic_test_suite_path = concolic_test_suite_dir / "test_concolic_coverage.py"
6766
concolic_test_suite_path.write_text(concolic_test_suite_code, encoding="utf8")

tests/test_code_utils.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -427,3 +427,17 @@ def test_calculate_tuple_sum(my_tuple: tuple[int, int, int]) -> None:
427427
calculate_tuple_sum((1, 2, 3))
428428
"""
429429
assert cleaned_code == expected_cleaned_code.strip()
430+
431+
concolic_generated_repr_code = """from src.blib2to3.pgen2.grammar import Grammar
432+
433+
def test_Grammar_copy():
434+
assert Grammar.copy(Grammar()) == <src.blib2to3.pgen2.grammar.Grammar object at 0x104c30f50>
435+
"""
436+
cleaned_code = clean_concolic_tests(concolic_generated_repr_code)
437+
expected_cleaned_code = """
438+
from src.blib2to3.pgen2.grammar import Grammar
439+
440+
def test_Grammar_copy():
441+
Grammar.copy(Grammar())
442+
"""
443+
assert cleaned_code == expected_cleaned_code.strip()

0 commit comments

Comments
 (0)