Skip to content

Commit f69e2ec

Browse files
authored
Fix a lot (#610)
1 parent 18a0f08 commit f69e2ec

File tree

12 files changed

+18
-27
lines changed

12 files changed

+18
-27
lines changed

lightllm/common/basemodel/layer_weights/transformer_layer_weight.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
# from lightllm.common.layers.mm import MM
44
from .base_layer_weight import BaseLayerWeight
5-
from .meta_weights import MMWeight, FusedMoeWeight
5+
from .meta_weights import MMWeight, ROWMMWeight, FusedMoeWeight
66
from lightllm.utils.log_utils import init_logger
77

88
logger = init_logger(__name__)
@@ -20,6 +20,7 @@ def __init__(self, layer_num, tp_rank, world_size, data_type, network_config, mo
2020
self.quant_cfg = quant_cfg
2121
self.init_static_params()
2222
self.fuse_pairs = {"k_proj&v_proj": "kv_proj"}
23+
self.kv_proj: ROWMMWeight = None
2324
return
2425

2526
def load_hf_weights(self, weights):
@@ -30,7 +31,7 @@ def fuse_weights(self):
3031
for pair_name, fuse_name in self.fuse_pairs.items():
3132
attr1_name, attr2_name = pair_name.split("&")
3233
with self.lock:
33-
if hasattr(self, fuse_name):
34+
if getattr(self, fuse_name, None) is not None:
3435
continue
3536
attr1 = getattr(self, attr1_name)
3637
attr2 = getattr(self, attr2_name)

lightllm/models/baichuan13b/layer_infer/transformer_layer_infer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ def _bind_func(self):
2626
return
2727

2828
def _get_qkv(self, input, cache_kv, infer_state, layer_weight: BaiChuan13bTransformerLayerWeight) -> torch.Tensor:
29-
q = layer_weight.q_proj.mm(input)
29+
q = layer_weight.q_proj.mm(input.view(-1, self.embed_dim_))
3030
cache_kv = layer_weight.kv_proj.mm(
3131
input, out=cache_kv.view(-1, (self.tp_k_head_num_ + self.tp_v_head_num_) * self.head_dim_)
3232
).view(-1, (self.tp_k_head_num_ + self.tp_v_head_num_), self.head_dim_)

lightllm/models/baichuan13b/layer_weights/transformer_layer_weight.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,3 @@ def __init__(self, layer_num, tp_rank, world_size, data_type, network_config, mo
1313

1414
def init_static_params(self):
1515
return BloomTransformerLayerWeight.init_static_params(self)
16-
17-
def verify_load(self):
18-
super().verify_load()
19-
assert self.tp_alibi is not None, "load error"
20-
return

lightllm/models/baichuan2_7b/layer_infer/transformer_layer_infer.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@ def __init__(self, layer_num, tp_rank, world_size, network_config, mode=[]):
1515
def _get_qkv(
1616
self, input, cache_kv: torch.Tensor, infer_state: LlamaInferStateInfo, layer_weight: LlamaTransformerLayerWeight
1717
) -> torch.Tensor:
18-
1918
q = layer_weight.q_proj.mm(input.view(-1, self.embed_dim_)).view(-1, self.tp_q_head_num_, self.head_dim_)
2019
cache_kv = layer_weight.kv_proj.mm(
2120
input, out=cache_kv.view(-1, (self.tp_k_head_num_ + self.tp_v_head_num_) * self.head_dim_)

lightllm/models/baichuan7b/layer_weights/transformer_layer_weight.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -12,11 +12,7 @@ def __init__(self, layer_num, tp_rank, world_size, data_type, network_config, mo
1212

1313
def _init_config(self):
1414
self.network_config_["num_key_value_heads"] = self.network_config_["num_attention_heads"]
15-
self.n_embed = self.network_config_["hidden_size"]
16-
self.n_head = self.network_config_["num_attention_heads"]
17-
self.n_inter = self.network_config_["intermediate_size"]
18-
self.n_kv_head = self.network_config_["num_key_value_heads"]
19-
self.head_dim = self.network_config_.get("head_dim", self.n_embed // self.n_head)
15+
super()._init_config()
2016

2117
def load_hf_weights(self, weights):
2218
qkv_weight_name = f"{self.layer_name}.self_attn.W_pack.weight"

lightllm/models/bloom/layer_infer/transformer_layer_infer.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ def _ffn_norm(self, input, infer_state: InferStateInfo, layer_weight: BloomTrans
4646
def _get_qkv(
4747
self, input, cache_kv, infer_state: InferStateInfo, layer_weight: BloomTransformerLayerWeight
4848
) -> torch.Tensor:
49-
q = layer_weight.q_proj.mm(input)
49+
q = layer_weight.q_proj.mm(input.view(-1, self.embed_dim_))
5050
cache_kv = layer_weight.kv_proj.mm(
5151
input, out=cache_kv.view(-1, (self.tp_k_head_num_ + self.tp_v_head_num_) * self.head_dim_)
5252
).view(-1, (self.tp_k_head_num_ + self.tp_v_head_num_), self.head_dim_)
@@ -94,13 +94,11 @@ def _token_attention_kernel(
9494
return o_tensor
9595

9696
def _get_o(self, input, infer_state: InferStateInfo, layer_weight: BloomTransformerLayerWeight) -> torch.Tensor:
97-
input = input.view(-1, self.tp_o_head_num_ * self.head_dim_)
98-
o_tensor = layer_weight.o_proj.mm(input)
97+
o_tensor = layer_weight.o_proj.mm(input.view(-1, self.tp_o_head_num_ * self.head_dim_))
9998
return o_tensor
10099

101100
def _ffn(self, input, infer_state: InferStateInfo, layer_weight: BloomTransformerLayerWeight) -> torch.Tensor:
102-
input = input.view(-1, self.embed_dim_)
103-
ffn1_out = layer_weight.up_proj.mm(input)
101+
ffn1_out = layer_weight.up_proj.mm(input.view(-1, self.embed_dim_))
104102
input = None
105103
gelu_out = torch.nn.functional.gelu(ffn1_out, approximate="tanh")
106104
ffn1_out = None

lightllm/models/bloom/layer_weights/transformer_layer_weight.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,6 @@ def get_slopes_power_of_2(n):
5050
class BloomTransformerLayerWeight(LlamaTransformerLayerWeight):
5151
def __init__(self, layer_num, tp_rank, world_size, data_type, network_config, mode, quant_cfg=None):
5252
super().__init__(layer_num, tp_rank, world_size, data_type, network_config, mode, quant_cfg, layer_prefix="h")
53-
self.init_static_params()
5453
return
5554

5655
def _init_config(self):

lightllm/models/chatglm2/layer_weights/transformer_layer_weight.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -24,17 +24,17 @@ def _preprocess_weight(self, weights):
2424
qkv_weight_name = f"{self.layer_name}.self_attention.query_key_value.weight"
2525
if qkv_weight_name in weights:
2626
qkv_weight_ = weights[qkv_weight_name]
27-
weights[self._q_weight_name] = qkv_weight_[:, : self.n_embed]
28-
weights[self._k_weight_name] = qkv_weight_[:, self.n_embed : self.n_embed + n_kv_embed]
29-
weights[self._v_weight_name] = qkv_weight_[:, self.n_embed + n_kv_embed : self.n_embed + 2 * n_kv_embed]
27+
weights[self._q_weight_name] = qkv_weight_[: self.n_embed, :]
28+
weights[self._k_weight_name] = qkv_weight_[self.n_embed : self.n_embed + n_kv_embed, :]
29+
weights[self._v_weight_name] = qkv_weight_[self.n_embed + n_kv_embed : self.n_embed + 2 * n_kv_embed, :]
3030
del weights[qkv_weight_name]
3131

3232
qkv_bias_name = f"{self.layer_name}.self_attention.query_key_value.bias"
3333
if qkv_bias_name in weights:
3434
qkv_bias_ = weights[qkv_bias_name]
3535
weights[self._q_bias_name] = qkv_bias_[: self.n_embed]
36-
weights[self._k_bias_name] = qkv_bias_[:, self.n_embed : self.n_embed + n_kv_embed]
37-
weights[self._v_bias_name] = qkv_bias_[:, self.n_embed + n_kv_embed : self.n_embed + 2 * n_kv_embed]
36+
weights[self._k_bias_name] = qkv_bias_[self.n_embed : self.n_embed + n_kv_embed]
37+
weights[self._v_bias_name] = qkv_bias_[self.n_embed + n_kv_embed : self.n_embed + 2 * n_kv_embed]
3838
del weights[qkv_bias_name]
3939

4040
def _init_config(self):

lightllm/models/gemma_2b/layer_infer/transformer_layer_infer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ def __init__(self, layer_num, tp_rank, world_size, network_config, mode=[]):
2525
def _ffn(
2626
self, input, infer_state: LlamaInferStateInfo, layer_weight: Gemma_2bTransformerLayerWeight
2727
) -> torch.Tensor:
28-
up_gate_out = layer_weight.gate_up_proj.mm(input)
28+
up_gate_out = layer_weight.gate_up_proj.mm(input.view(-1, self.embed_dim_))
2929
ffn1_out = self.alloc_tensor((input.size(0), up_gate_out.size(1) // 2), input.dtype)
3030
gelu_and_mul_fwd(up_gate_out, ffn1_out)
3131
input = None

lightllm/models/gemma_2b/layer_weights/transformer_layer_weight.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ def __init__(self, layer_num, tp_rank, world_size, data_type, network_config, mo
1212

1313
def _init_qkv(self):
1414
q_split_n_embed = self.head_dim * self.n_head // self.world_size_
15-
kv_split_n_embed = self.head_dim * self.n_kv_head // self.world_size_
15+
kv_split_n_embed = self.head_dim * self.n_kv_head
1616
self.q_proj = ROWMMWeight(self._q_weight_name, self.data_type_, q_split_n_embed, bias_name=self._q_bias_name)
1717
self.k_proj = ROWMMWeight(
1818
self._k_weight_name,

0 commit comments

Comments
 (0)