Skip to content

Commit 6fe5879

Browse files
committed
qwen weight repeat for tp_size > kv_head_num
1 parent 7a18054 commit 6fe5879

File tree

6 files changed

+78
-84
lines changed

6 files changed

+78
-84
lines changed

lightllm/models/qwen2/layer_weights/transformer_layer_weight.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,3 +13,42 @@ def _init_weight_names(self):
1313
self._q_bias_name = f"model.layers.{self.layer_num_}.self_attn.q_proj.bias"
1414
self._k_bias_name = f"model.layers.{self.layer_num_}.self_attn.k_proj.bias"
1515
self._v_bias_name = f"model.layers.{self.layer_num_}.self_attn.v_proj.bias"
16+
17+
def _parse_config(self):
18+
self.tp_q_head_num_ = self.network_config_["num_attention_heads"] // self.tp_world_size_
19+
self.tp_k_head_num_ = max(self.network_config_["num_key_value_heads"] // self.tp_world_size_, 1)
20+
self.tp_v_head_num_ = self.tp_k_head_num_
21+
self.tp_o_head_num_ = self.tp_q_head_num_
22+
head_dim = self.network_config_["hidden_size"] // self.network_config_["num_attention_heads"]
23+
self.head_dim = self.network_config_.get("head_dim", head_dim)
24+
assert self.tp_k_head_num_ * self.tp_world_size_ % self.network_config_["num_key_value_heads"] == 0
25+
26+
def _repeat_weight(self, name, weights):
27+
# for tp_world_size_ > num_key_value_heads
28+
if name not in weights:
29+
return
30+
31+
tensor = weights[name]
32+
num_kv_heads = self.network_config_["num_key_value_heads"]
33+
repeat_size = self.tp_k_head_num_ * self.tp_world_size_ // num_kv_heads
34+
35+
if tensor.ndim == 1:
36+
# Bias (1D tensor)
37+
tensor = tensor.reshape(num_kv_heads, -1).unsqueeze(1).repeat(1, repeat_size, 1).reshape(-1)
38+
else:
39+
# Weight (2D tensor)
40+
tensor = (
41+
tensor.reshape(num_kv_heads, -1, tensor.shape[-1])
42+
.unsqueeze(1)
43+
.repeat(1, repeat_size, 1, 1)
44+
.reshape(-1, tensor.shape[-1])
45+
)
46+
weights[name] = tensor
47+
48+
def load_hf_weights(self, weights):
49+
self._repeat_weight(self._k_weight_name, weights)
50+
self._repeat_weight(self._v_weight_name, weights)
51+
if self._k_bias_name is not None and self._v_bias_name is not None:
52+
self._repeat_weight(self._k_bias_name, weights)
53+
self._repeat_weight(self._v_bias_name, weights)
54+
return super().load_hf_weights(weights)

lightllm/models/qwen2/model.py

Lines changed: 26 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
from lightllm.models.qwen2.layer_weights.pre_and_post_layer_weight import Qwen2PreAndPostLayerWeight
33
from lightllm.models.qwen2.layer_weights.transformer_layer_weight import Qwen2TransformerLayerWeight
44
from lightllm.models.llama.model import LlamaTpPartModel
5+
from lightllm.common.mem_utils import select_mem_manager_class
56

67

78
@ModelRegistry("qwen2")
@@ -22,7 +23,30 @@ def _init_config(self):
2223
return
2324

2425
def _verify_params(self):
25-
assert self.load_way in ["HF"], "mistral only supports HF format to load Now!"
26-
assert self.config["num_key_value_heads"] % self.tp_world_size_ == 0
26+
assert self.load_way in ["HF", "DS"], "llama only supports HF and DS format to load Now!"
2727
assert self.config["num_attention_heads"] % self.tp_world_size_ == 0
2828
return
29+
30+
def _init_some_value(self):
31+
# Dealing with head_dim_!=n_embed // num_attention_heads scenarios, such as mistral 13B
32+
head_dim_ = self.config["n_embed"] // self.config["num_attention_heads"]
33+
self.head_dim_ = self.config.get("head_dim", head_dim_)
34+
self.tp_k_head_num_ = max(self.config["num_key_value_heads"] // self.tp_world_size_, 1)
35+
self.tp_v_head_num_ = self.tp_k_head_num_
36+
self.layers_num = self.config["n_layer"]
37+
self.vocab_size = self.config["vocab_size"]
38+
return
39+
40+
def _init_mem_manager(self):
41+
head_dim_ = self.config["hidden_size"] // self.config["num_attention_heads"]
42+
head_dim_ = self.config.get("head_dim", head_dim_)
43+
tp_k_head_num_ = max(self.config["num_key_value_heads"] // self.tp_world_size_, 1)
44+
self.mem_manager = select_mem_manager_class(self.mode)(
45+
self.max_total_token_num,
46+
dtype=self.data_type,
47+
head_num=tp_k_head_num_,
48+
head_dim=head_dim_,
49+
layer_num=self.config["num_hidden_layers"],
50+
mem_fraction=self.mem_fraction,
51+
)
52+
return

lightllm/models/qwen3/layer_weights/transformer_layer_weight.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,8 @@
22
import torch
33
import math
44
import numpy as np
5-
from lightllm.common.basemodel import TransformerLayerWeight
65
from lightllm.models.llama.layer_weights.transformer_layer_weight import LlamaTransformerLayerWeight
7-
from lightllm.utils.envs_utils import enable_env_vars
6+
from lightllm.models.qwen2.layer_weights.transformer_layer_weight import Qwen2TransformerLayerWeight
87
from lightllm.common.basemodel.layer_weights.meta_weights import (
98
ROWMMWeight,
109
MultiROWMMWeight,
@@ -17,7 +16,7 @@
1716
from functools import partial
1817

1918

20-
class Qwen3TransformerLayerWeight(LlamaTransformerLayerWeight):
19+
class Qwen3TransformerLayerWeight(Qwen2TransformerLayerWeight):
2120
def __init__(self, layer_num, data_type, network_config, mode=[], quant_cfg=None):
2221
super().__init__(layer_num, data_type, network_config, mode, quant_cfg)
2322
return
@@ -26,6 +25,9 @@ def _init_weight_names(self):
2625
super()._init_weight_names()
2726
self._q_norm_name = f"model.layers.{self.layer_num_}.self_attn.q_norm.weight"
2827
self._k_norm_name = f"model.layers.{self.layer_num_}.self_attn.k_norm.weight"
28+
self._q_bias_name = None
29+
self._k_bias_name = None
30+
self._v_bias_name = None
2931

3032
def _init_norm(self):
3133
super()._init_norm()

lightllm/models/qwen3/model.py

Lines changed: 2 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -3,16 +3,15 @@
33
from lightllm.models.registry import ModelRegistry
44
from lightllm.models.qwen3.layer_infer.transformer_layer_infer import Qwen3TransformerLayerInfer
55
from lightllm.models.qwen3.layer_weights.transformer_layer_weight import Qwen3TransformerLayerWeight
6-
from lightllm.models.llama.model import LlamaTpPartModel
6+
from lightllm.models.qwen2.model import Qwen2TpPartModel
77
from lightllm.utils.log_utils import init_logger
8-
from lightllm.common.mem_utils import select_mem_manager_class
98

109

1110
logger = init_logger(__name__)
1211

1312

1413
@ModelRegistry("qwen3")
15-
class Qwen3TpPartModel(LlamaTpPartModel):
14+
class Qwen3TpPartModel(Qwen2TpPartModel):
1615
# weight class
1716
transformer_weight_class = Qwen3TransformerLayerWeight
1817

@@ -22,17 +21,3 @@ class Qwen3TpPartModel(LlamaTpPartModel):
2221
def __init__(self, kvargs):
2322
super().__init__(kvargs)
2423
return
25-
26-
def _init_mem_manager(self):
27-
head_dim_ = self.config["hidden_size"] // self.config["num_attention_heads"]
28-
head_dim_ = self.config.get("head_dim", head_dim_)
29-
tp_k_head_num_ = max(self.config["num_key_value_heads"] // self.tp_world_size_, 1)
30-
self.mem_manager = select_mem_manager_class(self.mode)(
31-
self.max_total_token_num,
32-
dtype=self.data_type,
33-
head_num=tp_k_head_num_,
34-
head_dim=head_dim_,
35-
layer_num=self.config["num_hidden_layers"],
36-
mem_fraction=self.mem_fraction,
37-
)
38-
return

lightllm/models/qwen3_moe/layer_weights/transformer_layer_weight.py

Lines changed: 4 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import math
44
import numpy as np
55
from lightllm.common.basemodel import TransformerLayerWeight
6-
from lightllm.models.llama.layer_weights.transformer_layer_weight import LlamaTransformerLayerWeight
6+
from lightllm.models.qwen3.layer_weights.transformer_layer_weight import Qwen3TransformerLayerWeight
77
from lightllm.utils.envs_utils import enable_env_vars
88
from lightllm.common.basemodel.layer_weights.meta_weights import (
99
ROWMMWeight,
@@ -17,7 +17,7 @@
1717
from functools import partial
1818

1919

20-
class Qwen3MOETransformerLayerWeight(LlamaTransformerLayerWeight):
20+
class Qwen3MOETransformerLayerWeight(Qwen3TransformerLayerWeight):
2121
def __init__(self, layer_num, data_type, network_config, mode=[], quant_cfg=None):
2222
self.n_routed_experts = network_config["num_experts"]
2323
self.is_moe = (
@@ -46,36 +46,15 @@ def _init_weight_names(self):
4646
self._ffn_norm_weight_name = f"model.layers.{self.layer_num_}.post_attention_layernorm.weight"
4747
self._ffn_norm_bias_name = None
4848

49-
def _parse_config(self):
50-
self.tp_q_head_num_ = self.network_config_["num_attention_heads"] // self.tp_world_size_
51-
self.tp_k_head_num_ = max(self.network_config_["num_key_value_heads"] // self.tp_world_size_, 1)
52-
self.tp_v_head_num_ = self.tp_k_head_num_
53-
self.tp_o_head_num_ = self.tp_q_head_num_
54-
self.head_dim = self.network_config_["head_dim"]
55-
assert self.tp_k_head_num_ * self.tp_world_size_ % self.network_config_["num_key_value_heads"] == 0
56-
57-
def _repeat_weight(self, name, weights):
58-
repeat_size = self.tp_k_head_num_ * self.tp_world_size_ // self.network_config_["num_key_value_heads"]
59-
repeat_params = (1, repeat_size, 1, 1)
60-
if name in weights:
61-
weights[name] = (
62-
weights[name]
63-
.reshape(self.network_config_["num_key_value_heads"], -1, weights[name].shape[1])
64-
.unsqueeze(1)
65-
.repeat(repeat_params)
66-
.reshape(-1, weights[name].shape[1])
67-
)
68-
6949
def load_hf_weights(self, weights):
70-
self._repeat_weight(self._k_weight_name, weights)
71-
self._repeat_weight(self._v_weight_name, weights)
50+
super().load_hf_weights(weights)
7251
kv_b_quant_method = self.quant_cfg.get_quant_method(self.layer_num_, "kv_b_proj")
7352
if self.quant_cfg.quantized_weight:
7453
_k_scale_weight_name = self._k_weight_name.replace("weight", kv_b_quant_method.weight_scale_suffix)
7554
self._repeat_weight(_k_scale_weight_name, weights)
7655
_v_scale_weight_name = self._v_weight_name.replace("weight", kv_b_quant_method.weight_scale_suffix)
7756
self._repeat_weight(_v_scale_weight_name, weights)
78-
return super().load_hf_weights(weights)
57+
return
7958

8059
def _init_weight(self):
8160
self._init_qkv()
@@ -127,8 +106,3 @@ def _init_moe(self):
127106
)
128107
else:
129108
raise ValueError(f"Unsupported moe mode: {moe_mode}")
130-
131-
def _init_norm(self):
132-
super()._init_norm()
133-
self.q_norm_weight_ = NormWeight(weight_name=self._q_norm_name, data_type=self.data_type_)
134-
self.k_norm_weight_ = NormWeight(weight_name=self._k_norm_name, data_type=self.data_type_)

lightllm/models/qwen3_moe/model.py

Lines changed: 2 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -3,16 +3,15 @@
33
from lightllm.models.registry import ModelRegistry
44
from lightllm.models.qwen3_moe.layer_infer.transformer_layer_infer import Qwen3MOETransformerLayerInfer
55
from lightllm.models.qwen3_moe.layer_weights.transformer_layer_weight import Qwen3MOETransformerLayerWeight
6-
from lightllm.models.llama.model import LlamaTpPartModel
7-
from lightllm.common.mem_utils import select_mem_manager_class
6+
from lightllm.models.qwen3.model import Qwen3TpPartModel
87
from lightllm.utils.log_utils import init_logger
98

109

1110
logger = init_logger(__name__)
1211

1312

1413
@ModelRegistry("qwen3_moe")
15-
class Qwen3MOEModel(LlamaTpPartModel):
14+
class Qwen3MOEModel(Qwen3TpPartModel):
1615
# weight class
1716
transformer_weight_class = Qwen3MOETransformerLayerWeight
1817

@@ -22,32 +21,3 @@ class Qwen3MOEModel(LlamaTpPartModel):
2221
def __init__(self, kvargs):
2322
super().__init__(kvargs)
2423
return
25-
26-
def _verify_params(self):
27-
assert self.load_way in ["HF", "DS"], "llama only supports HF and DS format to load Now!"
28-
assert self.config["num_attention_heads"] % self.tp_world_size_ == 0
29-
return
30-
31-
def _init_some_value(self):
32-
# Dealing with head_dim_!=n_embed // num_attention_heads scenarios, such as mistral 13B
33-
head_dim_ = self.config["n_embed"] // self.config["num_attention_heads"]
34-
self.head_dim_ = self.config.get("head_dim", head_dim_)
35-
self.tp_k_head_num_ = max(self.config["num_key_value_heads"] // self.tp_world_size_, 1)
36-
self.tp_v_head_num_ = self.tp_k_head_num_
37-
self.layers_num = self.config["n_layer"]
38-
self.vocab_size = self.config["vocab_size"]
39-
return
40-
41-
def _init_mem_manager(self):
42-
head_dim_ = self.config["hidden_size"] // self.config["num_attention_heads"]
43-
head_dim_ = self.config.get("head_dim", head_dim_)
44-
tp_k_head_num_ = max(self.config["num_key_value_heads"] // self.tp_world_size_, 1)
45-
self.mem_manager = select_mem_manager_class(self.mode)(
46-
self.max_total_token_num,
47-
dtype=self.data_type,
48-
head_num=tp_k_head_num_,
49-
head_dim=head_dim_,
50-
layer_num=self.config["num_hidden_layers"],
51-
mem_fraction=self.mem_fraction,
52-
)
53-
return

0 commit comments

Comments
 (0)