Skip to content

Commit c6dbfba

Browse files
authored
[KERNELS] Fix roofline plots for kernels that are only compute-bound (#7670)
Previous result <img width="840" height="600" alt="roofline" src="https://github.com/user-attachments/assets/73a532f3-7af2-48fb-83e1-78618ca0538e" /> After fix <img width="840" height="600" alt="roofline" src="https://github.com/user-attachments/assets/dcb84a2a-4723-4880-be10-a4d83a2e7587" />
1 parent 9e3886f commit c6dbfba

File tree

1 file changed

+10
-1
lines changed

1 file changed

+10
-1
lines changed

python/triton_kernels/bench/bench_mlp.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -188,7 +188,16 @@ def roofline_mlp(batch_ranges, dim1, dim2, n_expts_tot, n_expts_act, x_dtype, w_
188188
ax.set_ylim(100, max_tflops + 500)
189189
# plot roofline
190190
opints = [p.opint for p in perfs]
191-
knee = bisect_left(opints, max_tflops / max_tbps) - 1
191+
knee = bisect_left(opints, max_tflops / max_tbps)
192+
if knee > 0: # has a bandwidth-bound knee
193+
x_bw = [xs[0], xs[knee - 1]]
194+
y_bw = [opints[0] * max_tbps, max_tflops]
195+
else: # no knee found, compute-bound only
196+
x_bw = y_bw = []
197+
x_comp = xs[knee:]
198+
y_comp = [max_tflops] * len(x_comp)
199+
ax.plot(x_bw, y_bw, "--", label=f"BW-bound ({max_tbps:.1f} TB/s)", color="blue")
200+
ax.plot(x_comp, y_comp, "--", label=f"Compute-bound ({max_tflops:.0f} TFLOP/s)", color="orange")
192201
x_bw, x_comp = xs[:knee], xs[knee:]
193202
x_bw = [x_bw[0], x_comp[0]]
194203
y_bw = [opints[0] * max_tbps, max_tflops]

0 commit comments

Comments
 (0)