Skip to content

Commit 383e5e2

Browse files
committed
add triton grouped_topk
1 parent ca7d2f9 commit 383e5e2

File tree

2 files changed

+329
-0
lines changed

2 files changed

+329
-0
lines changed
Lines changed: 231 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,231 @@
1+
# adopt from https://github.com/triton-lang/triton/issues/3698#issuecomment-2067681396
2+
import torch
3+
import triton
4+
import triton.language as tl
5+
from triton.language.standard import _log2, sum, zeros_like
6+
7+
8+
@triton.jit
9+
def _compare_and_swap(x, ids, flip, i: tl.core.constexpr, n_dims: tl.core.constexpr):
10+
n_outer: tl.core.constexpr = x.numel >> n_dims
11+
shape: tl.core.constexpr = [n_outer * 2 ** i, 2, 2 ** (n_dims - i - 1)]
12+
y = tl.core.reshape(x, shape)
13+
# slice left/right with 'stride' 2**(n_dims - i - 1)
14+
mask = tl.core.arange(0, 2)[None, :, None]
15+
left = tl.core.broadcast_to(sum(y * (1 - mask), 1)[:, None, :], shape)
16+
right = tl.core.broadcast_to(sum(y * mask, 1)[:, None, :], shape)
17+
left = tl.core.reshape(left, x.shape)
18+
right = tl.core.reshape(right, x.shape)
19+
20+
# idx
21+
y_idx = tl.core.reshape(ids, shape)
22+
left_idx = tl.core.broadcast_to(sum(y_idx * (1 - mask), 1)[:, None, :], shape)
23+
right_idx = tl.core.broadcast_to(sum(y_idx * mask, 1)[:, None, :], shape)
24+
left_idx = tl.core.reshape(left_idx, x.shape)
25+
right_idx = tl.core.reshape(right_idx, x.shape)
26+
27+
# actual compare-and-swap
28+
idtype = tl.core.get_int_dtype(bitwidth=x.dtype.primitive_bitwidth, signed=True)
29+
ileft = left.to(idtype, bitcast=True)
30+
iright = right.to(idtype, bitcast=True)
31+
ix = x.to(idtype, bitcast=True)
32+
33+
cond = (left > right) ^ flip
34+
35+
ret = ix ^ tl.core.where(cond, ileft ^ iright, zeros_like(ix))
36+
37+
new_ids = ids ^ tl.core.where(cond, left_idx ^ right_idx, zeros_like(ids))
38+
39+
return ret.to(x.dtype, bitcast=True), new_ids
40+
41+
42+
@triton.jit
43+
def _bitonic_merge(x, ids, stage: tl.core.constexpr, order: tl.core.constexpr, n_dims: tl.core.constexpr):
44+
"""
45+
order_type 0 == ascending
46+
order_type 1 == descending
47+
order_type 2 == alternating
48+
"""
49+
n_outer: tl.core.constexpr = x.numel >> n_dims
50+
tl.core.static_assert(stage <= n_dims)
51+
# flip denotes whether to re-arrange sub-sequences of elements in ascending or
52+
# descending order.
53+
# if flip = 00000000... then all elements will be re-arranged ascendingly at this stage
54+
# if flip = 00110011... then all the elements will be re-arranged alternatingly (with
55+
# a stride of 2) at this stage
56+
if order == 2:
57+
shape: tl.core.constexpr = [n_outer * 2 ** (n_dims - 1 - stage), 2, 2 ** stage]
58+
flip = tl.core.reshape(tl.core.broadcast_to(tl.core.arange(0, 2)[None, :, None], shape), x.shape)
59+
else:
60+
flip = order
61+
# perform `stage` rounds of `compare-and-swap`
62+
for i in tl.core.static_range(stage):
63+
x, ids = _compare_and_swap(x, ids, flip, i + (n_dims - stage), n_dims)
64+
return x, ids
65+
66+
67+
@triton.jit
68+
def argsort(x, ids, dim: tl.core.constexpr = None, descending: tl.core.constexpr = tl.core.CONSTEXPR_0):
69+
# handle default dimension or check that it is the most minor dim
70+
_dim: tl.core.constexpr = len(x.shape) - 1 if dim is None else dim
71+
tl.core.static_assert(_dim == len(x.shape) - 1, "only minor dimension is currently supported")
72+
# iteratively run bitonic merge-sort steps
73+
n_dims: tl.core.constexpr = _log2(x.shape[_dim])
74+
75+
for i in tl.core.static_range(1, n_dims + 1):
76+
x, ids = _bitonic_merge(x, ids, i, 2 if i < n_dims else descending, n_dims)
77+
return x, ids
78+
79+
80+
@triton.jit
81+
def grouped_topk_kernel(
82+
gating_output_ptr,
83+
gating_output_stride_m,
84+
gating_output_stride_n,
85+
correction_bias_ptr,
86+
scores_buffer_ptr, # [token_num, total_expert_num]
87+
scores_stride_m,
88+
scores_stride_n,
89+
scores_stride_token_m,
90+
scores_stride_group,
91+
scores_stride_group_v,
92+
out_topk_weights,
93+
out_topk_weights_stride_m,
94+
out_topk_weights_stride_n,
95+
out_topk_ids,
96+
out_topk_ids_stride_m,
97+
out_topk_ids_stride_n,
98+
group_num,
99+
group_expert_num,
100+
total_expert_num, # group_num * group_expert_num == total_expert_num
101+
topk_num,
102+
group_topk_num,
103+
IS_SIGMOID: tl.constexpr,
104+
HAS_CORRECTION_BIAS: tl.constexpr,
105+
EXPERT_BLOCK_SIZE: tl.constexpr, # tl.next_power_two_of(total_expert_num)
106+
EXPERT_GROUP_NUM: tl.constexpr, # tl.next_power_two_of(group_num)
107+
EXPERT_GROUP_SIZE: tl.constexpr, # tl.next_power_two_of(group_expert_num)
108+
RENORMALIZE: tl.constexpr,
109+
):
110+
token_index = tl.program_id(axis=0)
111+
offs_n = tl.arange(0, EXPERT_BLOCK_SIZE)
112+
hidden_states = tl.load(
113+
gating_output_ptr + token_index * gating_output_stride_m + offs_n,
114+
mask=offs_n < total_expert_num,
115+
other=-10000000.0,
116+
)
117+
if IS_SIGMOID:
118+
scores = tl.sigmoid(hidden_states)
119+
else:
120+
scores = tl.softmax(hidden_states)
121+
122+
if HAS_CORRECTION_BIAS:
123+
scores += tl.load(correction_bias_ptr + offs_n, mask=offs_n < total_expert_num, other=-10000000.0)
124+
125+
offs_group = tl.arange(0, EXPERT_GROUP_NUM)
126+
offs_group_v = tl.arange(0, EXPERT_GROUP_SIZE)
127+
tl.store(scores_buffer_ptr + scores_stride_m * token_index + offs_n, scores, mask=offs_n < total_expert_num)
128+
group_scores = tl.load(
129+
scores_buffer_ptr
130+
+ scores_stride_token_m * token_index
131+
+ offs_group[:, None] * scores_stride_group
132+
+ offs_group_v[None, :] * scores_stride_group_v,
133+
mask=(offs_group < group_num)[:, None] & (offs_group_v < group_expert_num)[None, :],
134+
other=-10000000.0,
135+
) # [group, group_size]
136+
137+
group_value = tl.max(group_scores, axis=1) # [group,]
138+
sorted_group_value = tl.sort(group_value, descending=True)
139+
group_topk_value = tl.sum(tl.where(offs_group == group_topk_num - 1, sorted_group_value, 0.0))
140+
mask_group_scores = tl.where(
141+
((group_value >= group_topk_value)[:, None]) & ((offs_group_v < group_expert_num)[None, :]),
142+
group_scores,
143+
-10000000.0,
144+
)
145+
146+
tl.store(
147+
scores_buffer_ptr
148+
+ scores_stride_token_m * token_index
149+
+ offs_group[:, None] * scores_stride_group
150+
+ offs_group_v[None, :] * scores_stride_group_v,
151+
mask_group_scores,
152+
mask=((offs_group < group_num)[:, None]) & ((offs_group_v < group_expert_num)[None, :]),
153+
) # [group, group_size]
154+
155+
mask_scores = tl.load(
156+
scores_buffer_ptr + scores_stride_m * token_index + offs_n, mask=offs_n < total_expert_num, other=-10000000.0
157+
)
158+
sorted_scores, sorted_indexes = argsort(mask_scores, offs_n, descending=True)
159+
160+
if RENORMALIZE:
161+
sum_scores = tl.sum(tl.where(offs_n < topk_num, sorted_scores, 0.0))
162+
renormlize_scores = sorted_scores / sum_scores
163+
164+
tl.store(
165+
out_topk_weights + token_index * out_topk_weights_stride_m + offs_n,
166+
renormlize_scores,
167+
mask=offs_n < topk_num,
168+
)
169+
tl.store(out_topk_ids + token_index * out_topk_ids_stride_m + offs_n, sorted_indexes, mask=offs_n < topk_num)
170+
else:
171+
tl.store(
172+
out_topk_weights + token_index * out_topk_weights_stride_m + offs_n, sorted_scores, mask=offs_n < topk_num
173+
)
174+
tl.store(out_topk_ids + token_index * out_topk_ids_stride_m + offs_n, sorted_indexes, mask=offs_n < topk_num)
175+
return
176+
177+
178+
def triton_grouped_topk(
179+
hidden_states: torch.Tensor,
180+
gating_output: torch.Tensor,
181+
correction_bias: torch.Tensor,
182+
topk: int,
183+
renormalize: bool,
184+
num_expert_group: int = 0,
185+
topk_group: int = 0,
186+
scoring_func: str = "softmax",
187+
):
188+
189+
if correction_bias is not None:
190+
has_correction_bias = True
191+
else:
192+
has_correction_bias = False
193+
194+
token_num, total_expert_num = gating_output.shape
195+
if gating_output.dtype == torch.float64:
196+
dtype = torch.float64
197+
else:
198+
dtype = torch.float32
199+
200+
scores_buffer = torch.empty((token_num, total_expert_num), dtype=dtype, device="cuda")
201+
out_topk_weights = torch.empty((token_num, topk), dtype=torch.float32, device="cuda")
202+
out_topk_ids = torch.empty((token_num, topk), dtype=torch.int32, device="cuda")
203+
204+
assert total_expert_num % num_expert_group == 0
205+
206+
grouped_topk_kernel[(token_num,)](
207+
gating_output,
208+
*gating_output.stride(),
209+
correction_bias,
210+
scores_buffer,
211+
*scores_buffer.stride(),
212+
*scores_buffer.view(token_num, num_expert_group, -1).stride(),
213+
out_topk_weights,
214+
*out_topk_weights.stride(),
215+
out_topk_ids,
216+
*out_topk_ids.stride(),
217+
group_num=num_expert_group,
218+
group_expert_num=total_expert_num // num_expert_group,
219+
total_expert_num=total_expert_num,
220+
topk_num=topk,
221+
group_topk_num=topk_group,
222+
IS_SIGMOID=scoring_func == "sigmoid",
223+
HAS_CORRECTION_BIAS=has_correction_bias,
224+
EXPERT_BLOCK_SIZE=triton.next_power_of_2(total_expert_num),
225+
EXPERT_GROUP_NUM=triton.next_power_of_2(num_expert_group),
226+
EXPERT_GROUP_SIZE=triton.next_power_of_2(total_expert_num // num_expert_group),
227+
RENORMALIZE=renormalize,
228+
num_warps=1,
229+
num_stages=1,
230+
)
231+
return out_topk_weights, out_topk_ids
Lines changed: 98 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,98 @@
1+
import torch
2+
import time
3+
import pytest
4+
import numpy as np
5+
from lightllm.common.fused_moe.topk_select import grouped_topk
6+
from lightllm.common.fused_moe.grouped_topk import triton_grouped_topk
7+
from lightllm.utils.log_utils import init_logger
8+
9+
logger = init_logger(__name__)
10+
11+
seed = 42
12+
torch.manual_seed(seed)
13+
14+
if torch.cuda.is_available():
15+
torch.cuda.manual_seed(seed)
16+
torch.cuda.manual_seed_all(seed)
17+
18+
19+
@pytest.mark.parametrize(
20+
"expert_num, topk_group, group_num, topk_num, scoring_func, token_num",
21+
[
22+
(*a, b, c)
23+
for a in [(256, 4, 8, 8), (160, 3, 8, 6)]
24+
for b in ["softmax", "sigmoid"]
25+
for c in [1, 8, 256, 1024, 2048, 4096, 8192]
26+
],
27+
)
28+
def test_grouped_topk(expert_num, topk_group, group_num, topk_num, scoring_func, token_num):
29+
print("test", expert_num, topk_group, group_num, topk_num, scoring_func, token_num)
30+
dtype = torch.float32
31+
hidden_state = torch.randn((token_num, 1), dtype=dtype, device="cuda")
32+
gating_output = torch.randn((token_num, expert_num), dtype=dtype, device="cuda") * 10
33+
correction_bias = torch.randn((expert_num,), dtype=dtype, device="cuda")
34+
correction_bias[correction_bias <= 0.0] = 0.0
35+
36+
old_topk_weights, old_topk_ids = grouped_topk(
37+
hidden_state,
38+
gating_output=gating_output,
39+
correction_bias=correction_bias,
40+
topk=topk_num,
41+
renormalize=True,
42+
num_expert_group=group_num,
43+
topk_group=topk_group,
44+
scoring_func=scoring_func,
45+
)
46+
47+
new_topk_weights, new_topk_ids = triton_grouped_topk(
48+
None,
49+
gating_output=gating_output,
50+
correction_bias=correction_bias,
51+
topk=topk_num,
52+
renormalize=True,
53+
num_expert_group=group_num,
54+
topk_group=topk_group,
55+
scoring_func=scoring_func,
56+
)
57+
58+
torch.cuda.synchronize()
59+
start = time.time()
60+
for _ in range(60):
61+
old_topk_weights, old_topk_ids = grouped_topk(
62+
hidden_state,
63+
gating_output=gating_output,
64+
correction_bias=correction_bias,
65+
topk=topk_num,
66+
renormalize=True,
67+
num_expert_group=group_num,
68+
topk_group=topk_group,
69+
scoring_func=scoring_func,
70+
)
71+
torch.cuda.synchronize()
72+
print(f"old cost time {time.time() - start} s")
73+
74+
torch.cuda.synchronize()
75+
start = time.time()
76+
for _ in range(60):
77+
new_topk_weights, new_topk_ids = triton_grouped_topk(
78+
None,
79+
gating_output=gating_output,
80+
correction_bias=correction_bias,
81+
topk=topk_num,
82+
renormalize=True,
83+
num_expert_group=group_num,
84+
topk_group=topk_group,
85+
scoring_func=scoring_func,
86+
)
87+
torch.cuda.synchronize()
88+
print(f"new cost time {time.time() - start} s")
89+
90+
assert torch.equal(torch.sort(old_topk_ids, dim=1)[0], torch.sort(new_topk_ids, dim=1)[0])
91+
assert torch.allclose(
92+
torch.sort(old_topk_weights, dim=1)[0], torch.sort(new_topk_weights, dim=1)[0], atol=1e-4, rtol=0
93+
)
94+
return
95+
96+
97+
if __name__ == "__main__":
98+
pytest.main()

0 commit comments

Comments
 (0)