@@ -94,7 +94,7 @@ def __init__(
9494 function_benchmark_timings : dict [BenchmarkKey , int ] | None = None ,
9595 total_benchmark_timings : dict [BenchmarkKey , int ] | None = None ,
9696 args : Namespace | None = None ,
97- replay_tests_dir : Path | None = None
97+ replay_tests_dir : Path | None = None ,
9898 ) -> None :
9999 self .project_root = test_cfg .project_root_path
100100 self .test_cfg = test_cfg
@@ -273,7 +273,7 @@ def optimize_function(self) -> Result[BestOptimization, str]:
273273 processed_benchmark_info = process_benchmark_data (
274274 replay_performance_gain = best_optimization .replay_performance_gain ,
275275 fto_benchmark_timings = self .function_benchmark_timings ,
276- total_benchmark_timings = self .total_benchmark_timings
276+ total_benchmark_timings = self .total_benchmark_timings ,
277277 )
278278 explanation = Explanation (
279279 raw_explanation_message = best_optimization .candidate .explanation ,
@@ -283,7 +283,7 @@ def optimize_function(self) -> Result[BestOptimization, str]:
283283 best_runtime_ns = best_optimization .runtime ,
284284 function_name = function_to_optimize_qualified_name ,
285285 file_path = self .function_to_optimize .file_path ,
286- benchmark_details = processed_benchmark_info .benchmark_details if processed_benchmark_info else None
286+ benchmark_details = processed_benchmark_info .benchmark_details if processed_benchmark_info else None ,
287287 )
288288
289289 self .log_successful_optimization (explanation , generated_tests )
@@ -328,7 +328,7 @@ def optimize_function(self) -> Result[BestOptimization, str]:
328328 coverage_message = coverage_message ,
329329 git_remote = self .args .git_remote ,
330330 )
331- if self .args .all or env_utils .get_pr_number ():
331+ if self .args .all or env_utils .get_pr_number () or ( self . args . file and not self . args . function ) :
332332 self .write_code_and_helpers (
333333 self .function_to_optimize_source_code ,
334334 original_helper_code ,
@@ -397,8 +397,10 @@ def determine_best_candidate(
397397 if done and (future_line_profile_results is not None ):
398398 line_profile_results = future_line_profile_results .result ()
399399 candidates .extend (line_profile_results )
400- original_len += len (candidates )
401- logger .info (f"Added results from line profiler to candidates, total candidates now: { original_len } " )
400+ original_len += len (candidates )
401+ logger .info (
402+ f"Added results from line profiler to candidates, total candidates now: { original_len } "
403+ )
402404 future_line_profile_results = None
403405 candidate_index += 1
404406 candidate = candidates .popleft ()
@@ -425,7 +427,6 @@ def determine_best_candidate(
425427 )
426428 continue
427429
428-
429430 run_results = self .run_optimized_candidate (
430431 optimization_candidate_index = candidate_index ,
431432 baseline_results = original_code_baseline ,
@@ -464,12 +465,19 @@ def determine_best_candidate(
464465 tree .add (f"Speedup ratio: { perf_gain + 1 :.1f} X" )
465466 replay_perf_gain = {}
466467 if self .args .benchmark :
467- test_results_by_benchmark = candidate_result .benchmarking_test_results .group_by_benchmarks (self .total_benchmark_timings .keys (), self .replay_tests_dir , self .project_root )
468+ test_results_by_benchmark = (
469+ candidate_result .benchmarking_test_results .group_by_benchmarks (
470+ self .total_benchmark_timings .keys (), self .replay_tests_dir , self .project_root
471+ )
472+ )
468473 if len (test_results_by_benchmark ) > 0 :
469474 benchmark_tree = Tree ("Speedup percentage on benchmarks:" )
470475 for benchmark_key , candidate_test_results in test_results_by_benchmark .items ():
471-
472- original_code_replay_runtime = original_code_baseline .replay_benchmarking_test_results [benchmark_key ].total_passed_runtime ()
476+ original_code_replay_runtime = (
477+ original_code_baseline .replay_benchmarking_test_results [
478+ benchmark_key
479+ ].total_passed_runtime ()
480+ )
473481 candidate_replay_runtime = candidate_test_results .total_passed_runtime ()
474482 replay_perf_gain [benchmark_key ] = performance_gain (
475483 original_runtime_ns = original_code_replay_runtime ,
@@ -958,13 +966,17 @@ def establish_original_code_baseline(
958966 logger .debug (f"Total original code runtime (ns): { total_timing } " )
959967
960968 if self .args .benchmark :
961- replay_benchmarking_test_results = benchmarking_results .group_by_benchmarks (self .total_benchmark_timings .keys (), self .replay_tests_dir , self .project_root )
969+ replay_benchmarking_test_results = benchmarking_results .group_by_benchmarks (
970+ self .total_benchmark_timings .keys (), self .replay_tests_dir , self .project_root
971+ )
962972 return Success (
963973 (
964974 OriginalCodeBaseline (
965975 behavioral_test_results = behavioral_results ,
966976 benchmarking_test_results = benchmarking_results ,
967- replay_benchmarking_test_results = replay_benchmarking_test_results if self .args .benchmark else None ,
977+ replay_benchmarking_test_results = replay_benchmarking_test_results
978+ if self .args .benchmark
979+ else None ,
968980 runtime = total_timing ,
969981 coverage_results = coverage_results ,
970982 line_profile_results = line_profile_results ,
@@ -1077,16 +1089,22 @@ def run_optimized_candidate(
10771089
10781090 logger .debug (f"Total optimized code { optimization_candidate_index } runtime (ns): { total_candidate_timing } " )
10791091 if self .args .benchmark :
1080- candidate_replay_benchmarking_results = candidate_benchmarking_results .group_by_benchmarks (self .total_benchmark_timings .keys (), self .replay_tests_dir , self .project_root )
1092+ candidate_replay_benchmarking_results = candidate_benchmarking_results .group_by_benchmarks (
1093+ self .total_benchmark_timings .keys (), self .replay_tests_dir , self .project_root
1094+ )
10811095 for benchmark_name , benchmark_results in candidate_replay_benchmarking_results .items ():
1082- logger .debug (f"Benchmark { benchmark_name } runtime (ns): { humanize_runtime (benchmark_results .total_passed_runtime ())} " )
1096+ logger .debug (
1097+ f"Benchmark { benchmark_name } runtime (ns): { humanize_runtime (benchmark_results .total_passed_runtime ())} "
1098+ )
10831099 return Success (
10841100 OptimizedCandidateResult (
10851101 max_loop_count = loop_count ,
10861102 best_test_runtime = total_candidate_timing ,
10871103 behavior_test_results = candidate_behavior_results ,
10881104 benchmarking_test_results = candidate_benchmarking_results ,
1089- replay_benchmarking_test_results = candidate_replay_benchmarking_results if self .args .benchmark else None ,
1105+ replay_benchmarking_test_results = candidate_replay_benchmarking_results
1106+ if self .args .benchmark
1107+ else None ,
10901108 optimization_candidate_index = optimization_candidate_index ,
10911109 total_candidate_timing = total_candidate_timing ,
10921110 )
0 commit comments