Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions code_to_optimize/bubble_sort.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
def sorter(arr):
print("codeflash stdout: Sorting list")
for i in range(len(arr)):
for j in range(len(arr) - 1):
if arr[j] > arr[j + 1]:
temp = arr[j]
arr[j] = arr[j + 1]
arr[j + 1] = temp
print(f"result: {arr}")
return arr
5 changes: 5 additions & 0 deletions code_to_optimize/bubble_sort_method.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,17 @@
import sys


class BubbleSorter:
def __init__(self, x=0):
self.x = x

def sorter(self, arr):
print("codeflash stdout : BubbleSorter.sorter() called")
for i in range(len(arr)):
for j in range(len(arr) - 1):
if arr[j] > arr[j + 1]:
temp = arr[j]
arr[j] = arr[j + 1]
arr[j + 1] = temp
print("stderr test", file=sys.stderr)
return arr
9 changes: 8 additions & 1 deletion codeflash/verification/equivalence.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import difflib
import sys

from codeflash.cli_cmds.console import logger
from codeflash.cli_cmds.console import console, logger
from codeflash.verification.comparator import comparator
from codeflash.verification.test_results import TestResults, TestType, VerificationType

Expand Down Expand Up @@ -61,6 +62,12 @@ def compare_test_results(original_results: TestResults, candidate_results: TestR
cdd_test_result.return_value,
)
break
if (original_test_result.stdout and cdd_test_result.stdout) and not comparator(
original_test_result.stdout, cdd_test_result.stdout
):
are_equal = False
break

if original_test_result.test_type in [TestType.EXISTING_UNIT_TEST, TestType.CONCOLIC_COVERAGE_TEST] and (
cdd_test_result.did_pass != original_test_result.did_pass
):
Expand Down
18 changes: 17 additions & 1 deletion codeflash/verification/parse_test_output.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,11 @@ def parse_func(file_path: Path) -> XMLParser:
return parse(file_path, xml_parser)


matches_re = re.compile(r"!######(.*?):(.*?)([^\.:]*?):(.*?):(.*?):(.*?)######!")
cleaner_re = re.compile(r"!######.*?######!|-+\s*Captured\s+(Log|Out)\s*-+\n?")



def parse_test_return_values_bin(file_location: Path, test_files: TestFiles, test_config: TestConfig) -> TestResults:
test_results = TestResults()
if not file_location.exists():
Expand Down Expand Up @@ -259,7 +264,13 @@ def parse_test_xml(
message = testcase.result[0].message.lower()
if "timed out" in message:
timed_out = True
matches = re.findall(r"!######(.*?):(.*?)([^\.:]*?):(.*?):(.*?):(.*?)######!", testcase.system_out or "")

sys_stdout = testcase.system_out or ""
matches = matches_re.findall(sys_stdout)

if sys_stdout:
sys_stdout = cleaner_re.sub("", sys_stdout).strip()

if not matches or not len(matches):
test_results.add(
FunctionTestInvocation(
Expand All @@ -278,6 +289,7 @@ def parse_test_xml(
test_type=test_type,
return_value=None,
timed_out=timed_out,
stdout=sys_stdout,
)
)

Expand Down Expand Up @@ -306,6 +318,7 @@ def parse_test_xml(
test_type=test_type,
return_value=None,
timed_out=timed_out,
stdout=sys_stdout,
)
)

Expand Down Expand Up @@ -393,6 +406,7 @@ def merge_test_results(
verification_type=VerificationType(result_bin.verification_type)
if result_bin.verification_type
else None,
stdout=xml_result.stdout,
)
)
elif xml_results.test_results[0].id.iteration_id is not None:
Expand Down Expand Up @@ -422,6 +436,7 @@ def merge_test_results(
verification_type=VerificationType(bin_result.verification_type)
if bin_result.verification_type
else None,
stdout=xml_result.stdout,
)
)
else:
Expand All @@ -448,6 +463,7 @@ def merge_test_results(
verification_type=VerificationType(bin_result.verification_type)
if bin_result.verification_type
else None,
stdout=xml_result.stdout,
)
)

Expand Down
1 change: 1 addition & 0 deletions codeflash/verification/test_results.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,7 @@ class FunctionTestInvocation:
return_value: Optional[object] # The return value of the function invocation
timed_out: Optional[bool]
verification_type: Optional[str] = VerificationType.FUNCTION_CALL
stdout: Optional[str] = None

@property
def unique_invocation_loop_id(self) -> str:
Expand Down
4 changes: 1 addition & 3 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,6 @@ types-gevent = "^24.11.0.20241230"
types-greenlet = "^3.1.0.20241221"
types-pexpect = "^4.9.0.20241208"
types-unidiff = "^0.7.0.20240505"
sqlalchemy = "^2.0.38"
uv = ">=0.6.2"

[tool.poetry.build]
Expand Down Expand Up @@ -178,8 +177,7 @@ ignore = [
"TD003",
"TD004",
"PLR2004",
"UP007",
"N802", # we use a lot of stdlib which follows this convention
"UP007" # remove once we drop 3.9 support.
]

[tool.ruff.lint.flake8-type-checking]
Expand Down
8 changes: 6 additions & 2 deletions tests/scripts/end_to_end_test_bubblesort_pytest.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,15 @@ def run_test(expected_improvement_pct: int) -> bool:
test_framework="pytest",
min_improvement_x=1.0,
coverage_expectations=[
CoverageExpectation(function_name="sorter", expected_coverage=100.0, expected_lines=[2, 3, 4, 5, 6, 7, 8])
CoverageExpectation(
function_name="sorter", expected_coverage=100.0, expected_lines=[2, 3, 4, 5, 6, 7, 8, 9, 10]
)
],
)
cwd = (pathlib.Path(__file__).parent.parent.parent / "code_to_optimize").resolve()
return run_codeflash_command(cwd, config, expected_improvement_pct)
return run_codeflash_command(
cwd, config, expected_improvement_pct, ['print("codeflash stdout: Sorting list")', 'print(f"result: {arr}")']
)


if __name__ == "__main__":
Expand Down
34 changes: 25 additions & 9 deletions tests/scripts/end_to_end_test_utilities.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,19 +63,21 @@ def validate_coverage(stdout: str, expectations: list[CoverageExpectation]) -> b
assert coverage_match, f"Failed to find coverage data for {expect.function_name}"

coverage = float(coverage_match.group(1))
assert (
coverage == expect.expected_coverage
), f"Coverage was {coverage} instead of {expect.expected_coverage} for function: {expect.function_name}"
assert coverage == expect.expected_coverage, (
f"Coverage was {coverage} instead of {expect.expected_coverage} for function: {expect.function_name}"
)

executed_lines = list(map(int, coverage_match.group(2).split(", ")))
assert (
executed_lines == expect.expected_lines
), f"Executed lines were {executed_lines} instead of {expect.expected_lines} for function: {expect.function_name}"
assert executed_lines == expect.expected_lines, (
f"Executed lines were {executed_lines} instead of {expect.expected_lines} for function: {expect.function_name}"
)

return True


def run_codeflash_command(cwd: pathlib.Path, config: TestConfig, expected_improvement_pct: int) -> bool:
def run_codeflash_command(
cwd: pathlib.Path, config: TestConfig, expected_improvement_pct: int, expected_in_stdout: list[str] = None
) -> bool:
logging.basicConfig(level=logging.INFO)
if config.trace_mode:
return run_trace_test(cwd, config, expected_improvement_pct)
Expand All @@ -97,12 +99,21 @@ def run_codeflash_command(cwd: pathlib.Path, config: TestConfig, expected_improv
return_code = process.wait()
stdout = "".join(output)

if not validate_output(stdout, return_code, expected_improvement_pct, config):
validated = validate_output(stdout, return_code, expected_improvement_pct, config)
if not validated:
# Write original file contents back to file
path_to_file.write_text(file_contents, "utf-8")
logging.info("Codeflash run did not meet expected requirements for testing, reverting file changes.")
return False
return True

if expected_in_stdout:
stdout_validated = validate_stdout_in_candidate(stdout, expected_in_stdout)
if not stdout_validated:
logging.error("Failed to find expected output in candidate output")
validated = False
logging.info(f"Success: Expected output found in candidate output")

return validated


def build_command(cwd: pathlib.Path, config: TestConfig, test_root: pathlib.Path) -> list[str]:
Expand Down Expand Up @@ -164,6 +175,11 @@ def validate_output(stdout: str, return_code: int, expected_improvement_pct: int
return True


def validate_stdout_in_candidate(stdout: str, expected_in_stdout: list[str]) -> bool:
candidate_output = stdout[stdout.find("INFO Best candidate") : stdout.find("Best Candidate Explanation")]
return all(expected in candidate_output for expected in expected_in_stdout)


def run_trace_test(cwd: pathlib.Path, config: TestConfig, expected_improvement_pct: int) -> bool:
# First command: Run the tracer
test_root = cwd / "tests" / (config.test_framework or "")
Expand Down
47 changes: 46 additions & 1 deletion tests/test_codeflash_capture.py
Original file line number Diff line number Diff line change
Expand Up @@ -485,7 +485,6 @@ def __init__(self, x=2):
assert test_results[1].id.test_module_path == "code_to_optimize.tests.pytest.test_codeflash_capture_temp"
assert test_results[1].id.function_getting_tested == "some_function"
assert test_results[1].id.iteration_id == "11_0"

assert test_results[2].did_pass
assert test_results[2].return_value[0]["x"] == 2
assert test_results[2].id.test_function_name == "test_example_test_3"
Expand All @@ -494,6 +493,17 @@ def __init__(self, x=2):
assert test_results[2].id.function_getting_tested == "some_function"
assert test_results[2].id.iteration_id == "16_0"

test_results2, _ = func_optimizer.run_and_parse_tests(
testing_type=TestingMode.BEHAVIOR,
test_env=test_env,
test_files=func_optimizer.test_files,
optimization_iteration=0,
pytest_min_loops=1,
pytest_max_loops=1,
testing_time=0.1,
)
assert compare_test_results(test_results, test_results2)

finally:
test_path.unlink(missing_ok=True)
sample_code_path.unlink(missing_ok=True)
Expand Down Expand Up @@ -605,6 +615,18 @@ def __init__(self, *args, **kwargs):
assert test_results[2].id.function_getting_tested == "some_function"
assert test_results[2].id.iteration_id == "16_0"

results2, _ = func_optimizer.run_and_parse_tests(
testing_type=TestingMode.BEHAVIOR,
test_env=test_env,
test_files=func_optimizer.test_files,
optimization_iteration=0,
pytest_min_loops=1,
pytest_max_loops=1,
testing_time=0.1,
)

assert compare_test_results(test_results, results2)

finally:
test_path.unlink(missing_ok=True)
sample_code_path.unlink(missing_ok=True)
Expand Down Expand Up @@ -720,6 +742,17 @@ def __init__(self, x=2):
assert test_results[2].id.function_getting_tested == "some_function"
assert test_results[2].id.iteration_id == "12_2" # Third call

test_results2, _ = func_optimizer.run_and_parse_tests(
testing_type=TestingMode.BEHAVIOR,
test_env=test_env,
test_files=func_optimizer.test_files,
optimization_iteration=0,
pytest_min_loops=1,
pytest_max_loops=1,
testing_time=0.1,
)

assert compare_test_results(test_results, test_results2)
finally:
test_path.unlink(missing_ok=True)
sample_code_path.unlink(missing_ok=True)
Expand Down Expand Up @@ -856,6 +889,18 @@ def another_helper(self):
assert test_results[3].id.function_getting_tested == "AnotherHelperClass.__init__"
assert test_results[3].verification_type == VerificationType.INIT_STATE_HELPER

results2, _ = func_optimizer.run_and_parse_tests(
testing_type=TestingMode.BEHAVIOR,
test_env=test_env,
test_files=func_optimizer.test_files,
optimization_iteration=0,
pytest_min_loops=1,
pytest_max_loops=1,
testing_time=0.1,
)

assert compare_test_results(test_results, results2)

finally:
test_path.unlink(missing_ok=True)
fto_file_path.unlink(missing_ok=True)
Expand Down
42 changes: 39 additions & 3 deletions tests/test_instrument_all_and_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,13 @@ def test_sort():
pytest_max_loops=1,
testing_time=0.1,
)

out_str = """codeflash stdout: Sorting list
result: [0, 1, 2, 3, 4, 5]

codeflash stdout: Sorting list
result: [0.0, 1.0, 2.0, 3.0, 4.0, 5.0]"""
assert out_str == test_results[0].stdout
assert test_results[0].id.function_getting_tested == "sorter"
assert test_results[0].id.iteration_id == "1_0"
assert test_results[0].id.test_class_name is None
Expand All @@ -179,6 +186,7 @@ def test_sort():
assert test_results[0].runtime > 0
assert test_results[0].did_pass
assert test_results[0].return_value == ([0, 1, 2, 3, 4, 5],)
assert out_str == test_results[1].stdout.strip()

assert test_results[1].id.function_getting_tested == "sorter"
assert test_results[1].id.iteration_id == "4_0"
Expand All @@ -190,6 +198,22 @@ def test_sort():
)
assert test_results[1].runtime > 0
assert test_results[1].did_pass
results2, _ = func_optimizer.run_and_parse_tests(
testing_type=TestingMode.BEHAVIOR,
test_env=test_env,
test_files=func_optimizer.test_files,
optimization_iteration=0,
pytest_min_loops=1,
pytest_max_loops=1,
testing_time=0.1,
)
out_str = """codeflash stdout: Sorting list
result: [0, 1, 2, 3, 4, 5]

codeflash stdout: Sorting list
result: [0.0, 1.0, 2.0, 3.0, 4.0, 5.0]"""
assert out_str == results2[0].stdout.strip()
assert compare_test_results(test_results, results2)
finally:
fto_path.write_text(original_code, "utf-8")
test_path.unlink(missing_ok=True)
Expand Down Expand Up @@ -340,13 +364,11 @@ def test_sort():
pytest_max_loops=1,
testing_time=0.1,
)

assert len(test_results) == 4
assert test_results[0].id.function_getting_tested == "BubbleSorter.__init__"
assert test_results[0].id.test_function_name == "test_sort"
assert test_results[0].did_pass
assert test_results[0].return_value[0] == {"x": 0}

assert test_results[1].id.function_getting_tested == "BubbleSorter.sorter"
assert test_results[1].id.iteration_id == "2_0"
assert test_results[1].id.test_class_name is None
Expand All @@ -358,7 +380,9 @@ def test_sort():
assert test_results[1].runtime > 0
assert test_results[1].did_pass
assert test_results[1].return_value == ([0, 1, 2, 3, 4, 5],)

out_str = """codeflash stdout : BubbleSorter.sorter() called\n\n\ncodeflash stdout : BubbleSorter.sorter() called"""
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why does this have \n whereas all others have proper newlines?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

was debugging something, forgot to revert.

assert test_results[1].stdout == out_str
assert compare_test_results(test_results, test_results)
assert test_results[2].id.function_getting_tested == "BubbleSorter.__init__"
assert test_results[2].id.test_function_name == "test_sort"
assert test_results[2].did_pass
Expand All @@ -375,6 +399,18 @@ def test_sort():
assert test_results[3].runtime > 0
assert test_results[3].did_pass

results2, _ = func_optimizer.run_and_parse_tests(
testing_type=TestingMode.BEHAVIOR,
test_env=test_env,
test_files=func_optimizer.test_files,
optimization_iteration=0,
pytest_min_loops=1,
pytest_max_loops=1,
testing_time=0.1,
)

assert compare_test_results(test_results, results2)

# Replace with optimized code that mutated instance attribute
optimized_code = """
class BubbleSorter:
Expand Down
Loading
Loading