@@ -94,7 +94,7 @@ def __init__(
9494        function_benchmark_timings : dict [BenchmarkKey , int ] |  None  =  None ,
9595        total_benchmark_timings : dict [BenchmarkKey , int ] |  None  =  None ,
9696        args : Namespace  |  None  =  None ,
97-         replay_tests_dir : Path | None  =  None 
97+         replay_tests_dir : Path   |   None  =  None , 
9898    ) ->  None :
9999        self .project_root  =  test_cfg .project_root_path 
100100        self .test_cfg  =  test_cfg 
@@ -273,7 +273,7 @@ def optimize_function(self) -> Result[BestOptimization, str]:
273273                    processed_benchmark_info  =  process_benchmark_data (
274274                        replay_performance_gain = best_optimization .replay_performance_gain ,
275275                        fto_benchmark_timings = self .function_benchmark_timings ,
276-                         total_benchmark_timings = self .total_benchmark_timings 
276+                         total_benchmark_timings = self .total_benchmark_timings , 
277277                    )
278278                explanation  =  Explanation (
279279                    raw_explanation_message = best_optimization .candidate .explanation ,
@@ -283,7 +283,7 @@ def optimize_function(self) -> Result[BestOptimization, str]:
283283                    best_runtime_ns = best_optimization .runtime ,
284284                    function_name = function_to_optimize_qualified_name ,
285285                    file_path = self .function_to_optimize .file_path ,
286-                     benchmark_details = processed_benchmark_info .benchmark_details  if  processed_benchmark_info  else  None 
286+                     benchmark_details = processed_benchmark_info .benchmark_details  if  processed_benchmark_info  else  None , 
287287                )
288288
289289                self .log_successful_optimization (explanation , generated_tests )
@@ -328,7 +328,7 @@ def optimize_function(self) -> Result[BestOptimization, str]:
328328                        coverage_message = coverage_message ,
329329                        git_remote = self .args .git_remote ,
330330                    )
331-                     if  self .args .all  or  env_utils .get_pr_number ():
331+                     if  self .args .all  or  env_utils .get_pr_number ()  or  ( self . args . file   and   not   self . args . function ) :
332332                        self .write_code_and_helpers (
333333                            self .function_to_optimize_source_code ,
334334                            original_helper_code ,
@@ -397,8 +397,10 @@ def determine_best_candidate(
397397                    if  done  and  (future_line_profile_results  is  not   None ):
398398                        line_profile_results  =  future_line_profile_results .result ()
399399                        candidates .extend (line_profile_results )
400-                         original_len +=  len (candidates )
401-                         logger .info (f"Added results from line profiler to candidates, total candidates now: { original_len }  " )
400+                         original_len  +=  len (candidates )
401+                         logger .info (
402+                             f"Added results from line profiler to candidates, total candidates now: { original_len }  " 
403+                         )
402404                        future_line_profile_results  =  None 
403405                    candidate_index  +=  1 
404406                    candidate  =  candidates .popleft ()
@@ -425,7 +427,6 @@ def determine_best_candidate(
425427                        )
426428                        continue 
427429
428- 
429430                    run_results  =  self .run_optimized_candidate (
430431                        optimization_candidate_index = candidate_index ,
431432                        baseline_results = original_code_baseline ,
@@ -464,12 +465,19 @@ def determine_best_candidate(
464465                            tree .add (f"Speedup ratio: { perf_gain  +  1 :.1f}  X" )
465466                            replay_perf_gain  =  {}
466467                            if  self .args .benchmark :
467-                                 test_results_by_benchmark  =  candidate_result .benchmarking_test_results .group_by_benchmarks (self .total_benchmark_timings .keys (), self .replay_tests_dir , self .project_root )
468+                                 test_results_by_benchmark  =  (
469+                                     candidate_result .benchmarking_test_results .group_by_benchmarks (
470+                                         self .total_benchmark_timings .keys (), self .replay_tests_dir , self .project_root 
471+                                     )
472+                                 )
468473                                if  len (test_results_by_benchmark ) >  0 :
469474                                    benchmark_tree  =  Tree ("Speedup percentage on benchmarks:" )
470475                                for  benchmark_key , candidate_test_results  in  test_results_by_benchmark .items ():
471- 
472-                                     original_code_replay_runtime  =  original_code_baseline .replay_benchmarking_test_results [benchmark_key ].total_passed_runtime ()
476+                                     original_code_replay_runtime  =  (
477+                                         original_code_baseline .replay_benchmarking_test_results [
478+                                             benchmark_key 
479+                                         ].total_passed_runtime ()
480+                                     )
473481                                    candidate_replay_runtime  =  candidate_test_results .total_passed_runtime ()
474482                                    replay_perf_gain [benchmark_key ] =  performance_gain (
475483                                        original_runtime_ns = original_code_replay_runtime ,
@@ -958,13 +966,17 @@ def establish_original_code_baseline(
958966            logger .debug (f"Total original code runtime (ns): { total_timing }  " )
959967
960968            if  self .args .benchmark :
961-                 replay_benchmarking_test_results  =  benchmarking_results .group_by_benchmarks (self .total_benchmark_timings .keys (), self .replay_tests_dir , self .project_root )
969+                 replay_benchmarking_test_results  =  benchmarking_results .group_by_benchmarks (
970+                     self .total_benchmark_timings .keys (), self .replay_tests_dir , self .project_root 
971+                 )
962972            return  Success (
963973                (
964974                    OriginalCodeBaseline (
965975                        behavioral_test_results = behavioral_results ,
966976                        benchmarking_test_results = benchmarking_results ,
967-                         replay_benchmarking_test_results  =  replay_benchmarking_test_results  if  self .args .benchmark  else  None ,
977+                         replay_benchmarking_test_results = replay_benchmarking_test_results 
978+                         if  self .args .benchmark 
979+                         else  None ,
968980                        runtime = total_timing ,
969981                        coverage_results = coverage_results ,
970982                        line_profile_results = line_profile_results ,
@@ -1077,16 +1089,22 @@ def run_optimized_candidate(
10771089
10781090            logger .debug (f"Total optimized code { optimization_candidate_index }   runtime (ns): { total_candidate_timing }  " )
10791091            if  self .args .benchmark :
1080-                 candidate_replay_benchmarking_results  =  candidate_benchmarking_results .group_by_benchmarks (self .total_benchmark_timings .keys (), self .replay_tests_dir , self .project_root )
1092+                 candidate_replay_benchmarking_results  =  candidate_benchmarking_results .group_by_benchmarks (
1093+                     self .total_benchmark_timings .keys (), self .replay_tests_dir , self .project_root 
1094+                 )
10811095                for  benchmark_name , benchmark_results  in  candidate_replay_benchmarking_results .items ():
1082-                     logger .debug (f"Benchmark { benchmark_name }   runtime (ns): { humanize_runtime (benchmark_results .total_passed_runtime ())}  " )
1096+                     logger .debug (
1097+                         f"Benchmark { benchmark_name }   runtime (ns): { humanize_runtime (benchmark_results .total_passed_runtime ())}  " 
1098+                     )
10831099            return  Success (
10841100                OptimizedCandidateResult (
10851101                    max_loop_count = loop_count ,
10861102                    best_test_runtime = total_candidate_timing ,
10871103                    behavior_test_results = candidate_behavior_results ,
10881104                    benchmarking_test_results = candidate_benchmarking_results ,
1089-                     replay_benchmarking_test_results  =  candidate_replay_benchmarking_results  if  self .args .benchmark  else  None ,
1105+                     replay_benchmarking_test_results = candidate_replay_benchmarking_results 
1106+                     if  self .args .benchmark 
1107+                     else  None ,
10901108                    optimization_candidate_index = optimization_candidate_index ,
10911109                    total_candidate_timing = total_candidate_timing ,
10921110                )
0 commit comments