@@ -169,78 +169,84 @@ def ep_scatter(
169169
170170@triton .jit
171171def _fwd_kernel_ep_gather (
172- total_token_num ,
173172 input_tensor ,
174173 input_tensor_stride0 ,
175174 input_tensor_stride1 ,
176- recv_topk_ids ,
177- recv_topk_ids_stride0 ,
178- recv_topk_ids_stride1 ,
175+ recv_topk ,
176+ recv_topk_stride0 ,
177+ recv_topk_stride1 ,
179178 recv_topk_weight ,
180179 recv_topk_weight_stride0 ,
181180 recv_topk_weight_stride1 ,
182181 input_index ,
183182 input_index_stride0 ,
184183 input_index_stride1 ,
184+ expert_start_loc ,
185185 output_tensor ,
186186 output_tensor_stride0 ,
187187 output_tensor_stride1 ,
188- topk_num : tl .constexpr ,
188+ topk_col : tl .constexpr ,
189189 BLOCK_D : tl .constexpr ,
190+ HIDDEN_SIZE : tl .constexpr ,
191+ HIDDEN_SIZE_PAD : tl .constexpr ,
190192):
191- cur_block = tl .program_id (0 )
192- start_cur_token = tl .program_id ( 1 )
193- grid_num = tl . num_programs ( 1 )
194-
195- for cur_token in range ( start_cur_token , total_token_num , grid_num ):
196- off_d = tl . arange (0 , BLOCK_D )
197- accumulator = tl .zeros ([ BLOCK_D ], dtype = tl . float32 )
198- for topk_index in range ( 0 , topk_num ) :
199- expert_id = tl .load (recv_topk_ids + cur_token * recv_topk_ids_stride0 + topk_index )
200- if expert_id >= 0 :
201- source_token_index = tl . load ( input_index + cur_token * input_index_stride0 + topk_index )
202- acc_weight = tl .load (recv_topk_weight + cur_token * recv_topk_weight_stride0 + topk_index )
203- tmp = tl .load (input_tensor + source_token_index * input_tensor_stride0 + cur_block * BLOCK_D + off_d )
204- accumulator += tmp .to (tl .float32 ) * acc_weight
205-
206- tl .store (
207- output_tensor + cur_token * output_tensor_stride0 + cur_block * BLOCK_D + off_d ,
208- accumulator .to (output_tensor .dtype .element_ty ),
209- )
193+ token_id = tl .program_id (0 )
194+ offset = tl .arange ( 0 , HIDDEN_SIZE_PAD )
195+ mask = offset < HIDDEN_SIZE
196+ accumulator = tl . zeros ([ HIDDEN_SIZE_PAD ], dtype = tl . float32 )
197+
198+ for start_topk in range (0 , topk_col ):
199+ cur_expert = tl .load ( recv_topk + token_id * recv_topk_stride0 + start_topk )
200+ if cur_expert >= 0 :
201+ start_ = tl .load (expert_start_loc + cur_expert )
202+ dst = tl . load ( input_index + token_id * input_index_stride0 + start_topk ) + start_
203+
204+ weight = tl .load (recv_topk_weight + token_id * recv_topk_weight_stride0 + start_topk )
205+ tmp = tl .load (input_tensor + dst + offset )
206+ accumulator += tmp .to (tl .float32 ) * weight
207+
208+ tl .store (
209+ output_tensor + token_id * output_tensor_stride0 + offset ,
210+ accumulator .to (output_tensor .dtype .element_ty ),
211+ )
210212
211213
212214@torch .no_grad ()
213215def ep_gather (
214216 input_tensor : torch .Tensor ,
215- recv_topk_ids : torch .Tensor ,
217+ recv_topk : torch .Tensor ,
216218 recv_topk_weight : torch .Tensor ,
217219 input_index : torch .Tensor ,
220+ expert_start_loc : torch .Tensor ,
218221 output_tensor : torch .Tensor ,
219222):
220223 BLOCK_D = 128 # block size of quantization
221224 num_warps = 4
222225 num_tokens = output_tensor .shape [0 ]
223226 hidden_size = input_tensor .shape [1 ]
224- grid = ( triton . cdiv ( hidden_size , BLOCK_D ), min (num_tokens , 1024 ) )
225- _fwd_kernel_ep_gather [ grid ](
226- num_tokens ,
227+ grid = min (num_tokens , 65535 )
228+
229+ _fwd_kernel_ep_gather [( grid ,)](
227230 input_tensor ,
228231 input_tensor .stride (0 ),
229232 input_tensor .stride (1 ),
230- recv_topk_ids ,
231- recv_topk_ids .stride (0 ),
232- recv_topk_ids .stride (1 ),
233+ recv_topk ,
234+ recv_topk .stride (0 ),
235+ recv_topk .stride (1 ),
233236 recv_topk_weight ,
234237 recv_topk_weight .stride (0 ),
235238 recv_topk_weight .stride (1 ),
236239 input_index ,
237240 input_index .stride (0 ),
238241 input_index .stride (1 ),
242+ expert_start_loc ,
239243 output_tensor ,
240244 output_tensor .stride (0 ),
241245 output_tensor .stride (1 ),
242- topk_num = recv_topk_ids .shape [1 ],
246+ topk_col = recv_topk .shape [1 ],
243247 num_warps = num_warps ,
244248 BLOCK_D = BLOCK_D ,
249+ HIDDEN_SIZE = hidden_size ,
250+ HIDDEN_SIZE_PAD = triton .next_power_of_2 (hidden_size ),
245251 )
246252 return
0 commit comments