@@ -162,7 +162,7 @@ def optimize_function(self) -> Result[BestOptimization, str]:
162162 f"Generating new tests and optimizations for function { self .function_to_optimize .function_name } " ,
163163 transient = True ,
164164 ):
165- #TODO: do a/b testing with same codegen but different testgen
165+ # TODO: do a/b testing with same codegen but different testgen
166166 generated_results = self .generate_tests_and_optimizations (
167167 testgen_context_code = code_context .testgen_context_code ,
168168 read_writable_code = code_context .read_writable_code ,
@@ -760,7 +760,8 @@ def generate_tests_and_optimizations(
760760 run_experiment : bool = False ,
761761 ) -> Result [tuple [GeneratedTestsList , dict [str , list [FunctionCalledInTest ]], OptimizationSet ], str ]:
762762 assert len (generated_test_paths ) == N_TESTS_TO_GENERATE
763- max_workers = N_TESTS_TO_GENERATE + 2 if not run_experiment else N_TESTS_TO_GENERATE + 3
763+ max_workers = 2 * N_TESTS_TO_GENERATE + 2 if not run_experiment else 2 * N_TESTS_TO_GENERATE + 3
764+ self .local_aiservice_client = LocalAiServiceClient ()
764765 console .rule ()
765766 with concurrent .futures .ThreadPoolExecutor (max_workers = max_workers ) as executor :
766767 # Submit the test generation task as future
@@ -770,6 +771,7 @@ def generate_tests_and_optimizations(
770771 [definition .fully_qualified_name for definition in helper_functions ],
771772 generated_test_paths ,
772773 generated_perf_test_paths ,
774+ run_experiment = True ,
773775 )
774776 future_optimization_candidates = executor .submit (
775777 self .aiservice_client .optimize_python_code ,
@@ -1223,8 +1225,9 @@ def generate_and_instrument_tests(
12231225 helper_function_names : list [str ],
12241226 generated_test_paths : list [Path ],
12251227 generated_perf_test_paths : list [Path ],
1228+ run_experiment : bool
12261229 ) -> list [concurrent .futures .Future ]:
1227- return [
1230+ original = [
12281231 executor .submit (
12291232 generate_tests ,
12301233 self .aiservice_client ,
@@ -1234,12 +1237,34 @@ def generate_and_instrument_tests(
12341237 Path (self .original_module_path ),
12351238 self .test_cfg ,
12361239 INDIVIDUAL_TESTCASE_TIMEOUT ,
1237- self .function_trace_id ,
1240+ self .function_trace_id ,#[:-4]+"TST0" if run_experiment else self.function_trace_id,
12381241 test_index ,
12391242 test_path ,
12401243 test_perf_path ,
1244+ single_prompt = False ,
12411245 )
12421246 for test_index , (test_path , test_perf_path ) in enumerate (
12431247 zip (generated_test_paths , generated_perf_test_paths )
12441248 )
12451249 ]
1250+ if run_experiment :
1251+ original += [
1252+ executor .submit (
1253+ generate_tests ,
1254+ self .local_aiservice_client ,
1255+ source_code_being_tested ,
1256+ self .function_to_optimize ,
1257+ helper_function_names ,
1258+ Path (self .original_module_path ),
1259+ self .test_cfg ,
1260+ INDIVIDUAL_TESTCASE_TIMEOUT ,
1261+ self .function_trace_id ,#[:-4]+"TST1",
1262+ test_index ,
1263+ test_path ,
1264+ test_perf_path ,
1265+ single_prompt = True ,
1266+ )
1267+ for test_index , (test_path , test_perf_path ) in enumerate (
1268+ zip (generated_test_paths , generated_perf_test_paths )
1269+ )]
1270+ return original
0 commit comments