Skip to content

Commit 7c305e4

Browse files
committed
update gemma
1 parent 5fc5194 commit 7c305e4

File tree

9 files changed

+57
-55
lines changed

9 files changed

+57
-55
lines changed

lightllm/common/basemodel/layer_weights/meta_weights/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,10 @@
33
MMWeight,
44
MultiMMWeight,
55
ROWMMWeight,
6+
ROWMMWeightNoTP,
67
COLMMWeight,
78
MultiROWMMWeight,
9+
MultiROWMMWeightNoTP,
810
CustomMMWeight,
911
CustomBMMWeight,
1012
)

lightllm/common/basemodel/layer_weights/meta_weights/mm_weight.py

Lines changed: 21 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,8 @@ class MMWeightTpl(BaseWeightTpl):
77
def __init__(self, data_type, split_n_embed):
88
super().__init__()
99
self.data_type_ = data_type
10-
self.split_n_embed = split_n_embed
10+
self.start = split_n_embed * self.tp_rank_
11+
self.end = split_n_embed * (self.tp_rank_ + 1)
1112
self.quant_method = None
1213
self.weight = None
1314
self.bias = None
@@ -58,32 +59,35 @@ def __init__(self, weight_name, data_type, split_n_embed, bias_name=None):
5859
super().__init__(weight_name, data_type, split_n_embed, bias_name)
5960

6061
def load_hf_weights(self, weights):
61-
start = self.split_n_embed * self.tp_rank_
62-
end = self.split_n_embed * (self.tp_rank_ + 1)
6362
weight = None
6463
if self.weight_name in weights:
6564
weight = weights[self.weight_name].to(self.data_type_)
66-
self.weight = weight[start:end]
65+
self.weight = weight[self.start : self.end]
6766
if self.bias_name in weights:
68-
bias = weights[self.bias_name].to(self.data_type_)[start:end]
67+
bias = weights[self.bias_name].to(self.data_type_)[self.start : self.end]
6968
self.bias = bias.cuda(self.tp_rank_)
7069
if weight is None:
7170
return
7271
self._post_load_weights()
7372
return
7473

7574

75+
class ROWMMWeightNoTP(MMWeight):
76+
def __init__(self, weight_name, data_type, split_n_embed, bias_name=None):
77+
super().__init__(weight_name, data_type, split_n_embed, bias_name)
78+
self.start = 0
79+
self.end = split_n_embed
80+
81+
7682
class COLMMWeight(MMWeight):
7783
def __init__(self, weight_name, data_type, split_n_embed, bias_name=None):
7884
super().__init__(weight_name, data_type, split_n_embed, bias_name)
7985

8086
def load_hf_weights(self, weights):
81-
start = self.split_n_embed * self.tp_rank_
82-
end = self.split_n_embed * (self.tp_rank_ + 1)
8387
weight = None
8488
if self.weight_name in weights:
8589
weight = weights[self.weight_name].to(self.data_type_)
86-
self.weight = weight[:, start:end]
90+
self.weight = weight[:, self.start : self.end]
8791
if self.bias_name in weights:
8892
bias = weights[self.bias_name].to(self.data_type_)
8993
self.bias = (bias / self.world_size_).cuda(self.tp_rank_)
@@ -126,20 +130,25 @@ def _fuse(self):
126130
return self
127131

128132
def load_hf_weights(self, weights):
129-
start = self.split_n_embed * self.tp_rank_
130-
end = self.split_n_embed * (self.tp_rank_ + 1)
131133
weight = None
132134
for i in range(len(self.weight_names)):
133135
if self.weight_names[i] in weights:
134136
weight = weights[self.weight_names[i]].to(self.data_type_)
135-
self.weights[i] = weight[start:end]
137+
self.weights[i] = weight[self.start : self.end]
136138
if self.has_bias and self.bias_names[i] in weights:
137139
bias = weights[self.bias_names[i]].to(self.data_type_)
138-
self.biases[i] = bias[start:end]
140+
self.biases[i] = bias[self.start : self.end]
139141
self._fuse()
140142
return
141143

142144

145+
class MultiROWMMWeightNoTP(MultiROWMMWeight):
146+
def __init__(self, weight_names, data_type, split_n_embed, bias_names=None):
147+
super().__init__(weight_names, data_type, split_n_embed, bias_names)
148+
self.start = 0
149+
self.end = split_n_embed
150+
151+
143152
class CustomMMWeight(ROWMMWeight):
144153
def __init__(
145154
self,

lightllm/models/deepseek2/layer_weights/transformer_layer_weight.py

Lines changed: 16 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -96,21 +96,25 @@ def __init__(
9696
disable_qk_absorb=False,
9797
disable_vo_absorb=False,
9898
):
99+
self.disable_qk_absorb = disable_qk_absorb
100+
self.disable_vo_absorb = disable_vo_absorb
99101
super().__init__(layer_num, tp_rank, world_size, data_type, network_config, mode, quant_cfg)
102+
return
103+
104+
def _parse_config(self):
105+
super()._parse_config()
100106
self.is_moe = (
101107
self.network_config_["n_routed_experts"] is not None
102108
and self.layer_num_ >= self.network_config_["first_k_dense_replace"]
103109
and self.layer_num_ % self.network_config_["moe_layer_freq"] == 0
104110
)
105-
self.tp_q_head_num_ = network_config["num_attention_heads"] // self.world_size_
111+
self.tp_q_head_num_ = self.network_config_["num_attention_heads"] // self.world_size_
106112
self.n_routed_experts = self.network_config_["n_routed_experts"]
107113
self.q_lora_rank = self.network_config_["q_lora_rank"]
108114
self.qk_nope_head_dim = self.network_config_["qk_nope_head_dim"]
109115
self.qk_rope_head_dim = self.network_config_["qk_rope_head_dim"]
110116
self.num_attention_heads = self.network_config_["num_attention_heads"]
111117
self.kv_lora_rank = self.network_config_["kv_lora_rank"]
112-
self.disable_qk_absorb = disable_qk_absorb
113-
self.disable_vo_absorb = disable_vo_absorb
114118
self.fuse_pairs = {}
115119
if not self.disable_qk_absorb:
116120
if self.q_lora_rank is None:
@@ -125,16 +129,15 @@ def __init__(
125129
}
126130
)
127131

128-
self.init_qkvo()
132+
def _init_weight(self):
133+
self._init_qkvo()
129134
if self.is_moe:
130-
self.init_moe()
135+
self._init_moe()
131136
else:
132-
self.init_ffn()
133-
self.init_norm()
134-
self.set_quantization()
135-
return
137+
self._init_ffn()
138+
self._init_norm()
136139

137-
def init_qkvo(self):
140+
def _init_qkvo(self):
138141
q_split_n_embed = self.qk_nope_head_dim * self.tp_q_head_num_
139142
q_split_n_embed_with_rope = (
140143
(self.qk_nope_head_dim + self.qk_rope_head_dim) * self.num_attention_heads // self.world_size_
@@ -201,7 +204,7 @@ def _load_mlp(self, mlp_prefix, split_inter_size):
201204
self.up_proj = ROWMMWeight(f"{mlp_prefix}.up_proj.weight", self.data_type_, split_inter_size, wait_fuse=True)
202205
self.down_proj = COLMMWeight(f"{mlp_prefix}.down_proj.weight", self.data_type_, split_inter_size)
203206

204-
def init_moe(self):
207+
def _init_moe(self):
205208
moe_intermediate_size = self.network_config_["moe_intermediate_size"]
206209
self.moe_gate = ROWMMWeight(
207210
f"model.layers.{self.layer_num_}.mlp.gate.weight", self.data_type_, moe_intermediate_size, disable_tp=True
@@ -220,12 +223,12 @@ def init_moe(self):
220223
data_type=self.data_type_,
221224
)
222225

223-
def init_ffn(self):
226+
def _init_ffn(self):
224227
inter_size = self.network_config_["intermediate_size"]
225228
split_inter_size = inter_size // self.world_size_
226229
self._load_mlp(f"model.layers.{self.layer_num_}.mlp", split_inter_size)
227230

228-
def init_norm(self):
231+
def _init_norm(self):
229232
self.att_norm_weight_ = NormWeight(f"model.layers.{self.layer_num_}.input_layernorm.weight", self.data_type_)
230233
self.ffn_norm_weight_ = NormWeight(
231234
f"model.layers.{self.layer_num_}.post_attention_layernorm.weight", self.data_type_

lightllm/models/deepseek2/model.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,8 @@ class Deepseek2TpPartModel(LlamaTpPartModel):
2222

2323
def __init__(self, kvargs):
2424
super().__init__(kvargs)
25+
self.disable_qk_absorb = int(os.getenv("DISABLE_QK_ABSORB", 0))
26+
self.disable_vo_absorb = int(os.getenv("DISABLE_VO_ABSORB", 0))
2527
return
2628

2729
def _init_some_value(self):

lightllm/models/gemma_2b/layer_weights/transformer_layer_weight.py

Lines changed: 4 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import math
33
import numpy as np
44
from lightllm.models.llama.layer_weights.transformer_layer_weight import LlamaTransformerLayerWeight
5-
from lightllm.common.basemodel.layer_weights.meta_weights import GEMMANormWeight, ROWMMWeight
5+
from lightllm.common.basemodel.layer_weights.meta_weights import GEMMANormWeight, ROWMMWeight, MultiROWMMWeightNoTP
66

77

88
class Gemma_2bTransformerLayerWeight(LlamaTransformerLayerWeight):
@@ -14,21 +14,11 @@ def _init_qkv(self):
1414
q_split_n_embed = self.head_dim * self.n_head // self.world_size_
1515
kv_split_n_embed = self.head_dim * self.n_kv_head
1616
self.q_proj = ROWMMWeight(self._q_weight_name, self.data_type_, q_split_n_embed, bias_name=self._q_bias_name)
17-
self.k_proj = ROWMMWeight(
18-
self._k_weight_name,
17+
self.kv_proj = MultiROWMMWeightNoTP(
18+
[self._k_weight_name, self._v_weight_name],
1919
self.data_type_,
2020
kv_split_n_embed,
21-
bias_name=self._k_bias_name,
22-
wait_fuse=True,
23-
disable_tp=True,
24-
)
25-
self.v_proj = ROWMMWeight(
26-
self._v_weight_name,
27-
self.data_type_,
28-
kv_split_n_embed,
29-
bias_name=self._v_bias_name,
30-
wait_fuse=True,
31-
disable_tp=True,
21+
bias_names=[self._k_bias_name, self._v_bias_name],
3222
)
3323

3424
def _init_norm(self):

lightllm/models/gemma_2b/model.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from lightllm.models.llama.layer_infer.post_layer_infer import LlamaPostLayerInfer
66
from lightllm.models.llama.infer_struct import LlamaInferStateInfo
77
from lightllm.models.llama.model import LlamaTpPartModel
8+
from lightllm.common.mem_utils import select_mem_manager_class
89

910

1011
class Gemma_2bTpPartModel(LlamaTpPartModel):
@@ -33,3 +34,14 @@ def _verify_params(self):
3334
# assert self.config["num_key_value_heads"] % self.world_size_ == 0
3435
assert self.config["num_attention_heads"] % self.world_size_ == 0
3536
return
37+
38+
def _init_mem_manager(self):
39+
self.mem_manager = select_mem_manager_class(self.mode)(
40+
self.max_total_token_num,
41+
dtype=self.data_type,
42+
head_num=self.config["num_key_value_heads"],
43+
head_dim=self.config["hidden_size"] // self.config["num_attention_heads"],
44+
layer_num=self.config["num_hidden_layers"],
45+
mem_fraction=self.mem_fraction,
46+
)
47+
return

lightllm/server/api_cli.py

Lines changed: 0 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -234,16 +234,4 @@ def make_argument_parser() -> argparse.ArgumentParser:
234234
help="""Path of quantization config. It can be used for mixed quantization.
235235
Examples can be found in lightllm/common/quantization/configs.""",
236236
)
237-
parser.add_argument(
238-
"--disable_qk_absorb",
239-
default=False,
240-
action="store_true",
241-
help="Disable mla qk weight absorption",
242-
)
243-
parser.add_argument(
244-
"--disable_vo_absorb",
245-
default=False,
246-
action="store_true",
247-
help="Disable mla vo weight absorption",
248-
)
249237
return parser

lightllm/server/router/manager.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -133,8 +133,6 @@ async def wait_to_model_ready(self):
133133
"batch_max_tokens": self.args.batch_max_tokens,
134134
"quant_type": self.args.quant_type,
135135
"quant_cfg": self.args.quant_cfg,
136-
"disable_qk_absorb": self.args.disable_qk_absorb,
137-
"disable_vo_absorb": self.args.disable_vo_absorb,
138136
"pd_rpyc_port": self.args.pd_tp_infer_rpyc_ports[rank_id], # 非 pd 模式可以不设置
139137
}
140138
init_model_ret.append(self.model_rpcs[rank_id].init_model(kvargs))

lightllm/server/router/model_infer/mode_backend/base_backend.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -109,8 +109,6 @@ def init_model(self, kvargs):
109109
"batch_max_tokens": kvargs.get("batch_max_tokens", None),
110110
"quant_type": kvargs.get("quant_type", None),
111111
"quant_cfg": kvargs.get("quant_cfg", None),
112-
"disable_qk_absorb": kvargs.get("disable_qk_absorb", False),
113-
"disable_vo_absorb": kvargs.get("disable_vo_absorb", False),
114112
"run_mode": self.run_mode,
115113
}
116114

0 commit comments

Comments
 (0)