Skip to content

Commit 618ec40

Browse files
authored
[Tutorials] Add units to result tables (#8631)
Closes #8588 Example: ``` vector-add-performance: size Triton (GB/s) Torch (GB/s) 0 4096.0 8.827586 8.777143 1 8192.0 16.695652 17.454545 2 16384.0 34.711865 34.516854 ```
1 parent faeb1eb commit 618ec40

File tree

1 file changed

+8
-8
lines changed

1 file changed

+8
-8
lines changed

python/triton/testing.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -320,11 +320,11 @@ def _run(self, bench: Benchmark, save_path: str, show_plots: bool, print_data: b
320320

321321
import matplotlib.pyplot as plt
322322
import pandas as pd
323-
y_mean = bench.line_names
324-
y_min = [f'{x}-min' for x in bench.line_names]
325-
y_max = [f'{x}-max' for x in bench.line_names]
323+
y_mean_labels = [f'{x} ({bench.ylabel})' for x in bench.line_names]
324+
y_min_labels = [f'{x}-min ({bench.ylabel})' for x in bench.line_names]
325+
y_max_labels = [f'{x}-max ({bench.ylabel})' for x in bench.line_names]
326326
x_names = list(bench.x_names)
327-
df = pd.DataFrame(columns=x_names + y_mean + y_min + y_max)
327+
df = pd.DataFrame(columns=x_names + y_mean_labels + y_min_labels + y_max_labels)
328328
for x in bench.x_vals:
329329
# x can be a single value or a sequence of values.
330330
if not isinstance(x, (list, tuple)):
@@ -351,11 +351,11 @@ def _run(self, bench: Benchmark, save_path: str, show_plots: bool, print_data: b
351351
ax = plt.subplot()
352352
# Plot first x value on x axis if there are multiple.
353353
first_x = x_names[0]
354-
for i, y in enumerate(bench.line_names):
355-
y_min, y_max = df[y + '-min'], df[y + '-max']
354+
for i, (mean_label, min_label, max_label) in enumerate(zip(y_mean_labels, y_min_labels, y_max_labels)):
355+
y_min, y_max = df[min_label], df[max_label]
356356
col = bench.styles[i][0] if bench.styles else None
357357
sty = bench.styles[i][1] if bench.styles else None
358-
ax.plot(df[first_x], df[y], label=y, color=col, ls=sty)
358+
ax.plot(df[first_x], df[mean_label], label=mean_label, color=col, ls=sty)
359359
if not y_min.isnull().all() and not y_max.isnull().all():
360360
y_min = y_min.astype(float)
361361
y_max = y_max.astype(float)
@@ -370,7 +370,7 @@ def _run(self, bench: Benchmark, save_path: str, show_plots: bool, print_data: b
370370
plt.show()
371371
if save_path:
372372
plt.savefig(os.path.join(save_path, f"{bench.plot_name}.png"))
373-
df = df[x_names + bench.line_names]
373+
df = df[x_names + y_mean_labels]
374374
if diff_col and df.shape[1] == 2:
375375
col0, col1 = df.columns.tolist()
376376
df['Diff'] = df[col1] - df[col0]

0 commit comments

Comments
 (0)