Skip to content

Commit 48f72df

Browse files
ptilletmeta-codesync[bot]
authored andcommitted
[Cherry-pick] [triton_kernels] minor rename (#8044) (#621)
Summary: Cherry-picked from upstream OAI repository. Original Commit: d4399a1 Original Author: Philippe Tillet Original Date: 2025-09-02 21:43:18 -0700 Original commit message: ``` [triton_kernels] minor rename (#8044) ``` This PR was automatically cherry-picked from the upstream triton-lang/triton repository. Pull Request resolved: #621 Reviewed By: dshi7, minjang Differential Revision: D86003266 Pulled By: agron911 fbshipit-source-id: 86d52ffba9429346220c4305dc8976ff51276e96
1 parent 98bdb85 commit 48f72df

File tree

2 files changed

+4
-4
lines changed

2 files changed

+4
-4
lines changed

python/triton_kernels/bench/bench_mlp.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,14 +5,14 @@
55
import torch
66
import argparse
77
import triton_kernels
8+
import triton_kernels.roofline as roofline
89
import triton_kernels.swiglu
910
from triton_kernels.matmul_ogs import matmul_ogs, PrecisionConfig, FlexCtx, FnSpecs, FusedActivation
1011
from triton_kernels.target_info import get_cdna_version
1112
import distributed as triton_dist
1213
from triton_kernels.tensor_details import layout
1314
from bench_utils import quantize_weight
1415
import tempfile
15-
import roofline
1616

1717

1818
def bench_mlp(batch_per_expt, dim1, dim2, n_expts_tot, n_expts_act, x_dtype, w_dtype, TP, EP):

python/triton_kernels/triton_kernels/roofline.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -191,7 +191,7 @@ def plot_roofline(series, flops_dtype, out_path, max_tbps="memset", max_tflops="
191191
comp_x = [x_knee] + xs[knee_idx:]
192192
comp_y = [max_tflops] * (1 + (n - knee_idx))
193193

194-
y_roof_sampled = [min(op * max_tbps, max_tflops) for op in opints]
194+
y_roof = [min(op * max_tbps, max_tflops) for op in opints]
195195

196196
# --- helpers ---
197197
def interp(yxs, yys, x):
@@ -233,13 +233,13 @@ def interp(yxs, yys, x):
233233
xmin, xmax = xs[0], xs[-1]
234234
dx = 0.05 * (xmax - xmin) if xmax > xmin else 1.0
235235
ax.set_xlim(xmin - dx, xmax + dx)
236-
ax.set_ylim(min(y_roof_sampled) * 0.8 if y_roof_sampled else 0.0, max_tflops * 1.05)
236+
ax.set_ylim(min(y_roof) * 0.8 if y_roof else 0.0, max_tflops * 1.05)
237237

238238
# Points of interest
239239
if points_of_interest:
240240
for x_pt, label in points_of_interest.items():
241241
y_pt = interp(xs, series_perf[0], x_pt)
242-
y_rf = interp(xs, y_roof_sampled, x_pt)
242+
y_rf = interp(xs, y_roof, x_pt)
243243
ax.plot([x_pt], [y_pt], marker="o", ms=4, mfc="white", mec="black", zorder=3)
244244
ax.annotate(f"{label}\n{int(y_pt)} TFLOP/s ({int(y_pt/y_rf*100)}%)", xy=(x_pt, y_pt), xytext=(5, -25),
245245
textcoords="offset points", fontsize=7, ha="left", va="bottom")

0 commit comments

Comments
 (0)