Skip to content

Commit fc024d8

Browse files
Merge OpenAI Triton commit f05cdc4 (#4134)
This PR change the Triton base from 553d01d to f05cdc4 (May 5). Pass rate: 94.57%
2 parents 5271aa4 + 71c5fd4 commit fc024d8

File tree

16 files changed

+142
-78
lines changed

16 files changed

+142
-78
lines changed

.github/workflows/documentation.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ permissions: read-all
88

99
jobs:
1010
Build-Documentation:
11-
runs-on: [a100-runner-set]
11+
runs-on: [nvidia-a100]
1212
timeout-minutes: 30
1313

1414
steps:

bench/bench/bench_mlp.py

Lines changed: 94 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from pathlib import Path
2+
import matplotlib.pyplot as plt
23
import json
34
import triton.profiler as proton
45
import torch
@@ -8,6 +9,7 @@
89
from triton_bench.numerics import InFlexData
910
from triton_bench.routing import routing
1011
from triton_bench.target_info import is_hip, get_cdna_version
12+
from dataclasses import dataclass
1113

1214
if torch.cuda.is_available() and not is_hip():
1315
from triton._C.libtriton import nvidia
@@ -66,9 +68,38 @@ def quantize(w, dtype, dev, **opt):
6668
actual_weight_scale_shape=weight_scale_shape)
6769

6870

69-
def bench_mlp(batch, dim1, dim2, n_expts_tot, n_expts_act, x_dtype, w_dtype,
70-
# tensor / expert parallelism
71-
TP=1, EP=1, name=""):
71+
@dataclass
72+
class PerfData:
73+
time: float
74+
flops: float
75+
bytes: float
76+
77+
@property
78+
def tflops(self):
79+
return self.flops / self.time * 1e-3
80+
81+
@property
82+
def tbps(self):
83+
return self.bytes / self.time * 1e-3
84+
85+
@property
86+
def opint(self):
87+
# operational intensity
88+
assert self.bytes > 0
89+
return self.flops / self.bytes
90+
91+
@property
92+
def util(self) -> float:
93+
if SPECS is None:
94+
return 0.0
95+
96+
peak_flops = max(SPECS["MAX_TFLOPS8"], SPECS.get("MAX_TFLOPS16", 0))
97+
min_t_flop = self.flops / peak_flops * 1e-3 # ns → µs
98+
min_t_bw = self.bytes / SPECS["MAX_TBPS"] * 1e-3
99+
return max(min_t_flop, min_t_bw) / self.time
100+
101+
102+
def bench_mlp(batch, dim1, dim2, n_expts_tot, n_expts_act, x_dtype, w_dtype, TP, EP, name):
72103
assert n_expts_tot % EP == 0
73104
assert dim2 % TP == 0
74105
dev = "cuda"
@@ -96,7 +127,7 @@ def bench_mlp(batch, dim1, dim2, n_expts_tot, n_expts_act, x_dtype, w_dtype,
96127
pc2 = PrecisionConfig(mx_ctx=w2_mx, flex_ctx=FlexCtx(rhs_data=w2_flex))
97128

98129
# -- benchmark --
99-
fpath = Path(f"logs/{name}/{batch}-{dim1}-{dim2}-{n_expts_tot}-{n_expts_act}-{x_dtype}-{w_dtype}.hatchet")
130+
fpath = Path(f"logs/{name}/{x_dtype}-{w_dtype}-TP{TP}-EP{EP}/profiles/batch-{batch}.hatchet")
100131
fpath.parent.mkdir(parents=True, exist_ok=True)
101132
x_dtype = {"fp16": torch.float16, "bf16": torch.bfloat16, "fp8": torch.float8_e4m3fn}[x_dtype]
102133
# special treatment of fp8_e4m3 on AMD CDNA3 because it uses fp8_e4m3fnuz
@@ -115,7 +146,7 @@ def bench_mlp(batch, dim1, dim2, n_expts_tot, n_expts_act, x_dtype, w_dtype,
115146
else:
116147
rdata, gather_indx, scatter_indx = None, None, None
117148
x = matmul_ogs(x, w1, b1, rdata, gather_indx=gather_indx, precision_config=pc1)
118-
x = triton_bench.swiglu.swiglu(x, 1.0, pcs)
149+
x = triton_bench.swiglu.swiglu(x, 1.0, pcs, routing_data=rdata)
119150
x = matmul_ogs(x, w2, b2, rdata, scatter_indx=scatter_indx, precision_config=pc2)
120151
proton.finalize()
121152

@@ -127,42 +158,70 @@ def bench_mlp(batch, dim1, dim2, n_expts_tot, n_expts_act, x_dtype, w_dtype,
127158
matmuls = [
128159
x for x in data[0]["children"] if "_matmul" in x["frame"]["name"] and "metadata" not in x["frame"]["name"]
129160
]
130-
tot_bytes = sum([x["metrics"]["bytes"] for x in matmuls])
131-
tot_flops = {w: sum([x["metrics"].get(f"flops{w}", 0) for x in matmuls]) for w in [8, 16]}
161+
bytes = sum([x["metrics"]["bytes"] for x in matmuls])
162+
flops = {w: sum([x["metrics"].get(f"flops{w}", 0) for x in matmuls]) for w in [8, 16]}
163+
flops = sum([flops[w] for w in [8, 16]])
132164
# compute total time (incl. "not useful" work)
133165
# TODO: proton should really be recording that in the json instead of
134166
# relying on the user to aggregate
135-
tot_time = sum(x["metrics"].get("time (ns)", 0) for x in data[0]["children"])
136-
min_time_flops = min_time_bytes = 0
137-
if SPECS is not None:
138-
min_time_flops = sum([tot_flops[w] / SPECS[f"MAX_TFLOPS{w}"] for w in [8, 16]]) * 1e-3
139-
min_time_bytes = tot_bytes / SPECS["MAX_TBPS"] * 1e-3
140-
min_time = max(min_time_flops, min_time_bytes)
141-
util = min_time / tot_time
142-
else:
143-
util = 0.0
144-
tflops = sum([tot_flops[w] for w in [8, 16]]) / tot_time * 1e-3
145-
tbps = tot_bytes / tot_time * 1e-3
146-
print(f"Utilization: {util:.0%}; {tflops:>6.1f} TFLOPs, {tbps:.1f} TB/s")
147-
148-
return util, tflops, tbps
167+
time = sum(x["metrics"].get("time (ns)", 0) for x in data[0]["children"])
168+
return PerfData(time, flops, bytes)
169+
170+
171+
def roofline_mlp(batch_ranges, dim1, dim2, n_expts_tot, n_expts_act, x_dtype, w_dtype, TP=1, EP=1, name="",
172+
verbose=True):
173+
import numpy as np
174+
from itertools import chain
175+
from bisect import bisect_left
176+
batches = list(chain(*[range(*r) for r in batch_ranges]))
177+
# collect performance data
178+
perfs = []
179+
print(f"Benchmarking {name} ({x_dtype}x{w_dtype}, TP={TP}, EP={EP})...")
180+
print("===============================================================")
181+
for batch in batches:
182+
perfs += [bench_mlp(batch, dim1, dim2, n_expts_tot, n_expts_act, x_dtype, w_dtype, TP, EP, name)]
183+
if verbose:
184+
print(f"Batch: {batch}; Util: {perfs[-1].util}; TFLOPS: {perfs[-1].tflops}; TBPS: {perfs[-1].tbps}")
185+
print("===============================================================")
186+
# machine limits
187+
fig, ax = plt.subplots(figsize=(7, 5), dpi=120)
188+
ax.set_xlabel("batch size (toks/expt)")
189+
ax.set_ylabel("performance [TFLOP/s]")
190+
ax.set_title("roofline")
191+
# add a tiny margin so points are not flush with the frame
192+
xs = [batch * n_expts_act / n_expts_tot for batch in batches]
193+
perf = [p.tflops for p in perfs]
194+
xmin, xmax = min(xs), max(xs)
195+
dx = 0.05 * (xmax - xmin) if xmax > xmin else 1.0
196+
ax.set_xlim(xmin - dx, xmax + dx)
197+
ax.set_ylim(100, SPECS["MAX_TFLOPS8"] + 500)
198+
# plot roofline
199+
max_tbps = SPECS["MAX_TBPS"]
200+
max_tflops = SPECS["MAX_TFLOPS8"]
201+
opints = [p.opint for p in perfs]
202+
knee = bisect_left(opints, max_tflops / max_tbps) - 1
203+
x_bw, x_comp = xs[:knee], xs[knee:]
204+
y_bw = [op * max_tbps for op in opints[:knee]]
205+
y_comp = [max_tflops] * len(x_comp)
206+
ax.plot(x_bw, y_bw, "--", label=f"BW-bound ({max_tbps:.0f} TB/s)")
207+
ax.plot(x_comp, y_comp, "--", label=f"Compute-bound ({max_tflops:.0f} TFLOP/s)")
208+
# plot data
209+
ax.scatter(xs, perf, marker="+")
210+
ax.legend(frameon=False, loc="lower right")
211+
ax.grid(True, which="both", ls=":", lw=0.5)
212+
fig.tight_layout()
213+
fpath = Path(f"logs/{name}/{x_dtype}-{w_dtype}-TP{TP}-EP{EP}/roofline.png")
214+
plt.savefig(fpath)
149215

150216

151217
if __name__ == "__main__":
152218
has_native_mx4 = torch.cuda.get_device_capability(0)[0] >= 10 or get_cdna_version() == 4
153219
if SPECS is None:
154220
print("Current GPU has no specs provided, utilization is N/A")
155-
if has_native_mx4:
156-
bench_mlp(8192, 8192, 8192, 1, 1, "fp8", "fp8", TP=1, EP=1, name="dense")
157-
bench_mlp(8192, 8192, 8192, 1, 1, "fp8", "mx4", TP=1, EP=1, name="dense")
158-
bench_mlp(2048, 5120, 8192, 128, 4, "fp8", "fp8", TP=4, EP=1, name="llama4")
159-
bench_mlp(2048, 5120, 8192, 128, 4, "fp8", "mx4", TP=4, EP=1, name="llama4")
160-
else:
161-
# bf16/fp16 x fp8 is skipped because matmul_ogs requires x and w has the
162-
# same type when not doing mxfp operation
163-
bench_mlp(8192, 8192, 8192, 1, 1, "fp8", "fp8", TP=1, EP=1, name="dense")
164-
bench_mlp(8192, 8192, 8192, 1, 1, "fp16", "mx4", TP=1, EP=1, name="dense")
165-
bench_mlp(8192, 8192, 8192, 1, 1, "bf16", "mx4", TP=1, EP=1, name="dense")
166-
bench_mlp(2048, 5120, 8192, 128, 4, "fp8", "fp8", TP=4, EP=1, name="llama4")
167-
bench_mlp(2048, 5120, 8192, 128, 4, "bf16", "mx4", TP=4, EP=1, name="llama4")
168-
bench_mlp(2048, 5120, 8192, 128, 4, "fp16", "mx4", TP=4, EP=1, name="llama4")
221+
batch_ranges = [(1024, 32768, 1024)]
222+
dense_dtypes = ["fp8", "fp8"]
223+
quantized_dtypes = ["fp8", "mx4"] if has_native_mx4 else ["bf16", "mx4"]
224+
roofline_mlp(batch_ranges, 8192, 8192, 1, 1, *dense_dtypes, TP=1, EP=1, name="dense")
225+
roofline_mlp(batch_ranges, 8192, 8192, 1, 1, *quantized_dtypes, TP=1, EP=1, name="dense")
226+
roofline_mlp(batch_ranges, 5120, 8192, 128, 4, *dense_dtypes, TP=1, EP=1, name="llama4-maverick")
227+
roofline_mlp(batch_ranges, 5120, 8192, 128, 4, *quantized_dtypes, TP=1, EP=1, name="llama4-maverick")

bench/triton_bench/matmul_ogs_details/_common.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -87,9 +87,10 @@ def matmul_launch_metadata(grid, kernel, args):
8787
fM = M if M is not None else n_tokens
8888
fK = K if K is not None else n_tokens
8989
ret[f"flops{nbits}"] = 2.0 * fM * N * fK
90-
skipped = 0
90+
gindx = args.get("GatherIndx", None)
9191
sindx = args.get("WriteBackIndx", None)
92-
if sindx is not None:
93-
skipped = (sindx == -1).sum() / sindx.numel()
94-
ret["bytes"] = int((1 - skipped) * Y.numel() * Y.element_size() + X.numel() * X.element_size() + n_w_bytes)
92+
sskipped = 0. if sindx is None else (sindx == -1).sum() / sindx.shape[0]
93+
gskipped = 0. if gindx is None else (gindx == -1).sum() / gindx.shape[0]
94+
ret["bytes"] = int((1 - sskipped) * Y.numel() * Y.element_size() + (1 - gskipped) * X.numel() * X.element_size() +
95+
n_w_bytes)
9596
return ret

bench/triton_bench/matmul_ogs_details/_ptma_matmul_ogs.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -58,17 +58,13 @@ def _make_tensor_desc(ptr, shape, strides, block_shape, transpose: tl.constexpr
5858
tl.static_assert(len(shape) == len(strides))
5959
tl.static_assert(len(strides) == len(block_shape))
6060
if transpose:
61-
# Pass constexpr(1) to workaround torchflow tracer changing values of 1 to 2 during compile.
62-
# We check that the stride is actually 1 before launching the kernel.
6361
return tl.make_tensor_descriptor(
6462
ptr,
6563
shape=shape[:-2] + [shape[-1], shape[-2]],
6664
strides=strides[:-2] + [strides[-1], tl.constexpr(1)],
6765
block_shape=block_shape[:-2] + [block_shape[-1], block_shape[-2]],
6866
)
6967
else:
70-
# Pass constexpr(1) to workaround torchflow tracer changing values of 1 to 2 during compile.
71-
# We check that the stride is actually 1 before launching the kernel.
7268
return tl.make_tensor_descriptor(
7369
ptr,
7470
shape=shape,

bench/triton_bench/matmul_ogs_details/opt_flags.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,6 @@ def make_default_opt_flags_amd(
5959
block_m = 128
6060
elif tokens_per_expt >= 512 and n >= 2048:
6161
block_m = 128
62-
6362
else:
6463
block_m = max(32, min(triton.next_power_of_2(tokens_per_expt), 64))
6564
if routing_data is not None:
@@ -139,7 +138,7 @@ def make_default_opt_flags_nvidia(
139138
elif enforce_bitwise_invariance:
140139
block_m = 128
141140
else:
142-
block_m = max(16, min(triton.next_power_of_2(tokens_per_expt), 128))
141+
block_m = max(64, min(triton.next_power_of_2(tokens_per_expt), 128))
143142
# TODO: remove when triton is more optimized for H100 MXFP4
144143
arch = None
145144
if (

bench/triton_bench/swiglu.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from triton.tools.tensor_descriptor import TensorDescriptor
66
from .swiglu_details._swiglu import _swiglu
77
from triton_bench import target_info
8+
from .matmul_ogs_details.metadata import compute_metadata
89

910

1011
@dataclass(frozen=True)
@@ -23,7 +24,7 @@ class PrecisionConfig:
2324
class SwiGLU(torch.autograd.Function):
2425

2526
@staticmethod
26-
def forward(ctx, a, alpha, precision_config, expt_data, num_experts):
27+
def forward(ctx, a, alpha, precision_config, routing_data, num_experts):
2728
N = a.shape[-1]
2829
M = a.numel() // N
2930
assert a.stride()[-1] == 1
@@ -48,7 +49,7 @@ def forward(ctx, a, alpha, precision_config, expt_data, num_experts):
4849
# launch semi-persistent kernel
4950
N_BLOCKS = triton.cdiv(N // 2, BLOCK_N)
5051
num_sms = target_info.num_sms()
51-
if expt_data is not None:
52+
if routing_data is not None:
5253
waves_per_sm = 32 if target_info.is_hip() else 128
5354
num_pid = num_sms * (waves_per_sm // num_warps)
5455
M_BLOCKS = max(1, triton.cdiv(num_pid, N_BLOCKS))
@@ -59,6 +60,9 @@ def forward(ctx, a, alpha, precision_config, expt_data, num_experts):
5960
grid = (8 * num_sms, )
6061
else:
6162
grid = (min(M_BLOCKS * N_BLOCKS, 4 * num_sms), )
63+
expt_data = None
64+
if routing_data is not None:
65+
expt_data = compute_metadata(routing_data, M, BLOCK_M).buffer
6266
_swiglu[grid](
6367
out_desc,
6468
flex_ctx.out_data.reinterpret(out),
@@ -91,8 +95,8 @@ def forward(ctx, a, alpha, precision_config, expt_data, num_experts):
9195
return out
9296

9397

94-
def swiglu(a, alpha, precision_config, expt_data=None, num_experts=0):
95-
return SwiGLU.apply(a, alpha, precision_config, expt_data, num_experts)
98+
def swiglu(a, alpha, precision_config, routing_data=None, num_experts=0):
99+
return SwiGLU.apply(a, alpha, precision_config, routing_data, num_experts)
96100

97101

98102
def swiglu_torch(a, alpha, precision_config):

docs/conf.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ def get_cmake_dir():
4343
plat_name = sysconfig.get_platform()
4444
python_version = sysconfig.get_python_version()
4545
dir_name = f"cmake.{plat_name}-{sys.implementation.name}-{python_version}"
46-
cmake_dir = Path("../python") / "build" / dir_name
46+
cmake_dir = Path("../build") / dir_name
4747
return cmake_dir
4848

4949

@@ -100,7 +100,7 @@ def setup(app):
100100
app.connect("autodoc-process-signature", process_sig)
101101
max_jobs = os.getenv("MAX_JOBS", str(2 * os.cpu_count()))
102102
print(f"Installing Triton Python package using {max_jobs} threads")
103-
subprocess.run("pip install -e ..", shell=True, env=os.environ.copy())
103+
subprocess.run("pip install -e ../", shell=True, env=os.environ.copy())
104104

105105
setup_generated_mlir_docs()
106106

include/triton/Dialect/Triton/IR/Utility.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -173,6 +173,10 @@ template <typename T> auto seq(T start, T end, T step) {
173173
[=](T i) { return start + i * step; });
174174
}
175175

176+
// Combine the current mask with the given predicate.
177+
Value getPredMask(RewriterBase &rewriter, Type typeLike, Value currentMask,
178+
Value pred);
179+
176180
} // namespace triton
177181
} // namespace mlir
178182

include/triton/Dialect/Triton/Transforms/Utility.h

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,6 @@ using namespace mlir;
77

88
namespace mlir::triton {
99

10-
Value getPredMask(RewriterBase &rewriter, Type typeLike, Value currentMask,
11-
Value pred);
12-
1310
triton::MakeTensorPtrOp getMakeTensorPtrOp(Value v);
1411

1512
} // namespace mlir::triton

lib/Dialect/Triton/IR/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ add_triton_library(TritonIR
88
Traits.cpp
99
Types.cpp
1010
OpInterfaces.cpp
11+
Utility.cpp
1112

1213
DEPENDS
1314
TritonTableGen

0 commit comments

Comments
 (0)