@@ -223,7 +223,7 @@ def moe_align1(
223223def moe_align_fused_kernel (
224224 topk_ids_ptr , # [token_num, topk]
225225 topk_weights_ptr , # [token_num, topk]
226- expert_to_index_ptr , # [expert_num, token_num * topk]
226+ expert_to_token_index_ptr , # [expert_num, token_num * topk]
227227 expert_to_weight_ptr , # [expert_num, token_num * topk]
228228 expert_token_num_ptr , # [expert_num]
229229 token_num ,
@@ -242,7 +242,7 @@ def moe_align_fused_kernel(
242242
243243 # 按 token 顺序写 index 和 weight
244244 tl .store (
245- expert_to_index_ptr + expert_ids * (token_num * topk_num ) + write_pos ,
245+ expert_to_token_index_ptr + expert_ids * (token_num * topk_num ) + write_pos ,
246246 offs ,
247247 mask = mask ,
248248 )
@@ -268,7 +268,7 @@ def _get_moe_align_fused_configs():
268268 "BLOCK_SIZE" : bt ,
269269 "num_warps" : nw ,
270270 }
271- for nw in [4 , 8 ]
271+ for nw in [1 , 2 , 4 , 8 ]
272272 for bt in [128 , 256 , 512 , 1024 , 2048 ]
273273 ]
274274
@@ -278,10 +278,10 @@ def _get_moe_align_fused_configs():
278278 configs_gen_func = _get_moe_align_fused_configs ,
279279 static_key_func = _get_moe_align_fused_static_key ,
280280 run_key_func = lambda topk_ids : topk_ids .shape [0 ],
281- mutates_args = ["expert_to_index " , "expert_to_weight" , "expert_token_num" ],
281+ mutates_args = ["expert_to_token_index " , "expert_to_weight" , "expert_token_num" ],
282282)
283283def moe_align_fused (
284- expert_to_index , expert_to_weight , expert_token_num , topk_ids , topk_weights , run_config : Optional [dict ] = None
284+ expert_to_token_index , expert_to_weight , expert_token_num , topk_ids , topk_weights , run_config : Optional [dict ] = None
285285):
286286 token_num , topk_num = topk_ids .shape
287287 if run_config is None :
@@ -293,15 +293,15 @@ def moe_align_fused(
293293 moe_align_fused_kernel [grid ](
294294 topk_ids ,
295295 topk_weights ,
296- expert_to_index ,
296+ expert_to_token_index ,
297297 expert_to_weight ,
298298 expert_token_num ,
299299 token_num ,
300300 topk_num ,
301301 BLOCK_SIZE = BLOCK_SIZE ,
302302 num_warps = num_warps ,
303303 )
304- return expert_to_index , expert_to_weight , expert_token_num
304+ return expert_to_token_index , expert_to_weight , expert_token_num
305305
306306
307307@triton .jit
@@ -805,7 +805,13 @@ def fused_experts_impl(
805805 expert_to_tokens = torch .empty ((E , topk_num * tokens_in_chunk ), dtype = torch .int32 , device = "cuda" )
806806 expert_to_weights = torch .empty ((E , topk_num * tokens_in_chunk ), dtype = torch .float32 , device = "cuda" )
807807 expert_to_token_num = torch .zeros ((E ,), dtype = torch .int32 , device = "cuda" )
808- moe_align_fused (expert_to_tokens , expert_to_weights , expert_to_token_num , curr_topk_ids , curr_topk_weights )
808+ moe_align_fused (
809+ expert_to_token_index = expert_to_tokens ,
810+ expert_to_weight = expert_to_weights ,
811+ expert_token_num = expert_to_token_num ,
812+ topk_ids = curr_topk_ids ,
813+ topk_weights = curr_topk_weights ,
814+ )
809815
810816 reused_mblock_infos = grouped_matmul (
811817 curr_topk_ids .numel (),
0 commit comments