Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
313 changes: 226 additions & 87 deletions codeflash/discovery/discover_unit_tests.py

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion codeflash/discovery/functions_to_optimize.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,7 +268,7 @@ def find_all_functions_in_file(file_path: Path) -> dict[Path, list[FunctionToOpt
def get_all_replay_test_functions(
replay_test: Path, test_cfg: TestConfig, project_root_path: Path
) -> dict[Path, list[FunctionToOptimize]]:
function_tests = discover_unit_tests(test_cfg, discover_only_these_tests=[replay_test])
function_tests, _ = discover_unit_tests(test_cfg, discover_only_these_tests=[replay_test])
# Get the absolute file paths for each function, excluding class name if present
filtered_valid_functions = defaultdict(list)
file_to_functions_map = defaultdict(list)
Expand Down
19 changes: 12 additions & 7 deletions codeflash/optimization/function_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,6 @@
from codeflash.models.models import (
BestOptimization,
CodeOptimizationContext,
FunctionCalledInTest,
GeneratedTests,
GeneratedTestsList,
OptimizationSet,
Expand Down Expand Up @@ -87,7 +86,13 @@

from codeflash.discovery.functions_to_optimize import FunctionToOptimize
from codeflash.either import Result
from codeflash.models.models import BenchmarkKey, CoverageData, FunctionSource, OptimizedCandidate
from codeflash.models.models import (
BenchmarkKey,
CoverageData,
FunctionCalledInTest,
FunctionSource,
OptimizedCandidate,
)
from codeflash.verification.verification_utils import TestConfig


Expand All @@ -97,7 +102,7 @@ def __init__(
function_to_optimize: FunctionToOptimize,
test_cfg: TestConfig,
function_to_optimize_source_code: str = "",
function_to_tests: dict[str, list[FunctionCalledInTest]] | None = None,
function_to_tests: dict[str, set[FunctionCalledInTest]] | None = None,
function_to_optimize_ast: ast.FunctionDef | None = None,
aiservice_client: AiServiceClient | None = None,
function_benchmark_timings: dict[BenchmarkKey, int] | None = None,
Expand Down Expand Up @@ -213,7 +218,7 @@ def optimize_function(self) -> Result[BestOptimization, str]: # noqa: PLR0911

function_to_optimize_qualified_name = self.function_to_optimize.qualified_name
function_to_all_tests = {
key: self.function_to_tests.get(key, []) + function_to_concolic_tests.get(key, [])
key: self.function_to_tests.get(key, set()) | function_to_concolic_tests.get(key, set())
for key in set(self.function_to_tests) | set(function_to_concolic_tests)
}
instrumented_unittests_created_for_function = self.instrument_existing_tests(function_to_all_tests)
Expand Down Expand Up @@ -690,7 +695,7 @@ def cleanup_leftover_test_return_values() -> None:
get_run_tmp_file(Path("test_return_values_0.bin")).unlink(missing_ok=True)
get_run_tmp_file(Path("test_return_values_0.sqlite")).unlink(missing_ok=True)

def instrument_existing_tests(self, function_to_all_tests: dict[str, list[FunctionCalledInTest]]) -> set[Path]:
def instrument_existing_tests(self, function_to_all_tests: dict[str, set[FunctionCalledInTest]]) -> set[Path]:
existing_test_files_count = 0
replay_test_files_count = 0
concolic_coverage_test_files_count = 0
Expand All @@ -701,7 +706,7 @@ def instrument_existing_tests(self, function_to_all_tests: dict[str, list[Functi
logger.info(f"Did not find any pre-existing tests for '{func_qualname}', will only use generated tests.")
console.rule()
else:
test_file_invocation_positions = defaultdict(list[FunctionCalledInTest])
test_file_invocation_positions = defaultdict(list)
for tests_in_file in function_to_all_tests.get(func_qualname):
test_file_invocation_positions[
(tests_in_file.tests_in_file.test_file, tests_in_file.tests_in_file.test_type)
Expand Down Expand Up @@ -787,7 +792,7 @@ def generate_tests_and_optimizations(
generated_test_paths: list[Path],
generated_perf_test_paths: list[Path],
run_experiment: bool = False, # noqa: FBT001, FBT002
) -> Result[tuple[GeneratedTestsList, dict[str, list[FunctionCalledInTest]], OptimizationSet], str]:
) -> Result[tuple[GeneratedTestsList, dict[str, set[FunctionCalledInTest]], OptimizationSet], str]:
assert len(generated_test_paths) == N_TESTS_TO_GENERATE
max_workers = N_TESTS_TO_GENERATE + 2 if not run_experiment else N_TESTS_TO_GENERATE + 3
console.rule()
Expand Down
7 changes: 4 additions & 3 deletions codeflash/optimization/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ def create_function_optimizer(
self,
function_to_optimize: FunctionToOptimize,
function_to_optimize_ast: ast.FunctionDef | None = None,
function_to_tests: dict[str, list[FunctionCalledInTest]] | None = None,
function_to_tests: dict[str, set[FunctionCalledInTest]] | None = None,
function_to_optimize_source_code: str | None = "",
function_benchmark_timings: dict[str, dict[BenchmarkKey, float]] | None = None,
total_benchmark_timings: dict[BenchmarkKey, float] | None = None,
Expand Down Expand Up @@ -162,8 +162,9 @@ def run(self) -> None:

console.rule()
start_time = time.time()
function_to_tests: dict[str, list[FunctionCalledInTest]] = discover_unit_tests(self.test_cfg)
num_discovered_tests: int = sum([len(value) for value in function_to_tests.values()])
function_to_tests, num_discovered_tests = discover_unit_tests(
self.test_cfg, file_to_funcs_to_optimize=file_to_funcs_to_optimize
)
console.rule()
logger.info(
f"Discovered {num_discovered_tests} existing unit tests in {(time.time() - start_time):.1f}s at {self.test_cfg.tests_root}"
Expand Down
2 changes: 1 addition & 1 deletion codeflash/result/create_pr.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@

def existing_tests_source_for(
function_qualified_name_with_modules_from_root: str,
function_to_tests: dict[str, list[FunctionCalledInTest]],
function_to_tests: dict[str, set[FunctionCalledInTest]],
tests_root: Path,
) -> str:
test_files = function_to_tests.get(function_qualified_name_with_modules_from_root)
Expand Down
5 changes: 2 additions & 3 deletions codeflash/verification/concolic_testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@

def generate_concolic_tests(
test_cfg: TestConfig, args: Namespace, function_to_optimize: FunctionToOptimize, function_to_optimize_ast: ast.AST
) -> tuple[dict[str, list[FunctionCalledInTest]], str]:
) -> tuple[dict[str, set[FunctionCalledInTest]], str]:
start_time = time.perf_counter()
function_to_concolic_tests = {}
concolic_test_suite_code = ""
Expand Down Expand Up @@ -78,8 +78,7 @@ def generate_concolic_tests(
test_framework=args.test_framework,
pytest_cmd=args.pytest_cmd,
)
function_to_concolic_tests = discover_unit_tests(concolic_test_cfg)
num_discovered_concolic_tests: int = sum([len(value) for value in function_to_concolic_tests.values()])
function_to_concolic_tests, num_discovered_concolic_tests = discover_unit_tests(concolic_test_cfg)
logger.info(
f"Created {num_discovered_concolic_tests} "
f"concolic unit test case{'s' if num_discovered_concolic_tests != 1 else ''} "
Expand Down
2 changes: 1 addition & 1 deletion tests/test_static_analysis.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import ast
import ast
from pathlib import Path

from codeflash.code_utils.static_analysis import (
Expand Down
Loading
Loading