Skip to content

Commit 832b2eb

Browse files
authored
Merge pull request #18 from codeflash-ai/stdout_comparison_
stdout comparison
2 parents 9a002b7 + dcf9384 commit 832b2eb

12 files changed

+189
-20
lines changed

code_to_optimize/bubble_sort.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
11
def sorter(arr):
2+
print("codeflash stdout: Sorting list")
23
for i in range(len(arr)):
34
for j in range(len(arr) - 1):
45
if arr[j] > arr[j + 1]:
56
temp = arr[j]
67
arr[j] = arr[j + 1]
78
arr[j + 1] = temp
9+
print(f"result: {arr}")
810
return arr
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,17 @@
1+
import sys
2+
3+
14
class BubbleSorter:
25
def __init__(self, x=0):
36
self.x = x
47

58
def sorter(self, arr):
9+
print("codeflash stdout : BubbleSorter.sorter() called")
610
for i in range(len(arr)):
711
for j in range(len(arr) - 1):
812
if arr[j] > arr[j + 1]:
913
temp = arr[j]
1014
arr[j] = arr[j + 1]
1115
arr[j + 1] = temp
16+
print("stderr test", file=sys.stderr)
1217
return arr

codeflash/verification/equivalence.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
1+
import difflib
12
import sys
23

3-
from codeflash.cli_cmds.console import logger
4+
from codeflash.cli_cmds.console import console, logger
45
from codeflash.verification.comparator import comparator
56
from codeflash.verification.test_results import TestResults, TestType, VerificationType
67

@@ -61,6 +62,12 @@ def compare_test_results(original_results: TestResults, candidate_results: TestR
6162
cdd_test_result.return_value,
6263
)
6364
break
65+
if (original_test_result.stdout and cdd_test_result.stdout) and not comparator(
66+
original_test_result.stdout, cdd_test_result.stdout
67+
):
68+
are_equal = False
69+
break
70+
6471
if original_test_result.test_type in [TestType.EXISTING_UNIT_TEST, TestType.CONCOLIC_COVERAGE_TEST] and (
6572
cdd_test_result.did_pass != original_test_result.did_pass
6673
):

codeflash/verification/parse_test_output.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,11 @@ def parse_func(file_path: Path) -> XMLParser:
4242
return parse(file_path, xml_parser)
4343

4444

45+
matches_re = re.compile(r"!######(.*?):(.*?)([^\.:]*?):(.*?):(.*?):(.*?)######!")
46+
cleaner_re = re.compile(r"!######.*?######!|-+\s*Captured\s+(Log|Out)\s*-+\n?")
47+
48+
49+
4550
def parse_test_return_values_bin(file_location: Path, test_files: TestFiles, test_config: TestConfig) -> TestResults:
4651
test_results = TestResults()
4752
if not file_location.exists():
@@ -259,7 +264,13 @@ def parse_test_xml(
259264
message = testcase.result[0].message.lower()
260265
if "timed out" in message:
261266
timed_out = True
262-
matches = re.findall(r"!######(.*?):(.*?)([^\.:]*?):(.*?):(.*?):(.*?)######!", testcase.system_out or "")
267+
268+
sys_stdout = testcase.system_out or ""
269+
matches = matches_re.findall(sys_stdout)
270+
271+
if sys_stdout:
272+
sys_stdout = cleaner_re.sub("", sys_stdout).strip()
273+
263274
if not matches or not len(matches):
264275
test_results.add(
265276
FunctionTestInvocation(
@@ -278,6 +289,7 @@ def parse_test_xml(
278289
test_type=test_type,
279290
return_value=None,
280291
timed_out=timed_out,
292+
stdout=sys_stdout,
281293
)
282294
)
283295

@@ -306,6 +318,7 @@ def parse_test_xml(
306318
test_type=test_type,
307319
return_value=None,
308320
timed_out=timed_out,
321+
stdout=sys_stdout,
309322
)
310323
)
311324

@@ -393,6 +406,7 @@ def merge_test_results(
393406
verification_type=VerificationType(result_bin.verification_type)
394407
if result_bin.verification_type
395408
else None,
409+
stdout=xml_result.stdout,
396410
)
397411
)
398412
elif xml_results.test_results[0].id.iteration_id is not None:
@@ -422,6 +436,7 @@ def merge_test_results(
422436
verification_type=VerificationType(bin_result.verification_type)
423437
if bin_result.verification_type
424438
else None,
439+
stdout=xml_result.stdout,
425440
)
426441
)
427442
else:
@@ -448,6 +463,7 @@ def merge_test_results(
448463
verification_type=VerificationType(bin_result.verification_type)
449464
if bin_result.verification_type
450465
else None,
466+
stdout=xml_result.stdout,
451467
)
452468
)
453469

codeflash/verification/test_results.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,7 @@ class FunctionTestInvocation:
9393
return_value: Optional[object] # The return value of the function invocation
9494
timed_out: Optional[bool]
9595
verification_type: Optional[str] = VerificationType.FUNCTION_CALL
96+
stdout: Optional[str] = None
9697

9798
@property
9899
def unique_invocation_loop_id(self) -> str:

pyproject.toml

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -119,7 +119,6 @@ types-gevent = "^24.11.0.20241230"
119119
types-greenlet = "^3.1.0.20241221"
120120
types-pexpect = "^4.9.0.20241208"
121121
types-unidiff = "^0.7.0.20240505"
122-
sqlalchemy = "^2.0.38"
123122
uv = ">=0.6.2"
124123

125124
[tool.poetry.build]
@@ -178,8 +177,7 @@ ignore = [
178177
"TD003",
179178
"TD004",
180179
"PLR2004",
181-
"UP007",
182-
"N802", # we use a lot of stdlib which follows this convention
180+
"UP007" # remove once we drop 3.9 support.
183181
]
184182

185183
[tool.ruff.lint.flake8-type-checking]

tests/scripts/end_to_end_test_bubblesort_pytest.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,11 +11,15 @@ def run_test(expected_improvement_pct: int) -> bool:
1111
test_framework="pytest",
1212
min_improvement_x=1.0,
1313
coverage_expectations=[
14-
CoverageExpectation(function_name="sorter", expected_coverage=100.0, expected_lines=[2, 3, 4, 5, 6, 7, 8])
14+
CoverageExpectation(
15+
function_name="sorter", expected_coverage=100.0, expected_lines=[2, 3, 4, 5, 6, 7, 8, 9, 10]
16+
)
1517
],
1618
)
1719
cwd = (pathlib.Path(__file__).parent.parent.parent / "code_to_optimize").resolve()
18-
return run_codeflash_command(cwd, config, expected_improvement_pct)
20+
return run_codeflash_command(
21+
cwd, config, expected_improvement_pct, ['print("codeflash stdout: Sorting list")', 'print(f"result: {arr}")']
22+
)
1923

2024

2125
if __name__ == "__main__":

tests/scripts/end_to_end_test_utilities.py

Lines changed: 25 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -63,19 +63,21 @@ def validate_coverage(stdout: str, expectations: list[CoverageExpectation]) -> b
6363
assert coverage_match, f"Failed to find coverage data for {expect.function_name}"
6464

6565
coverage = float(coverage_match.group(1))
66-
assert (
67-
coverage == expect.expected_coverage
68-
), f"Coverage was {coverage} instead of {expect.expected_coverage} for function: {expect.function_name}"
66+
assert coverage == expect.expected_coverage, (
67+
f"Coverage was {coverage} instead of {expect.expected_coverage} for function: {expect.function_name}"
68+
)
6969

7070
executed_lines = list(map(int, coverage_match.group(2).split(", ")))
71-
assert (
72-
executed_lines == expect.expected_lines
73-
), f"Executed lines were {executed_lines} instead of {expect.expected_lines} for function: {expect.function_name}"
71+
assert executed_lines == expect.expected_lines, (
72+
f"Executed lines were {executed_lines} instead of {expect.expected_lines} for function: {expect.function_name}"
73+
)
7474

7575
return True
7676

7777

78-
def run_codeflash_command(cwd: pathlib.Path, config: TestConfig, expected_improvement_pct: int) -> bool:
78+
def run_codeflash_command(
79+
cwd: pathlib.Path, config: TestConfig, expected_improvement_pct: int, expected_in_stdout: list[str] = None
80+
) -> bool:
7981
logging.basicConfig(level=logging.INFO)
8082
if config.trace_mode:
8183
return run_trace_test(cwd, config, expected_improvement_pct)
@@ -97,12 +99,21 @@ def run_codeflash_command(cwd: pathlib.Path, config: TestConfig, expected_improv
9799
return_code = process.wait()
98100
stdout = "".join(output)
99101

100-
if not validate_output(stdout, return_code, expected_improvement_pct, config):
102+
validated = validate_output(stdout, return_code, expected_improvement_pct, config)
103+
if not validated:
101104
# Write original file contents back to file
102105
path_to_file.write_text(file_contents, "utf-8")
103106
logging.info("Codeflash run did not meet expected requirements for testing, reverting file changes.")
104107
return False
105-
return True
108+
109+
if expected_in_stdout:
110+
stdout_validated = validate_stdout_in_candidate(stdout, expected_in_stdout)
111+
if not stdout_validated:
112+
logging.error("Failed to find expected output in candidate output")
113+
validated = False
114+
logging.info(f"Success: Expected output found in candidate output")
115+
116+
return validated
106117

107118

108119
def build_command(cwd: pathlib.Path, config: TestConfig, test_root: pathlib.Path) -> list[str]:
@@ -164,6 +175,11 @@ def validate_output(stdout: str, return_code: int, expected_improvement_pct: int
164175
return True
165176

166177

178+
def validate_stdout_in_candidate(stdout: str, expected_in_stdout: list[str]) -> bool:
179+
candidate_output = stdout[stdout.find("INFO Best candidate") : stdout.find("Best Candidate Explanation")]
180+
return all(expected in candidate_output for expected in expected_in_stdout)
181+
182+
167183
def run_trace_test(cwd: pathlib.Path, config: TestConfig, expected_improvement_pct: int) -> bool:
168184
# First command: Run the tracer
169185
test_root = cwd / "tests" / (config.test_framework or "")

tests/test_codeflash_capture.py

Lines changed: 46 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -485,7 +485,6 @@ def __init__(self, x=2):
485485
assert test_results[1].id.test_module_path == "code_to_optimize.tests.pytest.test_codeflash_capture_temp"
486486
assert test_results[1].id.function_getting_tested == "some_function"
487487
assert test_results[1].id.iteration_id == "11_0"
488-
489488
assert test_results[2].did_pass
490489
assert test_results[2].return_value[0]["x"] == 2
491490
assert test_results[2].id.test_function_name == "test_example_test_3"
@@ -494,6 +493,17 @@ def __init__(self, x=2):
494493
assert test_results[2].id.function_getting_tested == "some_function"
495494
assert test_results[2].id.iteration_id == "16_0"
496495

496+
test_results2, _ = func_optimizer.run_and_parse_tests(
497+
testing_type=TestingMode.BEHAVIOR,
498+
test_env=test_env,
499+
test_files=func_optimizer.test_files,
500+
optimization_iteration=0,
501+
pytest_min_loops=1,
502+
pytest_max_loops=1,
503+
testing_time=0.1,
504+
)
505+
assert compare_test_results(test_results, test_results2)
506+
497507
finally:
498508
test_path.unlink(missing_ok=True)
499509
sample_code_path.unlink(missing_ok=True)
@@ -605,6 +615,18 @@ def __init__(self, *args, **kwargs):
605615
assert test_results[2].id.function_getting_tested == "some_function"
606616
assert test_results[2].id.iteration_id == "16_0"
607617

618+
results2, _ = func_optimizer.run_and_parse_tests(
619+
testing_type=TestingMode.BEHAVIOR,
620+
test_env=test_env,
621+
test_files=func_optimizer.test_files,
622+
optimization_iteration=0,
623+
pytest_min_loops=1,
624+
pytest_max_loops=1,
625+
testing_time=0.1,
626+
)
627+
628+
assert compare_test_results(test_results, results2)
629+
608630
finally:
609631
test_path.unlink(missing_ok=True)
610632
sample_code_path.unlink(missing_ok=True)
@@ -720,6 +742,17 @@ def __init__(self, x=2):
720742
assert test_results[2].id.function_getting_tested == "some_function"
721743
assert test_results[2].id.iteration_id == "12_2" # Third call
722744

745+
test_results2, _ = func_optimizer.run_and_parse_tests(
746+
testing_type=TestingMode.BEHAVIOR,
747+
test_env=test_env,
748+
test_files=func_optimizer.test_files,
749+
optimization_iteration=0,
750+
pytest_min_loops=1,
751+
pytest_max_loops=1,
752+
testing_time=0.1,
753+
)
754+
755+
assert compare_test_results(test_results, test_results2)
723756
finally:
724757
test_path.unlink(missing_ok=True)
725758
sample_code_path.unlink(missing_ok=True)
@@ -856,6 +889,18 @@ def another_helper(self):
856889
assert test_results[3].id.function_getting_tested == "AnotherHelperClass.__init__"
857890
assert test_results[3].verification_type == VerificationType.INIT_STATE_HELPER
858891

892+
results2, _ = func_optimizer.run_and_parse_tests(
893+
testing_type=TestingMode.BEHAVIOR,
894+
test_env=test_env,
895+
test_files=func_optimizer.test_files,
896+
optimization_iteration=0,
897+
pytest_min_loops=1,
898+
pytest_max_loops=1,
899+
testing_time=0.1,
900+
)
901+
902+
assert compare_test_results(test_results, results2)
903+
859904
finally:
860905
test_path.unlink(missing_ok=True)
861906
fto_file_path.unlink(missing_ok=True)

tests/test_instrument_all_and_run.py

Lines changed: 39 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -168,6 +168,13 @@ def test_sort():
168168
pytest_max_loops=1,
169169
testing_time=0.1,
170170
)
171+
172+
out_str = """codeflash stdout: Sorting list
173+
result: [0, 1, 2, 3, 4, 5]
174+
175+
codeflash stdout: Sorting list
176+
result: [0.0, 1.0, 2.0, 3.0, 4.0, 5.0]"""
177+
assert out_str == test_results[0].stdout
171178
assert test_results[0].id.function_getting_tested == "sorter"
172179
assert test_results[0].id.iteration_id == "1_0"
173180
assert test_results[0].id.test_class_name is None
@@ -179,6 +186,7 @@ def test_sort():
179186
assert test_results[0].runtime > 0
180187
assert test_results[0].did_pass
181188
assert test_results[0].return_value == ([0, 1, 2, 3, 4, 5],)
189+
assert out_str == test_results[1].stdout.strip()
182190

183191
assert test_results[1].id.function_getting_tested == "sorter"
184192
assert test_results[1].id.iteration_id == "4_0"
@@ -190,6 +198,22 @@ def test_sort():
190198
)
191199
assert test_results[1].runtime > 0
192200
assert test_results[1].did_pass
201+
results2, _ = func_optimizer.run_and_parse_tests(
202+
testing_type=TestingMode.BEHAVIOR,
203+
test_env=test_env,
204+
test_files=func_optimizer.test_files,
205+
optimization_iteration=0,
206+
pytest_min_loops=1,
207+
pytest_max_loops=1,
208+
testing_time=0.1,
209+
)
210+
out_str = """codeflash stdout: Sorting list
211+
result: [0, 1, 2, 3, 4, 5]
212+
213+
codeflash stdout: Sorting list
214+
result: [0.0, 1.0, 2.0, 3.0, 4.0, 5.0]"""
215+
assert out_str == results2[0].stdout.strip()
216+
assert compare_test_results(test_results, results2)
193217
finally:
194218
fto_path.write_text(original_code, "utf-8")
195219
test_path.unlink(missing_ok=True)
@@ -340,13 +364,11 @@ def test_sort():
340364
pytest_max_loops=1,
341365
testing_time=0.1,
342366
)
343-
344367
assert len(test_results) == 4
345368
assert test_results[0].id.function_getting_tested == "BubbleSorter.__init__"
346369
assert test_results[0].id.test_function_name == "test_sort"
347370
assert test_results[0].did_pass
348371
assert test_results[0].return_value[0] == {"x": 0}
349-
350372
assert test_results[1].id.function_getting_tested == "BubbleSorter.sorter"
351373
assert test_results[1].id.iteration_id == "2_0"
352374
assert test_results[1].id.test_class_name is None
@@ -358,7 +380,9 @@ def test_sort():
358380
assert test_results[1].runtime > 0
359381
assert test_results[1].did_pass
360382
assert test_results[1].return_value == ([0, 1, 2, 3, 4, 5],)
361-
383+
out_str = """codeflash stdout : BubbleSorter.sorter() called\n\n\ncodeflash stdout : BubbleSorter.sorter() called"""
384+
assert test_results[1].stdout == out_str
385+
assert compare_test_results(test_results, test_results)
362386
assert test_results[2].id.function_getting_tested == "BubbleSorter.__init__"
363387
assert test_results[2].id.test_function_name == "test_sort"
364388
assert test_results[2].did_pass
@@ -375,6 +399,18 @@ def test_sort():
375399
assert test_results[3].runtime > 0
376400
assert test_results[3].did_pass
377401

402+
results2, _ = func_optimizer.run_and_parse_tests(
403+
testing_type=TestingMode.BEHAVIOR,
404+
test_env=test_env,
405+
test_files=func_optimizer.test_files,
406+
optimization_iteration=0,
407+
pytest_min_loops=1,
408+
pytest_max_loops=1,
409+
testing_time=0.1,
410+
)
411+
412+
assert compare_test_results(test_results, results2)
413+
378414
# Replace with optimized code that mutated instance attribute
379415
optimized_code = """
380416
class BubbleSorter:

0 commit comments

Comments
 (0)