Skip to content

Commit e790497

Browse files
committed
Add back default test_name, add --plot_log_scale
1 parent 8228393 commit e790497

File tree

1 file changed

+5
-4
lines changed

1 file changed

+5
-4
lines changed

scripts/compare-llama-bench.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -125,6 +125,7 @@
125125
parser.add_argument("--verbose", action="store_true", help="increase output verbosity")
126126
parser.add_argument("--plot", help="generate a performance comparison plot and save to specified file (e.g., plot.png)")
127127
parser.add_argument("--plot_x", help="parameter to use as x axis for plotting (default: n_depth)", default="n_depth")
128+
parser.add_argument("--plot_log_scale", action="store_true", help="use log scale for x axis in plots (off by default)")
128129

129130
known_args, unknown_args = parser.parse_known_args()
130131

@@ -612,7 +613,7 @@ def valid_format(data_files: list[str]) -> bool:
612613
headers += ["Test", f"t/s {name_baseline}", f"t/s {name_compare}", "Speedup"]
613614

614615
if known_args.plot:
615-
def create_performance_plot(table_data: list[list[str]], headers: list[str], baseline_name: str, compare_name: str, output_file: str, plot_x_param: str):
616+
def create_performance_plot(table_data: list[list[str]], headers: list[str], baseline_name: str, compare_name: str, output_file: str, plot_x_param: str, log_scale: bool = False):
616617
try:
617618
import matplotlib.pyplot as plt
618619
import matplotlib
@@ -661,7 +662,7 @@ def create_performance_plot(table_data: list[list[str]], headers: list[str], bas
661662
base_test = test_name.split("@d")[0]
662663
x_value = int(test_name.split("@d")[1])
663664
else:
664-
assert False
665+
base_test = test_name
665666

666667
if base_test.strip():
667668
group_key_parts.append(f"Test={base_test}")
@@ -731,7 +732,7 @@ def make_axes(num_groups, max_cols=2, base_size=(8, 4)):
731732
ax.plot(x_values, compare_vals, 's--', color='lightcoral', alpha=0.8,
732733
label=f'{compare_name}', linewidth=2, markersize=6)
733734

734-
if plot_x_param == "n_depth" and min(x_values) > 0 and max(x_values) > min(x_values) * 4:
735+
if log_scale and min(x_values) > 0:
735736
ax.set_xscale('log', base=2)
736737
unique_x = sorted(set(x_values))
737738
ax.set_xticks(unique_x)
@@ -764,7 +765,7 @@ def make_axes(num_groups, max_cols=2, base_size=(8, 4)):
764765
plt.savefig(output_file, dpi=300, bbox_inches='tight')
765766
plt.close()
766767

767-
create_performance_plot(table, headers, name_baseline, name_compare, known_args.plot, known_args.plot_x)
768+
create_performance_plot(table, headers, name_baseline, name_compare, known_args.plot, known_args.plot_x, known_args.plot_log_scale)
768769

769770
print(tabulate( # noqa: NP100
770771
table,

0 commit comments

Comments
 (0)