Skip to content

Commit 80449c2

Browse files
authored
[BENCH] make MoE routing another 4% faster (#7396)
This reduces routing runtime from 12.3us to 11.8us by tweaking block sizes and conditionally unrolling a loop if the number of iterations is small.
1 parent 85e1eb3 commit 80449c2

File tree

3 files changed

+3
-3
lines changed

3 files changed

+3
-3
lines changed

python/triton_kernels/triton_kernels/reduction_details/reduce_bitmatrix.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,7 @@ def sum_bitmatrix_rows(x, out_ret, partials_block_size=None, n_rows_raw=None):
8888
n_rows_pad, n_cols_raw = x.shape_pad[0], x.shape_raw[1]
8989
assert out_ret.shape == (n_cols_raw, )
9090

91-
TILE_SIZE = 2
91+
TILE_SIZE = max(1, 128 // PARTIALS_BLOCK_M)
9292
BLOCK_MM = PARTIALS_BLOCK_M * TILE_SIZE
9393

9494
pids_x = cdiv(n_rows_pad, BLOCK_MM)

python/triton_kernels/triton_kernels/routing.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,7 @@ class SortTokens(torch.autograd.Function):
9494

9595
@staticmethod
9696
def forward(ctx, expt_scal, expt_indx, bitmatrix):
97-
HIST_BLOCK_M = 64
97+
HIST_BLOCK_M = 32
9898
INDX_OFFS_BLOCK_M = 512
9999
MEMSET_BLOCK = 1024
100100
cdiv = triton.cdiv

python/triton_kernels/triton_kernels/topk_details/_topk_forward.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ def streaming_topk(X, stride_xm, n_expts_tot, offs_m, mask_m, N_EXPTS_PAD: tl.co
6262
acc = tl.topk(x, N_EXPTS_ACT, dim=1)
6363

6464
# subsequent iterations:
65-
for _i in range(loop_iterations):
65+
for _i in (tl.static_range if loop_iterations <= 4 else range)(loop_iterations):
6666
acc = tl.bitonic_merge(acc) # ensure sorted ascending for the merge
6767
X_ptrs -= BLOCK_N
6868
offs_x_n -= BLOCK_N

0 commit comments

Comments
 (0)