@@ -82,10 +82,13 @@ def _bind_ffn(self):
8282 moe_mode = os .environ .get ("MOE_MODE" , "TP" )
8383 if moe_mode == "EP" :
8484 self ._ffn = partial (Deepseek2TransformerLayerInfer ._moe_ffn_edp , self )
85+ self ._tpsp_ffn = self ._tpsp_ffn_ep
8586 else :
8687 self ._ffn = partial (Deepseek2TransformerLayerInfer ._moe_ffn , self )
88+ self ._tpsp_ffn = self ._tpsp_ffn_tp
8789 else :
8890 self ._ffn = partial (LlamaTransformerLayerInfer ._ffn , self )
91+ self ._tpsp_ffn = self ._tpsp_ffn_tp
8992
9093 def _bind_attention (self ):
9194 if "triton_fp8kv" in self .mode :
@@ -737,7 +740,10 @@ def _moe_ffn_edp(
737740 ep_output = ep_output .view (token_num , hidden_dim )
738741 return ep_output
739742
740- def _tpsp_ffn (
743+ def _tpsp_ffn (self , input , infer_state : Deepseek2InferStateInfo , layer_weight : Deepseek2TransformerLayerWeight ):
744+ raise Exception ("need bind to real impl" )
745+
746+ def _tpsp_ffn_tp (
741747 self , input , infer_state : Deepseek2InferStateInfo , layer_weight : Deepseek2TransformerLayerWeight
742748 ) -> torch .Tensor :
743749 input = input .view (- 1 , self .embed_dim_ )
@@ -762,6 +768,15 @@ def _tpsp_ffn(
762768 ffn2_out = reduce_o_tensor
763769 return ffn2_out
764770
771+ def _tpsp_ffn_ep (
772+ self , input , infer_state : Deepseek2InferStateInfo , layer_weight : Deepseek2TransformerLayerWeight
773+ ) -> torch .Tensor :
774+ input = input .view (- 1 , self .embed_dim_ )
775+
776+ ffn2_out = self ._ffn (input = input , infer_state = infer_state , layer_weight = layer_weight )
777+
778+ return ffn2_out
779+
765780 def overlap_tpsp_token_forward (
766781 self ,
767782 input_embdings : torch .Tensor ,
0 commit comments