Skip to content

Commit 1294776

Browse files
authored
[BENCH] move code around; added tests; improved expert-parallelism simulation (triton-lang#6538)
1 parent d066e15 commit 1294776

File tree

18 files changed

+528
-406
lines changed

18 files changed

+528
-406
lines changed

bench/bench/bench_mlp.py

Lines changed: 16 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,12 @@
11
from pathlib import Path
22
import json
3-
import triton
43
import triton.profiler as proton
54
import torch
65
import triton_bench.swiglu
76
from triton_bench.mxfp import downcast_to_mxfp
87
from triton_bench.matmul_ogs import MicroscalingCtx, matmul_ogs, PrecisionConfig, FlexCtx
98
from triton_bench.numerics import InFlexData
10-
from triton_bench.routing import routing, simulate_expert_sharded_routing
9+
from triton_bench.routing import routing
1110
from triton_bench.meta import cuda_capability_geq, is_hip, get_cdna_version
1211

1312
if torch.cuda.is_available() and not is_hip():
@@ -49,7 +48,8 @@ def _query_gpu_specs():
4948

5049
def quantize(w, dtype, dev, **opt):
5150
if dtype == "bf16":
52-
return w.to(torch.bfloat16), InFlexData(), MicroscalingCtx()
51+
wq = w.to(torch.bfloat16).transpose(-1, -2).contiguous().transpose(-1, -2)
52+
return wq, InFlexData(), MicroscalingCtx()
5353
elif dtype == "fp8":
5454
fp8e4_dtype = torch.float8_e4m3fn if get_cdna_version() != 3 \
5555
else torch.float8_e4m3fnuz
@@ -98,46 +98,35 @@ def bench_mlp(batch, dim1, dim2, n_expts_tot, n_expts_act, x_dtype, w_dtype,
9898
# -- benchmark --
9999
fpath = Path(f"logs/{name}/{batch}-{dim1}-{dim2}-{n_expts_tot}-{n_expts_act}-{x_dtype}-{w_dtype}.hatchet")
100100
fpath.parent.mkdir(parents=True, exist_ok=True)
101-
proton.start(str(fpath.with_suffix('')), hook="triton")
102-
proton.deactivate()
103-
# run layer
104101
x_dtype = {"bf16": torch.bfloat16, "fp8": torch.float8_e4m3fn}[x_dtype]
105102
# special treatment of fp8_e4m3 on AMD CDNA3 because it uses fp8_e4m3fnuz
106103
if x_dtype == torch.float8_e4m3fn and get_cdna_version() == 3:
107104
x_dtype = torch.float8_e4m3fnuz
105+
106+
x = torch.randn((batch, dim1), device=dev)
107+
xg = x.to(wg.dtype if n_expts_tot > 1 else x_dtype)
108+
x = x.to(x_dtype)
109+
# run layer
110+
proton.start(str(fpath.with_suffix('')), hook="triton")
108111
for i in range(100):
109-
x = torch.randn((batch, dim1), device=dev)
110-
x = x.to(wg.dtype if n_expts_tot > 1 else x_dtype)
111-
proton.activate()
112112
if n_expts_tot > 1:
113-
logits = matmul_ogs(x, wg, bg, precision_config=pcg)
114-
rdata, gather_indx, scatter_indx = routing(logits, n_expts_act)
115-
if EP > 1:
116-
proton.deactivate()
117-
# TODO: activate proton here when fast expert parallelism simulation is done
118-
m = logits.shape[0] * EP
119-
_, rdata, gather_indx, scatter_indx = simulate_expert_sharded_routing(m, rdata, EP, device=dev)
120-
proton.activate()
121-
x = x.to(x_dtype)
113+
logits = matmul_ogs(xg, wg, bg, precision_config=pcg)
114+
rdata, gather_indx, scatter_indx = routing(logits, n_expts_act, simulated_ep=EP)
122115
else:
123116
rdata, gather_indx, scatter_indx = None, None, None
124-
# c0 = torch.empty((x.shape[0], w1.shape[-1]), device=dev, dtype=x.dtype)
125-
# c1 = torch.empty((x.shape[0], w2.shape[-1]), device=dev, dtype=x.dtype)
126-
# TODO: cublas is simply set to None on AMD and may cause this to fail if uncommented
127-
# cublas.matmul(x, w1.squeeze(0), c0)
128-
# cublas.matmul(c0, w2.squeeze(0), c1)
129117
x = matmul_ogs(x, w1, b1, rdata, gather_indx=gather_indx, precision_config=pc1)
130118
x = triton_bench.swiglu.swiglu(x, 1.0, pcs)
131119
x = matmul_ogs(x, w2, b2, rdata, scatter_indx=scatter_indx, precision_config=pc2)
132-
proton.deactivate()
133120
proton.finalize()
134121

135122
# -- analyze --
136123
with open(f"{fpath}") as fd:
137124
data = json.load(fd)
138125
# TODO: this will be broken if kernels use scopes themselves
139126
# compute useful (a.k.a. matmul) bytes and flops
140-
matmuls = [x for x in data[0]["children"] if "matmul" in x["frame"]["name"]]
127+
matmuls = [
128+
x for x in data[0]["children"] if "_matmul" in x["frame"]["name"] and "metadata" not in x["frame"]["name"]
129+
]
141130
tot_bytes = sum([x["metrics"]["bytes"] for x in matmuls])
142131
tot_flops = {w: sum([x["metrics"].get(f"flops{w}", 0) for x in matmuls]) for w in [8, 16]}
143132
# compute total time (incl. "not useful" work)
@@ -163,5 +152,5 @@ def bench_mlp(batch, dim1, dim2, n_expts_tot, n_expts_act, x_dtype, w_dtype,
163152
qxdtype = "fp8" if has_native_mx4 else "bf16"
164153
print(bench_mlp(8192, 8192, 8192, 1, 1, "fp8", "fp8", TP=1, EP=1, name="dense"))
165154
print(bench_mlp(8192, 8192, 8192, 1, 1, qxdtype, "mx4", TP=1, EP=1, name="dense"))
166-
print(bench_mlp(2048, 5120, 8192, 128, 4, "fp8", "fp8", TP=4, EP=1, name="llama4"))
167-
print(bench_mlp(2048, 5120, 8192, 128, 4, qxdtype, "mx4", TP=4, EP=1, name="llama4"))
155+
print(bench_mlp(2048, 5120, 8192, 128, 4, "fp8", "fp8", TP=4, EP=2, name="llama4"))
156+
print(bench_mlp(2048, 5120, 8192, 128, 4, qxdtype, "mx4", TP=4, EP=2, name="llama4"))

bench/tests/test_compact.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
import pytest
2+
import torch
3+
from triton_bench.compact import masked_compact, masked_compact_torch
4+
5+
6+
@pytest.mark.parametrize("n_tokens, n_cols, k, p", [
7+
(8192, 64, 4, 0.5),
8+
(8192, 64, 4, 1.0),
9+
(131, 128, 16, 0.6),
10+
(496, 128, 16, 0.),
11+
])
12+
def test_masked_compact(n_tokens, n_cols, k, p):
13+
device = "cuda"
14+
yi = torch.rand((n_tokens, n_cols), device=device).argsort(dim=-1)
15+
yi = yi[:, :k].to(torch.int32)
16+
yv = torch.randn((n_tokens, k), dtype=torch.bfloat16, device=device)
17+
# "drop" indices from yi with probability `p`
18+
mask = torch.zeros((n_tokens, n_cols), dtype=torch.int32, device=device)
19+
keep = (torch.rand(yi.shape, device=device) < p)
20+
if keep.any():
21+
rows = torch.arange(yi.size(0), device=device).unsqueeze(1).expand_as(yi)
22+
mask[rows[keep], yi[keep]] = 1
23+
chunks = mask.view(*mask.shape[:-1], -1, 32)
24+
weights = (1 << torch.arange(32, dtype=torch.int32, device=device))
25+
bitmask = (chunks.int() * weights).sum(dim=-1)
26+
yv_ref, yi_ref = masked_compact_torch(yv, yi, bitmask)
27+
yv_tri, yi_tri = masked_compact(yv, yi, bitmask)
28+
assert torch.all(yi_ref == yi_tri)
29+
assert torch.all(yv_ref == yv_tri)

bench/tests/test_matmul.py

Lines changed: 4 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,9 @@
1-
import itertools
21
from dataclasses import dataclass, fields
32
import pytest
43
import torch
54
# benchmarking utilities
6-
import triton.profiler as proton
75
# routing utilities
8-
from triton_bench.routing import routing_torch, simulate_expert_sharded_routing
6+
from triton_bench.routing import routing
97
# matmul utilities
108
import triton_bench.matmul_ogs_details.opt_flags as opt_flags
119
from triton_bench.matmul_ogs import FlexCtx, PrecisionConfig, MicroscalingCtx
@@ -43,22 +41,11 @@ def mask_indx(idx, n_expts_act):
4341

4442
def init_routing_data(m, n_expts_tot, n_expts_act, n_expt_shards, do_gather, do_scatter):
4543
dev = "cuda"
46-
logits = torch.randn((m, n_expts_tot), dtype=torch.float32, device=dev, requires_grad=True)
47-
routing_data, gather_idx, scatter_idx = routing_torch(logits, n_expts_act)
48-
if n_expt_shards > 1:
49-
m = logits.shape[0] * n_expt_shards
50-
_, routing_data, gather_idx, scatter_idx = simulate_expert_sharded_routing(m, routing_data, n_expt_shards,
51-
device=logits.device)
44+
logits = torch.randn((m, n_expts_tot), dtype=torch.float16, device=dev, requires_grad=True)
45+
routing_data, gather_idx, scatter_idx = routing(logits, n_expts_act, simulated_ep=n_expt_shards)
5246
routing_data.gate_scal = None
5347
gather_idx = gather_idx if do_gather else None
5448
scatter_idx = scatter_idx if do_scatter else None
55-
if do_gather and do_scatter and n_expts_act == 1 and n_expt_shards == 1:
56-
# Compute expt_indx as in routing_torch to access routing_data.expt_hist
57-
expt_indx = torch.argsort(-torch.softmax(logits, dim=-1), dim=1,
58-
stable=True)[:, :n_expts_act].reshape(-1).to(torch.int32)
59-
assert (torch.argsort(expt_indx, stable=True) == scatter_idx.dst_indx).all()
60-
routing_data.expt_hist[expt_indx[scatter_idx.dst_indx[-n_expts_act:]]] -= 1
61-
scatter_idx = mask_indx(scatter_idx, n_expts_act)
6249
return m, routing_data, gather_idx, scatter_idx
6350

6451

@@ -315,7 +302,7 @@ def round_x(x, idx):
315302
scale = lambda val, scal: val if scal is None else val / scal
316303
if n_expt_shards > 1:
317304
if not do_scatter:
318-
n_rows = rdata.expt_hist[-1].item()
305+
n_rows = rdata.expt_hist.sum()
319306
assert n_rows > 0
320307
ref_y = ref_y[:n_rows]
321308
tri_y = tri_y[:n_rows]

bench/tests/test_routing.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,7 @@ def _assert_indx_equal(ref, tri):
7373

7474
def bench_routing():
7575
import triton.profiler as proton
76-
n_tokens = 2048
76+
n_tokens = 8192
7777
block_m = 128
7878
n_expts_tot, n_expts_act = 128, 4
7979
tri_logits = init_data(n_tokens, n_expts_tot)

bench/triton_bench/__init__.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
from dataclasses import dataclass
2+
3+
4+
@dataclass
5+
class Bitmatrix:
6+
data: "torch.Tensor"
7+
shape: tuple[int]

bench/triton_bench/compact.py

Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,87 @@
1+
import torch
2+
import triton
3+
import triton.language as tl
4+
from triton_bench import Bitmatrix
5+
6+
7+
@triton.jit
8+
def _masked_compact(Yv, Yi, BitMask, stride_bm, RetYv, RetYi, sentinel, K: tl.constexpr):
9+
pid_m = tl.program_id(0)
10+
yv = tl.load(Yv + pid_m * K + tl.arange(0, K))
11+
yi = tl.load(Yi + pid_m * K + tl.arange(0, K))
12+
div = yi // 32
13+
rem = yi % 32
14+
active_bits = (tl.load(BitMask + pid_m * stride_bm + div) >> rem) & 1
15+
exc_cumsum = tl.cumsum(active_bits, 0) - active_bits
16+
rev_arange = tl.where(active_bits, 0, K - 1 - tl.arange(0, K))
17+
write_indx = exc_cumsum + rev_arange
18+
yv = tl.where(active_bits, yv, sentinel)
19+
yi = tl.where(active_bits, yi, sentinel)
20+
tl.store(RetYv + pid_m * K + write_indx, yv)
21+
tl.store(RetYi + pid_m * K + write_indx, yi)
22+
23+
24+
def masked_compact(yv, yi, bitmask, sentinel=-1):
25+
"""
26+
Return compacted copies of *yv* and *yi* based on a per-row bitmask.
27+
28+
Only the elements whose index appears among the active bits of *bitmask*
29+
are kept; the rest are replaced by *sentinel*. Kept elements preserve
30+
their original left-to-right order.
31+
32+
Parameters
33+
----------
34+
yv : torch.Tensor, shape (B, K)
35+
Values tensor.
36+
yi : torch.Tensor, shape (B, K), dtype torch.long
37+
Integer indices (0 ≤ index < 32) associated with *yv*.
38+
bitmask : torch.Tensor, shape (B,) **or** (B, 32)
39+
Per-row mask of active indices. See the in-place version for details.
40+
sentinel : int, default -1
41+
Value written into dropped positions of the returned tensors.
42+
43+
Returns
44+
-------
45+
(yv_out, yi_out) : Tuple[torch.Tensor, torch.Tensor], each shape (B, K)
46+
New tensors with the same dtype/device as the inputs.
47+
48+
"""
49+
50+
n_rows, n_cols = yi.shape
51+
ret_yv = torch.empty_like(yv)
52+
ret_yi = torch.empty_like(yi)
53+
if isinstance(bitmask, Bitmatrix):
54+
bitmask = bitmask.data
55+
56+
_masked_compact[(n_rows, )](
57+
yv, yi, bitmask, bitmask.stride(0), # inputs
58+
ret_yv, ret_yi, # outputs
59+
sentinel, # sentinel
60+
K=n_cols # constants
61+
)
62+
return ret_yv, ret_yi
63+
64+
65+
def masked_compact_torch(yv: torch.Tensor, yi: torch.Tensor, bitmask: torch.Tensor, sentinel=-1):
66+
"""
67+
reference implementation of `masked_compact`
68+
"""
69+
B, K = yi.shape
70+
device, dtype = yi.device, yi.dtype
71+
# Expand bitmask to a boolean matrix of active bits (B, 32)
72+
w = (1 << torch.arange(32, device=device, dtype=bitmask.dtype))
73+
bits = (bitmask.unsqueeze(-1) & w) != 0
74+
mask = bits.flatten(start_dim=-2) # or bits.reshape(B, -1)
75+
# For every yi element decide whether it should be kept
76+
keep = mask.gather(1, yi.long())
77+
# Build a stable permutation that brings all "keep" items forward
78+
# False→0, True→1 ==> invert so kept==0, dropped==1, then argsort
79+
order = (~keep).to(torch.int).argsort(dim=1, stable=True)
80+
# Re‑order tensors according to above permutation
81+
yi_sorted = yi.gather(1, order)
82+
yv_sorted = yv.gather(1, order)
83+
# fill relevant positions with sentinel
84+
keep_sorted = keep.gather(1, order)
85+
yi_sorted[~keep_sorted] = sentinel
86+
yv_sorted[~keep_sorted] = sentinel
87+
return yv_sorted, yi_sorted

bench/triton_bench/matmul_ogs_details/_common.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -91,5 +91,5 @@ def matmul_launch_metadata(grid, kernel, args):
9191
sindx = args.get("WriteBackIndx", None)
9292
if sindx is not None:
9393
skipped = (sindx == -1).sum() / sindx.numel()
94-
ret["bytes"] = ((1 - skipped) * Y.numel() * Y.element_size() + X.numel() * X.element_size() + n_w_bytes)
94+
ret["bytes"] = int((1 - skipped) * Y.numel() * Y.element_size() + X.numel() * X.element_size() + n_w_bytes)
9595
return ret

bench/triton_bench/matmul_ogs_details/metadata.py

Lines changed: 36 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -14,41 +14,35 @@ class ExptData:
1414

1515

1616
@triton.jit
17-
def _memset_metadata(Metadata, metadata_size, BLOCK: tl.constexpr):
17+
def _matmul_metadata_memset(Hist, n_expts_tot, MDHist, MDTokStarts, MDTileStarts, MDTileInfo, md_n_tiles,
18+
BLOCK: tl.constexpr, TILE_DIM: tl.constexpr):
1819
pid = tl.program_id(0)
20+
# if pid == 0 - initialize cumsums
21+
if pid == 0:
22+
x_tok = tl.zeros([BLOCK], dtype=MDTokStarts.dtype.element_ty)
23+
x_tile = tl.zeros([BLOCK], dtype=MDTileStarts.dtype.element_ty)
24+
tl.store(MDTokStarts, 0)
25+
tl.store(MDTileStarts, 0)
26+
for i in range(0, n_expts_tot, BLOCK):
27+
offs_n = tl.arange(0, BLOCK) + i
28+
mask = offs_n < n_expts_tot
29+
hist_tok = tl.load(Hist + offs_n, mask=mask)
30+
hist_tile = tl.cdiv(hist_tok, TILE_DIM)
31+
tok_starts = tl.cumsum(hist_tok, 0) + x_tok
32+
x_tok += tl.sum(hist_tok, 0).to(MDTokStarts.dtype.element_ty)
33+
tile_starts = tl.cumsum(hist_tile, 0) + x_tile
34+
x_tile += tl.sum(hist_tile, 0).to(MDTileStarts.dtype.element_ty)
35+
tl.store(MDHist + offs_n, hist_tok, mask=mask)
36+
tl.store(MDTokStarts + 1 + offs_n, tok_starts, mask=mask)
37+
tl.store(MDTileStarts + 1 + offs_n, tile_starts, mask=mask)
38+
39+
# initialize block data
1940
offs = pid * BLOCK + tl.arange(0, BLOCK)
20-
tl.store(Metadata + offs, 0xffffffff, mask=offs < metadata_size)
41+
tl.store(MDTileInfo + offs, 0xffffffff, mask=offs < md_n_tiles)
2142

2243

2344
@triton.jit
24-
def _compute_metadata_1(Hist, n_expts_tot, MDHist, MDTokStarts, MDTileStarts, MDTileInfo, N_EXPTS_PAD: tl.constexpr,
25-
BLOCK: tl.constexpr, TILE_DIM: tl.constexpr):
26-
27-
BLOCK_N: tl.constexpr = 1024
28-
29-
x_tok = tl.zeros([BLOCK_N], dtype=MDTokStarts.dtype.element_ty)
30-
x_tile = tl.zeros([BLOCK_N], dtype=MDTileStarts.dtype.element_ty)
31-
32-
tl.store(MDTokStarts, 0)
33-
tl.store(MDTileStarts, 0)
34-
35-
for i in range(0, n_expts_tot, BLOCK_N):
36-
offs_n = tl.arange(0, BLOCK_N) + i
37-
mask = offs_n < n_expts_tot
38-
hist_tok = tl.load(Hist + offs_n, mask=mask)
39-
hist_tile = tl.cdiv(hist_tok, TILE_DIM)
40-
tok_starts = tl.cumsum(hist_tok, 0) + x_tok
41-
x_tok += tl.sum(hist_tok, 0)
42-
tile_starts = tl.cumsum(hist_tile, 0) + x_tile
43-
x_tile += tl.sum(hist_tile, 0)
44-
tl.store(MDHist + offs_n, hist_tok, mask=mask)
45-
tl.store(MDTokStarts + 1 + offs_n, tok_starts, mask=mask)
46-
tl.store(MDTileStarts + 1 + offs_n, tile_starts, mask=mask)
47-
48-
49-
@triton.jit
50-
def _compute_metadata_2(Hist, n_expts_tot, MDHist, MDTokStarts, MDTileStarts, MDTileInfo, N_EXPTS_PAD: tl.constexpr,
51-
BLOCK: tl.constexpr, TILE_DIM: tl.constexpr):
45+
def _matmul_metadata_compute(Hist, MDTileStarts, MDTileInfo, BLOCK: tl.constexpr, TILE_DIM: tl.constexpr):
5246

5347
expt_id = tl.program_id(0)
5448
n_tokens = tl.load(Hist + expt_id)
@@ -75,26 +69,21 @@ def compute_metadata(routing_data, n_rows, block_m):
7569
grid_m = n_rows
7670
else:
7771
grid_m = n_expts_tot - 1 - ((n_expts_tot - n_rows - 1) // block_m)
78-
n_expts_pad = cdiv(n_expts_tot, 128) * 128
7972
metadata_size = 3 * n_expts_tot + 2 + grid_m
8073
metadata = torch.empty(metadata_size, dtype=torch.int32, device=device)
8174
md_hist = metadata[:n_expts_tot]
82-
md_tok_starts = metadata[n_expts_tot:n_expts_tot * 2 + 1]
75+
md_offs = metadata[n_expts_tot:n_expts_tot * 2 + 1]
76+
md_offs_sum = metadata[3 * n_expts_tot + 2 - 1]
8377
md_tile_starts = metadata[n_expts_tot * 2 + 1:n_expts_tot * 3 + 2]
8478
md_tile_infos = metadata[n_expts_tot * 3 + 2:]
85-
_memset_metadata[(cdiv(metadata_size, MEMSET_BLOCK), )](
86-
metadata, metadata_size, # inputs
87-
BLOCK=MEMSET_BLOCK # optimization parameters
79+
_matmul_metadata_memset[(cdiv(metadata_size, MEMSET_BLOCK), )](
80+
routing_data.expt_hist, n_expts_tot, md_hist, md_offs, md_tile_starts, md_tile_infos, md_tile_infos.shape[0],
81+
BLOCK=MEMSET_BLOCK, # optimization parameters
82+
TILE_DIM=block_m, # constants
83+
)
84+
_matmul_metadata_compute[(n_expts_tot, )](
85+
routing_data.expt_hist, md_tile_starts, md_tile_infos, # outputs
86+
BLOCK=HIST2_BLOCK_M, # optimization parameters
87+
TILE_DIM=block_m, # constants
8888
)
89-
for kernel, num_blocks in [(_compute_metadata_1, 1), (_compute_metadata_2, n_expts_tot)]:
90-
kernel[(num_blocks, )](
91-
routing_data.expt_hist, n_expts_tot, # inputs
92-
md_hist, md_tok_starts, md_tile_starts, md_tile_infos, # outputs
93-
BLOCK=HIST2_BLOCK_M, # optimization parameters
94-
N_EXPTS_PAD=n_expts_pad, TILE_DIM=block_m, # constants
95-
)
96-
hist = metadata[:n_expts_tot]
97-
offs = metadata[n_expts_tot:2 * n_expts_tot + 1]
98-
offs_sum = metadata[3 * n_expts_tot + 2 - 1]
99-
blocks = metadata[n_expts_tot + 2 * (n_expts_tot + 1):]
100-
return ExptData(hist, offs, offs_sum, blocks, metadata)
89+
return ExptData(md_hist, md_offs, md_offs_sum, md_tile_infos, metadata)

0 commit comments

Comments
 (0)