Skip to content

Commit 3226a86

Browse files
committed
fix group topk
1 parent 918aa00 commit 3226a86

File tree

3 files changed

+25
-4
lines changed

3 files changed

+25
-4
lines changed

lightllm/common/fused_moe/grouped_topk.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -118,6 +118,7 @@ def grouped_topk_kernel(
118118
EXPERT_GROUP_NUM: tl.constexpr, # tl.next_power_two_of(group_num)
119119
EXPERT_GROUP_SIZE: tl.constexpr, # tl.next_power_two_of(group_expert_num)
120120
RENORMALIZE: tl.constexpr,
121+
GROUP_SCORE_USED_TOPK_NUM: tl.constexpr,
121122
):
122123
token_index = tl.program_id(axis=0)
123124
offs_n = tl.arange(0, EXPERT_BLOCK_SIZE)
@@ -148,7 +149,15 @@ def grouped_topk_kernel(
148149
other=-10000000.0,
149150
) # [group, group_size]
150151

151-
group_value = tl.max(group_scores, axis=1) # [group,]
152+
group_value = tl.sum(
153+
tl.where(
154+
(offs_group < group_num)[:, None] & (offs_group_v < GROUP_SCORE_USED_TOPK_NUM)[None, :],
155+
tl.sort(group_scores, dim=1, descending=True),
156+
0.0,
157+
),
158+
axis=1,
159+
)
160+
152161
sorted_group_value = tl.sort(group_value, descending=True)
153162
group_topk_value = tl.sum(tl.where(offs_group == group_topk_num - 1, sorted_group_value, 0.0))
154163
mask_group_scores = tl.where(
@@ -198,6 +207,7 @@ def triton_grouped_topk(
198207
num_expert_group: int = 0,
199208
topk_group: int = 0,
200209
scoring_func: str = "softmax",
210+
group_score_used_topk_num=2,
201211
):
202212

203213
if correction_bias is not None:
@@ -239,6 +249,7 @@ def triton_grouped_topk(
239249
EXPERT_GROUP_NUM=triton.next_power_of_2(num_expert_group),
240250
EXPERT_GROUP_SIZE=triton.next_power_of_2(total_expert_num // num_expert_group),
241251
RENORMALIZE=renormalize,
252+
GROUP_SCORE_USED_TOPK_NUM=group_score_used_topk_num,
242253
num_warps=1,
243254
num_stages=1,
244255
)

lightllm/common/fused_moe/topk_select.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,7 @@ def grouped_topk(
9393
return topk_weights.to(torch.float32), topk_ids.to(torch.int32)
9494

9595

96+
# biased_grouped_topk adapt from sgl-project/sglang/python/sglang/srt/layers/moe/topk.py
9697
def biased_grouped_topk(
9798
hidden_states: torch.Tensor,
9899
gating_output: torch.Tensor,
@@ -196,7 +197,12 @@ def select_experts(
196197
scoring_func=scoring_func,
197198
)
198199
else:
199-
topk_weights, topk_ids = biased_grouped_topk(
200+
group_score_topk_num = 1
201+
# for deepseek v3
202+
if topk_group == 4 and num_expert_group == 8 and top_k == 8:
203+
group_score_topk_num = 2
204+
205+
topk_weights, topk_ids = triton_grouped_topk(
200206
hidden_states=hidden_states,
201207
gating_output=router_logits,
202208
correction_bias=correction_bias,
@@ -205,7 +211,9 @@ def select_experts(
205211
num_expert_group=num_expert_group,
206212
topk_group=topk_group,
207213
scoring_func=scoring_func,
214+
group_score_used_topk_num=group_score_topk_num,
208215
)
216+
209217
elif custom_routing_function is None:
210218
topk_weights, topk_ids = fused_topk(
211219
hidden_states=hidden_states, gating_output=router_logits, topk=top_k, renormalize=renormalize

unit_tests/common/fused_moe/test_grouped_topk.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,8 @@
22
import time
33
import pytest
44
import numpy as np
5-
from lightllm.common.fused_moe.topk_select import grouped_topk
65
from lightllm.common.fused_moe.grouped_topk import triton_grouped_topk
6+
from lightllm.common.fused_moe.topk_select import biased_grouped_topk as grouped_topk
77
from lightllm.utils.log_utils import init_logger
88

99
logger = init_logger(__name__)
@@ -21,7 +21,9 @@
2121
[
2222
(*a, b, c)
2323
for a in [(256, 4, 8, 8), (160, 3, 8, 6)]
24-
for b in ["softmax", "sigmoid"]
24+
for b in [
25+
"sigmoid",
26+
]
2527
for c in [1, 8, 256, 1024, 2048, 4096, 8192]
2628
],
2729
)

0 commit comments

Comments
 (0)