@@ -28,7 +28,6 @@ def __init__(self, layer_num, network_config, mode=[]):
2828 )
2929 self .num_experts_per_tok = network_config ["num_experts_per_tok" ]
3030 self .norm_topk_prob = network_config ["norm_topk_prob" ]
31- self .n_shared_experts = network_config .get ("n_shared_experts" , None )
3231 super ().__init__ (layer_num , network_config , mode )
3332 self .head_dim_ = network_config ["head_dim" ]
3433 self .tp_k_head_num_ = max (self .tp_k_head_num_ , 1 )
@@ -153,10 +152,6 @@ def overlap_tpsp_token_forward(
153152 infer_state1 .hook ()
154153 infer_state1 .hook = None
155154
156- # 0 shared expert
157- if self .n_shared_experts is not None :
158- _0_shared_output = LlamaTransformerLayerInfer ._ffn (self , _0_input1 , infer_state , layer_weight )
159-
160155 # 0 dispatch
161156 (
162157 _0_recv_x ,
@@ -188,10 +183,6 @@ def overlap_tpsp_token_forward(
188183 infer_state .hook ()
189184 infer_state .hook = None
190185
191- # 1 shared expert
192- if self .n_shared_experts is not None :
193- _1_shared_output = LlamaTransformerLayerInfer ._ffn (self , _1_input1 , infer_state1 , layer_weight )
194-
195186 # 1 dispatch
196187 (
197188 _1_recv_x ,
@@ -227,9 +218,6 @@ def overlap_tpsp_token_forward(
227218 # 0 hook
228219 if getattr (infer_state , "hook" , None ) is not None :
229220 infer_state .hook ()
230- # _0_ffn_out *= self.routed_scaling_factor
231- if self .n_shared_experts is not None :
232- _0_ffn_out .add_ (_0_shared_output )
233221 input_embdings .add_ (_0_ffn_out .view (- 1 , self .embed_dim_ ))
234222 infer_state .hook = None
235223
@@ -241,9 +229,6 @@ def overlap_tpsp_token_forward(
241229 def _1_hook_post ():
242230 _1_hook ()
243231 nonlocal _1_ffn_out
244- # _1_ffn_out *= self.routed_scaling_factor
245- if self .n_shared_experts is not None :
246- _1_ffn_out .add_ (_1_shared_output )
247232 input_embdings1 .add_ (_1_ffn_out .view (- 1 , self .embed_dim_ ))
248233 return
249234
@@ -327,14 +312,6 @@ def overlap_tpsp_context_forward(
327312
328313 _1_overlap_event = Buffer .capture ()
329314
330- # 0 shared expert
331- if self .n_shared_experts is not None :
332- _0_shared_output = LlamaTransformerLayerInfer ._ffn (self , _0_input1 , infer_state , layer_weight )
333-
334- # 1 shared expert
335- if self .n_shared_experts is not None :
336- _1_shared_output = LlamaTransformerLayerInfer ._ffn (self , _1_input1 , infer_state1 , layer_weight )
337-
338315 # 0 moe calu
339316 _0_moe_out = layer_weight .experts .prefilled_group_gemm (
340317 _0_num_recv_tokens_per_expert_list , _0_recv_x , _0_recv_topk_idx , _0_recv_topk_weight
@@ -373,9 +350,6 @@ def overlap_tpsp_context_forward(
373350
374351 _1_combine_event = Buffer .capture ()
375352
376- # _0_ffn_out *= self.routed_scaling_factor
377- if self .n_shared_experts is not None :
378- _0_ffn_out .add_ (_0_shared_output )
379353 input_embdings .add_ (_0_ffn_out .view (- 1 , self .embed_dim_ ))
380354
381355 # 1 combine execute
@@ -384,9 +358,6 @@ def overlap_tpsp_context_forward(
384358 def _1_hook_post ():
385359 _1_hook ()
386360 nonlocal _1_ffn_out
387- # _1_ffn_out *= self.routed_scaling_factor
388- if self .n_shared_experts is not None :
389- _1_ffn_out .add_ (_1_shared_output )
390361 input_embdings1 .add_ (_1_ffn_out .view (- 1 , self .embed_dim_ ))
391362 return
392363
0 commit comments