1313from lightllm .models .llama .triton_kernel .silu_and_mul import silu_and_mul_fwd
1414from functools import partial
1515from lightllm .utils .log_utils import init_logger
16+ from lightllm .utils .dist_utils import get_global_world_size
1617
1718logger = init_logger (__name__ )
1819
@@ -27,6 +28,7 @@ def __init__(self, layer_num, network_config, mode=[]):
2728 )
2829 self .num_experts_per_tok = network_config ["num_experts_per_tok" ]
2930 self .norm_topk_prob = network_config ["norm_topk_prob" ]
31+ self .n_shared_experts = network_config .get ("n_shared_experts" , None )
3032 super ().__init__ (layer_num , network_config , mode )
3133 self .head_dim_ = network_config ["head_dim" ]
3234 self .tp_k_head_num_ = max (self .tp_k_head_num_ , 1 )
@@ -120,3 +122,274 @@ def _moe_ffn_edp(
120122
121123 ep_output = ep_output .view (token_num , hidden_dim )
122124 return ep_output
125+
126+ def overlap_tpsp_token_forward (
127+ self ,
128+ input_embdings : torch .Tensor ,
129+ input_embdings1 : torch .Tensor ,
130+ infer_state : LlamaInferStateInfo ,
131+ infer_state1 : LlamaInferStateInfo ,
132+ layer_weight : Qwen3MOETransformerLayerWeight ,
133+ ):
134+ if not self .is_moe :
135+ return super ().overlap_tpsp_token_forward (
136+ input_embdings , input_embdings1 , infer_state , infer_state1 , layer_weight
137+ )
138+ # 0 attention
139+ _0_input1 = self ._att_norm (input_embdings , infer_state , layer_weight )
140+ _0_cache_kv = self ._pre_cache_kv (infer_state , layer_weight )
141+ _0_q , _0_cache_kv = self ._tpsp_get_qkv (_0_input1 , _0_cache_kv , infer_state , layer_weight )
142+ _0_input1 = None
143+ self ._post_cache_kv (_0_cache_kv , infer_state , layer_weight )
144+ _0_o = self ._token_attention_kernel (_0_q , infer_state , layer_weight )
145+ _0_q = None
146+ _0_o = self ._tpsp_get_o (_0_o , infer_state , layer_weight )
147+ input_embdings .add_ (_0_o .view (- 1 , self .embed_dim_ ))
148+ _0_o = None
149+ _0_input1 = self ._ffn_norm (input_embdings , infer_state , layer_weight )
150+ _0_router_logits = layer_weight .moe_gate .mm (_0_input1 )
151+ # 1 hook
152+ if getattr (infer_state1 , "hook" , None ) is not None :
153+ infer_state1 .hook ()
154+ infer_state1 .hook = None
155+
156+ # 0 shared expert
157+ if self .n_shared_experts is not None :
158+ _0_shared_output = LlamaTransformerLayerInfer ._ffn (self , _0_input1 , infer_state , layer_weight )
159+
160+ # 0 dispatch
161+ (
162+ _0_recv_x ,
163+ _0_masked_m ,
164+ _0_topk_idx ,
165+ _0_topk_weight ,
166+ _0_handle ,
167+ _0_hook ,
168+ ) = layer_weight .experts .low_latency_dispatch (_0_input1 , _0_router_logits )
169+ infer_state .hook = _0_hook
170+
171+ # 1 attention
172+ _1_input1 = self ._att_norm (input_embdings1 , infer_state1 , layer_weight )
173+ _1_cache_kv = self ._pre_cache_kv (infer_state1 , layer_weight )
174+ _1_q , _1_cache_kv = self ._tpsp_get_qkv (_1_input1 , _1_cache_kv , infer_state1 , layer_weight )
175+ _1_input1 = None
176+ self ._post_cache_kv (_1_cache_kv , infer_state1 , layer_weight )
177+ _1_o = self ._token_attention_kernel (_1_q , infer_state1 , layer_weight )
178+ _1_q = None
179+ _1_o = self ._tpsp_get_o (_1_o , infer_state1 , layer_weight )
180+ input_embdings1 .add_ (_1_o .view (- 1 , self .embed_dim_ ))
181+ _1_o = None
182+ _1_input1 = self ._ffn_norm (input_embdings1 , infer_state1 , layer_weight )
183+ # to do gate and disptatch
184+
185+ _1_router_logits = layer_weight .moe_gate .mm (_1_input1 )
186+ # 0 hook
187+ if getattr (infer_state , "hook" , None ) is not None :
188+ infer_state .hook ()
189+ infer_state .hook = None
190+
191+ # 1 shared expert
192+ if self .n_shared_experts is not None :
193+ _1_shared_output = LlamaTransformerLayerInfer ._ffn (self , _1_input1 , infer_state1 , layer_weight )
194+
195+ # 1 dispatch
196+ (
197+ _1_recv_x ,
198+ _1_masked_m ,
199+ _1_topk_idx ,
200+ _1_topk_weight ,
201+ _1_handle ,
202+ _1_hook ,
203+ ) = layer_weight .experts .low_latency_dispatch (_1_input1 , _1_router_logits )
204+ infer_state1 .hook = _1_hook
205+
206+ # moe calu
207+ expected_m = triton .cdiv (
208+ input_embdings .shape [0 ] * get_global_world_size () * self .num_experts_per_tok , self .n_routed_experts
209+ )
210+ _0_moe_out = layer_weight .experts .masked_group_gemm (_0_recv_x , _0_masked_m , input_embdings .dtype , expected_m )
211+
212+ # 1 hook
213+ if getattr (infer_state1 , "hook" , None ) is not None :
214+ infer_state1 .hook ()
215+ infer_state1 .hook = None
216+
217+ # 0 combine
218+ _0_ffn_out , _0_hook = layer_weight .experts .low_latency_combine (
219+ _0_moe_out , _0_topk_idx , _0_topk_weight , _0_handle
220+ )
221+
222+ infer_state .hook = _0_hook
223+
224+ # to do moe caclue
225+ _1_moe_out = layer_weight .experts .masked_group_gemm (_1_recv_x , _1_masked_m , input_embdings1 .dtype , expected_m )
226+
227+ # 0 hook
228+ if getattr (infer_state , "hook" , None ) is not None :
229+ infer_state .hook ()
230+ # _0_ffn_out *= self.routed_scaling_factor
231+ if self .n_shared_experts is not None :
232+ _0_ffn_out .add_ (_0_shared_output )
233+ input_embdings .add_ (_0_ffn_out .view (- 1 , self .embed_dim_ ))
234+ infer_state .hook = None
235+
236+ # 1 combine
237+ _1_ffn_out , _1_hook = layer_weight .experts .low_latency_combine (
238+ _1_moe_out , _1_topk_idx , _1_topk_weight , _1_handle
239+ )
240+
241+ def _1_hook_post ():
242+ _1_hook ()
243+ nonlocal _1_ffn_out
244+ # _1_ffn_out *= self.routed_scaling_factor
245+ if self .n_shared_experts is not None :
246+ _1_ffn_out .add_ (_1_shared_output )
247+ input_embdings1 .add_ (_1_ffn_out .view (- 1 , self .embed_dim_ ))
248+ return
249+
250+ infer_state1 .hook = _1_hook_post
251+
252+ return input_embdings , input_embdings1
253+
254+ def overlap_tpsp_context_forward (
255+ self ,
256+ input_embdings : torch .Tensor ,
257+ input_embdings1 : torch .Tensor ,
258+ infer_state : LlamaInferStateInfo ,
259+ infer_state1 : LlamaInferStateInfo ,
260+ layer_weight : Qwen3MOETransformerLayerWeight ,
261+ ):
262+ if not self .is_moe :
263+ return super ().overlap_tpsp_context_forward (
264+ input_embdings , input_embdings1 , infer_state , infer_state1 , layer_weight
265+ )
266+ # 0 attention
267+ _0_input1 = self ._att_norm (input_embdings , infer_state , layer_weight )
268+ _0_cache_kv = self ._pre_cache_kv (infer_state , layer_weight )
269+ _0_q , _0_cache_kv = self ._tpsp_get_qkv (_0_input1 , _0_cache_kv , infer_state , layer_weight )
270+ _0_input1 = None
271+ self ._post_cache_kv (_0_cache_kv , infer_state , layer_weight )
272+ _0_o = self ._context_attention_kernel (_0_q , _0_cache_kv , infer_state , layer_weight )
273+ _0_q = None
274+ _0_o = self ._tpsp_get_o (_0_o , infer_state , layer_weight )
275+ input_embdings .add_ (_0_o .view (- 1 , self .embed_dim_ ))
276+ _0_o = None
277+ _0_input1 = self ._ffn_norm (input_embdings , infer_state , layer_weight )
278+ _0_router_logits = layer_weight .moe_gate .mm (_0_input1 )
279+
280+ # wait last 1 combine
281+ if getattr (infer_state1 , "hook" , None ) is not None :
282+ infer_state1 .hook ()
283+ infer_state1 .hook = None
284+
285+ _0_topk_weight , _0_topk_idx , _0_qinput_tensor = layer_weight .experts .select_experts_and_quant_input (
286+ _0_input1 , _0_router_logits
287+ )
288+ from deep_ep import Buffer
289+
290+ _0_overlap_event = Buffer .capture ()
291+
292+ # 1 attention
293+ _1_input1 = self ._att_norm (input_embdings1 , infer_state1 , layer_weight )
294+ _1_cache_kv = self ._pre_cache_kv (infer_state1 , layer_weight )
295+ _1_q , _1_cache_kv = self ._tpsp_get_qkv (_1_input1 , _1_cache_kv , infer_state1 , layer_weight )
296+ _1_input1 = None
297+ self ._post_cache_kv (_1_cache_kv , infer_state1 , layer_weight )
298+ _1_o = self ._context_attention_kernel (_1_q , _1_cache_kv , infer_state1 , layer_weight )
299+ _1_q = None
300+ _1_o = self ._tpsp_get_o (_1_o , infer_state1 , layer_weight )
301+ input_embdings1 .add_ (_1_o .view (- 1 , self .embed_dim_ ))
302+ _1_o = None
303+ _1_input1 = self ._ffn_norm (input_embdings1 , infer_state1 , layer_weight )
304+ # to do gate and disptatch
305+
306+ _1_router_logits = layer_weight .moe_gate .mm (_1_input1 )
307+
308+ # 0 dispatch execute
309+ (
310+ _0_recv_x ,
311+ _0_recv_topk_idx ,
312+ _0_recv_topk_weight ,
313+ _0_num_recv_tokens_per_expert_list ,
314+ _0_handle ,
315+ _0_hook ,
316+ ) = layer_weight .experts .dispatch (_0_qinput_tensor , _0_topk_idx , _0_topk_weight , overlap_event = _0_overlap_event )
317+ infer_state .hook = _0_hook
318+
319+ # wait 0 dispatch
320+ if getattr (infer_state , "hook" , None ) is not None :
321+ infer_state .hook ()
322+ infer_state .hook = None
323+
324+ _1_topk_weight , _1_topk_idx , _1_qinput_tensor = layer_weight .experts .select_experts_and_quant_input (
325+ _1_input1 , _1_router_logits
326+ )
327+
328+ _1_overlap_event = Buffer .capture ()
329+
330+ # 0 shared expert
331+ if self .n_shared_experts is not None :
332+ _0_shared_output = LlamaTransformerLayerInfer ._ffn (self , _0_input1 , infer_state , layer_weight )
333+
334+ # 1 shared expert
335+ if self .n_shared_experts is not None :
336+ _1_shared_output = LlamaTransformerLayerInfer ._ffn (self , _1_input1 , infer_state1 , layer_weight )
337+
338+ # 0 moe calu
339+ _0_moe_out = layer_weight .experts .prefilled_group_gemm (
340+ _0_num_recv_tokens_per_expert_list , _0_recv_x , _0_recv_topk_idx , _0_recv_topk_weight
341+ )
342+
343+ # 1 dispatch execute
344+ (
345+ _1_recv_x ,
346+ _1_recv_topk_idx ,
347+ _1_recv_topk_weight ,
348+ _1_num_recv_tokens_per_expert_list ,
349+ _1_handle ,
350+ _1_hook ,
351+ ) = layer_weight .experts .dispatch (_1_qinput_tensor , _1_topk_idx , _1_topk_weight , overlap_event = _1_overlap_event )
352+ infer_state1 .hook = _1_hook
353+
354+ # wait 1 dispatch
355+ if getattr (infer_state1 , "hook" , None ) is not None :
356+ infer_state1 .hook ()
357+ infer_state1 .hook = None
358+
359+ _0_combine_event = Buffer .capture ()
360+ # 0 combine execute
361+ _0_ffn_out , _0_hook = layer_weight .experts .combine (_0_moe_out , _0_handle , _0_combine_event )
362+ infer_state .hook = _0_hook
363+
364+ # 1 moe calc
365+ _1_moe_out = layer_weight .experts .prefilled_group_gemm (
366+ _1_num_recv_tokens_per_expert_list , _1_recv_x , _1_recv_topk_idx , _1_recv_topk_weight
367+ )
368+
369+ # wait 0 combine
370+ if getattr (infer_state , "hook" , None ) is not None :
371+ infer_state .hook ()
372+ infer_state .hook = None
373+
374+ _1_combine_event = Buffer .capture ()
375+
376+ # _0_ffn_out *= self.routed_scaling_factor
377+ if self .n_shared_experts is not None :
378+ _0_ffn_out .add_ (_0_shared_output )
379+ input_embdings .add_ (_0_ffn_out .view (- 1 , self .embed_dim_ ))
380+
381+ # 1 combine execute
382+ _1_ffn_out , _1_hook = layer_weight .experts .combine (_1_moe_out , _1_handle , _1_combine_event )
383+
384+ def _1_hook_post ():
385+ _1_hook ()
386+ nonlocal _1_ffn_out
387+ # _1_ffn_out *= self.routed_scaling_factor
388+ if self .n_shared_experts is not None :
389+ _1_ffn_out .add_ (_1_shared_output )
390+ input_embdings1 .add_ (_1_ffn_out .view (- 1 , self .embed_dim_ ))
391+ return
392+
393+ infer_state1 .hook = _1_hook_post
394+
395+ return input_embdings , input_embdings1
0 commit comments