@@ -160,6 +160,13 @@ def load_hf_weights(self, weights):
160160 weights [f"model.layers.{ self .layer_num_ } .self_attn.k_b_proj.weight" ] = self ._load_kb (kv_b_proj_ )
161161 weights [f"model.layers.{ self .layer_num_ } .self_attn.v_b_proj.weight" ] = self ._load_vb (kv_b_proj_ )
162162
163+ weights [f"model.layers.{ self .layer_num_ } .self_attn.cc_k_b_proj.weight" ] = self ._load_kb (kv_b_proj_ ).reshape (
164+ - 1 , self .kv_lora_rank
165+ )
166+ weights [f"model.layers.{ self .layer_num_ } .self_attn.cc_v_b_proj.weight" ] = (
167+ self ._load_vb (kv_b_proj_ ).transpose (0 , 1 ).reshape (self .kv_lora_rank , - 1 ).transpose (0 , 1 ).contiguous ()
168+ )
169+
163170 if (
164171 self .quant_cfg .quantized_weight
165172 and f"model.layers.{ self .layer_num_ } .self_attn.kv_b_proj." + self .weight_scale_suffix in weights
@@ -178,6 +185,17 @@ def load_hf_weights(self, weights):
178185 f"model.layers.{ self .layer_num_ } .self_attn.v_b_proj." + self .weight_scale_suffix
179186 ] = self ._load_vb_scale (kv_b_proj_scale_ , block_size )
180187
188+ weights [
189+ f"model.layers.{ self .layer_num_ } .self_attn.cc_k_b_proj." + self .weight_scale_suffix
190+ ] = self ._load_kb_scale (kv_b_proj_scale_ , block_size ).reshape (- 1 , self .kv_lora_rank // block_size )
191+ weights [f"model.layers.{ self .layer_num_ } .self_attn.cc_v_b_proj." + self .weight_scale_suffix ] = (
192+ self ._load_vb_scale (kv_b_proj_scale_ , block_size )
193+ .transpose (0 , 1 )
194+ .reshape (self .kv_lora_rank // block_size , - 1 )
195+ .transpose (0 , 1 )
196+ .contiguous ()
197+ )
198+
181199 return super ().load_hf_weights (weights )
182200
183201 def _set_quantization (self ):
@@ -237,6 +255,21 @@ def _init_qkvo(self):
237255 act_scale_suffix = self .act_scale_suffix ,
238256 )
239257
258+ self .cc_k_b_proj_ = ROWMMWeight (
259+ f"model.layers.{ self .layer_num_ } .self_attn.cc_k_b_proj.weight" ,
260+ self .data_type_ ,
261+ split_n_embed = self .tp_q_head_num_ * self .qk_nope_head_dim ,
262+ weight_scale_suffix = self .weight_scale_suffix ,
263+ act_scale_suffix = self .act_scale_suffix ,
264+ )
265+ self .cc_v_b_proj_ = ROWMMWeight (
266+ f"model.layers.{ self .layer_num_ } .self_attn.cc_v_b_proj.weight" ,
267+ self .data_type_ ,
268+ split_n_embed = self .tp_q_head_num_ * self .qk_nope_head_dim ,
269+ weight_scale_suffix = self .weight_scale_suffix ,
270+ act_scale_suffix = self .act_scale_suffix ,
271+ )
272+
240273 self .o_weight_ = COLMMWeight (
241274 f"model.layers.{ self .layer_num_ } .self_attn.o_proj.weight" ,
242275 self .data_type_ ,
0 commit comments