Skip to content

Commit e64cda2

Browse files
authored
[Bench] Make roofline graph lines and title clearer (#6852)
Use the first compute bound data point as the end point for the bandwidth-bound line so that it can connect with the compute-bound line. This is visually better espically if we have sparse data points. Along the way also make the title clear regarding the benchmarked case.
1 parent e6b9efd commit e64cda2

File tree

1 file changed

+5
-3
lines changed

1 file changed

+5
-3
lines changed

python/triton_kernels/bench/bench_mlp.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -175,7 +175,8 @@ def roofline_mlp(batch_ranges, dim1, dim2, n_expts_tot, n_expts_act, x_dtype, w_
175175
batches = list(chain(*[range(*r) for r in batch_ranges]))
176176
# collect performance data
177177
perfs = []
178-
print(f"Benchmarking {name} ({x_dtype}x{w_dtype}, TP={TP}, EP={EP})...")
178+
bench_case = f"{name} ({x_dtype}x{w_dtype}, TP={TP}, EP={EP})"
179+
print(f"Benchmarking {bench_case}...")
179180
print("===============================================================")
180181
for batch in batches:
181182
perfs += [bench_mlp(batch, dim1, dim2, n_expts_tot, n_expts_act, x_dtype, w_dtype, TP, EP, name)]
@@ -186,7 +187,7 @@ def roofline_mlp(batch_ranges, dim1, dim2, n_expts_tot, n_expts_act, x_dtype, w_
186187
fig, ax = plt.subplots(figsize=(7, 5), dpi=120)
187188
ax.set_xlabel("batch size (toks/expt)")
188189
ax.set_ylabel("performance [TFLOP/s]")
189-
ax.set_title("roofline")
190+
ax.set_title(f"{bench_case} roofline")
190191
# add a tiny margin so points are not flush with the frame
191192
xs = [batch * n_expts_act / n_expts_tot for batch in batches]
192193
perf = [p.tflops for p in perfs]
@@ -200,7 +201,8 @@ def roofline_mlp(batch_ranges, dim1, dim2, n_expts_tot, n_expts_act, x_dtype, w_
200201
opints = [p.opint for p in perfs]
201202
knee = bisect_left(opints, max_tflops / max_tbps) - 1
202203
x_bw, x_comp = xs[:knee], xs[knee:]
203-
y_bw = [op * max_tbps for op in opints[:knee]]
204+
x_bw = [x_bw[0], x_comp[0]]
205+
y_bw = [opints[0] * max_tbps, max_tflops]
204206
y_comp = [max_tflops] * len(x_comp)
205207
ax.plot(x_bw, y_bw, "--", label=f"BW-bound ({max_tbps:.0f} TB/s)")
206208
ax.plot(x_comp, y_comp, "--", label=f"Compute-bound ({max_tflops:.0f} TFLOP/s)")

0 commit comments

Comments
 (0)