@@ -88,8 +88,8 @@ def __init__(
8888 function_to_tests : dict [str , list [FunctionCalledInTest ]] | None = None ,
8989 function_to_optimize_ast : ast .FunctionDef | None = None ,
9090 aiservice_client : AiServiceClient | None = None ,
91- function_benchmark_timings : dict [str , int ] | None = None ,
92- total_benchmark_timings : dict [str , int ] | None = None ,
91+ function_benchmark_timings : dict [BenchmarkKey , int ] | None = None ,
92+ total_benchmark_timings : dict [BenchmarkKey , int ] | None = None ,
9393 args : Namespace | None = None ,
9494 ) -> None :
9595 self .project_root = test_cfg .project_root_path
@@ -428,20 +428,24 @@ def determine_best_candidate(
428428 tree .add (f"Speedup percentage: { perf_gain * 100 :.1f} %" )
429429 tree .add (f"Speedup ratio: { perf_gain + 1 :.1f} X" )
430430 if self .args .benchmark :
431- original_code_replay_runtime = original_code_baseline .replay_benchmarking_test_results .total_passed_runtime ()
432- candidate_replay_runtime = candidate_result .replay_benchmarking_test_results .total_passed_runtime ()
433- replay_perf_gain = performance_gain (
434- original_runtime_ns = original_code_replay_runtime ,
435- optimized_runtime_ns = candidate_replay_runtime ,
436- )
437- tree .add (f"Original benchmark replay runtime: { humanize_runtime (original_code_replay_runtime )} " )
438- tree .add (
439- f"Best benchmark replay runtime: { humanize_runtime (candidate_replay_runtime )} "
440- f"(measured over { candidate_result .max_loop_count } "
441- f"loop{ 's' if candidate_result .max_loop_count > 1 else '' } )"
442- )
443- tree .add (f"Speedup percentage for benchmark replay test: { replay_perf_gain * 100 :.1f} %" )
444- tree .add (f"Speedup ratio for benchmark replay test: { replay_perf_gain + 1 :.1f} X" )
431+
432+ benchmark_keys = {(benchmark .file_name , benchmark .function_name ) for benchmark in self .total_benchmark_timings }
433+ test_results_by_benchmark = candidate_result .benchmarking_test_results .group_by_benchmark (benchmark_keys )
434+ for benchmark_key , test_results in test_results_by_benchmark .items ():
435+ original_code_replay_runtime = original_code_baseline .replay_benchmarking_test_results [benchmark_key ].total_passed_runtime ()
436+ candidate_replay_runtime = candidate_result .replay_benchmarking_test_results .total_passed_runtime ()
437+ replay_perf_gain = performance_gain (
438+ original_runtime_ns = original_code_replay_runtime ,
439+ optimized_runtime_ns = candidate_replay_runtime ,
440+ )
441+ tree .add (f"Original benchmark replay runtime: { humanize_runtime (original_code_replay_runtime )} " )
442+ tree .add (
443+ f"Best benchmark replay runtime: { humanize_runtime (candidate_replay_runtime )} "
444+ f"(measured over { candidate_result .max_loop_count } "
445+ f"loop{ 's' if candidate_result .max_loop_count > 1 else '' } )"
446+ )
447+ tree .add (f"Speedup percentage for benchmark replay test: { replay_perf_gain * 100 :.1f} %" )
448+ tree .add (f"Speedup ratio for benchmark replay test: { replay_perf_gain + 1 :.1f} X" )
445449 best_optimization = BestOptimization (
446450 candidate = candidate ,
447451 helper_functions = code_context .helper_functions ,
@@ -898,7 +902,7 @@ def establish_original_code_baseline(
898902 logger .debug (f"Total original code runtime (ns): { total_timing } " )
899903
900904 if self .args .benchmark :
901- replay_benchmarking_test_results = benchmarking_results .filter (TestType .REPLAY_TEST )
905+ replay_benchmarking_test_results = benchmarking_results .filter_by_test_type (TestType .REPLAY_TEST )
902906 logger .info (f"Total replay test runtime: { humanize_runtime (replay_benchmarking_test_results .total_passed_runtime ())} " )
903907 return Success (
904908 (
@@ -1020,7 +1024,7 @@ def run_optimized_candidate(
10201024
10211025 logger .debug (f"Total optimized code { optimization_candidate_index } runtime (ns): { total_candidate_timing } " )
10221026 if self .args .benchmark :
1023- candidate_replay_benchmarking_results = candidate_benchmarking_results .filter (TestType .REPLAY_TEST )
1027+ candidate_replay_benchmarking_results = candidate_benchmarking_results .filter_by_test_type (TestType .REPLAY_TEST )
10241028 logger .debug (
10251029 f"Total optimized code { optimization_candidate_index } replay benchmark runtime (ns): { candidate_replay_benchmarking_results .total_passed_runtime ()} "
10261030 )
0 commit comments