Skip to content

Commit 58a7cb2

Browse files
author
wangzaijun
committed
fix
1 parent 5236f9c commit 58a7cb2

File tree

1 file changed

+16
-1
lines changed

1 file changed

+16
-1
lines changed

lightllm/models/deepseek2/layer_infer/transformer_layer_infer.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -82,10 +82,13 @@ def _bind_ffn(self):
8282
moe_mode = os.environ.get("MOE_MODE", "TP")
8383
if moe_mode == "EP":
8484
self._ffn = partial(Deepseek2TransformerLayerInfer._moe_ffn_edp, self)
85+
self._tpsp_ffn = self._tpsp_ffn_ep
8586
else:
8687
self._ffn = partial(Deepseek2TransformerLayerInfer._moe_ffn, self)
88+
self._tpsp_ffn = self._tpsp_ffn_tp
8789
else:
8890
self._ffn = partial(LlamaTransformerLayerInfer._ffn, self)
91+
self._tpsp_ffn = self._tpsp_ffn_tp
8992

9093
def _bind_attention(self):
9194
if "triton_fp8kv" in self.mode:
@@ -737,7 +740,10 @@ def _moe_ffn_edp(
737740
ep_output = ep_output.view(token_num, hidden_dim)
738741
return ep_output
739742

740-
def _tpsp_ffn(
743+
def _tpsp_ffn(self, input, infer_state: Deepseek2InferStateInfo, layer_weight: Deepseek2TransformerLayerWeight):
744+
raise Exception("need bind to real impl")
745+
746+
def _tpsp_ffn_tp(
741747
self, input, infer_state: Deepseek2InferStateInfo, layer_weight: Deepseek2TransformerLayerWeight
742748
) -> torch.Tensor:
743749
input = input.view(-1, self.embed_dim_)
@@ -762,6 +768,15 @@ def _tpsp_ffn(
762768
ffn2_out = reduce_o_tensor
763769
return ffn2_out
764770

771+
def _tpsp_ffn_ep(
772+
self, input, infer_state: Deepseek2InferStateInfo, layer_weight: Deepseek2TransformerLayerWeight
773+
) -> torch.Tensor:
774+
input = input.view(-1, self.embed_dim_)
775+
776+
ffn2_out = self._ffn(input=input, infer_state=infer_state, layer_weight=layer_weight)
777+
778+
return ffn2_out
779+
765780
def overlap_tpsp_token_forward(
766781
self,
767782
input_embdings: torch.Tensor,

0 commit comments

Comments
 (0)