Skip to content

Commit d8aff9c

Browse files
author
sangchengmeng
committed
fix grouped_topk tl.sort when numel=1
1 parent 2d95b73 commit d8aff9c

File tree

1 file changed

+3
-2
lines changed

1 file changed

+3
-2
lines changed

lightllm/common/fused_moe/grouped_topk.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -159,8 +159,9 @@ def grouped_topk_kernel(
159159
axis=1,
160160
)
161161

162-
sorted_group_value = tl.sort(group_value, descending=True)
163-
group_topk_value = tl.sum(tl.where(offs_group == group_topk_num - 1, sorted_group_value, 0.0))
162+
if EXPERT_GROUP_NUM > 1:
163+
group_value = tl.sort(group_value, descending=True)
164+
group_topk_value = tl.sum(tl.where(offs_group == group_topk_num - 1, group_value, 0.0))
164165
mask_group_scores = tl.where(
165166
((group_value >= group_topk_value)[:, None]) & ((offs_group_v < group_expert_num)[None, :]),
166167
group_scores,

0 commit comments

Comments
 (0)