Skip to content

Commit 552900a

Browse files
author
sangchengmeng
committed
Merge remote-tracking branch 'origin/main' into grouped_topk_cuda
2 parents 0a900c8 + c181e7a commit 552900a

File tree

6 files changed

+412
-35
lines changed

6 files changed

+412
-35
lines changed
Lines changed: 231 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,231 @@
1+
# adopt from https://github.com/triton-lang/triton/issues/3698#issuecomment-2067681396
2+
import torch
3+
import triton
4+
import triton.language as tl
5+
from triton.language.standard import _log2, sum, zeros_like
6+
7+
8+
@triton.jit
9+
def _compare_and_swap(x, ids, flip, i: tl.core.constexpr, n_dims: tl.core.constexpr):
10+
n_outer: tl.core.constexpr = x.numel >> n_dims
11+
shape: tl.core.constexpr = [n_outer * 2 ** i, 2, 2 ** (n_dims - i - 1)]
12+
y = tl.core.reshape(x, shape)
13+
# slice left/right with 'stride' 2**(n_dims - i - 1)
14+
mask = tl.core.arange(0, 2)[None, :, None]
15+
left = tl.core.broadcast_to(sum(y * (1 - mask), 1)[:, None, :], shape)
16+
right = tl.core.broadcast_to(sum(y * mask, 1)[:, None, :], shape)
17+
left = tl.core.reshape(left, x.shape)
18+
right = tl.core.reshape(right, x.shape)
19+
20+
# idx
21+
y_idx = tl.core.reshape(ids, shape)
22+
left_idx = tl.core.broadcast_to(sum(y_idx * (1 - mask), 1)[:, None, :], shape)
23+
right_idx = tl.core.broadcast_to(sum(y_idx * mask, 1)[:, None, :], shape)
24+
left_idx = tl.core.reshape(left_idx, x.shape)
25+
right_idx = tl.core.reshape(right_idx, x.shape)
26+
27+
# actual compare-and-swap
28+
idtype = tl.core.get_int_dtype(bitwidth=x.dtype.primitive_bitwidth, signed=True)
29+
ileft = left.to(idtype, bitcast=True)
30+
iright = right.to(idtype, bitcast=True)
31+
ix = x.to(idtype, bitcast=True)
32+
33+
cond = (left > right) ^ flip
34+
35+
ret = ix ^ tl.core.where(cond, ileft ^ iright, zeros_like(ix))
36+
37+
new_ids = ids ^ tl.core.where(cond, left_idx ^ right_idx, zeros_like(ids))
38+
39+
return ret.to(x.dtype, bitcast=True), new_ids
40+
41+
42+
@triton.jit
43+
def _bitonic_merge(x, ids, stage: tl.core.constexpr, order: tl.core.constexpr, n_dims: tl.core.constexpr):
44+
"""
45+
order_type 0 == ascending
46+
order_type 1 == descending
47+
order_type 2 == alternating
48+
"""
49+
n_outer: tl.core.constexpr = x.numel >> n_dims
50+
tl.core.static_assert(stage <= n_dims)
51+
# flip denotes whether to re-arrange sub-sequences of elements in ascending or
52+
# descending order.
53+
# if flip = 00000000... then all elements will be re-arranged ascendingly at this stage
54+
# if flip = 00110011... then all the elements will be re-arranged alternatingly (with
55+
# a stride of 2) at this stage
56+
if order == 2:
57+
shape: tl.core.constexpr = [n_outer * 2 ** (n_dims - 1 - stage), 2, 2 ** stage]
58+
flip = tl.core.reshape(tl.core.broadcast_to(tl.core.arange(0, 2)[None, :, None], shape), x.shape)
59+
else:
60+
flip = order
61+
# perform `stage` rounds of `compare-and-swap`
62+
for i in tl.core.static_range(stage):
63+
x, ids = _compare_and_swap(x, ids, flip, i + (n_dims - stage), n_dims)
64+
return x, ids
65+
66+
67+
@triton.jit
68+
def argsort(x, ids, dim: tl.core.constexpr = None, descending: tl.core.constexpr = tl.core.CONSTEXPR_0):
69+
# handle default dimension or check that it is the most minor dim
70+
_dim: tl.core.constexpr = len(x.shape) - 1 if dim is None else dim
71+
tl.core.static_assert(_dim == len(x.shape) - 1, "only minor dimension is currently supported")
72+
# iteratively run bitonic merge-sort steps
73+
n_dims: tl.core.constexpr = _log2(x.shape[_dim])
74+
75+
for i in tl.core.static_range(1, n_dims + 1):
76+
x, ids = _bitonic_merge(x, ids, i, 2 if i < n_dims else descending, n_dims)
77+
return x, ids
78+
79+
80+
@triton.jit
81+
def grouped_topk_kernel(
82+
gating_output_ptr,
83+
gating_output_stride_m,
84+
gating_output_stride_n,
85+
correction_bias_ptr,
86+
scores_buffer_ptr, # [token_num, total_expert_num]
87+
scores_stride_m,
88+
scores_stride_n,
89+
scores_stride_token_m,
90+
scores_stride_group,
91+
scores_stride_group_v,
92+
out_topk_weights,
93+
out_topk_weights_stride_m,
94+
out_topk_weights_stride_n,
95+
out_topk_ids,
96+
out_topk_ids_stride_m,
97+
out_topk_ids_stride_n,
98+
group_num,
99+
group_expert_num,
100+
total_expert_num, # group_num * group_expert_num == total_expert_num
101+
topk_num,
102+
group_topk_num,
103+
IS_SIGMOID: tl.constexpr,
104+
HAS_CORRECTION_BIAS: tl.constexpr,
105+
EXPERT_BLOCK_SIZE: tl.constexpr, # tl.next_power_two_of(total_expert_num)
106+
EXPERT_GROUP_NUM: tl.constexpr, # tl.next_power_two_of(group_num)
107+
EXPERT_GROUP_SIZE: tl.constexpr, # tl.next_power_two_of(group_expert_num)
108+
RENORMALIZE: tl.constexpr,
109+
):
110+
token_index = tl.program_id(axis=0)
111+
offs_n = tl.arange(0, EXPERT_BLOCK_SIZE)
112+
hidden_states = tl.load(
113+
gating_output_ptr + token_index * gating_output_stride_m + offs_n,
114+
mask=offs_n < total_expert_num,
115+
other=-10000000.0,
116+
)
117+
if IS_SIGMOID:
118+
scores = tl.sigmoid(hidden_states)
119+
else:
120+
scores = tl.softmax(hidden_states)
121+
122+
if HAS_CORRECTION_BIAS:
123+
scores += tl.load(correction_bias_ptr + offs_n, mask=offs_n < total_expert_num, other=-10000000.0)
124+
125+
offs_group = tl.arange(0, EXPERT_GROUP_NUM)
126+
offs_group_v = tl.arange(0, EXPERT_GROUP_SIZE)
127+
tl.store(scores_buffer_ptr + scores_stride_m * token_index + offs_n, scores, mask=offs_n < total_expert_num)
128+
group_scores = tl.load(
129+
scores_buffer_ptr
130+
+ scores_stride_token_m * token_index
131+
+ offs_group[:, None] * scores_stride_group
132+
+ offs_group_v[None, :] * scores_stride_group_v,
133+
mask=(offs_group < group_num)[:, None] & (offs_group_v < group_expert_num)[None, :],
134+
other=-10000000.0,
135+
) # [group, group_size]
136+
137+
group_value = tl.max(group_scores, axis=1) # [group,]
138+
sorted_group_value = tl.sort(group_value, descending=True)
139+
group_topk_value = tl.sum(tl.where(offs_group == group_topk_num - 1, sorted_group_value, 0.0))
140+
mask_group_scores = tl.where(
141+
((group_value >= group_topk_value)[:, None]) & ((offs_group_v < group_expert_num)[None, :]),
142+
group_scores,
143+
-10000000.0,
144+
)
145+
146+
tl.store(
147+
scores_buffer_ptr
148+
+ scores_stride_token_m * token_index
149+
+ offs_group[:, None] * scores_stride_group
150+
+ offs_group_v[None, :] * scores_stride_group_v,
151+
mask_group_scores,
152+
mask=((offs_group < group_num)[:, None]) & ((offs_group_v < group_expert_num)[None, :]),
153+
) # [group, group_size]
154+
155+
mask_scores = tl.load(
156+
scores_buffer_ptr + scores_stride_m * token_index + offs_n, mask=offs_n < total_expert_num, other=-10000000.0
157+
)
158+
sorted_scores, sorted_indexes = argsort(mask_scores, offs_n, descending=True)
159+
160+
if RENORMALIZE:
161+
sum_scores = tl.sum(tl.where(offs_n < topk_num, sorted_scores, 0.0))
162+
renormlize_scores = sorted_scores / sum_scores
163+
164+
tl.store(
165+
out_topk_weights + token_index * out_topk_weights_stride_m + offs_n,
166+
renormlize_scores,
167+
mask=offs_n < topk_num,
168+
)
169+
tl.store(out_topk_ids + token_index * out_topk_ids_stride_m + offs_n, sorted_indexes, mask=offs_n < topk_num)
170+
else:
171+
tl.store(
172+
out_topk_weights + token_index * out_topk_weights_stride_m + offs_n, sorted_scores, mask=offs_n < topk_num
173+
)
174+
tl.store(out_topk_ids + token_index * out_topk_ids_stride_m + offs_n, sorted_indexes, mask=offs_n < topk_num)
175+
return
176+
177+
178+
def triton_grouped_topk(
179+
hidden_states: torch.Tensor,
180+
gating_output: torch.Tensor,
181+
correction_bias: torch.Tensor,
182+
topk: int,
183+
renormalize: bool,
184+
num_expert_group: int = 0,
185+
topk_group: int = 0,
186+
scoring_func: str = "softmax",
187+
):
188+
189+
if correction_bias is not None:
190+
has_correction_bias = True
191+
else:
192+
has_correction_bias = False
193+
194+
token_num, total_expert_num = gating_output.shape
195+
if gating_output.dtype == torch.float64:
196+
dtype = torch.float64
197+
else:
198+
dtype = torch.float32
199+
200+
scores_buffer = torch.empty((token_num, total_expert_num), dtype=dtype, device="cuda")
201+
out_topk_weights = torch.empty((token_num, topk), dtype=torch.float32, device="cuda")
202+
out_topk_ids = torch.empty((token_num, topk), dtype=torch.int32, device="cuda")
203+
204+
assert total_expert_num % num_expert_group == 0
205+
206+
grouped_topk_kernel[(token_num,)](
207+
gating_output,
208+
*gating_output.stride(),
209+
correction_bias,
210+
scores_buffer,
211+
*scores_buffer.stride(),
212+
*scores_buffer.view(token_num, num_expert_group, -1).stride(),
213+
out_topk_weights,
214+
*out_topk_weights.stride(),
215+
out_topk_ids,
216+
*out_topk_ids.stride(),
217+
group_num=num_expert_group,
218+
group_expert_num=total_expert_num // num_expert_group,
219+
total_expert_num=total_expert_num,
220+
topk_num=topk,
221+
group_topk_num=topk_group,
222+
IS_SIGMOID=scoring_func == "sigmoid",
223+
HAS_CORRECTION_BIAS=has_correction_bias,
224+
EXPERT_BLOCK_SIZE=triton.next_power_of_2(total_expert_num),
225+
EXPERT_GROUP_NUM=triton.next_power_of_2(num_expert_group),
226+
EXPERT_GROUP_SIZE=triton.next_power_of_2(total_expert_num // num_expert_group),
227+
RENORMALIZE=renormalize,
228+
num_warps=1,
229+
num_stages=1,
230+
)
231+
return out_topk_weights, out_topk_ids

lightllm/common/fused_moe/topk_select.py

Lines changed: 24 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424

2525
use_cuda_grouped_topk = os.environ.get("GROUPED_TOPK_CUDA", "false").lower()
2626

27+
2728
def fused_topk(
2829
hidden_states: torch.Tensor,
2930
gating_output: torch.Tensor,
@@ -63,7 +64,7 @@ def grouped_topk(
6364
topk_group: int = 0,
6465
scoring_func: str = "softmax",
6566
):
66-
67+
6768
assert hidden_states.shape[0] == gating_output.shape[0], "Number of tokens mismatch"
6869
if scoring_func == "sigmoid":
6970
scores = torch.sigmoid(gating_output)
@@ -91,8 +92,9 @@ def grouped_topk(
9192

9293
return topk_weights.to(torch.float32), topk_ids.to(torch.int32)
9394

95+
9496
# This is used by the Deepseek-V2 model
95-
def grouped_topk_cuda(
97+
def cuda_grouped_topk(
9698
hidden_states: torch.Tensor,
9799
gating_output: torch.Tensor,
98100
correction_bias: torch.Tensor,
@@ -105,27 +107,26 @@ def grouped_topk_cuda(
105107

106108
assert hidden_states.shape[0] == gating_output.shape[0], "Number of tokens mismatch"
107109
num_tokens = gating_output.shape[0]
108-
num_experts = gating_output.shape[-1]
109110
topk_weights = torch.empty(num_tokens, topk, device=hidden_states.device, dtype=torch.float32)
110111
topk_indices = torch.empty(num_tokens, topk, device=hidden_states.device, dtype=torch.int32)
111112
token_expert_indices = torch.empty(num_tokens, topk_group, device=hidden_states.device, dtype=torch.int32)
112-
group_scores = torch.empty(num_tokens, num_expert_group, device=hidden_states.device, dtype=torch.float32)
113-
if correction_bias is None:
114-
correction_bias = torch.zeros_like(gating_output,dtype=torch.float32)
113+
group_scores = torch.empty(num_tokens, num_expert_group, device=hidden_states.device, dtype=torch.float32)
114+
if correction_bias is None:
115+
correction_bias = torch.zeros_like(gating_output, dtype=torch.float32)
115116
ops.grouped_topk(
116-
topk_weights,
117-
correction_bias,
118-
topk_indices,
119-
token_expert_indices,
120-
gating_output.float(),
121-
num_expert_group,
122-
topk_group,
123-
topk,
124-
renormalize,
125-
scoring_func,
126-
group_scores
117+
topk_weights,
118+
correction_bias,
119+
topk_indices,
120+
token_expert_indices,
121+
gating_output.float(),
122+
num_expert_group,
123+
topk_group,
124+
topk,
125+
renormalize,
126+
scoring_func,
127+
group_scores,
127128
)
128-
129+
129130
return topk_weights, topk_indices
130131

131132

@@ -141,14 +142,15 @@ def select_experts(
141142
scoring_func: str = "softmax",
142143
custom_routing_function: Optional[Callable] = None,
143144
):
144-
from lightllm.common.fused_moe.topk_select import fused_topk, grouped_topk
145+
from lightllm.common.fused_moe.topk_select import fused_topk
146+
from lightllm.common.fused_moe.grouped_topk import triton_grouped_topk
147+
145148
# DeekSeekv2 uses grouped_top_k
146149
if use_grouped_topk:
147150
assert topk_group is not None
148151
assert num_expert_group is not None
149152
if use_cuda_grouped_topk == "true":
150-
from lightllm.common.vllm_kernel import _custom_ops as ops
151-
topk_weights, topk_ids = grouped_topk_cuda(
153+
topk_weights, topk_ids = cuda_grouped_topk(
152154
hidden_states=hidden_states,
153155
gating_output=router_logits,
154156
correction_bias=correction_bias,
@@ -159,7 +161,7 @@ def select_experts(
159161
scoring_func=scoring_func,
160162
)
161163
else:
162-
topk_weights, topk_ids = grouped_topk(
164+
topk_weights, topk_ids = triton_grouped_topk(
163165
hidden_states=hidden_states,
164166
gating_output=router_logits,
165167
correction_bias=correction_bias,

lightllm/common/vllm_kernel/_ops.py

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -760,6 +760,7 @@ def topk_softmax(
760760
) -> None:
761761
torch.ops.vllm_moe.topk_softmax(topk_weights, topk_ids, token_expert_indicies, gating_output)
762762

763+
763764
def grouped_topk(
764765
topk_weights: torch.Tensor,
765766
correction_bias: torch.Tensor,
@@ -771,13 +772,23 @@ def grouped_topk(
771772
topk: int,
772773
renormalize: bool,
773774
scoring_func: str,
774-
group_scores: torch.Tensor = None
775+
group_scores: torch.Tensor = None,
775776
) -> None:
776777
torch.ops.vllm_moe.grouped_topk(
777-
topk_weights, correction_bias, topk_indices, group_indices, gating_output, num_expert_group,
778-
topk_group, topk, renormalize, scoring_func, group_scores
778+
topk_weights,
779+
correction_bias,
780+
topk_indices,
781+
group_indices,
782+
gating_output,
783+
num_expert_group,
784+
topk_group,
785+
topk,
786+
renormalize,
787+
scoring_func,
788+
group_scores,
779789
)
780790

791+
781792
def reshape_and_cache(
782793
key: torch.Tensor,
783794
value: torch.Tensor,

lightllm/models/deepseek2/layer_infer/transformer_layer_infer.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -154,7 +154,7 @@ def _decompress_kv(self, kv, infer_state: Deepseek2InferStateInfo, layer_weight:
154154
)
155155

156156
# CC
157-
compressed_kv = compressed_kv.view(-1, layer_weight.kv_lora_rank)
157+
compressed_kv = compressed_kv.view(-1, layer_weight.kv_lora_rank).contiguous()
158158
k_nope = self.alloc_tensor(
159159
[compressed_kv.shape[0], self.tp_q_head_num_, self.qk_nope_head_dim],
160160
dtype=compressed_kv.dtype,
@@ -163,10 +163,8 @@ def _decompress_kv(self, kv, infer_state: Deepseek2InferStateInfo, layer_weight:
163163
k_nope.shape,
164164
dtype=compressed_kv.dtype,
165165
)
166-
wk = layer_weight.k_b_proj_.weight.view(-1, layer_weight.kv_lora_rank).T
167-
wv = layer_weight.v_b_proj_.weight.transpose(0, 1).reshape(layer_weight.kv_lora_rank, -1)
168-
torch.mm(compressed_kv, wk, out=k_nope.reshape(compressed_kv.shape[0], -1))
169-
torch.mm(compressed_kv, wv, out=v.reshape(compressed_kv.shape[0], -1))
166+
layer_weight.cc_k_b_proj_.mm(compressed_kv, out=k_nope.reshape(compressed_kv.shape[0], -1))
167+
layer_weight.cc_v_b_proj_.mm(compressed_kv, out=v.reshape(compressed_kv.shape[0], -1))
170168
return k_nope, k_rope, v
171169

172170
def _context_attention_kernel_with_CC(

0 commit comments

Comments
 (0)