Skip to content

Commit 1e1ee31

Browse files
sufubaoshihaobai
andauthored
refactor models (#612)
Co-authored-by: baishihao <[email protected]>
1 parent 89c248f commit 1e1ee31

File tree

12 files changed

+108
-92
lines changed

12 files changed

+108
-92
lines changed

lightllm/models/bloom/layer_infer/transformer_layer_infer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,7 @@ def _get_o(self, input, infer_state: InferStateInfo, layer_weight: BloomTransfor
9898
return o_tensor
9999

100100
def _ffn(self, input, infer_state: InferStateInfo, layer_weight: BloomTransformerLayerWeight) -> torch.Tensor:
101-
ffn1_out = layer_weight.up_proj.mm(input.view(-1, self.embed_dim_))
101+
ffn1_out = layer_weight.gate_up_proj.mm(input.view(-1, self.embed_dim_))
102102
input = None
103103
gelu_out = torch.nn.functional.gelu(ffn1_out, approximate="tanh")
104104
ffn1_out = None

lightllm/models/bloom/layer_weights/transformer_layer_weight.py

Lines changed: 29 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -49,46 +49,51 @@ def get_slopes_power_of_2(n):
4949

5050
class BloomTransformerLayerWeight(LlamaTransformerLayerWeight):
5151
def __init__(self, layer_num, tp_rank, world_size, data_type, network_config, mode, quant_cfg=None):
52-
super().__init__(layer_num, tp_rank, world_size, data_type, network_config, mode, quant_cfg, layer_prefix="h")
52+
super().__init__(layer_num, tp_rank, world_size, data_type, network_config, mode, quant_cfg)
5353
return
5454

55-
def _init_config(self):
55+
def _parse_config(self):
5656
self.n_embed = self.network_config_["n_embed"]
5757
self.n_head = self.network_config_["num_attention_heads"]
5858
self.n_inter = self.network_config_["n_embed"] * 4
5959
self.n_kv_head = self.network_config_["num_attention_heads"]
6060
self.head_dim = self.network_config_.get("head_dim", self.n_embed // self.n_head)
61+
# 计算生成alibi
62+
assert self.n_head % self.world_size_ == 0
63+
tp_head_num = self.n_head // self.world_size_
64+
tmp_alibi = generate_alibi(self.n_head, dtype=torch.float32)
65+
self.tp_alibi = tmp_alibi[self.tp_rank_ * tp_head_num : (self.tp_rank_ + 1) * tp_head_num].contiguous().cuda()
6166

6267
def _init_weight_names(self):
63-
self._q_weight_name = f"{self.layer_name}.self_attention.q_proj.weight"
64-
self._q_bias_name = f"{self.layer_name}.self_attention.q_proj.bias"
65-
self._k_weight_name = f"{self.layer_name}.self_attention.k_proj.weight"
66-
self._k_bias_name = f"{self.layer_name}.self_attention.k_proj.bias"
67-
self._v_weight_name = f"{self.layer_name}.self_attention.v_proj.weight"
68-
self._v_bias_name = f"{self.layer_name}.self_attention.v_proj.bias"
69-
self._o_weight_name = f"{self.layer_name}.self_attention.o_proj.weight"
70-
self._o_bias_name = f"{self.layer_name}.self_attention.o_proj.bias"
71-
72-
self._up_weight_name = f"{self.layer_name}.mlp.dense_h_to_4h.weight"
73-
self._up_bias_name = f"{self.layer_name}.mlp.dense_h_to_4h.bias"
74-
self._down_weight_name = f"{self.layer_name}.mlp.dense_4h_to_h.weight"
75-
self._down_bias_name = f"{self.layer_name}.mlp.dense_4h_to_h.bias"
76-
77-
self.att_norm_weight_name = f"{self.layer_name}.input_layernorm.weight"
78-
self.att_norm_bias_name = f"{self.layer_name}.input_layernorm.bias"
79-
self.ffn_norm_weight_name = f"{self.layer_name}.post_attention_layernorm.weight"
80-
self.ffn_norm_bias_name = f"{self.layer_name}.post_attention_layernorm.bias"
68+
self._q_weight_name = f"h.{self.layer_num_}.self_attention.q_proj.weight"
69+
self._q_bias_name = f"h.{self.layer_num_}.self_attention.q_proj.bias"
70+
self._k_weight_name = f"h.{self.layer_num_}.self_attention.k_proj.weight"
71+
self._k_bias_name = f"h.{self.layer_num_}.self_attention.k_proj.bias"
72+
self._v_weight_name = f"h.{self.layer_num_}.self_attention.v_proj.weight"
73+
self._v_bias_name = f"h.{self.layer_num_}.self_attention.v_proj.bias"
74+
self._o_weight_name = f"h.{self.layer_num_}.self_attention.dense.weight"
75+
self._o_bias_name = f"h.{self.layer_num_}.self_attention.dense.bias"
76+
77+
self._gate_up_weight_name = f"h.{self.layer_num_}.mlp.dense_h_to_4h.weight"
78+
self._gate_up_bias_name = f"h.{self.layer_num_}.mlp.dense_h_to_4h.bias"
79+
self._down_weight_name = f"h.{self.layer_num_}.mlp.dense_4h_to_h.weight"
80+
self._down_bias_name = f"h.{self.layer_num_}.mlp.dense_4h_to_h.bias"
81+
82+
self._att_norm_weight_name = f"h.{self.layer_num_}.input_layernorm.weight"
83+
self._att_norm_bias_name = f"h.{self.layer_num_}.input_layernorm.bias"
84+
self._ffn_norm_weight_name = f"h.{self.layer_num_}.post_attention_layernorm.weight"
85+
self._ffn_norm_bias_name = f"h.{self.layer_num_}.post_attention_layernorm.bias"
8186

8287
def _preprocess_weight(self, weights):
83-
qkv_weight_name = f"{self.layer_name}.self_attention.query_key_value.weight"
88+
qkv_weight_name = f"h.{self.layer_num_}.self_attention.query_key_value.weight"
8489
if qkv_weight_name in weights:
8590
att_qkv_dense_weight = weights[qkv_weight_name].reshape(self.n_head, 3, -1, self.n_embed)
8691
weights[self._q_weight_name] = att_qkv_dense_weight[:, 0, :, :].reshape(-1, self.n_embed)
8792
weights[self._k_weight_name] = att_qkv_dense_weight[:, 1, :, :].reshape(-1, self.n_embed)
8893
weights[self._v_weight_name] = att_qkv_dense_weight[:, 2, :, :].reshape(-1, self.n_embed)
8994
del weights[qkv_weight_name]
9095

91-
qkv_bias_name = f"{self.layer_name}.self_attention.query_key_value.bias"
96+
qkv_bias_name = f"h.{self.layer_num_}.self_attention.query_key_value.bias"
9297
if qkv_bias_name in weights:
9398
att_qkv_dense_bias = weights[qkv_bias_name].reshape(self.n_head, 3, -1)
9499
weights[self._q_bias_name] = att_qkv_dense_bias[:, 0, :].reshape(-1)
@@ -101,19 +106,10 @@ def load_hf_weights(self, weights):
101106
super().load_hf_weights(weights)
102107
return
103108

104-
def init_static_params(self):
105-
# 计算生成alibi
106-
head_num = self.network_config_["num_attention_heads"]
107-
tp_head_num = head_num // self.world_size_
108-
tmp_alibi = generate_alibi(head_num, dtype=torch.float32)
109-
assert head_num % self.world_size_ == 0
110-
self.tp_alibi = tmp_alibi[self.tp_rank_ * tp_head_num : (self.tp_rank_ + 1) * tp_head_num].contiguous().cuda()
111-
return
112-
113109
def _init_ffn(self):
114110
split_inter_size = self.n_inter // self.world_size_
115-
self.up_proj = ROWMMWeight(
116-
self._up_weight_name, self.data_type_, split_inter_size, bias_name=self._up_bias_name, wait_fuse=True
111+
self.gate_up_proj = ROWMMWeight(
112+
self._gate_up_weight_name, self.data_type_, split_inter_size, bias_name=self._gate_up_bias_name
117113
)
118114
self.down_proj = COLMMWeight(
119115
self._down_weight_name, self.data_type_, split_inter_size, bias_name=self._down_bias_name

lightllm/models/chatglm2/layer_weights/transformer_layer_weight.py

Lines changed: 31 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -14,30 +14,35 @@ def __init__(self, layer_num, tp_rank, world_size, data_type, network_config, mo
1414
network_config,
1515
mode,
1616
quant_cfg,
17-
layer_prefix="transformer.encoder.layers",
1817
)
1918
return
2019

2120
def _preprocess_weight(self, weights):
2221
n_kv_embed = self.head_dim * self.n_kv_head
23-
24-
qkv_weight_name = f"{self.layer_name}.self_attention.query_key_value.weight"
22+
qkv_weight_name = f"transformer.encoder.layers.{self.layer_num_}.self_attention.query_key_value.weight"
2523
if qkv_weight_name in weights:
2624
qkv_weight_ = weights[qkv_weight_name]
2725
weights[self._q_weight_name] = qkv_weight_[: self.n_embed, :]
2826
weights[self._k_weight_name] = qkv_weight_[self.n_embed : self.n_embed + n_kv_embed, :]
2927
weights[self._v_weight_name] = qkv_weight_[self.n_embed + n_kv_embed : self.n_embed + 2 * n_kv_embed, :]
3028
del weights[qkv_weight_name]
3129

32-
qkv_bias_name = f"{self.layer_name}.self_attention.query_key_value.bias"
30+
qkv_bias_name = f"transformer.encoder.layers.{self.layer_num_}.self_attention.query_key_value.bias"
3331
if qkv_bias_name in weights:
3432
qkv_bias_ = weights[qkv_bias_name]
3533
weights[self._q_bias_name] = qkv_bias_[: self.n_embed]
3634
weights[self._k_bias_name] = qkv_bias_[self.n_embed : self.n_embed + n_kv_embed]
3735
weights[self._v_bias_name] = qkv_bias_[self.n_embed + n_kv_embed : self.n_embed + 2 * n_kv_embed]
3836
del weights[qkv_bias_name]
3937

40-
def _init_config(self):
38+
gate_up_weight_name = f"transformer.encoder.layers.{self.layer_num_}.mlp.dense_h_to_4h.weight"
39+
if gate_up_weight_name in weights:
40+
gate_up_weight_ = weights[gate_up_weight_name]
41+
weights[self._gate_weight_name] = gate_up_weight_[: self.n_inter, :]
42+
weights[self._up_weight_name] = gate_up_weight_[self.n_inter : 2 * self.n_inter, :]
43+
del weights[gate_up_weight_name]
44+
45+
def _parse_config(self):
4146
self.n_embed = self.network_config_["hidden_size"]
4247
self.n_head = self.network_config_["num_attention_heads"]
4348
self.n_inter = self.network_config_["ffn_hidden_size"]
@@ -49,11 +54,24 @@ def load_hf_weights(self, weights):
4954
super().load_hf_weights(weights)
5055
return
5156

52-
def _init_ffn(self):
53-
split_inter_size = self.n_inter // self.world_size_
54-
self.up_proj = ROWMMWeight(
55-
self._up_weight_name, self.data_type_, split_inter_size, bias_name=self._up_bias_name, wait_fuse=True
56-
)
57-
self.down_proj = COLMMWeight(
58-
self._down_weight_name, self.data_type_, split_inter_size, bias_name=self._down_bias_name
59-
)
57+
def _init_weight_names(self):
58+
self._q_weight_name = f"transformer.encoder.layers.{self.layer_num_}.self_attention.q_proj.weight"
59+
self._q_bias_name = f"transformer.encoder.layers.{self.layer_num_}.self_attention.q_proj.bias"
60+
self._k_weight_name = f"transformer.encoder.layers.{self.layer_num_}.self_attention.k_proj.weight"
61+
self._k_bias_name = f"transformer.encoder.layers.{self.layer_num_}.self_attention.k_proj.bias"
62+
self._v_weight_name = f"transformer.encoder.layers.{self.layer_num_}.self_attention.v_proj.weight"
63+
self._v_bias_name = f"transformer.encoder.layers.{self.layer_num_}.self_attention.v_proj.bias"
64+
self._o_weight_name = f"transformer.encoder.layers.{self.layer_num_}.self_attention.dense.weight"
65+
self._o_bias_name = None
66+
67+
self._gate_weight_name = f"transformer.encoder.layers.{self.layer_num_}.mlp.gate_proj.weight"
68+
self._gate_bias_name = None
69+
self._up_weight_name = f"transformer.encoder.layers.{self.layer_num_}.mlp.up_proj.weight"
70+
self._up_bias_name = None
71+
self._down_weight_name = f"transformer.encoder.layers.{self.layer_num_}.mlp.dense_4h_to_h.weight"
72+
self._down_bias_name = None
73+
74+
self._att_norm_weight_name = f"transformer.encoder.layers.{self.layer_num_}.input_layernorm.weight"
75+
self._att_norm_bias_name = None
76+
self._ffn_norm_weight_name = f"transformer.encoder.layers.{self.layer_num_}.post_attention_layernorm.weight"
77+
self._ffn_norm_bias_name = None

lightllm/models/cohere/layer_weights/transformer_layer_weight.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,15 +10,18 @@
1010

1111
class CohereTransformerLayerWeight(LlamaTransformerLayerWeight):
1212
def __init__(self, layer_num, tp_rank, world_size, data_type, network_config, mode=[], quant_cfg=None):
13-
self.use_qk_norm = network_config.get("use_qk_norm", False)
1413
super().__init__(layer_num, tp_rank, world_size, data_type, network_config, mode, quant_cfg)
1514
return
1615

16+
def _parse_config(self):
17+
super()._parse_config()
18+
self.use_qk_norm = self.network_config_.get("use_qk_norm", False)
19+
1720
def _init_norm(self, weights):
1821
q_split_head = self.network_config_["num_attention_heads"] // self.world_size_
1922
k_split_head = self.network_config_["num_key_value_heads"] // self.world_size_
2023

21-
self.att_norm_weight_ = NormWeight(self.att_norm_weight_name, self.data_type_)
24+
self.att_norm_weight_ = NormWeight(self._att_norm_weight_name, self.data_type_)
2225

2326
if self.use_qk_norm:
2427
self.q_norm_weight_ = TpNormWeight(

lightllm/models/gemma_2b/layer_weights/transformer_layer_weight.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,5 +22,5 @@ def _init_qkv(self):
2222
)
2323

2424
def _init_norm(self):
25-
self.att_norm_weight_ = GEMMANormWeight(self.att_norm_weight_name, self.data_type_)
26-
self.ffn_norm_weight_ = GEMMANormWeight(self.ffn_norm_weight_name, self.data_type_)
25+
self.att_norm_weight_ = GEMMANormWeight(self._att_norm_weight_name, self.data_type_)
26+
self.ffn_norm_weight_ = GEMMANormWeight(self._ffn_norm_weight_name, self.data_type_)

lightllm/models/internlm2/layer_weights/transformer_layer_weight.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,5 +33,5 @@ def _init_weight_names(self):
3333
self._gate_weight_name = f"model.layers.{self.layer_num_}.feed_forward.w1.weight"
3434
self._up_weight_name = f"model.layers.{self.layer_num_}.feed_forward.w3.weight"
3535
self._down_weight_name = f"model.layers.{self.layer_num_}.feed_forward.w2.weight"
36-
self.att_norm_weight_name = f"model.layers.{self.layer_num_}.attention_norm.weight"
37-
self.ffn_norm_weight_name = f"model.layers.{self.layer_num_}.ffn_norm.weight"
36+
self._att_norm_weight_name = f"model.layers.{self.layer_num_}.attention_norm.weight"
37+
self._ffn_norm_weight_name = f"model.layers.{self.layer_num_}.ffn_norm.weight"

lightllm/models/llama/layer_weights/transformer_layer_weight.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -49,10 +49,10 @@ def _init_weight_names(self):
4949
self._down_weight_name = f"model.layers.{self.layer_num_}.mlp.down_proj.weight"
5050
self._down_bias_name = None
5151

52-
self.att_norm_weight_name = f"model.layers.{self.layer_num_}.input_layernorm.weight"
53-
self.att_norm_bias_name = None
54-
self.ffn_norm_weight_name = f"model.layers.{self.layer_num_}.post_attention_layernorm.weight"
55-
self.ffn_norm_bias_name = None
52+
self._att_norm_weight_name = f"model.layers.{self.layer_num_}.input_layernorm.weight"
53+
self._att_norm_bias_name = None
54+
self._ffn_norm_weight_name = f"model.layers.{self.layer_num_}.post_attention_layernorm.weight"
55+
self._ffn_norm_bias_name = None
5656

5757
def _init_qkv(self):
5858
q_split_n_embed = self.head_dim * self.n_head // self.world_size_
@@ -83,8 +83,8 @@ def _init_ffn(self):
8383

8484
def _init_norm(self):
8585
self.att_norm_weight_ = NormWeight(
86-
self.att_norm_weight_name, self.data_type_, bias_name=self.att_norm_bias_name
86+
self._att_norm_weight_name, self.data_type_, bias_name=self._att_norm_bias_name
8787
)
8888
self.ffn_norm_weight_ = NormWeight(
89-
self.ffn_norm_weight_name, self.data_type_, bias_name=self.ffn_norm_bias_name
89+
self._ffn_norm_weight_name, self.data_type_, bias_name=self._ffn_norm_bias_name
9090
)

lightllm/models/mixtral/layer_weights/transformer_layer_weight.py

Lines changed: 6 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import math
33
import numpy as np
44
from lightllm.utils.log_utils import init_logger
5-
from lightllm.models.bloom.layer_weights.transformer_layer_weight import BloomTransformerLayerWeight
5+
from lightllm.models.llama.layer_weights.transformer_layer_weight import LlamaTransformerLayerWeight
66
from lightllm.common.basemodel.layer_weights.meta_weights import (
77
ROWMMWeight,
88
COLMMWeight,
@@ -13,7 +13,7 @@
1313
logger = init_logger(__name__)
1414

1515

16-
class MixtralTransformerLayerWeight(BloomTransformerLayerWeight):
16+
class MixtralTransformerLayerWeight(LlamaTransformerLayerWeight):
1717
def __init__(self, layer_num, tp_rank, world_size, data_type, network_config, mode=[], quant_cfg=None):
1818
super().__init__(
1919
layer_num,
@@ -23,23 +23,20 @@ def __init__(self, layer_num, tp_rank, world_size, data_type, network_config, mo
2323
network_config,
2424
mode,
2525
quant_cfg=quant_cfg,
26-
layer_prefix="model.layers",
2726
)
28-
29-
self._init_moe()
3027
return
3128

32-
def _init_config(self):
29+
def _parse_config(self):
3330
super()._init_config()
3431
self.n_routed_experts = self.network_config_["num_local_experts"]
3532

3633
def _init_weight_names(self):
3734
super()._init_weight_names()
38-
self.moe_gate_weight_name = f"{self.layer_name}.mlp.gate.weight"
35+
self.moe_gate_weight_name = f"model.layers.{self.layer_num_}.block_sparse_moe.gate.weight"
3936
self.moe_gate_bias_name = None
4037

4138
def _init_ffn(self, weights):
42-
pass
39+
self._init_moe(weights)
4340

4441
def _init_moe(self, weights):
4542
inter_size = self.network_config_["intermediate_size"]
@@ -53,7 +50,7 @@ def _init_moe(self, weights):
5350
gate_proj_name="w1",
5451
down_proj_name="w2",
5552
up_proj_name="w3",
56-
weight_prefix=f"{self.layer_name}.block_sparse_moe.experts",
53+
weight_prefix=f"model.layers.{self.layer_num_}.block_sparse_moe.experts",
5754
n_routed_experts=self.n_routed_experts,
5855
split_inter_size=split_inter_size,
5956
data_type=self.data_type_,

lightllm/models/qwen/layer_weights/transformer_layer_weight.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -45,5 +45,5 @@ def _init_weight_names(self):
4545
self._gate_weight_name = f"transformer.h.{self.layer_num_}.mlp.w2.weight"
4646
self._up_weight_name = f"transformer.h.{self.layer_num_}.mlp.w1.weight"
4747
self._down_weight_name = f"transformer.h.{self.layer_num_}.mlp.c_proj.weight"
48-
self.att_norm_weight_name = f"transformer.h.{self.layer_num_}.ln_1.weight"
49-
self.ffn_norm_weight_name = f"transformer.h.{self.layer_num_}.ln_2.weight"
48+
self._att_norm_weight_name = f"transformer.h.{self.layer_num_}.ln_1.weight"
49+
self._ffn_norm_weight_name = f"transformer.h.{self.layer_num_}.ln_2.weight"

lightllm/models/stablelm/layer_weights/transformer_layer_weight.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,5 +9,5 @@ def __init__(self, layer_num, tp_rank, world_size, data_type, network_config, mo
99

1010
def _init_weight_names(self):
1111
super()._init_weight_names()
12-
self.att_norm_bias_name = f"model.layers.{self.layer_num_}.input_layernorm.bias"
13-
self.ffn_norm_bias_name = f"model.layers.{self.layer_num_}.post_attention_layernorm.bias"
12+
self._att_norm_bias_name = f"model.layers.{self.layer_num_}.input_layernorm.bias"
13+
self._ffn_norm_bias_name = f"model.layers.{self.layer_num_}.post_attention_layernorm.bias"

0 commit comments

Comments
 (0)