Skip to content

Commit 981e987

Browse files
authored
[BENCH] multiple performance improvements to routing code (#6546)
1 parent c5fed8e commit 981e987

File tree

2 files changed

+40
-60
lines changed

2 files changed

+40
-60
lines changed

bench/tests/test_routing.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,11 @@ def bench_routing():
8383
tri_routing_data, tri_gather, tri_scatter = routing(tri_logits, n_expts_act)
8484
tri_metadata = compute_metadata(tri_routing_data, n_tokens * n_expts_act, block_m)
8585
proton.finalize()
86+
try:
87+
import os
88+
os.system("proton-viewer -m time/ms routing.hatchet")
89+
except:
90+
pass
8691

8792

8893
if __name__ == "__main__":

bench/triton_bench/routing.py

Lines changed: 35 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -13,10 +13,9 @@ def _routing_compute_expt_offs(ExpertHist, FinalExpertOffs, hist_size, # histog
1313
offs_n = i * BLOCK_N + tl.arange(0, BLOCK_N)
1414
mask_n = offs_n < hist_size
1515
hist2 = tl.load(ExpertHist + offs_n, mask=mask_n)
16-
tok_starts = tl.cumsum(hist2, 0) + x
16+
tok_starts = tl.cumsum(hist2, 0) - hist2 + x
1717
x += tl.sum(hist2, 0)
18-
tl.store(FinalExpertOffs, 0)
19-
tl.store(FinalExpertOffs + 1 + offs_n, tok_starts, mask=mask_n)
18+
tl.store(FinalExpertOffs + offs_n, tok_starts, mask=mask_n)
2019
offs_n += BLOCK_N
2120

2221

@@ -52,51 +51,33 @@ def _keyed_add(x, y):
5251

5352

5453
@triton.jit
55-
def _count_previous(x):
56-
"""
57-
Input x : uint16[..., N]
58-
Output y : uint32[..., N]
59-
semantics : y[..., i] = sum_j((x[..., j] == x[..., i]) & (j < i))
60-
credits: @apgoucher
61-
"""
54+
def _routing_compute_indx(GatherIndx, ScatterIndx, GateScal, ExptScal, ExptIndx, PartialOffs, stride_pm, n_gates,
55+
BLOCK_M: tl.constexpr, N_EXPTS_ACT: tl.constexpr):
6256

63-
BLOCK_N: tl.constexpr = x.shape[-1] # summation axis
64-
BATCHES: tl.constexpr = x.numel // BLOCK_N # number of batches
57+
pid_m = tl.program_id(0)
6558

66-
# reduce to two-dimensional case:
67-
y = tl.reshape(x, [BATCHES, BLOCK_N]).to(tl.uint32)
59+
tl.static_assert(N_EXPTS_ACT * BLOCK_M <= 32768)
6860

69-
tl.static_assert(BLOCK_N <= 32768, "compute_run_lengths requires axis to have length <= 32768")
61+
local_offs = tl.arange(0, N_EXPTS_ACT * BLOCK_M)
62+
offs = pid_m * BLOCK_M * N_EXPTS_ACT + local_offs
63+
expert = tl.load(ExptIndx + offs, mask=(offs < n_gates), other=-1).to(tl.uint32)
7064

71-
# sort (expert, position) ordered pairs to perform an argsort:
72-
kv_pairs = ((y << 16) | tl.arange(0, BLOCK_N)[None, :]).to(tl.uint32)
73-
sorted_kv_pairs = tl.sort(kv_pairs, 1)
65+
# stable-sort by expert ID:
66+
kv_pairs = ((expert << 16) | local_offs).to(tl.uint32)
67+
kv_pairs = tl.sort(kv_pairs, 0)
68+
expert = kv_pairs >> 16
69+
offs = pid_m * BLOCK_M * N_EXPTS_ACT + (kv_pairs & 0xffff)
70+
mask = expert != 0xffff
71+
gate_scal = tl.load(ExptScal + offs, mask=mask)
7472

7573
# compute run lengths in expert-sorted order:
76-
x = (sorted_kv_pairs & 0xffff0000 | 0x00000001)
77-
expts_and_inclusive_run_lengths = tl.associative_scan(x, 1, _keyed_add)
74+
x = (kv_pairs & 0xffff0000 | 0x00000001)
75+
expts_and_inclusive_run_lengths = tl.associative_scan(x, 0, _keyed_add)
7876
exclusive_run_lengths = (expts_and_inclusive_run_lengths - 1) & 0xffff
7977

80-
# undo permutation by doing another sort
81-
# TODO rewrite this when tl.scatter becomes available
82-
kv_pairs = ((sorted_kv_pairs << 16) | exclusive_run_lengths).to(tl.uint32)
83-
unsorted_run_lengths = tl.sort(kv_pairs) & 0xffff
84-
85-
res = tl.reshape(unsorted_run_lengths, x.shape)
86-
return res
87-
78+
gates = tl.load(PartialOffs + pid_m * stride_pm + expert, mask=(expert != 0xffff))
79+
gates += exclusive_run_lengths
8880

89-
@triton.jit
90-
def _routing_compute_indx(GatherIndx, ScatterIndx, GateScal, ExptScal, ExptIndx, PartialOffs, stride_pm, n_gates,
91-
BLOCK_M: tl.constexpr, N_EXPTS_ACT: tl.constexpr):
92-
pid_m = tl.program_id(0)
93-
offs = pid_m * BLOCK_M * N_EXPTS_ACT + tl.arange(0, N_EXPTS_ACT * BLOCK_M)
94-
mask = offs < n_gates
95-
indx = tl.load(ExptIndx + offs, mask=mask)
96-
mask = mask & (indx != -1)
97-
gates = tl.load(PartialOffs + pid_m * stride_pm + indx, mask=mask)
98-
gates += tl.reshape(_count_previous(indx), [BLOCK_M * N_EXPTS_ACT])
99-
gate_scal = tl.load(ExptScal + offs, mask=mask)
10081
tl.store(ScatterIndx + offs, gates, mask=mask)
10182
tl.store(GatherIndx + gates, offs, mask=mask)
10283
tl.store(GateScal + gates, gate_scal, mask=mask)
@@ -117,15 +98,16 @@ def _routing_clear_bitmatrix(Bitmatrix, stride_bm, shape_bn, cutoff, BLOCK_N: tl
11798

11899

119100
@triton.jit
120-
def _routing_memset_indx(Indx0, Indx1, size, sentinel, BLOCK: tl.constexpr):
101+
def _routing_memset_indx(Indx, size, sentinel, BLOCK: tl.constexpr, ExpertHist, FinalExpertOffs, hist_size,
102+
BLOCK_N: tl.constexpr):
121103
pid = tl.program_id(0)
122-
buf = tl.program_id(1)
123-
offs = pid * BLOCK + tl.arange(0, BLOCK)
124-
mask = offs < size
125-
if buf == 0:
126-
tl.store(Indx0 + offs, sentinel, mask=mask)
127-
if buf == 1:
128-
tl.store(Indx1 + offs, sentinel, mask=mask)
104+
105+
if pid == 0:
106+
_routing_compute_expt_offs(ExpertHist, FinalExpertOffs, hist_size, BLOCK_N)
107+
else:
108+
offs = (pid - 1) * BLOCK + tl.arange(0, BLOCK)
109+
mask = offs < size
110+
tl.store(Indx + offs, sentinel, mask=mask)
129111

130112

131113
@dataclass
@@ -204,22 +186,15 @@ def routing(logits, n_expts_act, expt_indx=None, simulated_ep=1):
204186
# perform compaction to update expt_scal / expt_indx
205187
hist, partial_hist = sum(bitmatrix, partials_block_size=HIST_BLOCK_M, dim=0)
206188
# scratchpad
207-
expt_offs = torch.empty(n_expts_tot + 1, dtype=torch.int32, device=device)
189+
expt_offs = torch.empty(n_expts_tot, dtype=torch.int32, device=device)
208190
indx_offs = torch.empty((cdiv(n_tokens, HIST_BLOCK_M), n_expts_tot), dtype=torch.int32, device=device)
191+
combined_indx = torch.empty(n_gates * 2, dtype=torch.int32, device=device)
209192
# output
210-
topk_indx = torch.empty(n_gates, dtype=torch.int32, device=device)
211-
gate_indx = torch.empty(n_gates, dtype=torch.int32, device=device)
193+
topk_indx = combined_indx[:n_gates]
194+
gate_indx = combined_indx[n_gates:]
212195
gate_scal = torch.empty(n_gates, dtype=logits.dtype, device=device)
213-
_routing_memset_indx[(cdiv(n_gates, MEMSET_BLOCK), 2)](
214-
topk_indx,
215-
gate_indx,
216-
n_gates,
217-
-1,
218-
BLOCK=MEMSET_BLOCK,
219-
)
220-
_routing_compute_expt_offs[(1, )](
221-
hist, expt_offs, hist.shape[0], BLOCK_N=512 # tunable parameters
222-
)
196+
_routing_memset_indx[(cdiv(n_gates * 2, MEMSET_BLOCK) + 1, )](combined_indx, n_gates * 2, -1, MEMSET_BLOCK, hist,
197+
expt_offs, hist.shape[0], BLOCK_N=512)
223198
_routing_compute_indx_offs[(n_expts_tot, )](
224199
expt_offs, partial_hist, # inputs
225200
indx_offs, partial_hist.shape[0], partial_hist.stride(0), # outputs

0 commit comments

Comments
 (0)