Skip to content

Commit c6c6fff

Browse files
committed
first pass
1 parent c93a7ca commit c6c6fff

File tree

3 files changed

+82
-2
lines changed

3 files changed

+82
-2
lines changed

codeflash/code_utils/code_replacer.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -336,3 +336,27 @@ def function_to_optimize_original_worktree_fqn(
336336
+ "."
337337
+ function_to_optimize.qualified_name
338338
)
339+
340+
341+
def clean_concolic_tests(test_suite_code: str) -> str:
342+
try:
343+
tree = ast.parse(test_suite_code)
344+
345+
for node in ast.walk(tree):
346+
if isinstance(node, ast.FunctionDef) and node.name.startswith("test_"):
347+
new_body = []
348+
for stmt in node.body:
349+
if isinstance(stmt, ast.Assert):
350+
if isinstance(stmt.test, ast.Compare) and isinstance(stmt.test.left, ast.Call):
351+
new_body.append(ast.Expr(value=stmt.test.left))
352+
else:
353+
new_body.append(stmt)
354+
355+
else:
356+
new_body.append(stmt)
357+
node.body = new_body
358+
359+
return ast.unparse(tree).strip()
360+
except SyntaxError:
361+
logger.warning("Failed to parse and modify CrossHair generated tests. Using original output.")
362+
return test_suite_code

codeflash/verification/concolic_testing.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,14 @@
11
from __future__ import annotations
22

33
import ast
4+
import difflib
45
import subprocess
56
import tempfile
67
from argparse import Namespace
78
from pathlib import Path
89

910
from codeflash.cli_cmds.console import console, logger
11+
from codeflash.code_utils.code_replacer import clean_concolic_tests
1012
from codeflash.code_utils.compat import SAFE_SYS_EXECUTABLE
1113
from codeflash.code_utils.static_analysis import has_typed_parameters
1214
from codeflash.discovery.discover_unit_tests import discover_unit_tests
@@ -21,7 +23,11 @@ def generate_concolic_tests(
2123
) -> tuple[dict[str, list[FunctionCalledInTest]], str]:
2224
function_to_concolic_tests = {}
2325
concolic_test_suite_code = ""
24-
if test_cfg.concolic_test_root_dir and has_typed_parameters(function_to_optimize_ast, function_to_optimize.parents):
26+
if (
27+
test_cfg.concolic_test_root_dir
28+
and isinstance(function_to_optimize_ast, (ast.FunctionDef, ast.AsyncFunctionDef))
29+
and has_typed_parameters(function_to_optimize_ast, function_to_optimize.parents)
30+
):
2531
logger.info("Generating concolic opcode coverage tests for the original code…")
2632
console.rule()
2733
try:
@@ -54,7 +60,8 @@ def generate_concolic_tests(
5460
return function_to_concolic_tests, concolic_test_suite_code
5561

5662
if cover_result.returncode == 0:
57-
concolic_test_suite_code: str = cover_result.stdout
63+
original_code: str = cover_result.stdout
64+
concolic_test_suite_code: str = clean_concolic_tests(original_code)
5865
concolic_test_suite_dir = Path(tempfile.mkdtemp(dir=test_cfg.concolic_test_root_dir))
5966
concolic_test_suite_path = concolic_test_suite_dir / "test_concolic_coverage.py"
6067
concolic_test_suite_path.write_text(concolic_test_suite_code, encoding="utf8")

tests/test_code_utils.py

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66

77
import pytest
88

9+
from codeflash.code_utils.code_replacer import clean_concolic_tests
910
from codeflash.code_utils.code_utils import (
1011
cleanup_paths,
1112
file_name_from_test_module_name,
@@ -378,3 +379,51 @@ def test_prepare_coverage_files(mock_get_run_tmp_file: MagicMock) -> None:
378379
assert coverage_database_file == mock_coverage_file
379380
assert coveragercfile == mock_coveragerc_file
380381
mock_coveragerc_file.write_text.assert_called_once_with(f"[run]\n branch = True\ndata_file={mock_coverage_file}\n")
382+
383+
384+
def test_clean_concolic_tests() -> None:
385+
original_code = """
386+
def test_add_numbers(x: int, y: int) -> None:
387+
assert add_numbers(1, 2) == 3
388+
389+
390+
def test_concatenate_strings(s1: str, s2: str) -> None:
391+
assert concatenate_strings("hello", "world") == "helloworld"
392+
393+
394+
def test_append_to_list(my_list: list[int], element: int) -> None:
395+
assert append_to_list([1, 2, 3], 4) == [1, 2, 3, 4]
396+
397+
398+
def test_get_dict_value(my_dict: dict[str, int], key: str) -> None:
399+
assert get_dict_value({"a": 1, "b": 2}, "a") == 1
400+
401+
402+
def test_union_sets(set1: set[int], set2: set[int]) -> None:
403+
assert union_sets({1, 2, 3}, {3, 4, 5}) == {1, 2, 3, 4, 5}
404+
405+
def test_calculate_tuple_sum(my_tuple: tuple[int, int, int]) -> None:
406+
assert calculate_tuple_sum((1, 2, 3)) == 6
407+
"""
408+
409+
cleaned_code = clean_concolic_tests(original_code)
410+
expected_cleaned_code = """
411+
def test_add_numbers(x: int, y: int) -> None:
412+
add_numbers(1, 2)
413+
414+
def test_concatenate_strings(s1: str, s2: str) -> None:
415+
concatenate_strings('hello', 'world')
416+
417+
def test_append_to_list(my_list: list[int], element: int) -> None:
418+
append_to_list([1, 2, 3], 4)
419+
420+
def test_get_dict_value(my_dict: dict[str, int], key: str) -> None:
421+
get_dict_value({'a': 1, 'b': 2}, 'a')
422+
423+
def test_union_sets(set1: set[int], set2: set[int]) -> None:
424+
union_sets({1, 2, 3}, {3, 4, 5})
425+
426+
def test_calculate_tuple_sum(my_tuple: tuple[int, int, int]) -> None:
427+
calculate_tuple_sum((1, 2, 3))
428+
"""
429+
assert cleaned_code == expected_cleaned_code.strip()

0 commit comments

Comments
 (0)