129129
130130logging .basicConfig (level = logging .DEBUG if known_args .verbose else logging .INFO )
131131
132- # Check for matplotlib if plotting is requested
133132if known_args .plot :
134133 try :
135134 import matplotlib .pyplot as plt
@@ -511,7 +510,6 @@ def valid_format(data_files: list[str]) -> bool:
511510
512511name_compare = bench_data .get_commit_name (hexsha8_compare )
513512
514-
515513# If the user provided columns to group the results by, use them:
516514if known_args .show is not None :
517515 show = known_args .show .split ("," )
@@ -556,6 +554,14 @@ def valid_format(data_files: list[str]) -> bool:
556554 show .remove (prop )
557555 except ValueError :
558556 pass
557+
558+ # add plot_x parameter to if it's not already there
559+ if known_args .plot :
560+ for k , v in PRETTY_NAMES .items ():
561+ if v == known_args .plot_x and k not in show :
562+ show .append (k )
563+ break
564+
559565 rows_show = bench_data .get_rows (show , hexsha8_baseline , hexsha8_compare )
560566
561567if not rows_show :
@@ -629,7 +635,6 @@ def create_performance_plot(table_data: list[list[str]], headers: list[str], bas
629635 plot_x_label = plot_x_param
630636 else :
631637 logger .error (f"Parameter '{ plot_x_param } ' not found in current table columns. Available columns: { ', ' .join (data_headers )} " )
632- logger .error (f"To plot by '{ plot_x_param } ', include it in --show parameter or ensure it varies in your data." )
633638 return
634639
635640 grouped_data = {}
@@ -671,7 +676,7 @@ def create_performance_plot(table_data: list[list[str]], headers: list[str], bas
671676
672677 group_key_parts .append (f"Test={ test_name } " )
673678
674- group_key = tuple (sorted ( group_key_parts ) )
679+ group_key = tuple (group_key_parts )
675680
676681 if group_key not in grouped_data :
677682 grouped_data [group_key ] = []
@@ -692,7 +697,7 @@ def make_axes(num_groups, max_cols=2, base_size=(8, 4)):
692697 cols = 1 if num_groups == 1 else min (max_cols , num_groups )
693698 rows = ceil (num_groups / cols )
694699
695- # scale figure size by grid dimensions
700+ # Scale figure size by grid dimensions
696701 w , h = base_size
697702 fig , ax_arr = plt .subplots (rows , cols ,
698703 figsize = (w * cols , h * rows ),
@@ -726,7 +731,7 @@ def make_axes(num_groups, max_cols=2, base_size=(8, 4)):
726731 ax .plot (x_values , compare_vals , 's--' , color = 'lightcoral' , alpha = 0.8 ,
727732 label = f'{ compare_name } ' , linewidth = 2 , markersize = 6 )
728733
729- if plot_x_param == "n_depth" and max (x_values ) > 0 and max (x_values ) > min (x_values ) * 4 :
734+ if plot_x_param == "n_depth" and min (x_values ) > 0 and max (x_values ) > min (x_values ) * 4 :
730735 ax .set_xscale ('log' , base = 2 )
731736 unique_x = sorted (set (x_values ))
732737 ax .set_xticks (unique_x )
@@ -741,7 +746,7 @@ def make_axes(num_groups, max_cols=2, base_size=(8, 4)):
741746 title = ', ' .join (title_parts ) if title_parts else "Performance comparison"
742747
743748 ax .set_xlabel (plot_x_label , fontsize = 12 , fontweight = 'bold' )
744- ax .set_ylabel ('Tokens per Second (t/s)' , fontsize = 12 , fontweight = 'bold' )
749+ ax .set_ylabel ('Tokens per second (t/s)' , fontsize = 12 , fontweight = 'bold' )
745750 ax .set_title (title , fontsize = 12 , fontweight = 'bold' )
746751 ax .legend (loc = 'best' , fontsize = 10 )
747752 ax .grid (True , alpha = 0.3 )
@@ -751,7 +756,7 @@ def make_axes(num_groups, max_cols=2, base_size=(8, 4)):
751756 for i in range (plot_idx , len (axes )):
752757 axes [i ].set_visible (False )
753758
754- fig .suptitle (f'Performance comparison: { compare_name } vs { baseline_name } ' ,
759+ fig .suptitle (f'Performance comparison: { compare_name } vs. { baseline_name } ' ,
755760 fontsize = 14 , fontweight = 'bold' )
756761 fig .subplots_adjust (top = 1 )
757762
0 commit comments