Skip to content

Commit 6e880c1

Browse files
committed
fix qwen3moe overlap
1 parent 83d358f commit 6e880c1

File tree

2 files changed

+49
-6
lines changed

2 files changed

+49
-6
lines changed

lightllm/common/basemodel/layer_weights/meta_weights/mm_weight/mm_weight.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import os
22
import torch
33
from abc import abstractmethod
4-
from typing import Optional, Tuple, List, Dict, Union
4+
from typing import Optional, Tuple, List, Dict, Union, Type
55
from lightllm.common.basemodel.layer_infer.cache_tensor_manager import g_cache_manager
66
from lightllm.common.quantization.quantize_method import QuantizationMethod
77
from lightllm.common.basemodel.layer_weights.meta_weights.base_weight import BaseWeightTpl
@@ -88,9 +88,9 @@ def _load_weights(self, weights: Dict[str, torch.Tensor]) -> None:
8888
class MultiMMWeightTpl(MMWeightTpl):
8989
def __init__(
9090
self,
91-
weight_names: str,
91+
weight_names: list[str],
9292
data_type: torch.dtype,
93-
bias_names: Optional[str] = None,
93+
bias_names: Optional[list[str]] = None,
9494
quant_method: QuantizationMethod = None,
9595
tp_rank: int = None,
9696
tp_world_size: int = None,
@@ -183,6 +183,6 @@ def _get_quant_method(cls, quant_cfg: Quantcfg, layer_num_: int, name: str) -> Q
183183

184184
@classmethod
185185
def _get_mmcls(
186-
cls, quant_method: QuantizationMethod
187-
) -> Optional[Union[MMWeightTpl, MultiMMWeightTpl, BMMWeightTpl]]:
188-
return None
186+
cls, quant_method: QuantizationMethod, quantized_weight: bool
187+
) -> Type[Union[MMWeightTpl, MultiMMWeightTpl, BMMWeightTpl]]:
188+
raise NotImplementedError("Subclasses must implement _get_mmcls method")

lightllm/models/qwen3_moe/layer_infer/transformer_layer_infer.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
from functools import partial
1515
from lightllm.utils.log_utils import init_logger
1616
from lightllm.utils.dist_utils import get_global_world_size
17+
from lightllm.distributed.communication_op import all_gather_into_tensor
1718

1819
logger = init_logger(__name__)
1920

@@ -82,6 +83,48 @@ def _get_qkv(
8283
)
8384
return q, cache_kv
8485

86+
def _tpsp_get_qkv(
87+
self,
88+
input: torch.Tensor,
89+
cache_kv,
90+
infer_state: LlamaInferStateInfo,
91+
layer_weight: Qwen3MOETransformerLayerWeight,
92+
) -> torch.Tensor:
93+
if self.tp_world_size_ > 1:
94+
sp_token_num, hidden_dim = input.shape
95+
gather_input = self.alloc_tensor(
96+
(sp_token_num * self.tp_world_size_, hidden_dim), dtype=input.dtype, device=input.device
97+
)
98+
all_gather_into_tensor(gather_input, input, group=infer_state.dist_group, async_op=False)
99+
input = gather_input[0 : len(infer_state.position_cos), :]
100+
101+
input = input.view(-1, self.embed_dim_)
102+
q = layer_weight.q_proj.mm(input)
103+
cache_kv = layer_weight.kv_proj.mm(
104+
input, out=cache_kv.view(-1, (self.tp_k_head_num_ + self.tp_v_head_num_) * self.head_dim_)
105+
).view(-1, (self.tp_k_head_num_ + self.tp_v_head_num_), self.head_dim_)
106+
107+
rmsnorm_forward(
108+
q.view(-1, self.head_dim_),
109+
weight=layer_weight.q_norm_weight_.weight,
110+
eps=self.eps_,
111+
out=q.view(-1, self.head_dim_),
112+
)
113+
114+
cache_kv[:, : self.tp_k_head_num_, :] = rmsnorm_forward(
115+
cache_kv[:, : self.tp_k_head_num_, :].reshape(-1, cache_kv.shape[-1]),
116+
weight=layer_weight.k_norm_weight_.weight,
117+
eps=self.eps_,
118+
).view(-1, self.tp_k_head_num_, cache_kv.shape[-1])
119+
120+
rotary_emb_fwd(
121+
q.view(-1, self.tp_q_head_num_, self.head_dim_),
122+
cache_kv[:, : self.tp_k_head_num_, :],
123+
infer_state.position_cos,
124+
infer_state.position_sin,
125+
)
126+
return q, cache_kv
127+
85128
def _moe_ffn(
86129
self, input, infer_state: LlamaInferStateInfo, layer_weight: Qwen3MOETransformerLayerWeight
87130
) -> torch.Tensor:

0 commit comments

Comments
 (0)