File tree Expand file tree Collapse file tree 1 file changed +8
-2
lines changed
lightllm/common/fused_moe Expand file tree Collapse file tree 1 file changed +8
-2
lines changed Original file line number Diff line number Diff line change @@ -140,6 +140,7 @@ def grouped_topk_kernel(
140140 offs_group = tl .arange (0 , EXPERT_GROUP_NUM )
141141 offs_group_v = tl .arange (0 , EXPERT_GROUP_SIZE )
142142 tl .store (scores_buffer_ptr + scores_stride_m * token_index + offs_n , scores , mask = offs_n < total_expert_num )
143+ tl .debug_barrier ()
143144 group_scores = tl .load (
144145 scores_buffer_ptr
145146 + scores_stride_token_m * token_index
@@ -174,7 +175,7 @@ def grouped_topk_kernel(
174175 mask_group_scores ,
175176 mask = ((offs_group < group_num )[:, None ]) & ((offs_group_v < group_expert_num )[None , :]),
176177 ) # [group, group_size]
177-
178+ tl . debug_barrier ()
178179 mask_scores = tl .load (
179180 scores_buffer_ptr + scores_stride_m * token_index + offs_n , mask = offs_n < total_expert_num , other = - 10000000.0
180181 )
@@ -227,6 +228,11 @@ def triton_grouped_topk(
227228
228229 assert total_expert_num % num_expert_group == 0
229230
231+ if token_num <= 256 :
232+ num_warps = 4
233+ else :
234+ num_warps = 1
235+
230236 grouped_topk_kernel [(token_num ,)](
231237 gating_output ,
232238 * gating_output .stride (),
@@ -250,7 +256,7 @@ def triton_grouped_topk(
250256 EXPERT_GROUP_SIZE = triton .next_power_of_2 (total_expert_num // num_expert_group ),
251257 RENORMALIZE = renormalize ,
252258 GROUP_SCORE_USED_TOPK_NUM = group_score_used_topk_num ,
253- num_warps = 1 ,
259+ num_warps = num_warps ,
254260 num_stages = 1 ,
255261 )
256262 return out_topk_weights , out_topk_ids
You can’t perform that action at this time.
0 commit comments