Skip to content
93 changes: 92 additions & 1 deletion codeflash/code_utils/code_replacer.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
from __future__ import annotations

import ast
import re
from collections import defaultdict
from functools import lru_cache
from typing import TYPE_CHECKING, TypeVar
from typing import TYPE_CHECKING, Optional, TypeVar

import libcst as cst

Expand Down Expand Up @@ -336,3 +337,93 @@ def function_to_optimize_original_worktree_fqn(
+ "."
+ function_to_optimize.qualified_name
)


class AssertCleanup:
def transform_asserts(self, code: str) -> str:
lines = code.splitlines()
result_lines = []

for line in lines:
transformed = self._transform_assert_line(line)
if transformed is not None:
result_lines.append(transformed)
else:
result_lines.append(line)

return "\n".join(result_lines)

def _transform_assert_line(self, line: str) -> Optional[str]:
indent = line[: len(line) - len(line.lstrip())]

assert_match = re.match(r"\s*assert\s+(.*?)(?:\s*==\s*.*)?$", line)
if assert_match:
expression = assert_match.group(1).strip()
if expression.startswith("not "):
return f"{indent}{expression}"

expression = re.sub(r"[,;]\s*$", "", expression)
return f"{indent}{expression}"

unittest_match = re.match(r"(\s*)self\.assert([A-Za-z]+)\((.*)\)$", line)
if unittest_match:
indent, assert_method, args = unittest_match.groups()

if args:
arg_parts = self._split_top_level_args(args)
if arg_parts and arg_parts[0]:
return f"{indent}{arg_parts[0]}"

return None

def _split_top_level_args(self, args_str: str) -> list[str]:
result = []
current = []
depth = 0

for char in args_str:
if char in "([{":
depth += 1
current.append(char)
elif char in ")]}":
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
for char in args_str:
if char in "([{":
depth += 1
current.append(char)
elif char in ")]}":
char_to_depth_change = {"(": 1, "[": 1, "{": 1, ")": -1, "]": -1, "}": -1}
if char in char_to_depth_change:
depth += char_to_depth_change[char]

depth -= 1
current.append(char)
elif char == "," and depth == 0:
result.append("".join(current).strip())
current = []
else:
current.append(char)

if current:
result.append("".join(current).strip())

return result


def clean_concolic_tests(test_suite_code: str) -> str:
try:
can_parse = True
tree = ast.parse(test_suite_code)
except SyntaxError:
can_parse = False

if not can_parse:
return AssertCleanup().transform_asserts(test_suite_code)

tree = ast.parse(test_suite_code)

for node in ast.walk(tree):
if isinstance(node, ast.FunctionDef) and node.name.startswith("test_"):
new_body = []
for stmt in node.body:
if isinstance(stmt, ast.Assert):
if isinstance(stmt.test, ast.Compare) and isinstance(stmt.test.left, ast.Call):
new_body.append(ast.Expr(value=stmt.test.left))
else:
new_body.append(stmt)

else:
new_body.append(stmt)
node.body = new_body

return ast.unparse(tree).strip()
32 changes: 17 additions & 15 deletions codeflash/discovery/functions_to_optimize.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

import ast
import json
import os
import random
import warnings
Expand Down Expand Up @@ -156,9 +157,9 @@ def get_functions_to_optimize(
project_root: Path,
module_root: Path,
) -> tuple[dict[Path, list[FunctionToOptimize]], int]:
assert (
sum([bool(optimize_all), bool(replay_test), bool(file)]) <= 1
), "Only one of optimize_all, replay_test, or file should be provided"
assert sum([bool(optimize_all), bool(replay_test), bool(file)]) <= 1, (
"Only one of optimize_all, replay_test, or file should be provided"
)
functions: dict[str, list[FunctionToOptimize]]
with warnings.catch_warnings():
warnings.simplefilter(action="ignore", category=SyntaxWarning)
Expand Down Expand Up @@ -434,9 +435,7 @@ def filter_functions(
test_functions_removed_count += len(functions)
continue
if file_path in ignore_paths or any(
# file_path.startswith(ignore_path + os.sep) for ignore_path in ignore_paths if ignore_path
file_path.startswith(str(ignore_path) + os.sep)
for ignore_path in ignore_paths
file_path.startswith(str(ignore_path) + os.sep) for ignore_path in ignore_paths
):
ignore_paths_removed_count += 1
continue
Expand All @@ -457,15 +456,17 @@ def filter_functions(
malformed_paths_count += 1
continue
if blocklist_funcs:
for function in functions.copy():
path = Path(function.file_path).name
if path in blocklist_funcs and function.function_name in blocklist_funcs[path]:
functions.remove(function)
logger.debug(f"Skipping {function.function_name} in {path} as it has already been optimized")
continue

functions = [
function
for function in functions
if not (
Path(function.file_path).name in blocklist_funcs
and function.qualified_name in blocklist_funcs[Path(function.file_path).name]
)
]
filtered_modified_functions[file_path] = functions
functions_count += len(functions)

if not disable_logs:
log_info = {
f"{test_functions_removed_count} test function{'s' if test_functions_removed_count != 1 else ''}": test_functions_removed_count,
Expand All @@ -475,10 +476,11 @@ def filter_functions(
f"{ignore_paths_removed_count} file{'s' if ignore_paths_removed_count != 1 else ''} from ignored paths": ignore_paths_removed_count,
f"{submodule_ignored_paths_count} file{'s' if submodule_ignored_paths_count != 1 else ''} from ignored submodules": submodule_ignored_paths_count,
}
log_string: str
if log_string := "\n".join([k for k, v in log_info.items() if v > 0]):
log_string = "\n".join([k for k, v in log_info.items() if v > 0])
if log_string:
logger.info(f"Ignoring: {log_string}")
console.rule()

return {Path(k): v for k, v in filtered_modified_functions.items() if v}, functions_count


Expand Down
10 changes: 8 additions & 2 deletions codeflash/verification/concolic_testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from pathlib import Path

from codeflash.cli_cmds.console import console, logger
from codeflash.code_utils.code_replacer import clean_concolic_tests
from codeflash.code_utils.compat import SAFE_SYS_EXECUTABLE
from codeflash.code_utils.static_analysis import has_typed_parameters
from codeflash.discovery.discover_unit_tests import discover_unit_tests
Expand All @@ -21,7 +22,11 @@ def generate_concolic_tests(
) -> tuple[dict[str, list[FunctionCalledInTest]], str]:
function_to_concolic_tests = {}
concolic_test_suite_code = ""
if test_cfg.concolic_test_root_dir and has_typed_parameters(function_to_optimize_ast, function_to_optimize.parents):
if (
test_cfg.concolic_test_root_dir
and isinstance(function_to_optimize_ast, (ast.FunctionDef, ast.AsyncFunctionDef))
and has_typed_parameters(function_to_optimize_ast, function_to_optimize.parents)
):
logger.info("Generating concolic opcode coverage tests for the original code…")
console.rule()
try:
Expand Down Expand Up @@ -54,7 +59,8 @@ def generate_concolic_tests(
return function_to_concolic_tests, concolic_test_suite_code

if cover_result.returncode == 0:
concolic_test_suite_code: str = cover_result.stdout
generated_concolic_test: str = cover_result.stdout
concolic_test_suite_code: str = clean_concolic_tests(generated_concolic_test)
concolic_test_suite_dir = Path(tempfile.mkdtemp(dir=test_cfg.concolic_test_root_dir))
concolic_test_suite_path = concolic_test_suite_dir / "test_concolic_coverage.py"
concolic_test_suite_path.write_text(concolic_test_suite_code, encoding="utf8")
Expand Down
63 changes: 63 additions & 0 deletions tests/test_code_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

import pytest

from codeflash.code_utils.code_replacer import clean_concolic_tests
from codeflash.code_utils.code_utils import (
cleanup_paths,
file_name_from_test_module_name,
Expand Down Expand Up @@ -378,3 +379,65 @@ def test_prepare_coverage_files(mock_get_run_tmp_file: MagicMock) -> None:
assert coverage_database_file == mock_coverage_file
assert coveragercfile == mock_coveragerc_file
mock_coveragerc_file.write_text.assert_called_once_with(f"[run]\n branch = True\ndata_file={mock_coverage_file}\n")


def test_clean_concolic_tests() -> None:
original_code = """
def test_add_numbers(x: int, y: int) -> None:
assert add_numbers(1, 2) == 3


def test_concatenate_strings(s1: str, s2: str) -> None:
assert concatenate_strings("hello", "world") == "helloworld"


def test_append_to_list(my_list: list[int], element: int) -> None:
assert append_to_list([1, 2, 3], 4) == [1, 2, 3, 4]


def test_get_dict_value(my_dict: dict[str, int], key: str) -> None:
assert get_dict_value({"a": 1, "b": 2}, "a") == 1


def test_union_sets(set1: set[int], set2: set[int]) -> None:
assert union_sets({1, 2, 3}, {3, 4, 5}) == {1, 2, 3, 4, 5}

def test_calculate_tuple_sum(my_tuple: tuple[int, int, int]) -> None:
assert calculate_tuple_sum((1, 2, 3)) == 6
"""

cleaned_code = clean_concolic_tests(original_code)
expected_cleaned_code = """
def test_add_numbers(x: int, y: int) -> None:
add_numbers(1, 2)

def test_concatenate_strings(s1: str, s2: str) -> None:
concatenate_strings('hello', 'world')

def test_append_to_list(my_list: list[int], element: int) -> None:
append_to_list([1, 2, 3], 4)

def test_get_dict_value(my_dict: dict[str, int], key: str) -> None:
get_dict_value({'a': 1, 'b': 2}, 'a')

def test_union_sets(set1: set[int], set2: set[int]) -> None:
union_sets({1, 2, 3}, {3, 4, 5})

def test_calculate_tuple_sum(my_tuple: tuple[int, int, int]) -> None:
calculate_tuple_sum((1, 2, 3))
"""
assert cleaned_code == expected_cleaned_code.strip()

concolic_generated_repr_code = """from src.blib2to3.pgen2.grammar import Grammar

def test_Grammar_copy():
assert Grammar.copy(Grammar()) == <src.blib2to3.pgen2.grammar.Grammar object at 0x104c30f50>
"""
cleaned_code = clean_concolic_tests(concolic_generated_repr_code)
expected_cleaned_code = """
from src.blib2to3.pgen2.grammar import Grammar

def test_Grammar_copy():
Grammar.copy(Grammar())
"""
assert cleaned_code == expected_cleaned_code.strip()
Loading