Skip to content

Commit 9b13c1c

Browse files
authored
[BENCH] Bitmatrix refactor (#6883)
This gives another 10% speed improvement to routing
1 parent e206b17 commit 9b13c1c

File tree

10 files changed

+132
-86
lines changed

10 files changed

+132
-86
lines changed
Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +0,0 @@
1-
from dataclasses import dataclass
2-
3-
4-
@dataclass
5-
class Bitmatrix:
6-
data: "torch.Tensor" # noqa: F821
7-
shape: tuple[int]
Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
from dataclasses import dataclass
2+
3+
import torch
4+
5+
from .reduction_details.reduce_bitmatrix import clear_sums, sum_bitmatrix_rows
6+
7+
8+
@dataclass
9+
class Bitmatrix:
10+
"""
11+
Represents a boolean matrix in a packed format where each element occupies
12+
a single bit of memory.
13+
14+
We use a Bitmatrix to represent the routing information, where each row
15+
corresponds to a token and each column corresponds to an expert.
16+
17+
S is either None or an all-zero array of size >= n_cols; we pass it along
18+
with the actual bitmatrix to avoid having to launch a separate memset
19+
kernel when we call Bitmatrix::sum().
20+
"""
21+
22+
data: torch.Tensor
23+
shape: tuple[int]
24+
S: torch.tensor
25+
26+
def sum(self, partials_block_size):
27+
n_rows, n_cols = self.shape
28+
dev = self.data.device
29+
if self.S is None:
30+
self.S = clear_sums(n_cols, dev)
31+
out_ret = self.S[:n_cols]
32+
self.S = None # throw error if we try to sum again
33+
return sum_bitmatrix_rows(self, out_ret, partials_block_size)

python/triton_kernels/triton_kernels/compaction.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import torch
22
from .compaction_details._masked_compaction import _masked_compaction
3-
from triton_kernels import Bitmatrix
3+
from .bitmatrix import Bitmatrix
44

55

66
def compaction(yv, yi, bitmask, sentinel=-1):
@@ -36,7 +36,7 @@ def compaction(yv, yi, bitmask, sentinel=-1):
3636
bitmask = bitmask.data
3737

3838
_masked_compaction[(n_rows, )](
39-
yv, yi, bitmask, bitmask.stride(0), # inputs
39+
yv, yi, bitmask, bitmask.stride(0), bitmask.stride(1), # inputs
4040
ret_yv, ret_yi, # outputs
4141
sentinel, # sentinel
4242
K=n_cols # constants

python/triton_kernels/triton_kernels/compaction_details/_masked_compaction.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,13 +3,13 @@
33

44

55
@triton.jit
6-
def _masked_compaction(Yv, Yi, BitMask, stride_bm, RetYv, RetYi, sentinel, K: tl.constexpr):
6+
def _masked_compaction(Yv, Yi, BitMask, stride_bm, stride_bn, RetYv, RetYi, sentinel, K: tl.constexpr):
77
pid_m = tl.program_id(0)
88
yv = tl.load(Yv + pid_m * K + tl.arange(0, K))
99
yi = tl.load(Yi + pid_m * K + tl.arange(0, K))
1010
div = yi // 32
1111
rem = yi % 32
12-
active_bits = (tl.load(BitMask + pid_m * stride_bm + div) >> rem) & 1
12+
active_bits = (tl.load(BitMask + pid_m * stride_bm + div * stride_bn) >> rem) & 1
1313
exc_cumsum = tl.cumsum(active_bits, 0) - active_bits
1414
rev_arange = tl.where(active_bits, 0, K - 1 - tl.arange(0, K))
1515
write_indx = exc_cumsum + rev_arange

python/triton_kernels/triton_kernels/reduction.py

Lines changed: 0 additions & 16 deletions
This file was deleted.

python/triton_kernels/triton_kernels/reduction_details/reduce_bitmatrix.py

Lines changed: 48 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import torch
12
import triton
23
import triton.language as tl
34

@@ -42,49 +43,65 @@ def vpopc(x):
4243

4344

4445
@triton.jit
45-
def _sum_bitmatrix_memset(Ret, ret_size, BLOCK: tl.constexpr):
46+
def _sum_bitmatrix_memset(Ret, BLOCK: tl.constexpr):
4647
pid = tl.program_id(0)
4748
offs = pid * BLOCK + tl.arange(0, BLOCK)
48-
tl.store(Ret + offs, 0, mask=offs < ret_size)
49+
tl.store(Ret + offs, 0)
4950

5051

5152
@triton.jit
52-
def _sum_bitmatrix_rows(B, shape_bm, stride_bm, # input bitmatrix
53-
Ret, Partials, stride_pm, shape_pn, # outputs
54-
BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr):
55-
tl.static_assert(BLOCK_N % 32 == 0)
53+
def _sum_bitmatrix_rows(B, shape_bm, stride_bm: tl.constexpr, stride_bn: tl.constexpr, # input bitmatrix
54+
Ret, Partials, stride_pm: tl.constexpr, stride_pn, shape_pn, # outputs
55+
BLOCK_MM: tl.constexpr, BLOCK_M: tl.constexpr):
56+
57+
tl.static_assert(BLOCK_MM % BLOCK_M == 0)
58+
TILE_SIZE: tl.constexpr = BLOCK_MM // BLOCK_M
5659
pid_m = tl.program_id(0)
5760
pid_n = tl.program_id(1)
58-
BLOCK_B: tl.constexpr = BLOCK_N // 32
59-
offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
60-
offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
61-
offs_b = pid_n * BLOCK_B + tl.arange(0, BLOCK_B)
62-
bits = tl.load(B + offs_m[None, :] * stride_bm + offs_b[:, None], mask=offs_m[None, :] < shape_bm)
63-
ret = tl.reshape(vpopc(bits), [BLOCK_N])
64-
mask = offs_n < shape_pn
65-
tl.atomic_add(Ret + offs_n, ret, mask=mask, sem="relaxed")
66-
tl.store(Partials + pid_m * stride_pm + offs_n, ret, mask=mask)
67-
68-
69-
def sum_bitmatrix_rows(x, out_ret, out_partials, partials_block_size=None):
61+
offs_m = pid_m * BLOCK_MM + tl.arange(0, BLOCK_MM)
62+
offs_n = pid_n * 32 + tl.arange(0, 32)
63+
bits = tl.load(B + pid_n * stride_bn + offs_m * stride_bm, mask=offs_m < shape_bm, other=0)
64+
bits = tl.reshape(bits, [TILE_SIZE, BLOCK_M])
65+
ret = vpopc(bits) # [TILE_SIZE, 32]
66+
67+
offs_t = pid_m * TILE_SIZE + tl.arange(0, TILE_SIZE)
68+
69+
tl.atomic_add(Ret + offs_n, tl.sum(ret, 0), sem="relaxed")
70+
tl.store(Partials + offs_t[:, None] * stride_pm + offs_n[None, :] * stride_pn, ret)
71+
72+
73+
def clear_sums(n_cols, device, MEMSET_BLOCK=512):
74+
cdiv = triton.cdiv
75+
blocks = cdiv(n_cols, MEMSET_BLOCK)
76+
out_ret = torch.empty((blocks * MEMSET_BLOCK, ), device=device, dtype=torch.int32)
77+
_sum_bitmatrix_memset[(blocks, )](out_ret, MEMSET_BLOCK)
78+
return out_ret
79+
80+
81+
def sum_bitmatrix_rows(x, out_ret, partials_block_size=None):
7082
assert partials_block_size is not None
7183
cdiv = triton.cdiv
7284
PARTIALS_BLOCK_M = partials_block_size
73-
BLOCK_N = 32
74-
MEMSET_BLOCK = 512
7585
n_rows, n_cols = x.shape
7686
assert out_ret.shape == (n_cols, )
77-
assert out_partials.shape == (cdiv(n_rows, PARTIALS_BLOCK_M), n_cols)
87+
88+
TILE_SIZE = 2
89+
BLOCK_MM = PARTIALS_BLOCK_M * TILE_SIZE
90+
91+
pids_x = cdiv(n_rows, BLOCK_MM)
92+
pids_y = cdiv(n_cols, 32)
93+
out_partials = torch.empty((pids_y * 32, pids_x * TILE_SIZE), device=out_ret.device, dtype=torch.int32)
94+
out_partials = torch.transpose(out_partials, 0, 1)
95+
7896
# output tensors
79-
_sum_bitmatrix_memset[(cdiv(out_ret.shape[0], MEMSET_BLOCK), )](
80-
out_ret, out_ret.shape[0], # outputs
81-
BLOCK=512 # tunable parameter
82-
)
83-
_sum_bitmatrix_rows[(cdiv(n_rows, PARTIALS_BLOCK_M), cdiv(n_cols, BLOCK_N))](
84-
x.data, x.data.shape[0], x.data.stride(0), # input
97+
_sum_bitmatrix_rows[(pids_x, pids_y)](
98+
x.data, x.data.shape[0], x.data.stride(0), x.data.stride(1), # input
8599
out_ret, # output [final reduction]
86-
out_partials, out_partials.stride(0), out_partials.shape[1], # output [partial reductions]
87-
BLOCK_N=BLOCK_N, # tunable parameters
88-
BLOCK_M=PARTIALS_BLOCK_M, # constants
89-
)
100+
out_partials, out_partials.stride(0), out_partials.stride(1),
101+
out_partials.shape[1], # output [partial reductions]
102+
BLOCK_M=PARTIALS_BLOCK_M, BLOCK_MM=BLOCK_MM, # constants
103+
num_warps=8)
104+
105+
out_partials = out_partials[:cdiv(n_rows, PARTIALS_BLOCK_M), :n_cols]
106+
90107
return out_ret, out_partials

python/triton_kernels/triton_kernels/routing.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,6 @@ def n_blocks(self, n_rows, block_m):
5555

5656
def routing(logits, n_expts_act, expt_indx=None, simulated_ep=1):
5757
from .topk import topk
58-
from .reduction import sum
5958
from .compaction import compaction
6059
assert expt_indx is None
6160
cdiv = triton.cdiv
@@ -72,6 +71,7 @@ def routing(logits, n_expts_act, expt_indx=None, simulated_ep=1):
7271
_routing_clear_bitmatrix[(n_tokens, )](
7372
bitmatrix.data,
7473
bitmatrix.data.stride(0),
74+
bitmatrix.data.stride(1),
7575
bitmatrix.data.shape[1],
7676
n_expts_tot // simulated_ep,
7777
BLOCK_N=512,
@@ -80,10 +80,9 @@ def routing(logits, n_expts_act, expt_indx=None, simulated_ep=1):
8080
n_expts_tot = n_expts_tot // simulated_ep
8181
bitmatrix.shape[-1] = n_expts_tot
8282
# perform compaction to update expt_scal / expt_indx
83-
hist, partial_hist = sum(bitmatrix, partials_block_size=HIST_BLOCK_M, dim=0)
83+
hist, partial_hist = bitmatrix.sum(partials_block_size=HIST_BLOCK_M)
8484
# scratchpad
8585
expt_offs = torch.empty(n_expts_tot, dtype=torch.int32, device=device)
86-
indx_offs = torch.empty((cdiv(n_tokens, HIST_BLOCK_M), n_expts_tot), dtype=torch.int32, device=device)
8786
combined_indx = torch.empty(n_gates * 2, dtype=torch.int32, device=device)
8887
# output
8988
topk_indx = combined_indx[:n_gates]
@@ -93,12 +92,14 @@ def routing(logits, n_expts_act, expt_indx=None, simulated_ep=1):
9392
expt_offs, hist.shape[0], BLOCK_N=512)
9493
_routing_compute_indx_offs[(n_expts_tot, )](
9594
expt_offs, partial_hist, # inputs
96-
indx_offs, partial_hist.shape[0], partial_hist.stride(0), # outputs
95+
partial_hist.shape[0], partial_hist.stride(0), partial_hist.stride(1), # outputs
9796
BLOCK_M=INDX_OFFS_BLOCK_M, # tunable parameters
9897
)
98+
indx_offs = partial_hist
99+
99100
_routing_compute_indx[(cdiv(n_tokens, HIST_BLOCK_M), )](
100101
topk_indx, gate_indx, gate_scal, # outputs
101-
expt_scal, expt_indx, indx_offs, indx_offs.stride(0), n_gates, # input
102+
expt_scal, expt_indx, indx_offs, indx_offs.stride(0), indx_offs.stride(1), n_gates, # input
102103
BLOCK_M=HIST_BLOCK_M, # tunable parameters
103104
N_EXPTS_ACT=n_expts_act, # constants
104105
num_warps=1 if HIST_BLOCK_M * n_expts_act // 32 < 4 else 4)

python/triton_kernels/triton_kernels/routing_details/_routing_compute.py

Lines changed: 9 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -18,21 +18,19 @@ def _routing_compute_expt_offs(ExpertHist, FinalExpertOffs, hist_size, # histog
1818

1919

2020
@triton.jit
21-
def _routing_compute_indx_offs(TokensStart, PartialHist, PartialOffs, shape_pm, stride_pm, BLOCK_M: tl.constexpr):
21+
def _routing_compute_indx_offs(TokensStart, PartialHist, shape_pm, stride_pm, stride_pn, BLOCK_M: tl.constexpr):
2222
expt_id = tl.program_id(0)
2323
offs_m = tl.arange(0, BLOCK_M)
2424
# initialize first row of the output
2525
start = tl.load(TokensStart + expt_id)
26-
tl.store(PartialOffs + expt_id, start)
2726
# iterate over input data
2827
curr_sum = start
2928
for _ in range(0, shape_pm, BLOCK_M):
30-
offs = offs_m * stride_pm + expt_id
29+
offs = offs_m * stride_pm + expt_id * stride_pn
3130
curr = tl.load(PartialHist + offs, mask=offs_m < shape_pm)
3231
out = tl.cumsum(curr, 0) + curr_sum
3332
curr_sum += tl.sum(curr, 0)
34-
offs = (1 + offs_m) * stride_pm + expt_id
35-
tl.store(PartialOffs + offs, out, mask=offs_m < shape_pm - 1)
33+
tl.store(PartialHist + offs, out - curr, mask=offs_m < shape_pm)
3634
offs_m += BLOCK_M
3735

3836

@@ -49,8 +47,8 @@ def _keyed_add(x, y):
4947

5048

5149
@triton.jit
52-
def _routing_compute_indx(GatherIndx, ScatterIndx, GateScal, ExptScal, ExptIndx, PartialOffs, stride_pm, n_gates,
53-
BLOCK_M: tl.constexpr, N_EXPTS_ACT: tl.constexpr):
50+
def _routing_compute_indx(GatherIndx, ScatterIndx, GateScal, ExptScal, ExptIndx, PartialOffs, stride_pm, stride_pn,
51+
n_gates, BLOCK_M: tl.constexpr, N_EXPTS_ACT: tl.constexpr):
5452

5553
pid_m = tl.program_id(0)
5654

@@ -73,7 +71,7 @@ def _routing_compute_indx(GatherIndx, ScatterIndx, GateScal, ExptScal, ExptIndx,
7371
expts_and_inclusive_run_lengths = tl.associative_scan(x, 0, _keyed_add)
7472
exclusive_run_lengths = (expts_and_inclusive_run_lengths - 1) & 0xffff
7573

76-
gates = tl.load(PartialOffs + pid_m * stride_pm + expert, mask=(expert != 0xffff))
74+
gates = tl.load(PartialOffs + pid_m * stride_pm + expert * stride_pn, mask=(expert != 0xffff))
7775
gates += exclusive_run_lengths
7876

7977
tl.store(ScatterIndx + offs, gates, mask=mask)
@@ -82,17 +80,17 @@ def _routing_compute_indx(GatherIndx, ScatterIndx, GateScal, ExptScal, ExptIndx,
8280

8381

8482
@triton.jit
85-
def _routing_clear_bitmatrix(Bitmatrix, stride_bm, shape_bn, cutoff, BLOCK_N: tl.constexpr):
83+
def _routing_clear_bitmatrix(Bitmatrix, stride_bm, stride_bn, shape_bn, cutoff, BLOCK_N: tl.constexpr):
8684
pid_m = tl.program_id(0)
8785
cutoff_word = cutoff // 32
8886
cutoff_bit = cutoff % 32
8987
cutoff_mask = (1 << (cutoff_bit)) - 1
9088
for start_n in range(0, shape_bn, BLOCK_N):
9189
offs_n = start_n + tl.arange(0, BLOCK_N)
92-
values = tl.load(Bitmatrix + pid_m * stride_bm + offs_n, mask=offs_n < shape_bn)
90+
values = tl.load(Bitmatrix + pid_m * stride_bm + offs_n * stride_bn, mask=offs_n < shape_bn)
9391
values = tl.where(offs_n == cutoff_word, values & cutoff_mask, values)
9492
values = tl.where(offs_n > cutoff_word, 0, values)
95-
tl.store(Bitmatrix + pid_m * stride_bm + offs_n, values, mask=offs_n < shape_bn)
93+
tl.store(Bitmatrix + pid_m * stride_bm + offs_n * stride_bn, values, mask=offs_n < shape_bn)
9694

9795

9896
@triton.jit
Lines changed: 18 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,13 @@
11
import torch
22
from .topk_details._topk import _topk
3-
from triton_kernels import Bitmatrix
3+
from .bitmatrix import Bitmatrix
44

55

66
def topk(x, k, dim=1, return_bitmatrix=True):
77
cdiv = lambda a, b: (a + b - 1) // b
8-
BLOCK_M = 8
9-
BLOCK_N = 128
8+
BLOCK_M = 32
9+
BLOCK_N = 32
10+
BLOCK_S = 128
1011
assert x.ndim == 2
1112
assert x.shape[-1] < 32768
1213
assert dim == 1
@@ -19,13 +20,23 @@ def topk(x, k, dim=1, return_bitmatrix=True):
1920
# NOTE: these are not returned
2021
y_vals = torch.empty((n_rows, k), dtype=x.dtype, device=dev)
2122
y_indx = torch.empty((n_rows, k), dtype=torch.int16, device=dev)
22-
bitmatrix = torch.empty((n_rows, n_cols_words), dtype=torch.uint32, device=dev)
23-
_topk[(cdiv(n_rows, BLOCK_M), )](
23+
24+
# create bitmatrix in transposed memory layout:
25+
bitmatrix = torch.empty((n_cols_words, cdiv(n_rows, 32) * 32), dtype=torch.uint32, device=dev)
26+
bitmatrix = torch.transpose(bitmatrix, 0, 1)[:n_rows]
27+
s_blocks = cdiv(n_cols, BLOCK_S)
28+
s_cols = s_blocks * BLOCK_S
29+
S = torch.empty((s_cols, ), dtype=torch.int32, device=dev)
30+
31+
pids = max(cdiv(n_rows, BLOCK_M), s_blocks)
32+
33+
_topk[(pids, )](
2434
x, x.stride(0), # inputs
2535
y_vals, y_indx, y_vals.stride(0), # output [topk]
26-
bitmatrix, bitmatrix.stride(0), # output [bitmatrix]
36+
bitmatrix, bitmatrix.stride(0), bitmatrix.stride(1), # output [bitmatrix]
2737
n_rows, n_cols, # shapes
38+
S, BLOCK_S, s_blocks, # thing to memset to zero
2839
BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, # tunable parameter
2940
N_EXPTS_PAD=n_cols_pad, N_EXPTS_ACT=k, # constants
3041
)
31-
return y_vals, y_indx, Bitmatrix(bitmatrix, [n_rows, n_cols])
42+
return y_vals, y_indx, Bitmatrix(bitmatrix, [n_rows, n_cols], S)

python/triton_kernels/triton_kernels/topk_details/_topk.py

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -40,9 +40,18 @@ def streaming_topk(X, stride_xm, n_expts_tot, offs_m, mask_m, N_EXPTS_PAD: tl.co
4040
@triton.jit
4141
def _topk(X, stride_xm, # inputs
4242
Yv, Yi, stride_ym, # topk values/indices
43-
Bits, stride_rm, n_rows, # bitmatrix
44-
n_expts_tot, BLOCK_M: tl.constexpr, N_EXPTS_PAD: tl.constexpr, N_EXPTS_ACT: tl.constexpr,
45-
BLOCK_N: tl.constexpr):
43+
Bits, stride_rm: tl.constexpr, stride_rn: tl.constexpr, n_rows, # bitmatrix
44+
n_expts_tot, S, BLOCK_S: tl.constexpr, s_blocks, # thing to memset
45+
BLOCK_M: tl.constexpr, N_EXPTS_PAD: tl.constexpr, N_EXPTS_ACT: tl.constexpr, BLOCK_N: tl.constexpr):
46+
47+
pid = tl.program_id(0)
48+
49+
if pid < s_blocks:
50+
tl.store(S + BLOCK_S * pid + tl.arange(0, BLOCK_S), tl.zeros([BLOCK_S], tl.int32))
51+
52+
if pid * BLOCK_M >= n_rows:
53+
# early exit:
54+
return
4655

4756
tl.static_assert(BLOCK_N % 32 == 0)
4857
tl.static_assert(N_EXPTS_PAD % BLOCK_N == 0)
@@ -52,7 +61,7 @@ def _topk(X, stride_xm, # inputs
5261
x_ultype: tl.constexpr = tl.dtype(f"uint{2*x_nbits}")
5362

5463
# load logits
55-
offs_m = tl.program_id(0) * BLOCK_M + tl.arange(0, BLOCK_M)
64+
offs_m = pid * BLOCK_M + tl.arange(0, BLOCK_M)
5665
mask_m = offs_m[:, None] < n_rows
5766
y = streaming_topk(X, stride_xm, n_expts_tot, offs_m, mask_m, N_EXPTS_PAD, N_EXPTS_ACT, BLOCK_N)
5867
y = y.to(x_ultype, bitcast=True)
@@ -79,5 +88,5 @@ def _topk(X, stride_xm, # inputs
7988
offs_r_n = tl.arange(0, BLOCK_N // 32) + i * (BLOCK_N // 32)
8089
y2 = tl.where(y_div[:, :, None] == offs_r_n[None, None, :], (1 << y_rem)[:, :, None], 0)
8190
r = tl.reduce_or(y2, axis=1)
82-
BitsPtrs = Bits + offs_m[:, None] * stride_rm + offs_r_n[None, :]
91+
BitsPtrs = Bits + offs_m[:, None] * stride_rm + offs_r_n[None, :] * stride_rn
8392
tl.store(BitsPtrs, r, mask=mask_m)

0 commit comments

Comments
 (0)