diff --git a/tests/scripts/end_to_end_test_futurehouse.py b/tests/scripts/end_to_end_test_futurehouse.py index 430982b77..e4fe3103d 100644 --- a/tests/scripts/end_to_end_test_futurehouse.py +++ b/tests/scripts/end_to_end_test_futurehouse.py @@ -7,7 +7,7 @@ def run_test(expected_improvement_pct: int) -> bool: config = TestConfig( file_path="src/aviary/common_tags.py", - expected_unit_tests=0, # todo: fix bug https://linear.app/codeflash-ai/issue/CF-921/test-discovery-does-not-work-properly-for-e2e-futurehouse-example for context + expected_unit_tests_count=2, min_improvement_x=0.05, coverage_expectations=[ CoverageExpectation( diff --git a/tests/scripts/end_to_end_test_topological_sort_worktree.py b/tests/scripts/end_to_end_test_topological_sort_worktree.py index 6a6b30122..3d4f86b77 100644 --- a/tests/scripts/end_to_end_test_topological_sort_worktree.py +++ b/tests/scripts/end_to_end_test_topological_sort_worktree.py @@ -17,7 +17,7 @@ def run_test(expected_improvement_pct: int) -> bool: expected_lines=[25, 26, 27, 28, 29, 30, 31], ) ], - expected_unit_tests=1, + expected_unit_test_files=1, # Per-function count ) cwd = (pathlib.Path(__file__).parent.parent.parent / "code_to_optimize").resolve() return_var = run_codeflash_command(cwd, config, expected_improvement_pct) diff --git a/tests/scripts/end_to_end_test_tracer_replay.py b/tests/scripts/end_to_end_test_tracer_replay.py index 72d6fe97f..26efd8ed2 100644 --- a/tests/scripts/end_to_end_test_tracer_replay.py +++ b/tests/scripts/end_to_end_test_tracer_replay.py @@ -8,7 +8,7 @@ def run_test(expected_improvement_pct: int) -> bool: config = TestConfig( trace_mode=True, min_improvement_x=0.1, - expected_unit_tests=0, + expected_unit_tests_count=None, # Tracer creates replay tests dynamically, skip validation coverage_expectations=[ CoverageExpectation(function_name="funcA", expected_coverage=100.0, expected_lines=[6, 7, 8, 9, 11, 14]) ], diff --git a/tests/scripts/end_to_end_test_utilities.py b/tests/scripts/end_to_end_test_utilities.py index 8649e1abb..777bf16ba 100644 --- a/tests/scripts/end_to_end_test_utilities.py +++ b/tests/scripts/end_to_end_test_utilities.py @@ -1,3 +1,4 @@ +import contextlib import logging import os import pathlib @@ -8,6 +9,11 @@ from dataclasses import dataclass, field from typing import Optional +try: + import tomllib +except ImportError: + import tomli as tomllib + @dataclass class CoverageExpectation: @@ -21,7 +27,10 @@ class TestConfig: # Make file_path optional when trace_mode is True file_path: Optional[pathlib.Path] = None function_name: Optional[str] = None - expected_unit_tests: Optional[int] = None + # Global count: "Discovered X existing unit tests and Y replay tests in Z.Zs at /path" + expected_unit_tests_count: Optional[int] = None + # Per-function count: "Discovered X existing unit test files, Y replay test files, and Z concolic..." + expected_unit_test_files: Optional[int] = None min_improvement_x: float = 0.1 trace_mode: bool = False coverage_expectations: list[CoverageExpectation] = field(default_factory=list) @@ -129,7 +138,20 @@ def build_command( if config.function_name: base_command.extend(["--function", config.function_name]) - base_command.extend(["--tests-root", str(test_root), "--module-root", str(cwd)]) + + # Check if pyproject.toml exists with codeflash config - if so, don't override it + pyproject_path = cwd / "pyproject.toml" + has_codeflash_config = False + if pyproject_path.exists(): + with contextlib.suppress(Exception): + with open(pyproject_path, "rb") as f: + pyproject_data = tomllib.load(f) + has_codeflash_config = "tool" in pyproject_data and "codeflash" in pyproject_data["tool"] + + # Only pass --tests-root and --module-root if they're not configured in pyproject.toml + if not has_codeflash_config: + base_command.extend(["--tests-root", str(test_root), "--module-root", str(cwd)]) + if benchmarks_root: base_command.extend(["--benchmark", "--benchmarks-root", str(benchmarks_root)]) if config.use_worktree: @@ -163,15 +185,30 @@ def validate_output(stdout: str, return_code: int, expected_improvement_pct: int logging.error(f"Performance improvement rate {improvement_x}x not above {config.min_improvement_x}x") return False - if config.expected_unit_tests is not None: - unit_test_match = re.search(r"Discovered (\d+) existing unit test file", stdout) + if config.expected_unit_tests_count is not None: + # Match the global test discovery message from optimizer.py which counts test invocations + # Format: "Discovered X existing unit tests and Y replay tests in Z.Zs at /path/to/tests" + unit_test_match = re.search(r"Discovered (\d+) existing unit tests? and \d+ replay tests? in [\d.]+s at", stdout) if not unit_test_match: - logging.error("Could not find unit test count") + logging.error("Could not find global unit test count") return False num_tests = int(unit_test_match.group(1)) - if num_tests != config.expected_unit_tests: - logging.error(f"Expected {config.expected_unit_tests} unit tests, found {num_tests}") + if num_tests != config.expected_unit_tests_count: + logging.error(f"Expected {config.expected_unit_tests_count} global unit tests, found {num_tests}") + return False + + if config.expected_unit_test_files is not None: + # Match the per-function test discovery message from function_optimizer.py + # Format: "Discovered X existing unit test files, Y replay test files, and Z concolic..." + unit_test_files_match = re.search(r"Discovered (\d+) existing unit test files?", stdout) + if not unit_test_files_match: + logging.error("Could not find per-function unit test file count") + return False + + num_test_files = int(unit_test_files_match.group(1)) + if num_test_files != config.expected_unit_test_files: + logging.error(f"Expected {config.expected_unit_test_files} unit test files, found {num_test_files}") return False if config.coverage_expectations: