Skip to content

Commit 7f1e666

Browse files
committed
stdout comparison in E2E
1 parent a0e0e17 commit 7f1e666

File tree

3 files changed

+33
-11
lines changed

3 files changed

+33
-11
lines changed

code_to_optimize/bubble_sort.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
11
def sorter(arr):
2+
print("codeflash stdout: Sorting list")
23
for i in range(len(arr)):
34
for j in range(len(arr) - 1):
45
if arr[j] > arr[j + 1]:
56
temp = arr[j]
67
arr[j] = arr[j + 1]
78
arr[j + 1] = temp
9+
print(f"result: {arr}")
810
return arr

tests/scripts/end_to_end_test_bubblesort_pytest.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,11 +11,15 @@ def run_test(expected_improvement_pct: int) -> bool:
1111
test_framework="pytest",
1212
min_improvement_x=1.0,
1313
coverage_expectations=[
14-
CoverageExpectation(function_name="sorter", expected_coverage=100.0, expected_lines=[2, 3, 4, 5, 6, 7, 8])
14+
CoverageExpectation(
15+
function_name="sorter", expected_coverage=100.0, expected_lines=[2, 3, 4, 5, 6, 7, 8, 9, 10]
16+
)
1517
],
1618
)
1719
cwd = (pathlib.Path(__file__).parent.parent.parent / "code_to_optimize").resolve()
18-
return run_codeflash_command(cwd, config, expected_improvement_pct)
20+
return run_codeflash_command(
21+
cwd, config, expected_improvement_pct, ['print("codeflash stdout: Sorting list")', 'print(f"result: {arr}")']
22+
)
1923

2024

2125
if __name__ == "__main__":

tests/scripts/end_to_end_test_utilities.py

Lines changed: 25 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -63,19 +63,21 @@ def validate_coverage(stdout: str, expectations: list[CoverageExpectation]) -> b
6363
assert coverage_match, f"Failed to find coverage data for {expect.function_name}"
6464

6565
coverage = float(coverage_match.group(1))
66-
assert (
67-
coverage == expect.expected_coverage
68-
), f"Coverage was {coverage} instead of {expect.expected_coverage} for function: {expect.function_name}"
66+
assert coverage == expect.expected_coverage, (
67+
f"Coverage was {coverage} instead of {expect.expected_coverage} for function: {expect.function_name}"
68+
)
6969

7070
executed_lines = list(map(int, coverage_match.group(2).split(", ")))
71-
assert (
72-
executed_lines == expect.expected_lines
73-
), f"Executed lines were {executed_lines} instead of {expect.expected_lines} for function: {expect.function_name}"
71+
assert executed_lines == expect.expected_lines, (
72+
f"Executed lines were {executed_lines} instead of {expect.expected_lines} for function: {expect.function_name}"
73+
)
7474

7575
return True
7676

7777

78-
def run_codeflash_command(cwd: pathlib.Path, config: TestConfig, expected_improvement_pct: int) -> bool:
78+
def run_codeflash_command(
79+
cwd: pathlib.Path, config: TestConfig, expected_improvement_pct: int, expected_in_stdout: list[str] = None
80+
) -> bool:
7981
logging.basicConfig(level=logging.INFO)
8082
if config.trace_mode:
8183
return run_trace_test(cwd, config, expected_improvement_pct)
@@ -97,12 +99,21 @@ def run_codeflash_command(cwd: pathlib.Path, config: TestConfig, expected_improv
9799
return_code = process.wait()
98100
stdout = "".join(output)
99101

100-
if not validate_output(stdout, return_code, expected_improvement_pct, config):
102+
validated = validate_output(stdout, return_code, expected_improvement_pct, config)
103+
if not validated:
101104
# Write original file contents back to file
102105
path_to_file.write_text(file_contents, "utf-8")
103106
logging.info("Codeflash run did not meet expected requirements for testing, reverting file changes.")
104107
return False
105-
return True
108+
109+
if expected_in_stdout:
110+
stdout_validated = validate_stdout_in_candidate(stdout, expected_in_stdout)
111+
if not stdout_validated:
112+
logging.error("Failed to find expected output in candidate output")
113+
validated = False
114+
logging.info(f"Success: Expected output found in candidate output")
115+
116+
return validated
106117

107118

108119
def build_command(cwd: pathlib.Path, config: TestConfig, test_root: pathlib.Path) -> list[str]:
@@ -164,6 +175,11 @@ def validate_output(stdout: str, return_code: int, expected_improvement_pct: int
164175
return True
165176

166177

178+
def validate_stdout_in_candidate(stdout: str, expected_in_stdout: list[str]) -> bool:
179+
candidate_output = stdout[stdout.find("INFO Best candidate") : stdout.find("Best Candidate Explanation")]
180+
return all(expected in candidate_output for expected in expected_in_stdout)
181+
182+
167183
def run_trace_test(cwd: pathlib.Path, config: TestConfig, expected_improvement_pct: int) -> bool:
168184
# First command: Run the tracer
169185
test_root = cwd / "tests" / (config.test_framework or "")

0 commit comments

Comments
 (0)