123123parser .add_argument ("-s" , "--show" , help = help_s )
124124parser .add_argument ("--verbose" , action = "store_true" , help = "increase output verbosity" )
125125parser .add_argument ("--plot" , help = "generate a performance comparison plot and save to specified file (e.g., plot.png)" )
126- parser .add_argument ("--plot_x" , help = "parameter to use as x- axis for plotting (default: n_depth)" , default = "n_depth" )
126+ parser .add_argument ("--plot_x" , help = "parameter to use as x axis for plotting (default: n_depth)" , default = "n_depth" )
127127
128128known_args , unknown_args = parser .parse_known_args ()
129129
136136 import matplotlib
137137 matplotlib .use ('Agg' )
138138 except ImportError as e :
139- print ("matplotlib is required for --plot." )
139+ logger . error ("matplotlib is required for --plot." )
140140 raise e
141141
142142if known_args .check :
@@ -613,9 +613,9 @@ def valid_format(data_files: list[str]) -> bool:
613613headers += ["Test" , f"t/s { name_baseline } " , f"t/s { name_compare } " , "Speedup" ]
614614
615615if known_args .plot :
616- def create_performance_plot (table_data , headers , baseline_name , compare_name , output_file , plot_x_param ):
616+ def create_performance_plot (table_data : list [ list [ str ]] , headers : list [ str ] , baseline_name : str , compare_name : str , output_file : str , plot_x_param : str ):
617617
618- data_headers = headers [:- 4 ] #Exclude the last 4 columns (Test, baseline t/s, compare t/s, Speedup)
618+ data_headers = headers [:- 4 ] # Exclude the last 4 columns (Test, baseline t/s, compare t/s, Speedup)
619619 plot_x_index = None
620620 plot_x_label = plot_x_param
621621
@@ -687,7 +687,6 @@ def create_performance_plot(table_data, headers, baseline_name, compare_name, ou
687687 logger .error ("No data available for plotting" )
688688 return
689689
690-
691690 def make_axes (num_groups , max_cols = 2 , base_size = (8 , 4 )):
692691 from math import ceil
693692 cols = 1 if num_groups == 1 else min (max_cols , num_groups )
@@ -696,8 +695,8 @@ def make_axes(num_groups, max_cols=2, base_size=(8, 4)):
696695 # scale figure size by grid dimensions
697696 w , h = base_size
698697 fig , ax_arr = plt .subplots (rows , cols ,
699- figsize = (w * cols , h * rows ),
700- squeeze = False )
698+ figsize = (w * cols , h * rows ),
699+ squeeze = False )
701700
702701 axes = ax_arr .flatten ()[:num_groups ]
703702 return fig , axes
@@ -739,7 +738,7 @@ def make_axes(num_groups, max_cols=2, base_size=(8, 4)):
739738 key , value = part .split ('=' , 1 )
740739 title_parts .append (f"{ key } : { value } " )
741740
742- title = ', ' .join (title_parts ) if title_parts else "Performance Comparison "
741+ title = ', ' .join (title_parts ) if title_parts else "Performance comparison "
743742
744743 ax .set_xlabel (plot_x_label , fontsize = 12 , fontweight = 'bold' )
745744 ax .set_ylabel ('Tokens per Second (t/s)' , fontsize = 12 , fontweight = 'bold' )
@@ -752,11 +751,10 @@ def make_axes(num_groups, max_cols=2, base_size=(8, 4)):
752751 for i in range (plot_idx , len (axes )):
753752 axes [i ].set_visible (False )
754753
755- fig .suptitle (f'Performance Comparison : { compare_name } vs { baseline_name } ' ,
756- fontsize = 14 , fontweight = 'bold' )
754+ fig .suptitle (f'Performance comparison : { compare_name } vs { baseline_name } ' ,
755+ fontsize = 14 , fontweight = 'bold' )
757756 fig .subplots_adjust (top = 1 )
758757
759-
760758 plt .tight_layout ()
761759 plt .savefig (output_file , dpi = 300 , bbox_inches = 'tight' )
762760 plt .close ()
0 commit comments