11from rich .console import Console
22from rich .table import Table
33
4+ from codeflash .cli_cmds .console import logger
45
5- def print_benchmark_table (function_benchmark_timings : dict [str , dict [str , int ]],
6- total_benchmark_timings : dict [str , int ]):
7- console = Console ()
86
7+ def validate_and_format_benchmark_table (function_benchmark_timings : dict [str , dict [str , int ]],
8+ total_benchmark_timings : dict [str , int ]) -> dict [str , list [tuple [str , float , float , float ]]]:
9+ function_to_result = {}
910 # Process each function's benchmark data
1011 for func_path , test_times in function_benchmark_timings .items ():
11- function_name = func_path .split (":" )[- 1 ]
12-
13- # Create a table for this function
14- table = Table (title = f"Function: { function_name } " , border_style = "blue" )
15-
16- # Add columns
17- table .add_column ("Benchmark Test" , style = "cyan" , no_wrap = True )
18- table .add_column ("Total Time (ms)" , justify = "right" , style = "green" )
19- table .add_column ("Function Time (ms)" , justify = "right" , style = "yellow" )
20- table .add_column ("Percentage (%)" , justify = "right" , style = "red" )
21-
2212 # Sort by percentage (highest first)
2313 sorted_tests = []
2414 for test_name , func_time in test_times .items ():
2515 total_time = total_benchmark_timings .get (test_name , 0 )
16+ if func_time > total_time :
17+ logger .debug (f"Skipping test { test_name } due to func_time { func_time } > total_time { total_time } " )
18+ # If the function time is greater than total time, likely to have multithreading / multiprocessing issues.
19+ # Do not try to project the optimization impact for this function.
20+ sorted_tests .append ((test_name , 0.0 , 0.0 , 0.0 ))
2621 if total_time > 0 :
2722 percentage = (func_time / total_time ) * 100
2823 # Convert nanoseconds to milliseconds
2924 func_time_ms = func_time / 1_000_000
3025 total_time_ms = total_time / 1_000_000
3126 sorted_tests .append ((test_name , total_time_ms , func_time_ms , percentage ))
32-
3327 sorted_tests .sort (key = lambda x : x [3 ], reverse = True )
28+ function_to_result [func_path ] = sorted_tests
29+ return function_to_result
30+
31+ def print_benchmark_table (function_to_results : dict [str , list [tuple [str , float , float , float ]]]) -> None :
32+ console = Console ()
33+ for func_path , sorted_tests in function_to_results .items ():
34+ function_name = func_path .split (":" )[- 1 ]
35+
36+ # Create a table for this function
37+ table = Table (title = f"Function: { function_name } " , border_style = "blue" )
38+
39+ # Add columns
40+ table .add_column ("Benchmark Test" , style = "cyan" , no_wrap = True )
41+ table .add_column ("Total Time (ms)" , justify = "right" , style = "green" )
42+ table .add_column ("Function Time (ms)" , justify = "right" , style = "yellow" )
43+ table .add_column ("Percentage (%)" , justify = "right" , style = "red" )
3444
35- # Add rows to the table
3645 for test_name , total_time , func_time , percentage in sorted_tests :
3746 benchmark_file , benchmark_func , benchmark_line = test_name .split ("::" )
3847 benchmark_name = f"{ benchmark_file } ::{ benchmark_func } "
39- table .add_row (
40- benchmark_name ,
41- f"{ total_time :.3f} " ,
42- f"{ func_time :.3f} " ,
43- f"{ percentage :.2f} "
44- )
48+ if total_time == 0.0 :
49+ table .add_row (
50+ benchmark_name ,
51+ "N/A" ,
52+ "N/A" ,
53+ "N/A"
54+ )
55+ else :
56+ table .add_row (
57+ benchmark_name ,
58+ f"{ total_time :.3f} " ,
59+ f"{ func_time :.3f} " ,
60+ f"{ percentage :.2f} "
61+ )
4562
4663 # Print the table
4764 console .print (table )
0 commit comments