@@ -219,6 +219,91 @@ def moe_align1(
219219 )
220220
221221
222+ @triton .jit
223+ def moe_align_fused_kernel (
224+ topk_ids_ptr , # [token_num, topk]
225+ topk_weights_ptr , # [token_num, topk]
226+ expert_to_token_index_ptr , # [expert_num, token_num * topk]
227+ expert_to_weight_ptr , # [expert_num, token_num * topk]
228+ expert_token_num_ptr , # [expert_num]
229+ token_num ,
230+ topk_num : tl .constexpr ,
231+ BLOCK_SIZE : tl .constexpr ,
232+ ):
233+ token_block = tl .program_id (0 )
234+ offs = token_block * BLOCK_SIZE + tl .arange (0 , BLOCK_SIZE )
235+ mask = offs < token_num * topk_num
236+
237+ expert_ids = tl .load (topk_ids_ptr + offs , mask = mask , other = 0 )
238+ weights = tl .load (topk_weights_ptr + offs , mask = mask , other = 0.0 )
239+
240+ # 用 atomic_add 给 expert 分配写位置
241+ write_pos = tl .atomic_add (expert_token_num_ptr + expert_ids , 1 , mask = mask )
242+
243+ # 按 token 顺序写 index 和 weight
244+ tl .store (
245+ expert_to_token_index_ptr + expert_ids * (token_num * topk_num ) + write_pos ,
246+ offs ,
247+ mask = mask ,
248+ )
249+ tl .store (
250+ expert_to_weight_ptr + expert_ids * (token_num * topk_num ) + write_pos ,
251+ weights ,
252+ mask = mask ,
253+ )
254+
255+
256+ def _get_moe_align_fused_static_key (
257+ topk_weights : torch .Tensor ,
258+ ) -> dict :
259+ topk_num = topk_weights .shape [1 ]
260+ return {
261+ "topk_num" : topk_num ,
262+ }
263+
264+
265+ def _get_moe_align_fused_configs ():
266+ return [
267+ {
268+ "BLOCK_SIZE" : bt ,
269+ "num_warps" : nw ,
270+ }
271+ for nw in [1 , 2 , 4 , 8 ]
272+ for bt in [128 , 256 , 512 , 1024 , 2048 ]
273+ ]
274+
275+
276+ @autotune (
277+ kernel_name = "moe_align_fused:v1" ,
278+ configs_gen_func = _get_moe_align_fused_configs ,
279+ static_key_func = _get_moe_align_fused_static_key ,
280+ run_key_func = lambda topk_ids : topk_ids .shape [0 ],
281+ mutates_args = ["expert_to_token_index" , "expert_to_weight" , "expert_token_num" ],
282+ )
283+ def moe_align_fused (
284+ expert_to_token_index , expert_to_weight , expert_token_num , topk_ids , topk_weights , run_config : Optional [dict ] = None
285+ ):
286+ token_num , topk_num = topk_ids .shape
287+ if run_config is None :
288+ run_config = {}
289+ BLOCK_SIZE = run_config .get ("BLOCK_SIZE" , 256 )
290+ num_warps = run_config .get ("num_warps" , 4 )
291+
292+ grid = (triton .cdiv (token_num * topk_num , BLOCK_SIZE ),)
293+ moe_align_fused_kernel [grid ](
294+ topk_ids ,
295+ topk_weights ,
296+ expert_to_token_index ,
297+ expert_to_weight ,
298+ expert_token_num ,
299+ token_num ,
300+ topk_num ,
301+ BLOCK_SIZE = BLOCK_SIZE ,
302+ num_warps = num_warps ,
303+ )
304+ return expert_to_token_index , expert_to_weight , expert_token_num
305+
306+
222307@triton .jit
223308def moe_align2_kernel (
224309 experts_token_num_ptr , # [expert_num,]
@@ -719,9 +804,14 @@ def fused_experts_impl(
719804
720805 expert_to_tokens = torch .empty ((E , topk_num * tokens_in_chunk ), dtype = torch .int32 , device = "cuda" )
721806 expert_to_weights = torch .empty ((E , topk_num * tokens_in_chunk ), dtype = torch .float32 , device = "cuda" )
722- moe_align (topk_ids = curr_topk_ids , out = expert_to_tokens )
723- expert_to_token_num = torch .empty ((E ,), dtype = torch .int32 , device = "cuda" )
724- moe_align1 (expert_to_tokens , curr_topk_weights , expert_to_weights , expert_to_token_num , topk = topk_num )
807+ expert_to_token_num = torch .zeros ((E ,), dtype = torch .int32 , device = "cuda" )
808+ moe_align_fused (
809+ expert_to_token_index = expert_to_tokens ,
810+ expert_to_weight = expert_to_weights ,
811+ expert_token_num = expert_to_token_num ,
812+ topk_ids = curr_topk_ids ,
813+ topk_weights = curr_topk_weights ,
814+ )
725815
726816 reused_mblock_infos = grouped_matmul (
727817 curr_topk_ids .numel (),
0 commit comments