@@ -94,7 +94,7 @@ def __init__(
9494 function_benchmark_timings : dict [BenchmarkKey , int ] | None = None ,
9595 total_benchmark_timings : dict [BenchmarkKey , int ] | None = None ,
9696 args : Namespace | None = None ,
97- replay_tests_dir : Path | None = None
97+ replay_tests_dir : Path | None = None ,
9898 ) -> None :
9999 self .project_root = test_cfg .project_root_path
100100 self .test_cfg = test_cfg
@@ -242,8 +242,7 @@ def optimize_function(self) -> Result[BestOptimization, str]:
242242 # request for new optimizations but don't block execution, check for completion later
243243 # adding to control and experiment set but with same traceid
244244 best_optimization = None
245-
246- for _u , candidates in enumerate ([optimizations_set .control , optimizations_set .experiment ]):
245+ for _u , (candidates , exp_type ) in enumerate (zip ([optimizations_set .control , optimizations_set .experiment ],["EXP0" ,"EXP1" ])):
247246 if candidates is None :
248247 continue
249248
@@ -253,8 +252,9 @@ def optimize_function(self) -> Result[BestOptimization, str]:
253252 original_code_baseline = original_code_baseline ,
254253 original_helper_code = original_helper_code ,
255254 file_path_to_helper_classes = file_path_to_helper_classes ,
255+ exp_type = exp_type ,
256256 )
257- ph ("cli-optimize-function-finished" , {"function_trace_id" : self .function_trace_id })
257+ ph ("cli-optimize-function-finished" , {"function_trace_id" : self .function_trace_id [: - 4 ] + exp_type if self . experiment_id else self . function_trace_id })
258258
259259 generated_tests = remove_functions_from_generated_tests (
260260 generated_tests = generated_tests , test_functions_to_remove = test_functions_to_remove
@@ -273,7 +273,7 @@ def optimize_function(self) -> Result[BestOptimization, str]:
273273 processed_benchmark_info = process_benchmark_data (
274274 replay_performance_gain = best_optimization .replay_performance_gain ,
275275 fto_benchmark_timings = self .function_benchmark_timings ,
276- total_benchmark_timings = self .total_benchmark_timings
276+ total_benchmark_timings = self .total_benchmark_timings ,
277277 )
278278 explanation = Explanation (
279279 raw_explanation_message = best_optimization .candidate .explanation ,
@@ -283,10 +283,10 @@ def optimize_function(self) -> Result[BestOptimization, str]:
283283 best_runtime_ns = best_optimization .runtime ,
284284 function_name = function_to_optimize_qualified_name ,
285285 file_path = self .function_to_optimize .file_path ,
286- benchmark_details = processed_benchmark_info .benchmark_details if processed_benchmark_info else None
286+ benchmark_details = processed_benchmark_info .benchmark_details if processed_benchmark_info else None ,
287287 )
288288
289- self .log_successful_optimization (explanation , generated_tests )
289+ self .log_successful_optimization (explanation , generated_tests , exp_type )
290290
291291 self .replace_function_and_helpers_with_optimized_code (
292292 code_context = code_context , optimized_code = best_optimization .candidate .source_code
@@ -324,11 +324,11 @@ def optimize_function(self) -> Result[BestOptimization, str]:
324324 explanation = explanation ,
325325 existing_tests_source = existing_tests ,
326326 generated_original_test_source = generated_tests_str ,
327- function_trace_id = self .function_trace_id ,
327+ function_trace_id = self .function_trace_id [: - 4 ] + exp_type if self . experiment_id else self . function_trace_id ,
328328 coverage_message = coverage_message ,
329329 git_remote = self .args .git_remote ,
330330 )
331- if self .args .all or env_utils .get_pr_number ():
331+ if self .args .all or env_utils .get_pr_number () or ( self . args . file and not self . args . function ) :
332332 self .write_code_and_helpers (
333333 self .function_to_optimize_source_code ,
334334 original_helper_code ,
@@ -361,6 +361,7 @@ def determine_best_candidate(
361361 original_code_baseline : OriginalCodeBaseline ,
362362 original_helper_code : dict [Path , str ],
363363 file_path_to_helper_classes : dict [Path , set [str ]],
364+ exp_type : str ,
364365 ) -> BestOptimization | None :
365366 best_optimization : BestOptimization | None = None
366367 best_runtime_until_now = original_code_baseline .runtime
@@ -377,28 +378,29 @@ def determine_best_candidate(
377378 candidates = deque (candidates )
378379 # Start a new thread for AI service request, start loop in main thread
379380 # check if aiservice request is complete, when it is complete, append result to the candidates list
380- with concurrent .futures .ThreadPoolExecutor (max_workers = 1 ) as executor :
381+ with concurrent .futures .ThreadPoolExecutor (max_workers = 2 ) as executor :
382+ ai_service_client = self .aiservice_client if exp_type == "EXP0" else self .local_aiservice_client
381383 future_line_profile_results = executor .submit (
382- self . aiservice_client .optimize_python_code_line_profiler ,
384+ ai_service_client .optimize_python_code_line_profiler ,
383385 source_code = code_context .read_writable_code ,
384386 dependency_code = code_context .read_only_context_code ,
385- trace_id = self .function_trace_id ,
387+ trace_id = self .function_trace_id [: - 4 ] + exp_type if self . experiment_id else self . function_trace_id ,
386388 line_profiler_results = original_code_baseline .line_profile_results ["str_out" ],
387389 num_candidates = 10 ,
388- experiment_metadata = None ,
390+ experiment_metadata = ExperimentMetadata ( id = self . experiment_id , group = "control" if exp_type == "EXP0" else "experiment" ) if self . experiment_id else None ,
389391 )
390392 try :
391393 candidate_index = 0
392- done = False
393394 original_len = len (candidates )
394395 while candidates :
395- # for candidate_index, candidate in enumerate(candidates, start=1):
396396 done = True if future_line_profile_results is None else future_line_profile_results .done ()
397397 if done and (future_line_profile_results is not None ):
398398 line_profile_results = future_line_profile_results .result ()
399399 candidates .extend (line_profile_results )
400- original_len += len (candidates )
401- logger .info (f"Added results from line profiler to candidates, total candidates now: { original_len } " )
400+ original_len += len (line_profile_results )
401+ logger .info (
402+ f"Added results from line profiler to candidates, total candidates now: { original_len } "
403+ )
402404 future_line_profile_results = None
403405 candidate_index += 1
404406 candidate = candidates .popleft ()
@@ -425,7 +427,6 @@ def determine_best_candidate(
425427 )
426428 continue
427429
428-
429430 run_results = self .run_optimized_candidate (
430431 optimization_candidate_index = candidate_index ,
431432 baseline_results = original_code_baseline ,
@@ -464,12 +465,19 @@ def determine_best_candidate(
464465 tree .add (f"Speedup ratio: { perf_gain + 1 :.1f} X" )
465466 replay_perf_gain = {}
466467 if self .args .benchmark :
467- test_results_by_benchmark = candidate_result .benchmarking_test_results .group_by_benchmarks (self .total_benchmark_timings .keys (), self .replay_tests_dir , self .project_root )
468+ test_results_by_benchmark = (
469+ candidate_result .benchmarking_test_results .group_by_benchmarks (
470+ self .total_benchmark_timings .keys (), self .replay_tests_dir , self .project_root
471+ )
472+ )
468473 if len (test_results_by_benchmark ) > 0 :
469474 benchmark_tree = Tree ("Speedup percentage on benchmarks:" )
470475 for benchmark_key , candidate_test_results in test_results_by_benchmark .items ():
471-
472- original_code_replay_runtime = original_code_baseline .replay_benchmarking_test_results [benchmark_key ].total_passed_runtime ()
476+ original_code_replay_runtime = (
477+ original_code_baseline .replay_benchmarking_test_results [
478+ benchmark_key
479+ ].total_passed_runtime ()
480+ )
473481 candidate_replay_runtime = candidate_test_results .total_passed_runtime ()
474482 replay_perf_gain [benchmark_key ] = performance_gain (
475483 original_runtime_ns = original_code_replay_runtime ,
@@ -511,16 +519,16 @@ def determine_best_candidate(
511519 logger .exception (f"Optimization interrupted: { e } " )
512520 raise
513521
514- self . aiservice_client .log_results (
515- function_trace_id = self .function_trace_id ,
522+ ai_service_client .log_results (
523+ function_trace_id = self .function_trace_id [: - 4 ] + exp_type if self . experiment_id else self . function_trace_id ,
516524 speedup_ratio = speedup_ratios ,
517525 original_runtime = original_code_baseline .runtime ,
518526 optimized_runtime = optimized_runtimes ,
519527 is_correct = is_correct ,
520528 )
521529 return best_optimization
522530
523- def log_successful_optimization (self , explanation : Explanation , generated_tests : GeneratedTestsList ) -> None :
531+ def log_successful_optimization (self , explanation : Explanation , generated_tests : GeneratedTestsList , exp_type : str ) -> None :
524532 explanation_panel = Panel (
525533 f"⚡️ Optimization successful! 📄 { self .function_to_optimize .qualified_name } in { explanation .file_path } \n "
526534 f"📈 { explanation .perf_improvement_line } \n "
@@ -547,7 +555,7 @@ def log_successful_optimization(self, explanation: Explanation, generated_tests:
547555 ph (
548556 "cli-optimize-success" ,
549557 {
550- "function_trace_id" : self .function_trace_id ,
558+ "function_trace_id" : self .function_trace_id [: - 4 ] + exp_type if self . experiment_id else self . function_trace_id ,
551559 "speedup_x" : explanation .speedup_x ,
552560 "speedup_pct" : explanation .speedup_pct ,
553561 "best_runtime" : explanation .best_runtime_ns ,
@@ -958,13 +966,17 @@ def establish_original_code_baseline(
958966 logger .debug (f"Total original code runtime (ns): { total_timing } " )
959967
960968 if self .args .benchmark :
961- replay_benchmarking_test_results = benchmarking_results .group_by_benchmarks (self .total_benchmark_timings .keys (), self .replay_tests_dir , self .project_root )
969+ replay_benchmarking_test_results = benchmarking_results .group_by_benchmarks (
970+ self .total_benchmark_timings .keys (), self .replay_tests_dir , self .project_root
971+ )
962972 return Success (
963973 (
964974 OriginalCodeBaseline (
965975 behavioral_test_results = behavioral_results ,
966976 benchmarking_test_results = benchmarking_results ,
967- replay_benchmarking_test_results = replay_benchmarking_test_results if self .args .benchmark else None ,
977+ replay_benchmarking_test_results = replay_benchmarking_test_results
978+ if self .args .benchmark
979+ else None ,
968980 runtime = total_timing ,
969981 coverage_results = coverage_results ,
970982 line_profile_results = line_profile_results ,
@@ -1077,16 +1089,22 @@ def run_optimized_candidate(
10771089
10781090 logger .debug (f"Total optimized code { optimization_candidate_index } runtime (ns): { total_candidate_timing } " )
10791091 if self .args .benchmark :
1080- candidate_replay_benchmarking_results = candidate_benchmarking_results .group_by_benchmarks (self .total_benchmark_timings .keys (), self .replay_tests_dir , self .project_root )
1092+ candidate_replay_benchmarking_results = candidate_benchmarking_results .group_by_benchmarks (
1093+ self .total_benchmark_timings .keys (), self .replay_tests_dir , self .project_root
1094+ )
10811095 for benchmark_name , benchmark_results in candidate_replay_benchmarking_results .items ():
1082- logger .debug (f"Benchmark { benchmark_name } runtime (ns): { humanize_runtime (benchmark_results .total_passed_runtime ())} " )
1096+ logger .debug (
1097+ f"Benchmark { benchmark_name } runtime (ns): { humanize_runtime (benchmark_results .total_passed_runtime ())} "
1098+ )
10831099 return Success (
10841100 OptimizedCandidateResult (
10851101 max_loop_count = loop_count ,
10861102 best_test_runtime = total_candidate_timing ,
10871103 behavior_test_results = candidate_behavior_results ,
10881104 benchmarking_test_results = candidate_benchmarking_results ,
1089- replay_benchmarking_test_results = candidate_replay_benchmarking_results if self .args .benchmark else None ,
1105+ replay_benchmarking_test_results = candidate_replay_benchmarking_results
1106+ if self .args .benchmark
1107+ else None ,
10901108 optimization_candidate_index = optimization_candidate_index ,
10911109 total_candidate_timing = total_candidate_timing ,
10921110 )
0 commit comments