Skip to content

Commit 513cc72

Browse files
committed
v3-fix
1 parent b8cfd70 commit 513cc72

File tree

3 files changed

+38
-6
lines changed

3 files changed

+38
-6
lines changed

lightllm/common/basemodel/basemodel.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,8 @@ def _init_config(self):
9999
repair_config(self.config, same_names=["num_hidden_layers", "n_layer"])
100100
if self.finetune_config:
101101
self.config["vocab_size"] = self.finetune_config.vocab_size
102-
102+
# self.config["num_hidden_layers"] = 4
103+
# self.config["n_layer"] = 4
103104
return
104105

105106
@final

lightllm/models/deepseek2/layer_infer/transformer_layer_infer.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -154,7 +154,7 @@ def _decompress_kv(self, kv, infer_state: Deepseek2InferStateInfo, layer_weight:
154154
)
155155

156156
# CC
157-
compressed_kv = compressed_kv.view(-1, layer_weight.kv_lora_rank)
157+
compressed_kv = compressed_kv.view(-1, layer_weight.kv_lora_rank).contiguous()
158158
k_nope = self.alloc_tensor(
159159
[compressed_kv.shape[0], self.tp_q_head_num_, self.qk_nope_head_dim],
160160
dtype=compressed_kv.dtype,
@@ -163,10 +163,8 @@ def _decompress_kv(self, kv, infer_state: Deepseek2InferStateInfo, layer_weight:
163163
k_nope.shape,
164164
dtype=compressed_kv.dtype,
165165
)
166-
wk = layer_weight.k_b_proj_.weight.view(-1, layer_weight.kv_lora_rank).T
167-
wv = layer_weight.v_b_proj_.weight.transpose(0, 1).reshape(layer_weight.kv_lora_rank, -1)
168-
torch.mm(compressed_kv, wk, out=k_nope.reshape(compressed_kv.shape[0], -1))
169-
torch.mm(compressed_kv, wv, out=v.reshape(compressed_kv.shape[0], -1))
166+
layer_weight.cc_k_b_proj_.mm(compressed_kv, out=k_nope.reshape(compressed_kv.shape[0], -1))
167+
layer_weight.cc_v_b_proj_.mm(compressed_kv, out=v.reshape(compressed_kv.shape[0], -1))
170168
return k_nope, k_rope, v
171169

172170
def _context_attention_kernel_with_CC(

lightllm/models/deepseek2/layer_weights/transformer_layer_weight.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -160,6 +160,13 @@ def load_hf_weights(self, weights):
160160
weights[f"model.layers.{self.layer_num_}.self_attn.k_b_proj.weight"] = self._load_kb(kv_b_proj_)
161161
weights[f"model.layers.{self.layer_num_}.self_attn.v_b_proj.weight"] = self._load_vb(kv_b_proj_)
162162

163+
weights[f"model.layers.{self.layer_num_}.self_attn.cc_k_b_proj.weight"] = self._load_kb(kv_b_proj_).reshape(
164+
-1, self.kv_lora_rank
165+
)
166+
weights[f"model.layers.{self.layer_num_}.self_attn.cc_v_b_proj.weight"] = (
167+
self._load_vb(kv_b_proj_).transpose(0, 1).reshape(self.kv_lora_rank, -1).transpose(0, 1).contiguous()
168+
)
169+
163170
if (
164171
self.quant_cfg.quantized_weight
165172
and f"model.layers.{self.layer_num_}.self_attn.kv_b_proj." + self.weight_scale_suffix in weights
@@ -178,6 +185,17 @@ def load_hf_weights(self, weights):
178185
f"model.layers.{self.layer_num_}.self_attn.v_b_proj." + self.weight_scale_suffix
179186
] = self._load_vb_scale(kv_b_proj_scale_, block_size)
180187

188+
weights[
189+
f"model.layers.{self.layer_num_}.self_attn.cc_k_b_proj." + self.weight_scale_suffix
190+
] = self._load_kb_scale(kv_b_proj_scale_, block_size).reshape(-1, self.kv_lora_rank // block_size)
191+
weights[f"model.layers.{self.layer_num_}.self_attn.cc_v_b_proj." + self.weight_scale_suffix] = (
192+
self._load_vb_scale(kv_b_proj_scale_, block_size)
193+
.transpose(0, 1)
194+
.reshape(self.kv_lora_rank // block_size, -1)
195+
.transpose(0, 1)
196+
.contiguous()
197+
)
198+
181199
return super().load_hf_weights(weights)
182200

183201
def _set_quantization(self):
@@ -237,6 +255,21 @@ def _init_qkvo(self):
237255
act_scale_suffix=self.act_scale_suffix,
238256
)
239257

258+
self.cc_k_b_proj_ = ROWMMWeight(
259+
f"model.layers.{self.layer_num_}.self_attn.cc_k_b_proj.weight",
260+
self.data_type_,
261+
split_n_embed=self.tp_q_head_num_ * self.qk_nope_head_dim,
262+
weight_scale_suffix=self.weight_scale_suffix,
263+
act_scale_suffix=self.act_scale_suffix,
264+
)
265+
self.cc_v_b_proj_ = ROWMMWeight(
266+
f"model.layers.{self.layer_num_}.self_attn.cc_v_b_proj.weight",
267+
self.data_type_,
268+
split_n_embed=self.tp_q_head_num_ * self.qk_nope_head_dim,
269+
weight_scale_suffix=self.weight_scale_suffix,
270+
act_scale_suffix=self.act_scale_suffix,
271+
)
272+
240273
self.o_weight_ = COLMMWeight(
241274
f"model.layers.{self.layer_num_}.self_attn.o_proj.weight",
242275
self.data_type_,

0 commit comments

Comments
 (0)