Skip to content

Commit f68a306

Browse files
authored
fix
1 parent 5e2dfc9 commit f68a306

File tree

3 files changed

+15
-2
lines changed

3 files changed

+15
-2
lines changed

lightllm/models/llama/layer_infer/post_layer_infer.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212
from lightllm.common.basemodel import PostLayerInferTpl
1313
from lightllm.utils.infer_utils import mark_cost_time
1414
from lightllm.distributed.communication_op import all_gather
15-
from lightllm.utils.envs_utils import get_env_start_args
1615

1716

1817
class LlamaPostLayerInfer(PostLayerInferTpl):

lightllm/models/qwen3_moe/layer_infer/transformer_layer_infer.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -120,6 +120,11 @@ def _tpsp_get_qkv(
120120
infer_state.position_cos,
121121
infer_state.position_sin,
122122
)
123+
124+
if infer_state.need_dp_prefill_balance:
125+
q = infer_state._all_to_all_unbalance_get(data=q)
126+
cache_kv = infer_state._all_to_all_unbalance_get(data=cache_kv)
127+
123128
return q, cache_kv
124129

125130
def _moe_ffn(

lightllm/models/stablelm/layer_infer/transformer_layer_infer.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import torch.distributed as dist
44
import numpy as np
55
from functools import partial
6-
6+
from typing import Tuple
77
from lightllm.models.llama.triton_kernel.rotary_emb import rotary_emb_fwd
88
from lightllm.models.bloom.triton_kernel.layernorm import layernorm_forward
99
from lightllm.models.stablelm.layer_weights.transformer_layer_weight import StablelmTransformerLayerWeight
@@ -38,6 +38,10 @@ def _get_qkv(
3838
)
3939
return q, cache_kv
4040

41+
def _tpsp_get_qkv(self, input, infer_state, layer_weight) -> Tuple[torch.Tensor, torch.Tensor]:
42+
# TODO
43+
raise Exception("not impl")
44+
4145
def _get_o(
4246
self, input, infer_state: LlamaInferStateInfo, layer_weight: StablelmTransformerLayerWeight
4347
) -> torch.Tensor:
@@ -46,6 +50,11 @@ def _get_o(
4650
)
4751
return o_tensor
4852

53+
def _tpsp_get_o(self, input, infer_state, layer_weight) -> Tuple[torch.Tensor, torch.Tensor]:
54+
# TODO
55+
raise Exception("not impl")
56+
57+
4958
def _att_norm(
5059
self, input, infer_state: LlamaInferStateInfo, layer_weight: StablelmTransformerLayerWeight
5160
) -> torch.Tensor:

0 commit comments

Comments
 (0)