@@ -511,7 +511,6 @@ def valid_format(data_files: list[str]) -> bool:
511511
512512name_compare = bench_data .get_commit_name (hexsha8_compare )
513513
514-
515514# If the user provided columns to group the results by, use them:
516515if known_args .show is not None :
517516 show = known_args .show .split ("," )
@@ -556,6 +555,14 @@ def valid_format(data_files: list[str]) -> bool:
556555 show .remove (prop )
557556 except ValueError :
558557 pass
558+
559+ # add plot_x parameter to if it's not already there
560+ if known_args .plot :
561+ for k , v in PRETTY_NAMES .items ():
562+ if v == known_args .plot_x and k not in show :
563+ show .append (k )
564+ break
565+
559566 rows_show = bench_data .get_rows (show , hexsha8_baseline , hexsha8_compare )
560567
561568if not rows_show :
@@ -629,7 +636,6 @@ def create_performance_plot(table_data: list[list[str]], headers: list[str], bas
629636 plot_x_label = plot_x_param
630637 else :
631638 logger .error (f"Parameter '{ plot_x_param } ' not found in current table columns. Available columns: { ', ' .join (data_headers )} " )
632- logger .error (f"To plot by '{ plot_x_param } ', include it in --show parameter or ensure it varies in your data." )
633639 return
634640
635641 grouped_data = {}
@@ -671,7 +677,7 @@ def create_performance_plot(table_data: list[list[str]], headers: list[str], bas
671677
672678 group_key_parts .append (f"Test={ test_name } " )
673679
674- group_key = tuple (sorted ( group_key_parts ) )
680+ group_key = tuple (group_key_parts )
675681
676682 if group_key not in grouped_data :
677683 grouped_data [group_key ] = []
@@ -692,7 +698,7 @@ def make_axes(num_groups, max_cols=2, base_size=(8, 4)):
692698 cols = 1 if num_groups == 1 else min (max_cols , num_groups )
693699 rows = ceil (num_groups / cols )
694700
695- # scale figure size by grid dimensions
701+ # Scale figure size by grid dimensions
696702 w , h = base_size
697703 fig , ax_arr = plt .subplots (rows , cols ,
698704 figsize = (w * cols , h * rows ),
@@ -726,7 +732,7 @@ def make_axes(num_groups, max_cols=2, base_size=(8, 4)):
726732 ax .plot (x_values , compare_vals , 's--' , color = 'lightcoral' , alpha = 0.8 ,
727733 label = f'{ compare_name } ' , linewidth = 2 , markersize = 6 )
728734
729- if plot_x_param == "n_depth" and max (x_values ) > 0 and max (x_values ) > min (x_values ) * 4 :
735+ if plot_x_param == "n_depth" and min (x_values ) > 0 and max (x_values ) > min (x_values ) * 4 :
730736 ax .set_xscale ('log' , base = 2 )
731737 unique_x = sorted (set (x_values ))
732738 ax .set_xticks (unique_x )
@@ -741,7 +747,7 @@ def make_axes(num_groups, max_cols=2, base_size=(8, 4)):
741747 title = ', ' .join (title_parts ) if title_parts else "Performance comparison"
742748
743749 ax .set_xlabel (plot_x_label , fontsize = 12 , fontweight = 'bold' )
744- ax .set_ylabel ('Tokens per Second (t/s)' , fontsize = 12 , fontweight = 'bold' )
750+ ax .set_ylabel ('Tokens per second (t/s)' , fontsize = 12 , fontweight = 'bold' )
745751 ax .set_title (title , fontsize = 12 , fontweight = 'bold' )
746752 ax .legend (loc = 'best' , fontsize = 10 )
747753 ax .grid (True , alpha = 0.3 )
@@ -751,7 +757,7 @@ def make_axes(num_groups, max_cols=2, base_size=(8, 4)):
751757 for i in range (plot_idx , len (axes )):
752758 axes [i ].set_visible (False )
753759
754- fig .suptitle (f'Performance comparison: { compare_name } vs { baseline_name } ' ,
760+ fig .suptitle (f'Performance comparison: { compare_name } vs. { baseline_name } ' ,
755761 fontsize = 14 , fontweight = 'bold' )
756762 fig .subplots_adjust (top = 1 )
757763
0 commit comments