Skip to content

Commit e206b17

Browse files
authored
[BENCH] 2% speedup from aligning matmul_ogs metadata (#6882)
This allows unmasked vectorised stores
1 parent ade4d3a commit e206b17

File tree

3 files changed

+59
-27
lines changed

3 files changed

+59
-27
lines changed

python/triton_kernels/tests/test_routing.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -52,11 +52,16 @@ def test_op(n_tokens, n_expts_tot, n_expts_act, block_m, device):
5252
ref_routing_data, ref_gather, ref_scatter = routing_torch(ref_logits, n_expts_act)
5353
tri_routing_data, tri_gather, tri_scatter = routing(tri_logits, n_expts_act)
5454
ref_metadata = ref_expt_data(ref_routing_data, n_tokens * n_expts_act, block_m)
55-
tri_metadata = compute_metadata(tri_routing_data, n_tokens * n_expts_act, block_m).buffer
55+
tri_metadata = compute_metadata(tri_routing_data, n_tokens * n_expts_act, block_m)
5656

5757
assert_close(ref_routing_data.gate_scal, tri_routing_data.gate_scal, 2e-2, 4e-3)
5858
assert_equal(ref_routing_data.expt_hist, tri_routing_data.expt_hist)
59-
assert_equal(ref_metadata, tri_metadata)
59+
60+
assert_equal(ref_metadata[:n_expts_tot], tri_metadata.hist)
61+
assert_equal(ref_metadata[n_expts_tot:2 * n_expts_tot + 1], tri_metadata.offs)
62+
assert_equal(ref_metadata[3 * n_expts_tot + 1], tri_metadata.offs_sum)
63+
assert_equal(ref_metadata[3 * n_expts_tot + 2:], tri_metadata.blocks)
64+
6065
assert ref_routing_data.n_expts_tot == ref_routing_data.n_expts_tot
6166
assert ref_routing_data.n_expts_act == ref_routing_data.n_expts_act
6267

python/triton_kernels/triton_kernels/matmul_ogs.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -548,8 +548,9 @@ def matmul_ogs(x, w, bias,
548548
x, w, gather_indx, scatter_indx, routing_data, opt_flags, preprocessing_features
549549
)
550550
if expt_data.buffer is not None:
551-
assert expt_data.buffer.shape[0] == 3*n_expts_tot + 2 + grid_m, \
552-
f"invalid expt_data, {expt_data.buffer.shape}, {n_expts_tot=}, {grid_m=}"
551+
assert expt_data.hist.shape[0] == n_expts_tot, "invalid expt_data"
552+
assert expt_data.offs.shape[0] == n_expts_tot + 1, "invalid expt_data"
553+
assert expt_data.blocks.shape[0] == grid_m, "invalid expt_data"
553554
# matrix multiplication
554555
n_cta = batch_size * grid_m * grid_n * opt_flags.split_k
555556
n_cta = min(target_info.num_sms(), n_cta) if opt_flags.is_persistent else n_cta

python/triton_kernels/triton_kernels/matmul_ogs_details/metadata.py

Lines changed: 49 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -14,31 +14,49 @@ class ExptData:
1414

1515

1616
@triton.jit
17-
def _matmul_metadata_memset(Hist, n_expts_tot, MDHist, MDTokStarts, MDTileStarts, MDTileInfo, md_n_tiles,
18-
BLOCK: tl.constexpr, TILE_DIM: tl.constexpr):
17+
def _matmul_metadata_memset(Hist, n_expts_tot, MDTokStarts, MDTileStarts, MDTileInfo, BLOCK: tl.constexpr,
18+
TILE_DIM: tl.constexpr, extra_block: tl.constexpr):
1919
pid = tl.program_id(0)
20+
21+
TileInfoOut = MDTileInfo + (pid - 1) * BLOCK + tl.arange(0, BLOCK)
22+
2023
# if pid == 0 - initialize cumsums
2124
if pid == 0:
2225
x_tok = tl.zeros([BLOCK], dtype=MDTokStarts.dtype.element_ty)
2326
x_tile = tl.zeros([BLOCK], dtype=MDTileStarts.dtype.element_ty)
24-
tl.store(MDTokStarts, 0)
25-
tl.store(MDTileStarts, 0)
27+
28+
Tok_ptrs = MDTokStarts + tl.arange(0, BLOCK)
29+
Tile_ptrs = MDTileStarts + tl.arange(0, BLOCK)
30+
2631
for i in range(0, n_expts_tot, BLOCK):
2732
offs_n = tl.arange(0, BLOCK) + i
28-
mask = offs_n < n_expts_tot
29-
hist_tok = tl.load(Hist + offs_n, mask=mask)
33+
if extra_block:
34+
# we need an extra block at the end just to contain the final
35+
# sum; this only happens if our total number of experts is an
36+
# exact multiple of BLOCK, obviating the need for any masking
37+
hist_tok = tl.load(Hist + offs_n)
38+
else:
39+
mask = offs_n < n_expts_tot
40+
hist_tok = tl.load(Hist + offs_n, mask=mask, other=0)
3041
hist_tile = tl.cdiv(hist_tok, TILE_DIM)
3142
tok_starts = tl.cumsum(hist_tok, 0) + x_tok
3243
x_tok += tl.sum(hist_tok, 0).to(MDTokStarts.dtype.element_ty)
3344
tile_starts = tl.cumsum(hist_tile, 0) + x_tile
3445
x_tile += tl.sum(hist_tile, 0).to(MDTileStarts.dtype.element_ty)
35-
tl.store(MDHist + offs_n, hist_tok, mask=mask)
36-
tl.store(MDTokStarts + 1 + offs_n, tok_starts, mask=mask)
37-
tl.store(MDTileStarts + 1 + offs_n, tile_starts, mask=mask)
3846

39-
# initialize block data
40-
offs = pid * BLOCK + tl.arange(0, BLOCK)
41-
tl.store(MDTileInfo + offs, 0xffffffff, mask=offs < md_n_tiles)
47+
tl.store(Tok_ptrs, tok_starts - hist_tok)
48+
tl.store(Tile_ptrs, tile_starts - hist_tile)
49+
50+
Tok_ptrs += BLOCK
51+
Tile_ptrs += BLOCK
52+
53+
if extra_block:
54+
tl.store(Tok_ptrs, x_tok)
55+
tl.store(Tile_ptrs, x_tile)
56+
57+
else:
58+
59+
tl.store(TileInfoOut, 0xffffffff)
4260

4361

4462
@triton.jit
@@ -60,7 +78,7 @@ def _matmul_metadata_compute(Hist, MDTileStarts, MDTileInfo, BLOCK: tl.constexpr
6078
def compute_metadata(routing_data, n_rows, block_m):
6179
if routing_data.expt_hist is None:
6280
return ExptData(None, None, None, None, None)
63-
MEMSET_BLOCK = 512
81+
MEMSET_BLOCK = 128
6482
HIST2_BLOCK_M = 512
6583
device = routing_data.expt_hist.device
6684
n_expts_tot = routing_data.n_expts_tot
@@ -69,21 +87,29 @@ def compute_metadata(routing_data, n_rows, block_m):
6987
grid_m = n_rows
7088
else:
7189
grid_m = n_expts_tot - 1 - ((n_expts_tot - n_rows - 1) // block_m)
72-
metadata_size = 3 * n_expts_tot + 2 + grid_m
90+
91+
n_expts_pad = cdiv(n_expts_tot, MEMSET_BLOCK) * MEMSET_BLOCK
92+
pad2 = cdiv(n_expts_tot + 1, MEMSET_BLOCK) * MEMSET_BLOCK
93+
extra_block = (n_expts_pad != pad2)
94+
pids = cdiv(grid_m, MEMSET_BLOCK) + 1
95+
96+
metadata_size = n_expts_pad + 2 * pad2 + MEMSET_BLOCK * (pids - 1)
97+
7398
metadata = torch.empty(metadata_size, dtype=torch.int32, device=device)
74-
md_hist = metadata[:n_expts_tot]
75-
md_offs = metadata[n_expts_tot:n_expts_tot * 2 + 1]
76-
md_offs_sum = metadata[3 * n_expts_tot + 2 - 1]
77-
md_tile_starts = metadata[n_expts_tot * 2 + 1:n_expts_tot * 3 + 2]
78-
md_tile_infos = metadata[n_expts_tot * 3 + 2:]
79-
_matmul_metadata_memset[(cdiv(metadata_size, MEMSET_BLOCK), )](
80-
routing_data.expt_hist, n_expts_tot, md_hist, md_offs, md_tile_starts, md_tile_infos, md_tile_infos.shape[0],
99+
100+
md_hist = routing_data.expt_hist[:n_expts_tot]
101+
md_offs = metadata[:n_expts_tot + 1]
102+
md_tile_starts = metadata[pad2:][:n_expts_tot + 1]
103+
md_offs_sum = md_tile_starts[-1]
104+
md_tile_infos = metadata[2 * pad2:][:grid_m]
105+
_matmul_metadata_memset[(pids, )](
106+
routing_data.expt_hist, n_expts_tot, md_offs, md_tile_starts, md_tile_infos,
81107
BLOCK=MEMSET_BLOCK, # optimization parameters
82108
TILE_DIM=block_m, # constants
83-
)
109+
extra_block=extra_block, num_warps=1)
84110
_matmul_metadata_compute[(n_expts_tot, )](
85111
routing_data.expt_hist, md_tile_starts, md_tile_infos, # outputs
86112
BLOCK=HIST2_BLOCK_M, # optimization parameters
87113
TILE_DIM=block_m, # constants
88-
)
114+
num_warps=4)
89115
return ExptData(md_hist, md_offs, md_offs_sum, md_tile_infos, metadata)

0 commit comments

Comments
 (0)