@@ -244,7 +244,7 @@ def __init__(
244244 ) = None
245245 n_tests = N_TESTS_TO_GENERATE_EFFECTIVE
246246 self .executor = concurrent .futures .ThreadPoolExecutor (
247- max_workers = n_tests + 2 if self .experiment_id is None else n_tests + 3
247+ max_workers = n_tests + 3 if self .experiment_id is None else n_tests + 4
248248 )
249249
250250 def can_be_optimized (self ) -> Result [tuple [bool , CodeOptimizationContext , dict [Path , str ]], str ]:
@@ -286,6 +286,7 @@ def generate_and_instrument_tests(
286286 list [Path ],
287287 set [Path ],
288288 dict | None ,
289+ str ,
289290 ]
290291 ]:
291292 """Generate and instrument tests, returning all necessary data for optimization."""
@@ -323,9 +324,14 @@ def generate_and_instrument_tests(
323324
324325 generated_tests : GeneratedTestsList
325326 optimizations_set : OptimizationSet
326- count_tests , generated_tests , function_to_concolic_tests , concolic_test_str , optimizations_set = (
327- generated_results .unwrap ()
328- )
327+ (
328+ count_tests ,
329+ generated_tests ,
330+ function_to_concolic_tests ,
331+ concolic_test_str ,
332+ optimizations_set ,
333+ function_references ,
334+ ) = generated_results .unwrap ()
329335
330336 for i , generated_test in enumerate (generated_tests .generated_tests ):
331337 with generated_test .behavior_file_path .open ("w" , encoding = "utf8" ) as f :
@@ -371,6 +377,7 @@ def generate_and_instrument_tests(
371377 generated_perf_test_paths ,
372378 instrumented_unittests_created_for_function ,
373379 original_conftest_content ,
380+ function_references ,
374381 )
375382 )
376383
@@ -403,6 +410,7 @@ def optimize_function(self) -> Result[BestOptimization, str]:
403410 generated_perf_test_paths ,
404411 instrumented_unittests_created_for_function ,
405412 original_conftest_content ,
413+ function_references ,
406414 ) = test_setup_result .unwrap ()
407415
408416 baseline_setup_result = self .setup_and_establish_baseline (
@@ -437,6 +445,7 @@ def optimize_function(self) -> Result[BestOptimization, str]:
437445 generated_tests = generated_tests ,
438446 test_functions_to_remove = test_functions_to_remove ,
439447 concolic_test_str = concolic_test_str ,
448+ function_references = function_references ,
440449 )
441450
442451 # Add function to code context hash if in gh actions
@@ -458,6 +467,7 @@ def determine_best_candidate(
458467 original_helper_code : dict [Path , str ],
459468 file_path_to_helper_classes : dict [Path , set [str ]],
460469 exp_type : str ,
470+ function_references : str ,
461471 ) -> BestOptimization | None :
462472 best_optimization : BestOptimization | None = None
463473 _best_runtime_until_now = original_code_baseline .runtime
@@ -667,6 +677,7 @@ def determine_best_candidate(
667677 else self .function_trace_id ,
668678 ai_service_client = ai_service_client ,
669679 executor = self .executor ,
680+ function_references = function_references ,
670681 )
671682 )
672683 else :
@@ -753,6 +764,7 @@ def determine_best_candidate(
753764 optimization_ids = optimization_ids ,
754765 speedups = speedups_list ,
755766 trace_id = self .function_trace_id [:- 4 ] + exp_type if self .experiment_id else self .function_trace_id ,
767+ function_references = function_references ,
756768 )
757769 concurrent .futures .wait ([future_ranking ])
758770 ranking = future_ranking .result ()
@@ -766,7 +778,7 @@ def determine_best_candidate(
766778 min_key = min (overall_ranking , key = overall_ranking .get )
767779 elif len (optimization_ids ) == 1 :
768780 min_key = 0 # only one candidate in valid _opts, already returns if there are no valid candidates
769- else : # 0? shouldn't happen but it's there to escape potential bugs
781+ else : # 0? shouldn't happen, but it's there to escape potential bugs
770782 return None
771783 best_optimization = valid_candidates_with_shorter_code [min_key ]
772784 # reassign code string which is the shortest
@@ -790,6 +802,7 @@ def refine_optimizations(
790802 trace_id : str ,
791803 ai_service_client : AiServiceClient ,
792804 executor : concurrent .futures .ThreadPoolExecutor ,
805+ function_references : str | None = None ,
793806 ) -> concurrent .futures .Future :
794807 request = [
795808 AIServiceRefinerRequest (
@@ -804,6 +817,7 @@ def refine_optimizations(
804817 trace_id = trace_id ,
805818 original_line_profiler_results = original_code_baseline .line_profile_results ["str_out" ],
806819 optimized_line_profiler_results = opt .line_profiler_test_results ["str_out" ],
820+ function_references = function_references ,
807821 )
808822 for opt in valid_optimizations
809823 ]
@@ -1089,7 +1103,7 @@ def generate_tests_and_optimizations(
10891103 generated_test_paths : list [Path ],
10901104 generated_perf_test_paths : list [Path ],
10911105 run_experiment : bool = False , # noqa: FBT001, FBT002
1092- ) -> Result [tuple [GeneratedTestsList , dict [str , set [FunctionCalledInTest ]], OptimizationSet ], str ]:
1106+ ) -> Result [tuple [GeneratedTestsList , dict [str , set [FunctionCalledInTest ]], OptimizationSet ], str , str ]:
10931107 n_tests = N_TESTS_TO_GENERATE_EFFECTIVE
10941108 assert len (generated_test_paths ) == n_tests
10951109 console .rule ()
@@ -1116,7 +1130,15 @@ def generate_tests_and_optimizations(
11161130 future_concolic_tests = self .executor .submit (
11171131 generate_concolic_tests , self .test_cfg , self .args , self .function_to_optimize , self .function_to_optimize_ast
11181132 )
1119- futures = [* future_tests , future_optimization_candidates , future_concolic_tests ]
1133+ future_references = self .executor .submit (
1134+ get_opt_review_metrics ,
1135+ self .function_to_optimize_source_code ,
1136+ self .function_to_optimize .file_path ,
1137+ self .function_to_optimize .qualified_name ,
1138+ self .project_root ,
1139+ self .test_cfg .tests_root ,
1140+ )
1141+ futures = [* future_tests , future_optimization_candidates , future_concolic_tests , future_references ]
11201142 if run_experiment :
11211143 future_candidates_exp = self .executor .submit (
11221144 self .local_aiservice_client .optimize_python_code ,
@@ -1168,7 +1190,7 @@ def generate_tests_and_optimizations(
11681190 logger .warning (f"Failed to generate and instrument tests for { self .function_to_optimize .function_name } " )
11691191 return Failure (f"/!\\ NO TESTS GENERATED for { self .function_to_optimize .function_name } " )
11701192 function_to_concolic_tests , concolic_test_str = future_concolic_tests .result ()
1171-
1193+ function_references = future_references . result ()
11721194 count_tests = len (tests )
11731195 if concolic_test_str :
11741196 count_tests += 1
@@ -1182,6 +1204,7 @@ def generate_tests_and_optimizations(
11821204 function_to_concolic_tests ,
11831205 concolic_test_str ,
11841206 OptimizationSet (control = candidates , experiment = candidates_experiment ),
1207+ function_references ,
11851208 )
11861209 self .generate_and_instrument_tests_results = result
11871210 return Success (result )
@@ -1263,6 +1286,7 @@ def find_and_process_best_optimization(
12631286 generated_tests : GeneratedTestsList ,
12641287 test_functions_to_remove : list [str ],
12651288 concolic_test_str : str | None ,
1289+ function_references : str ,
12661290 ) -> BestOptimization | None :
12671291 """Find the best optimization candidate and process it with all required steps."""
12681292 best_optimization = None
@@ -1279,6 +1303,7 @@ def find_and_process_best_optimization(
12791303 original_helper_code = original_helper_code ,
12801304 file_path_to_helper_classes = file_path_to_helper_classes ,
12811305 exp_type = exp_type ,
1306+ function_references = function_references ,
12821307 )
12831308 ph (
12841309 "cli-optimize-function-finished" ,
@@ -1347,6 +1372,7 @@ def find_and_process_best_optimization(
13471372 exp_type ,
13481373 original_helper_code ,
13491374 code_context ,
1375+ function_references ,
13501376 )
13511377 return best_optimization
13521378
@@ -1364,6 +1390,7 @@ def process_review(
13641390 exp_type : str ,
13651391 original_helper_code : dict [Path , str ],
13661392 code_context : CodeOptimizationContext ,
1393+ function_references : str ,
13671394 ) -> None :
13681395 coverage_message = (
13691396 original_code_baseline .coverage_results .build_message ()
@@ -1430,6 +1457,7 @@ def process_review(
14301457 original_throughput = original_throughput_str ,
14311458 optimized_throughput = optimized_throughput_str ,
14321459 throughput_improvement = throughput_improvement_str ,
1460+ function_references = function_references ,
14331461 )
14341462 new_explanation = Explanation (
14351463 raw_explanation_message = new_explanation_raw_str or explanation .raw_explanation_message ,
@@ -1466,16 +1494,9 @@ def process_review(
14661494 opt_review_response = ""
14671495 if raise_pr or staging_review :
14681496 data ["root_dir" ] = git_root_dir ()
1469- calling_fn_details = get_opt_review_metrics (
1470- self .function_to_optimize_source_code ,
1471- self .function_to_optimize .file_path ,
1472- self .function_to_optimize .qualified_name ,
1473- self .project_root ,
1474- self .test_cfg .tests_root ,
1475- )
14761497 try :
14771498 opt_review_response = self .aiservice_client .get_optimization_review (
1478- ** data , calling_fn_details = calling_fn_details
1499+ ** data , calling_fn_details = function_references
14791500 )
14801501 except Exception as e :
14811502 logger .debug (f"optimization review response failed, investigate { e } " )
0 commit comments