Skip to content

Commit 8776dd9

Browse files
Merge OpenAI Triton commit a0cc214 (#3910)
This PR change the Triton base from 413b521 to a0cc214 (Apr 9). Pass rate: 88.42%
2 parents 9facf00 + 221abbb commit 8776dd9

File tree

27 files changed

+4431
-55
lines changed

27 files changed

+4431
-55
lines changed

bench/bench/bench_mlp.py

Lines changed: 138 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,138 @@
1+
from pathlib import Path
2+
import json
3+
import triton.profiler as proton
4+
import torch
5+
import triton_bench.swiglu
6+
from triton_bench.mxfp import downcast_to_mxfp
7+
from triton_bench.matmul_ogs import MicroscalingCtx, matmul_ogs, PrecisionConfig, FlexCtx
8+
from triton_bench.numerics import InFlexData
9+
from triton_bench.routing import routing_torch, simulate_expert_sharded_routing
10+
from triton_bench.meta import cuda_capability_geq
11+
12+
if torch.cuda.is_available():
13+
from triton._C.libtriton import nvidia
14+
cublas_workspace = torch.empty(32 * 1024 * 1024, device="cuda", dtype=torch.uint8)
15+
cublas = nvidia.cublas.CublasLt(cublas_workspace)
16+
else:
17+
cublas = None
18+
19+
20+
def _query_gpu_specs():
21+
import subprocess
22+
cmd = ["nvidia-smi", "--query-gpu=name", "--format=csv,noheader", "-i=0"]
23+
output = subprocess.check_output(cmd, stderr=subprocess.DEVNULL).decode().strip()
24+
name = output.splitlines()[0]
25+
return {
26+
"NVIDIA H100 80GB HBM3": {"MAX_TFLOPS8": 1979, "MAX_TFLOPS16": 989, "MAX_TBPS": 3.35}, "HGX GB200":
27+
{"MAX_TFLOPS8": 4500, "MAX_TFLOPS16": 2250, "MAX_TBPS": 8.0}
28+
}[name]
29+
30+
31+
SPECS = _query_gpu_specs()
32+
33+
34+
def quantize(w, dtype, dev, **opt):
35+
if dtype == "bf16":
36+
return w.to(torch.bfloat16), InFlexData(), MicroscalingCtx()
37+
elif dtype == "fp8":
38+
wq = w.to(torch.float8_e4m3fn).transpose(-1, -2).contiguous().transpose(-1, -2)
39+
return wq, InFlexData(dtype=wq.dtype, scale=w.abs().max().unsqueeze(0)), \
40+
MicroscalingCtx()
41+
else:
42+
assert dtype == "mx4", f"{dtype=}"
43+
swizzle_mx_scale = opt["swizzle_mx_scale"]
44+
swizzle_axis = 2 if swizzle_mx_scale else None
45+
w = w.to(torch.bfloat16)
46+
w, mx_scales, weight_scale_shape = downcast_to_mxfp(w, torch.uint8, axis=1, swizzle_axis=swizzle_axis)
47+
return w, InFlexData(), MicroscalingCtx(weight_scale=mx_scales, swizzle_mx=swizzle_mx_scale,
48+
actual_weight_scale_shape=weight_scale_shape)
49+
50+
51+
def bench_mlp(batch, dim1, dim2, n_expts_tot, n_expts_act, x_dtype, w_dtype,
52+
# tensor / expert parallelism
53+
TP=1, EP=1, name=""):
54+
assert n_expts_tot % EP == 0
55+
assert dim2 % TP == 0
56+
dev = "cuda"
57+
# input
58+
# weights
59+
wg = torch.randn((dim1, n_expts_tot), device=dev)
60+
w1 = torch.randn((n_expts_tot // EP, dim1, dim2 // TP), device=dev)
61+
w2 = torch.randn((n_expts_tot // EP, dim2 // TP // 2, dim1), device=dev)
62+
# biases
63+
bg = torch.randn((n_expts_tot, ), device=dev)
64+
b1 = torch.randn((dim2 // TP, ), device=dev)
65+
b2 = torch.randn((dim1, ), device=dev)
66+
67+
# -- numerics --
68+
optg = dict()
69+
opt1 = {"swizzle_mx_scale": True} if w_dtype == "mx4" else dict()
70+
opt2 = {"swizzle_mx_scale": True} if w_dtype == "mx4" else dict()
71+
wg, wg_flex, wg_mx = quantize(wg, "bf16", dev, **optg)
72+
w1, w1_flex, w1_mx = quantize(w1, w_dtype, dev, **opt1)
73+
w2, w2_flex, w2_mx = quantize(w2, w_dtype, dev, **opt2)
74+
pcg = PrecisionConfig(mx_ctx=wg_mx, flex_ctx=FlexCtx(rhs_data=wg_flex))
75+
pcs = triton_bench.swiglu.PrecisionConfig(limit=1.0)
76+
pc1 = PrecisionConfig(mx_ctx=w1_mx, flex_ctx=FlexCtx(rhs_data=w1_flex))
77+
pc2 = PrecisionConfig(mx_ctx=w2_mx, flex_ctx=FlexCtx(rhs_data=w2_flex))
78+
79+
# -- benchmark --
80+
fpath = Path(f"logs/{name}/{batch}-{dim1}-{dim2}-{n_expts_tot}-{n_expts_act}-{x_dtype}-{w_dtype}.hatchet")
81+
fpath.parent.mkdir(parents=True, exist_ok=True)
82+
proton.start(str(fpath.with_suffix('')), hook="triton")
83+
proton.deactivate()
84+
# run layer
85+
x_dtype = {"bf16": torch.bfloat16, "fp8": torch.float8_e4m3fn}[x_dtype]
86+
for i in range(100):
87+
x = torch.randn((batch, dim1), device=dev)
88+
x = x.to(wg.dtype if n_expts_tot > 1 else x_dtype)
89+
# TODO: activate proton here when fast routing is done
90+
if n_expts_tot > 1:
91+
logits = matmul_ogs(x, wg, bg, precision_config=pcg)
92+
rdata, gather_indx, scatter_indx = routing_torch(logits, n_expts_act)
93+
if EP > 1:
94+
m = logits.shape[0] * EP
95+
_, rdata, gather_indx, scatter_indx = simulate_expert_sharded_routing(m, rdata, EP, device=dev)
96+
x = x.to(x_dtype)
97+
else:
98+
rdata, gather_indx, scatter_indx = None, None, None
99+
proton.activate()
100+
# c0 = torch.empty((x.shape[0], w1.shape[-1]), device=dev, dtype=x.dtype)
101+
# c1 = torch.empty((x.shape[0], w2.shape[-1]), device=dev, dtype=x.dtype)
102+
# cublas.matmul(x, w1.squeeze(0), c0)
103+
# cublas.matmul(c0, w2.squeeze(0), c1)
104+
x = matmul_ogs(x, w1, b1, rdata, gather_indx=gather_indx, precision_config=pc1)
105+
x = triton_bench.swiglu.swiglu(x, 1.0, pcs)
106+
x = matmul_ogs(x, w2, b2, rdata, scatter_indx=scatter_indx, precision_config=pc2)
107+
proton.deactivate()
108+
proton.finalize()
109+
110+
# -- analyze --
111+
with open(f"{fpath}") as fd:
112+
data = json.load(fd)
113+
# TODO: this will be broken if kernels use scopes themselves
114+
# compute useful (a.k.a. matmul) bytes and flops
115+
matmuls = [x for x in data[0]["children"] if "matmul" in x["frame"]["name"]]
116+
tot_bytes = sum([x["metrics"]["bytes"] for x in matmuls])
117+
tot_flops = {w: sum([x["metrics"].get(f"flops{w}", 0) for x in matmuls]) for w in [8, 16]}
118+
# compute total time (incl. "not useful" work)
119+
# TODO: proton should really be recording that in the json instead of
120+
# relying on the user to aggregate
121+
tot_time = sum(x["metrics"].get("time (ns)", 0) for x in data[0]["children"])
122+
min_time_flops = sum([tot_flops[w] / SPECS[f"MAX_TFLOPS{w}"] for w in [8, 16]]) * 1e-3
123+
min_time_bytes = tot_bytes / SPECS["MAX_TBPS"] * 1e-3
124+
min_time = max(min_time_flops, min_time_bytes)
125+
util = min_time / tot_time
126+
tflops = sum([tot_flops[w] for w in [8, 16]]) / tot_time * 1e-3
127+
tbps = tot_bytes / tot_time * 1e-3
128+
129+
return util, tflops, tbps
130+
131+
132+
if __name__ == "__main__":
133+
has_native_mx4 = torch.cuda.get_device_capability(0)[0] >= 10
134+
qxdtype = "fp8" if has_native_mx4 else "bf16"
135+
print(bench_mlp(8192, 8192, 8192, 1, 1, "fp8", "fp8", TP=1, EP=1, name="dense"))
136+
print(bench_mlp(8192, 8192, 8192, 1, 1, qxdtype, "mx4", TP=1, EP=1, name="dense"))
137+
print(bench_mlp(1024, 5120, 8192, 128, 4, "fp8", "fp8", TP=4, EP=2, name="llama4"))
138+
print(bench_mlp(1024, 5120, 8192, 128, 4, qxdtype, "mx4", TP=4, EP=2, name="llama4"))

bench/pyproject.toml

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
[project]
2+
name = "triton_bench"
3+
version = "1.0.0"
4+
dependencies = ["torch", "numpy", "pytest"]
5+
6+
[build-system]
7+
requires = ["setuptools>=64.0"]
8+
build-backend = "setuptools.build_meta"
9+
10+
[tool.setuptools.packages.find]
11+
include = ["triton_bench*"]

bench/tests/__init__.py

Whitespace-only changes.

0 commit comments

Comments
 (0)