Skip to content

Commit f3d8e61

Browse files
Qwen3MOE for tp8 (repeat kv) and qwen3 dense fix (#877)
Co-authored-by: baishihao <baishihao@sensetime.com> Co-authored-by: wangzaijun <wzjhelloworld@qq.com>
1 parent 37e5071 commit f3d8e61

File tree

7 files changed

+67
-7
lines changed

7 files changed

+67
-7
lines changed

lightllm/models/qwen3/layer_weights/transformer_layer_weight.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -19,12 +19,6 @@
1919

2020
class Qwen3TransformerLayerWeight(LlamaTransformerLayerWeight):
2121
def __init__(self, layer_num, data_type, network_config, mode=[], quant_cfg=None):
22-
self.n_routed_experts = network_config["num_experts"]
23-
self.is_moe = (
24-
network_config["num_experts"] > 0
25-
and layer_num not in network_config["mlp_only_layers"]
26-
and (layer_num + 1) % network_config["decoder_sparse_step"] == 0
27-
)
2822
super().__init__(layer_num, data_type, network_config, mode, quant_cfg)
2923
return
3024

lightllm/models/qwen3/model.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from lightllm.models.qwen3.layer_weights.transformer_layer_weight import Qwen3TransformerLayerWeight
55
from lightllm.models.llama.model import LlamaTpPartModel
66
from lightllm.utils.log_utils import init_logger
7+
from lightllm.common.mem_utils import select_mem_manager_class
78

89

910
logger = init_logger(__name__)
@@ -19,3 +20,17 @@ class Qwen3TpPartModel(LlamaTpPartModel):
1920
def __init__(self, kvargs):
2021
super().__init__(kvargs)
2122
return
23+
24+
def _init_mem_manager(self):
25+
head_dim_ = self.config["hidden_size"] // self.config["num_attention_heads"]
26+
head_dim_ = self.config.get("head_dim", head_dim_)
27+
tp_k_head_num_ = max(self.config["num_key_value_heads"] // self.tp_world_size_, 1)
28+
self.mem_manager = select_mem_manager_class(self.mode)(
29+
self.max_total_token_num,
30+
dtype=self.data_type,
31+
head_num=tp_k_head_num_,
32+
head_dim=head_dim_,
33+
layer_num=self.config["num_hidden_layers"],
34+
mem_fraction=self.mem_fraction,
35+
)
36+
return

lightllm/models/qwen3_moe/layer_infer/transformer_layer_infer.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,8 @@ def __init__(self, layer_num, network_config, mode=[]):
2929
self.norm_topk_prob = network_config["norm_topk_prob"]
3030
super().__init__(layer_num, network_config, mode)
3131
self.head_dim_ = network_config["head_dim"]
32+
self.tp_k_head_num_ = max(self.tp_k_head_num_, 1)
33+
self.tp_v_head_num_ = max(self.tp_v_head_num_, 1)
3234
return
3335

3436
def _bind_func(self):

lightllm/models/qwen3_moe/layer_weights/transformer_layer_weight.py

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,31 @@ def _init_weight_names(self):
4646
self._ffn_norm_weight_name = f"model.layers.{self.layer_num_}.post_attention_layernorm.weight"
4747
self._ffn_norm_bias_name = None
4848

49+
def _parse_config(self):
50+
self.tp_q_head_num_ = self.network_config_["num_attention_heads"] // self.tp_world_size_
51+
self.tp_k_head_num_ = max(self.network_config_["num_key_value_heads"] // self.tp_world_size_, 1)
52+
self.tp_v_head_num_ = self.tp_k_head_num_
53+
self.tp_o_head_num_ = self.tp_q_head_num_
54+
self.head_dim = self.network_config_["head_dim"]
55+
assert self.tp_k_head_num_ * self.tp_world_size_ % self.network_config_["num_key_value_heads"] == 0
56+
57+
def _repeat_weight(self, name, weights):
58+
repeat_size = self.tp_k_head_num_ * self.tp_world_size_ // self.network_config_["num_key_value_heads"]
59+
repeat_params = (1, repeat_size, 1, 1)
60+
if name in weights:
61+
weights[name] = (
62+
weights[name]
63+
.reshape(self.network_config_["num_key_value_heads"], self.head_dim, -1)
64+
.unsqueeze(1)
65+
.repeat(repeat_params)
66+
.reshape(self.network_config_["num_key_value_heads"] * self.head_dim * repeat_size, -1)
67+
)
68+
69+
def load_hf_weights(self, weights):
70+
self._repeat_weight(self._k_weight_name, weights)
71+
self._repeat_weight(self._v_weight_name, weights)
72+
return super().load_hf_weights(weights)
73+
4974
def _init_weight(self):
5075
self._init_qkv()
5176
self._init_o()
@@ -99,6 +124,5 @@ def _init_moe(self):
99124

100125
def _init_norm(self):
101126
super()._init_norm()
102-
103127
self.q_norm_weight_ = NormWeight(weight_name=self._q_norm_name, data_type=self.data_type_)
104128
self.k_norm_weight_ = NormWeight(weight_name=self._k_norm_name, data_type=self.data_type_)

lightllm/models/qwen3_moe/model.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from lightllm.models.qwen3_moe.layer_infer.transformer_layer_infer import Qwen3MOETransformerLayerInfer
44
from lightllm.models.qwen3_moe.layer_weights.transformer_layer_weight import Qwen3MOETransformerLayerWeight
55
from lightllm.models.llama.model import LlamaTpPartModel
6+
from lightllm.common.mem_utils import select_mem_manager_class
67
from lightllm.utils.log_utils import init_logger
78

89

@@ -19,3 +20,22 @@ class Qwen3MOEModel(LlamaTpPartModel):
1920
def __init__(self, kvargs):
2021
super().__init__(kvargs)
2122
return
23+
24+
def _verify_params(self):
25+
assert self.load_way in ["HF", "DS"], "llama only supports HF and DS format to load Now!"
26+
assert self.config["num_attention_heads"] % self.tp_world_size_ == 0
27+
return
28+
29+
def _init_mem_manager(self):
30+
head_dim_ = self.config["hidden_size"] // self.config["num_attention_heads"]
31+
head_dim_ = self.config.get("head_dim", head_dim_)
32+
tp_k_head_num_ = max(self.config["num_key_value_heads"] // self.tp_world_size_, 1)
33+
self.mem_manager = select_mem_manager_class(self.mode)(
34+
self.max_total_token_num,
35+
dtype=self.data_type,
36+
head_num=tp_k_head_num_,
37+
head_dim=head_dim_,
38+
layer_num=self.config["num_hidden_layers"],
39+
mem_fraction=self.mem_fraction,
40+
)
41+
return

lightllm/server/api_models.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,7 @@ class ChatCompletionRequest(BaseModel):
7676
ignore_eos: Optional[bool] = False
7777
role_settings: Optional[Dict[str, str]] = None
7878
character_settings: Optional[List[Dict[str, str]]] = None
79+
chat_template_kwargs: Optional[Dict[str, bool]] = None
7980

8081

8182
class FunctionResponse(BaseModel):

lightllm/server/build_prompt.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,10 @@ async def build_prompt(request, tools) -> str:
1616
kwargs["character_settings"] = request.character_settings
1717
if request.role_settings:
1818
kwargs["role_setting"] = request.role_settings
19+
20+
if request.chat_template_kwargs:
21+
kwargs.update(request.chat_template_kwargs)
22+
1923
try:
2024
input_str = tokenizer.apply_chat_template(**kwargs, tokenize=False, add_generation_prompt=True, tools=tools)
2125
except:

0 commit comments

Comments
 (0)