Skip to content

Commit 69e43dd

Browse files
authored
Merge pull request #26 from codeflash-ai/clean_concolic_tests
Clean concolic tests
2 parents e758b64 + b9c19e9 commit 69e43dd

File tree

5 files changed

+185
-18
lines changed

5 files changed

+185
-18
lines changed

codeflash/code_utils/code_replacer.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
11
from __future__ import annotations
22

33
import ast
4+
import re
45
from collections import defaultdict
56
from functools import lru_cache
6-
from typing import TYPE_CHECKING, TypeVar
7+
from typing import TYPE_CHECKING, Optional, TypeVar
78

89
import libcst as cst
910

Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,95 @@
1+
from __future__ import annotations
2+
3+
import ast
4+
import re
5+
from typing import Optional
6+
7+
8+
class AssertCleanup:
9+
def transform_asserts(self, code: str) -> str:
10+
lines = code.splitlines()
11+
result_lines = []
12+
13+
for line in lines:
14+
transformed = self._transform_assert_line(line)
15+
result_lines.append(transformed if transformed is not None else line)
16+
17+
return "\n".join(result_lines)
18+
19+
def _transform_assert_line(self, line: str) -> Optional[str]:
20+
indent = line[: len(line) - len(line.lstrip())]
21+
22+
assert_match = self.assert_re.match(line)
23+
if assert_match:
24+
expression = assert_match.group(1).strip()
25+
if expression.startswith("not "):
26+
return f"{indent}{expression}"
27+
28+
expression = expression.rstrip(",;")
29+
return f"{indent}{expression}"
30+
31+
unittest_match = self.unittest_re.match(line)
32+
if unittest_match:
33+
indent, assert_method, args = unittest_match.groups()
34+
35+
if args:
36+
arg_parts = self._split_top_level_args(args)
37+
if arg_parts and arg_parts[0]:
38+
return f"{indent}{arg_parts[0]}"
39+
40+
return None
41+
42+
def _split_top_level_args(self, args_str: str) -> list[str]:
43+
result = []
44+
current = []
45+
depth = 0
46+
47+
for char in args_str:
48+
if char in "([{":
49+
depth += 1
50+
current.append(char)
51+
elif char in ")]}":
52+
depth -= 1
53+
current.append(char)
54+
elif char == "," and depth == 0:
55+
result.append("".join(current).strip())
56+
current = []
57+
else:
58+
current.append(char)
59+
60+
if current:
61+
result.append("".join(current).strip())
62+
63+
return result
64+
65+
def __init__(self):
66+
# Pre-compiling regular expressions for faster execution
67+
self.assert_re = re.compile(r"\s*assert\s+(.*?)(?:\s*==\s*.*)?$")
68+
self.unittest_re = re.compile(r"(\s*)self\.assert([A-Za-z]+)\((.*)\)$")
69+
70+
71+
def clean_concolic_tests(test_suite_code: str) -> str:
72+
try:
73+
can_parse = True
74+
tree = ast.parse(test_suite_code)
75+
except SyntaxError:
76+
can_parse = False
77+
78+
if not can_parse:
79+
return AssertCleanup().transform_asserts(test_suite_code)
80+
81+
for node in ast.walk(tree):
82+
if isinstance(node, ast.FunctionDef) and node.name.startswith("test_"):
83+
new_body = []
84+
for stmt in node.body:
85+
if isinstance(stmt, ast.Assert):
86+
if isinstance(stmt.test, ast.Compare) and isinstance(stmt.test.left, ast.Call):
87+
new_body.append(ast.Expr(value=stmt.test.left))
88+
else:
89+
new_body.append(stmt)
90+
91+
else:
92+
new_body.append(stmt)
93+
node.body = new_body
94+
95+
return ast.unparse(tree).strip()

codeflash/discovery/functions_to_optimize.py

Lines changed: 17 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from __future__ import annotations
22

33
import ast
4+
import json
45
import os
56
import random
67
import warnings
@@ -156,9 +157,9 @@ def get_functions_to_optimize(
156157
project_root: Path,
157158
module_root: Path,
158159
) -> tuple[dict[Path, list[FunctionToOptimize]], int]:
159-
assert (
160-
sum([bool(optimize_all), bool(replay_test), bool(file)]) <= 1
161-
), "Only one of optimize_all, replay_test, or file should be provided"
160+
assert sum([bool(optimize_all), bool(replay_test), bool(file)]) <= 1, (
161+
"Only one of optimize_all, replay_test, or file should be provided"
162+
)
162163
functions: dict[str, list[FunctionToOptimize]]
163164
with warnings.catch_warnings():
164165
warnings.simplefilter(action="ignore", category=SyntaxWarning)
@@ -434,9 +435,7 @@ def filter_functions(
434435
test_functions_removed_count += len(functions)
435436
continue
436437
if file_path in ignore_paths or any(
437-
# file_path.startswith(ignore_path + os.sep) for ignore_path in ignore_paths if ignore_path
438-
file_path.startswith(str(ignore_path) + os.sep)
439-
for ignore_path in ignore_paths
438+
file_path.startswith(str(ignore_path) + os.sep) for ignore_path in ignore_paths
440439
):
441440
ignore_paths_removed_count += 1
442441
continue
@@ -457,15 +456,17 @@ def filter_functions(
457456
malformed_paths_count += 1
458457
continue
459458
if blocklist_funcs:
460-
for function in functions.copy():
461-
path = Path(function.file_path).name
462-
if path in blocklist_funcs and function.function_name in blocklist_funcs[path]:
463-
functions.remove(function)
464-
logger.debug(f"Skipping {function.function_name} in {path} as it has already been optimized")
465-
continue
466-
459+
functions = [
460+
function
461+
for function in functions
462+
if not (
463+
function.file_path.name in blocklist_funcs
464+
and function.qualified_name in blocklist_funcs[function.file_path.name]
465+
)
466+
]
467467
filtered_modified_functions[file_path] = functions
468468
functions_count += len(functions)
469+
469470
if not disable_logs:
470471
log_info = {
471472
f"{test_functions_removed_count} test function{'s' if test_functions_removed_count != 1 else ''}": test_functions_removed_count,
@@ -475,10 +476,11 @@ def filter_functions(
475476
f"{ignore_paths_removed_count} file{'s' if ignore_paths_removed_count != 1 else ''} from ignored paths": ignore_paths_removed_count,
476477
f"{submodule_ignored_paths_count} file{'s' if submodule_ignored_paths_count != 1 else ''} from ignored submodules": submodule_ignored_paths_count,
477478
}
478-
log_string: str
479-
if log_string := "\n".join([k for k, v in log_info.items() if v > 0]):
479+
log_string = "\n".join([k for k, v in log_info.items() if v > 0])
480+
if log_string:
480481
logger.info(f"Ignoring: {log_string}")
481482
console.rule()
483+
482484
return {Path(k): v for k, v in filtered_modified_functions.items() if v}, functions_count
483485

484486

codeflash/verification/concolic_testing.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from pathlib import Path
88

99
from codeflash.cli_cmds.console import console, logger
10+
from codeflash.code_utils.concolic_utils import clean_concolic_tests
1011
from codeflash.code_utils.compat import SAFE_SYS_EXECUTABLE
1112
from codeflash.code_utils.static_analysis import has_typed_parameters
1213
from codeflash.discovery.discover_unit_tests import discover_unit_tests
@@ -21,7 +22,11 @@ def generate_concolic_tests(
2122
) -> tuple[dict[str, list[FunctionCalledInTest]], str]:
2223
function_to_concolic_tests = {}
2324
concolic_test_suite_code = ""
24-
if test_cfg.concolic_test_root_dir and has_typed_parameters(function_to_optimize_ast, function_to_optimize.parents):
25+
if (
26+
test_cfg.concolic_test_root_dir
27+
and isinstance(function_to_optimize_ast, (ast.FunctionDef, ast.AsyncFunctionDef))
28+
and has_typed_parameters(function_to_optimize_ast, function_to_optimize.parents)
29+
):
2530
logger.info("Generating concolic opcode coverage tests for the original code…")
2631
console.rule()
2732
try:
@@ -54,7 +59,8 @@ def generate_concolic_tests(
5459
return function_to_concolic_tests, concolic_test_suite_code
5560

5661
if cover_result.returncode == 0:
57-
concolic_test_suite_code: str = cover_result.stdout
62+
generated_concolic_test: str = cover_result.stdout
63+
concolic_test_suite_code: str = clean_concolic_tests(generated_concolic_test)
5864
concolic_test_suite_dir = Path(tempfile.mkdtemp(dir=test_cfg.concolic_test_root_dir))
5965
concolic_test_suite_path = concolic_test_suite_dir / "test_concolic_coverage.py"
6066
concolic_test_suite_path.write_text(concolic_test_suite_code, encoding="utf8")

tests/test_code_utils.py

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
module_name_from_file_path,
1919
path_belongs_to_site_packages,
2020
)
21+
from codeflash.code_utils.concolic_utils import clean_concolic_tests
2122
from codeflash.code_utils.coverage_utils import generate_candidates, prepare_coverage_files
2223

2324

@@ -378,3 +379,65 @@ def test_prepare_coverage_files(mock_get_run_tmp_file: MagicMock) -> None:
378379
assert coverage_database_file == mock_coverage_file
379380
assert coveragercfile == mock_coveragerc_file
380381
mock_coveragerc_file.write_text.assert_called_once_with(f"[run]\n branch = True\ndata_file={mock_coverage_file}\n")
382+
383+
384+
def test_clean_concolic_tests() -> None:
385+
original_code = """
386+
def test_add_numbers(x: int, y: int) -> None:
387+
assert add_numbers(1, 2) == 3
388+
389+
390+
def test_concatenate_strings(s1: str, s2: str) -> None:
391+
assert concatenate_strings("hello", "world") == "helloworld"
392+
393+
394+
def test_append_to_list(my_list: list[int], element: int) -> None:
395+
assert append_to_list([1, 2, 3], 4) == [1, 2, 3, 4]
396+
397+
398+
def test_get_dict_value(my_dict: dict[str, int], key: str) -> None:
399+
assert get_dict_value({"a": 1, "b": 2}, "a") == 1
400+
401+
402+
def test_union_sets(set1: set[int], set2: set[int]) -> None:
403+
assert union_sets({1, 2, 3}, {3, 4, 5}) == {1, 2, 3, 4, 5}
404+
405+
def test_calculate_tuple_sum(my_tuple: tuple[int, int, int]) -> None:
406+
assert calculate_tuple_sum((1, 2, 3)) == 6
407+
"""
408+
409+
cleaned_code = clean_concolic_tests(original_code)
410+
expected_cleaned_code = """
411+
def test_add_numbers(x: int, y: int) -> None:
412+
add_numbers(1, 2)
413+
414+
def test_concatenate_strings(s1: str, s2: str) -> None:
415+
concatenate_strings('hello', 'world')
416+
417+
def test_append_to_list(my_list: list[int], element: int) -> None:
418+
append_to_list([1, 2, 3], 4)
419+
420+
def test_get_dict_value(my_dict: dict[str, int], key: str) -> None:
421+
get_dict_value({'a': 1, 'b': 2}, 'a')
422+
423+
def test_union_sets(set1: set[int], set2: set[int]) -> None:
424+
union_sets({1, 2, 3}, {3, 4, 5})
425+
426+
def test_calculate_tuple_sum(my_tuple: tuple[int, int, int]) -> None:
427+
calculate_tuple_sum((1, 2, 3))
428+
"""
429+
assert cleaned_code == expected_cleaned_code.strip()
430+
431+
concolic_generated_repr_code = """from src.blib2to3.pgen2.grammar import Grammar
432+
433+
def test_Grammar_copy():
434+
assert Grammar.copy(Grammar()) == <src.blib2to3.pgen2.grammar.Grammar object at 0x104c30f50>
435+
"""
436+
cleaned_code = clean_concolic_tests(concolic_generated_repr_code)
437+
expected_cleaned_code = """
438+
from src.blib2to3.pgen2.grammar import Grammar
439+
440+
def test_Grammar_copy():
441+
Grammar.copy(Grammar())
442+
"""
443+
assert cleaned_code == expected_cleaned_code.strip()

0 commit comments

Comments
 (0)