@@ -114,29 +114,37 @@ def moe_align1_kernel(
114114 experts_topk_weight , # [expert_num, token_num * topk_num]
115115 experts_topk_weight_stride0 ,
116116 experts_topk_weight_stride1 ,
117- TOKEN_BLOCK_N : tl .constexpr ,
117+ TOKEN_BLOCK_SIZE : tl .constexpr ,
118+ NUM_STAGE : tl .constexpr ,
118119):
119120
120121 expert_id = tl .program_id (axis = 0 )
121- n_range = tl .arange (0 , TOKEN_BLOCK_N )
122122
123- topk_weights_data = tl .load (topk_weights + n_range , mask = n_range < experts_info_n , other = 0 )
124- expert_data = tl .load (
125- experts_info_ptr + expert_id * experts_info_stride0 + n_range , mask = n_range < experts_info_n , other = 0
126- )
127- cumsum_expert_data = tl .cumsum (expert_data )
123+ off_n = tl .arange (0 , TOKEN_BLOCK_SIZE )
128124
129- tl .store (expert_token_num_ptr + expert_id , tl .max (cumsum_expert_data ))
130- tl .store (
131- experts_info_ptr + expert_id * experts_info_stride0 + cumsum_expert_data - 1 ,
132- n_range ,
133- mask = (expert_data == 1 ) & (n_range < experts_info_n ),
134- )
135- tl .store (
136- experts_topk_weight + expert_id * experts_topk_weight_stride0 + cumsum_expert_data - 1 ,
137- topk_weights_data ,
138- mask = (expert_data == 1 ) & (n_range < experts_info_n ),
139- )
125+ pre_sum = 0
126+
127+ for start_loc in tl .range (0 , experts_info_n , TOKEN_BLOCK_SIZE , num_stages = NUM_STAGE ):
128+ n_range = start_loc + off_n
129+ topk_weights_data = tl .load (topk_weights + n_range , mask = n_range < experts_info_n , other = 0 )
130+ expert_data = tl .load (
131+ experts_info_ptr + expert_id * experts_info_stride0 + n_range , mask = n_range < experts_info_n , other = 0
132+ )
133+ cumsum_expert_data = tl .cumsum (expert_data ) + pre_sum
134+ pre_sum = tl .max (cumsum_expert_data )
135+ tl .store (
136+ experts_info_ptr + expert_id * experts_info_stride0 + cumsum_expert_data - 1 ,
137+ n_range ,
138+ mask = (expert_data == 1 ) & (n_range < experts_info_n ),
139+ )
140+ tl .store (
141+ experts_topk_weight + expert_id * experts_topk_weight_stride0 + cumsum_expert_data - 1 ,
142+ topk_weights_data ,
143+ mask = (expert_data == 1 ) & (n_range < experts_info_n ),
144+ )
145+
146+ tl .store (expert_token_num_ptr + expert_id , pre_sum )
147+ return
140148
141149
142150def moe_align1 (
@@ -184,7 +192,11 @@ def moe_align1(
184192 assert token_num_mul_topk <= FFN_MOE_CHUNK_SIZE * topk_num , "need split to handle seq len too long"
185193 assert exports_token_num .shape [0 ] == expert_num
186194 assert topk_weights .is_contiguous ()
187- TOKEN_BLOCK_N = triton .next_power_of_2 (token_num_mul_topk )
195+ if token_num_mul_topk <= 512 :
196+ TOKEN_BLOCK_SIZE = 256
197+ else :
198+ TOKEN_BLOCK_SIZE = 512 if token_num_mul_topk <= 4 * 1024 else 2048
199+
188200 grid = (expert_num ,)
189201 moe_align1_kernel [grid ](
190202 experts_info ,
@@ -197,7 +209,8 @@ def moe_align1(
197209 experts_weight_info ,
198210 experts_weight_info .stride (0 ),
199211 experts_weight_info .stride (1 ),
200- TOKEN_BLOCK_N = TOKEN_BLOCK_N ,
212+ TOKEN_BLOCK_SIZE = TOKEN_BLOCK_SIZE ,
213+ NUM_STAGE = 4 ,
201214 num_warps = 8 ,
202215 num_stages = 1 ,
203216 )
0 commit comments