Skip to content

Commit e6f6b8d

Browse files
committed
speed group_topk kernel
1 parent 48ce26b commit e6f6b8d

File tree

1 file changed

+8
-2
lines changed

1 file changed

+8
-2
lines changed

lightllm/common/fused_moe/grouped_topk.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -140,6 +140,7 @@ def grouped_topk_kernel(
140140
offs_group = tl.arange(0, EXPERT_GROUP_NUM)
141141
offs_group_v = tl.arange(0, EXPERT_GROUP_SIZE)
142142
tl.store(scores_buffer_ptr + scores_stride_m * token_index + offs_n, scores, mask=offs_n < total_expert_num)
143+
tl.debug_barrier()
143144
group_scores = tl.load(
144145
scores_buffer_ptr
145146
+ scores_stride_token_m * token_index
@@ -174,7 +175,7 @@ def grouped_topk_kernel(
174175
mask_group_scores,
175176
mask=((offs_group < group_num)[:, None]) & ((offs_group_v < group_expert_num)[None, :]),
176177
) # [group, group_size]
177-
178+
tl.debug_barrier()
178179
mask_scores = tl.load(
179180
scores_buffer_ptr + scores_stride_m * token_index + offs_n, mask=offs_n < total_expert_num, other=-10000000.0
180181
)
@@ -227,6 +228,11 @@ def triton_grouped_topk(
227228

228229
assert total_expert_num % num_expert_group == 0
229230

231+
if token_num <= 256:
232+
num_warps = 4
233+
else:
234+
num_warps = 1
235+
230236
grouped_topk_kernel[(token_num,)](
231237
gating_output,
232238
*gating_output.stride(),
@@ -250,7 +256,7 @@ def triton_grouped_topk(
250256
EXPERT_GROUP_SIZE=triton.next_power_of_2(total_expert_num // num_expert_group),
251257
RENORMALIZE=renormalize,
252258
GROUP_SCORE_USED_TOPK_NUM=group_score_used_topk_num,
253-
num_warps=1,
259+
num_warps=num_warps,
254260
num_stages=1,
255261
)
256262
return out_topk_weights, out_topk_ids

0 commit comments

Comments
 (0)