5757from codeflash .models .models import (
5858 BestOptimization ,
5959 CodeOptimizationContext ,
60- FunctionCalledInTest ,
6160 GeneratedTests ,
6261 GeneratedTestsList ,
6362 OptimizationSet ,
8786
8887 from codeflash .discovery .functions_to_optimize import FunctionToOptimize
8988 from codeflash .either import Result
90- from codeflash .models .models import BenchmarkKey , CoverageData , FunctionSource , OptimizedCandidate
89+ from codeflash .models .models import (
90+ BenchmarkKey ,
91+ CoverageData ,
92+ FunctionCalledInTest ,
93+ FunctionSource ,
94+ OptimizedCandidate ,
95+ )
9196 from codeflash .verification .verification_utils import TestConfig
9297
9398
@@ -97,7 +102,7 @@ def __init__(
97102 function_to_optimize : FunctionToOptimize ,
98103 test_cfg : TestConfig ,
99104 function_to_optimize_source_code : str = "" ,
100- function_to_tests : dict [str , list [FunctionCalledInTest ]] | None = None ,
105+ function_to_tests : dict [str , set [FunctionCalledInTest ]] | None = None ,
101106 function_to_optimize_ast : ast .FunctionDef | None = None ,
102107 aiservice_client : AiServiceClient | None = None ,
103108 function_benchmark_timings : dict [BenchmarkKey , int ] | None = None ,
@@ -213,7 +218,7 @@ def optimize_function(self) -> Result[BestOptimization, str]: # noqa: PLR0911
213218
214219 function_to_optimize_qualified_name = self .function_to_optimize .qualified_name
215220 function_to_all_tests = {
216- key : self .function_to_tests .get (key , []) + function_to_concolic_tests .get (key , [] )
221+ key : self .function_to_tests .get (key , set ()) | function_to_concolic_tests .get (key , set () )
217222 for key in set (self .function_to_tests ) | set (function_to_concolic_tests )
218223 }
219224 instrumented_unittests_created_for_function = self .instrument_existing_tests (function_to_all_tests )
@@ -690,7 +695,7 @@ def cleanup_leftover_test_return_values() -> None:
690695 get_run_tmp_file (Path ("test_return_values_0.bin" )).unlink (missing_ok = True )
691696 get_run_tmp_file (Path ("test_return_values_0.sqlite" )).unlink (missing_ok = True )
692697
693- def instrument_existing_tests (self , function_to_all_tests : dict [str , list [FunctionCalledInTest ]]) -> set [Path ]:
698+ def instrument_existing_tests (self , function_to_all_tests : dict [str , set [FunctionCalledInTest ]]) -> set [Path ]:
694699 existing_test_files_count = 0
695700 replay_test_files_count = 0
696701 concolic_coverage_test_files_count = 0
@@ -701,7 +706,7 @@ def instrument_existing_tests(self, function_to_all_tests: dict[str, list[Functi
701706 logger .info (f"Did not find any pre-existing tests for '{ func_qualname } ', will only use generated tests." )
702707 console .rule ()
703708 else :
704- test_file_invocation_positions = defaultdict (list [ FunctionCalledInTest ] )
709+ test_file_invocation_positions = defaultdict (list )
705710 for tests_in_file in function_to_all_tests .get (func_qualname ):
706711 test_file_invocation_positions [
707712 (tests_in_file .tests_in_file .test_file , tests_in_file .tests_in_file .test_type )
@@ -787,7 +792,7 @@ def generate_tests_and_optimizations(
787792 generated_test_paths : list [Path ],
788793 generated_perf_test_paths : list [Path ],
789794 run_experiment : bool = False , # noqa: FBT001, FBT002
790- ) -> Result [tuple [GeneratedTestsList , dict [str , list [FunctionCalledInTest ]], OptimizationSet ], str ]:
795+ ) -> Result [tuple [GeneratedTestsList , dict [str , set [FunctionCalledInTest ]], OptimizationSet ], str ]:
791796 assert len (generated_test_paths ) == N_TESTS_TO_GENERATE
792797 max_workers = N_TESTS_TO_GENERATE + 2 if not run_experiment else N_TESTS_TO_GENERATE + 3
793798 console .rule ()
0 commit comments