@@ -185,27 +185,31 @@ def _fwd_kernel_ep_gather(
185185 output_tensor ,
186186 output_tensor_stride0 ,
187187 output_tensor_stride1 ,
188+ HIDDEN_SIZE : tl .constexpr ,
189+ HIDDEN_SIZE_PAD : tl .constexpr ,
188190 topk_num : tl .constexpr ,
189191 BLOCK_D : tl .constexpr ,
190192):
191- cur_block = tl .program_id (0 )
192- start_cur_token = tl .program_id (1 )
193- grid_num = tl .num_programs (1 )
193+ start_cur_token = tl .program_id (0 )
194+ grid_num = tl .num_programs (0 )
195+
196+ offset_d = tl .arange (0 , HIDDEN_SIZE_PAD )
197+ mask = offset_d < HIDDEN_SIZE
194198
195199 for cur_token in range (start_cur_token , total_token_num , grid_num ):
196- off_d = tl .arange (0 , BLOCK_D )
197- accumulator = tl .zeros ([BLOCK_D ], dtype = tl .float32 )
200+ accumulator = tl .zeros ([HIDDEN_SIZE_PAD ], dtype = tl .float32 )
198201 for topk_index in range (0 , topk_num ):
199202 expert_id = tl .load (recv_topk_ids + cur_token * recv_topk_ids_stride0 + topk_index )
200203 if expert_id >= 0 :
201204 source_token_index = tl .load (input_index + cur_token * input_index_stride0 + topk_index )
202205 acc_weight = tl .load (recv_topk_weight + cur_token * recv_topk_weight_stride0 + topk_index )
203- tmp = tl .load (input_tensor + source_token_index * input_tensor_stride0 + cur_block * BLOCK_D + off_d )
206+ tmp = tl .load (input_tensor + source_token_index * input_tensor_stride0 + offset_d , mask = mask )
204207 accumulator += tmp .to (tl .float32 ) * acc_weight
205208
206209 tl .store (
207- output_tensor + cur_token * output_tensor_stride0 + cur_block * BLOCK_D + off_d ,
210+ output_tensor + cur_token * output_tensor_stride0 + offset_d ,
208211 accumulator .to (output_tensor .dtype .element_ty ),
212+ mask = mask ,
209213 )
210214
211215
@@ -221,7 +225,7 @@ def ep_gather(
221225 num_warps = 4
222226 num_tokens = output_tensor .shape [0 ]
223227 hidden_size = input_tensor .shape [1 ]
224- grid = (triton . cdiv ( hidden_size , BLOCK_D ), min (num_tokens , 1024 ) )
228+ grid = (min (num_tokens , 32768 ), )
225229 _fwd_kernel_ep_gather [grid ](
226230 num_tokens ,
227231 input_tensor ,
@@ -240,6 +244,8 @@ def ep_gather(
240244 output_tensor .stride (0 ),
241245 output_tensor .stride (1 ),
242246 topk_num = recv_topk_ids .shape [1 ],
247+ HIDDEN_SIZE = hidden_size ,
248+ HIDDEN_SIZE_PAD = triton .next_power_of_2 (hidden_size ),
243249 num_warps = num_warps ,
244250 BLOCK_D = BLOCK_D ,
245251 )
0 commit comments