@@ -227,15 +227,13 @@ def moe_align_fused_kernel(
227227 expert_to_weight_ptr , # [expert_num, token_num * topk]
228228 expert_token_num_ptr , # [expert_num]
229229 token_num ,
230- topk : tl .constexpr ,
231- BLOCK_TOK : tl .constexpr ,
230+ topk_num : tl .constexpr ,
231+ BLOCK_SIZE : tl .constexpr ,
232232):
233233 token_block = tl .program_id (0 )
234- offs = token_block * BLOCK_TOK + tl .arange (0 , BLOCK_TOK )
235- mask = offs < token_num * topk
234+ offs = token_block * BLOCK_SIZE + tl .arange (0 , BLOCK_SIZE )
235+ mask = offs < token_num * topk_num
236236
237- # 遍历 topk
238- # for k in range(topk):
239237 expert_ids = tl .load (topk_ids_ptr + offs , mask = mask , other = 0 )
240238 weights = tl .load (topk_weights_ptr + offs , mask = mask , other = 0.0 )
241239
@@ -244,12 +242,12 @@ def moe_align_fused_kernel(
244242
245243 # 按 token 顺序写 index 和 weight
246244 tl .store (
247- expert_to_index_ptr + expert_ids * (token_num * topk ) + write_pos ,
245+ expert_to_index_ptr + expert_ids * (token_num * topk_num ) + write_pos ,
248246 offs ,
249247 mask = mask ,
250248 )
251249 tl .store (
252- expert_to_weight_ptr + expert_ids * (token_num * topk ) + write_pos ,
250+ expert_to_weight_ptr + expert_ids * (token_num * topk_num ) + write_pos ,
253251 weights ,
254252 mask = mask ,
255253 )
@@ -258,20 +256,18 @@ def moe_align_fused_kernel(
258256def _get_moe_align_fused_static_key (
259257 topk_weights : torch .Tensor ,
260258) -> dict :
261- topk = topk_weights .shape [1 ]
259+ topk_num = topk_weights .shape [1 ]
262260 return {
263- "topk " : topk ,
261+ "topk_num " : topk_num ,
264262 }
265263
266264
267265def _get_moe_align_fused_configs ():
268266 return [
269267 {
270- "BLOCK_TOK " : bt ,
268+ "BLOCK_SIZE " : bt ,
271269 "num_warps" : nw ,
272- "num_stages" : ns ,
273270 }
274- for ns in [2 , 3 , 4 , 5 ]
275271 for nw in [4 , 8 ]
276272 for bt in [128 , 256 , 512 , 1024 , 2048 ]
277273 ]
@@ -285,27 +281,25 @@ def _get_moe_align_fused_configs():
285281 mutates_args = ["expert_to_index" , "expert_to_weight" , "expert_token_num" ],
286282)
287283def moe_align_fused (
288- expert_to_index , expert_to_weight , expert_token_num , topk_ids , topk_weights , topk , run_config : Optional [dict ] = None
284+ expert_to_index , expert_to_weight , expert_token_num , topk_ids , topk_weights , run_config : Optional [dict ] = None
289285):
290- token_num , topk = topk_ids .shape
286+ token_num , topk_num = topk_ids .shape
291287 if run_config is None :
292288 run_config = {}
293- BLOCK_TOK = run_config .get ("BLOCK_TOK " , 256 )
289+ BLOCK_SIZE = run_config .get ("BLOCK_SIZE " , 256 )
294290 num_warps = run_config .get ("num_warps" , 4 )
295- num_stages = run_config .get ("num_stages" , 3 )
296291
297- grid = (triton .cdiv (token_num * topk , BLOCK_TOK ),)
292+ grid = (triton .cdiv (token_num * topk_num , BLOCK_SIZE ),)
298293 moe_align_fused_kernel [grid ](
299294 topk_ids ,
300295 topk_weights ,
301296 expert_to_index ,
302297 expert_to_weight ,
303298 expert_token_num ,
304299 token_num ,
305- topk ,
306- BLOCK_TOK = BLOCK_TOK ,
300+ topk_num ,
301+ BLOCK_SIZE = BLOCK_SIZE ,
307302 num_warps = num_warps ,
308- num_stages = num_stages ,
309303 )
310304 return expert_to_index , expert_to_weight , expert_token_num
311305
@@ -811,9 +805,7 @@ def fused_experts_impl(
811805 expert_to_tokens = torch .empty ((E , topk_num * tokens_in_chunk ), dtype = torch .int32 , device = "cuda" )
812806 expert_to_weights = torch .empty ((E , topk_num * tokens_in_chunk ), dtype = torch .float32 , device = "cuda" )
813807 expert_to_token_num = torch .zeros ((E ,), dtype = torch .int32 , device = "cuda" )
814- moe_align_fused (
815- expert_to_tokens , expert_to_weights , expert_to_token_num , curr_topk_ids , curr_topk_weights , topk = topk_num
816- )
808+ moe_align_fused (expert_to_tokens , expert_to_weights , expert_to_token_num , curr_topk_ids , curr_topk_weights )
817809
818810 reused_mblock_infos = grouped_matmul (
819811 curr_topk_ids .numel (),
0 commit comments