@@ -66,6 +66,7 @@ def fuse_vb_o(self, layer_weight):
6666class Deepseek2TransformerLayerWeight (TransformerLayerWeight ):
6767 def __init__ (self , layer_num , tp_rank , world_size , data_type , network_config , mode = [], quant_cfg = None ):
6868 self .enable_dp = os .getenv ("ENABLE_DP" , "0" ).upper () in ["ON" , "TRUE" , "1" ]
69+ self .enable_cc_method = os .getenv ("ENABLE_CC_METHOD" , "False" ).upper () in ["ON" , "TRUE" , "1" ]
6970 super ().__init__ (layer_num , tp_rank , world_size , data_type , network_config , mode , quant_cfg )
7071 return
7172
@@ -132,34 +133,44 @@ def _load_kb(self, kv_b_proj_):
132133 k_b_proj_ = kv_b_proj_ .view (self .num_attention_heads , self .qk_nope_head_dim * 2 , self .kv_lora_rank )[
133134 :, : self .qk_nope_head_dim , :
134135 ]
135- return k_b_proj_ .contiguous ().to (self . data_type_ )
136+ return k_b_proj_ .contiguous ().to (kv_b_proj_ . dtype )
136137
137138 def _load_kb_scale (self , kv_b_proj_ , block_size ):
138139 k_b_proj_scale_ = kv_b_proj_ .view (
139140 self .num_attention_heads , self .qk_nope_head_dim * 2 // block_size , self .kv_lora_rank // block_size
140141 )[:, : self .qk_nope_head_dim // block_size , :]
141- return k_b_proj_scale_ .contiguous ().to (self . data_type_ )
142+ return k_b_proj_scale_ .contiguous ().to (kv_b_proj_ . dtype )
142143
143144 def _load_vb (self , kv_b_proj_ ):
144145 v_b_proj_ = kv_b_proj_ .T .view (self .kv_lora_rank , self .num_attention_heads , self .qk_nope_head_dim * 2 ,)[
145146 :, :, self .qk_nope_head_dim :
146147 ].transpose (0 , 1 )
147- return v_b_proj_ .contiguous ().to (self . data_type_ )
148+ return v_b_proj_ .contiguous ().to (kv_b_proj_ . dtype )
148149
149150 def _load_vb_scale (self , kv_b_proj_scale_ , block_size ):
150151 v_b_proj_scale_ = kv_b_proj_scale_ .T .view (
151152 self .kv_lora_rank // block_size ,
152153 self .num_attention_heads ,
153154 self .qk_nope_head_dim * 2 // block_size ,
154155 )[:, :, self .qk_nope_head_dim // block_size :].transpose (0 , 1 )
155- return v_b_proj_scale_ .contiguous ().to (self . data_type_ )
156+ return v_b_proj_scale_ .contiguous ().to (kv_b_proj_scale_ . dtype )
156157
157158 def load_hf_weights (self , weights ):
158159 if f"model.layers.{ self .layer_num_ } .self_attn.kv_b_proj.weight" in weights :
159160 kv_b_proj_ = weights [f"model.layers.{ self .layer_num_ } .self_attn.kv_b_proj.weight" ]
160161 weights [f"model.layers.{ self .layer_num_ } .self_attn.k_b_proj.weight" ] = self ._load_kb (kv_b_proj_ )
161162 weights [f"model.layers.{ self .layer_num_ } .self_attn.v_b_proj.weight" ] = self ._load_vb (kv_b_proj_ )
162-
163+ if self .enable_cc_method :
164+ weights [f"model.layers.{ self .layer_num_ } .self_attn.cc_k_b_proj.weight" ] = self ._load_kb (
165+ kv_b_proj_
166+ ).reshape (- 1 , self .kv_lora_rank )
167+ weights [f"model.layers.{ self .layer_num_ } .self_attn.cc_v_b_proj.weight" ] = (
168+ self ._load_vb (kv_b_proj_ )
169+ .transpose (0 , 1 )
170+ .reshape (self .kv_lora_rank , - 1 )
171+ .transpose (0 , 1 )
172+ .contiguous ()
173+ )
163174 if (
164175 self .quant_cfg .quantized_weight
165176 and f"model.layers.{ self .layer_num_ } .self_attn.kv_b_proj." + self .weight_scale_suffix in weights
@@ -177,6 +188,17 @@ def load_hf_weights(self, weights):
177188 weights [
178189 f"model.layers.{ self .layer_num_ } .self_attn.v_b_proj." + self .weight_scale_suffix
179190 ] = self ._load_vb_scale (kv_b_proj_scale_ , block_size )
191+ if self .enable_cc_method :
192+ weights [
193+ f"model.layers.{ self .layer_num_ } .self_attn.cc_k_b_proj." + self .weight_scale_suffix
194+ ] = self ._load_kb_scale (kv_b_proj_scale_ , block_size ).reshape (- 1 , self .kv_lora_rank // block_size )
195+ weights [f"model.layers.{ self .layer_num_ } .self_attn.cc_v_b_proj." + self .weight_scale_suffix ] = (
196+ self ._load_vb_scale (kv_b_proj_scale_ , block_size )
197+ .transpose (0 , 1 )
198+ .reshape (self .kv_lora_rank // block_size , - 1 )
199+ .transpose (0 , 1 )
200+ .contiguous ()
201+ )
180202
181203 return super ().load_hf_weights (weights )
182204
@@ -236,6 +258,21 @@ def _init_qkvo(self):
236258 weight_scale_suffix = self .weight_scale_suffix ,
237259 act_scale_suffix = self .act_scale_suffix ,
238260 )
261+ if self .enable_cc_method :
262+ self .cc_k_b_proj_ = ROWMMWeight (
263+ f"model.layers.{ self .layer_num_ } .self_attn.cc_k_b_proj.weight" ,
264+ self .data_type_ ,
265+ split_n_embed = self .tp_q_head_num_ * self .qk_nope_head_dim ,
266+ weight_scale_suffix = self .weight_scale_suffix ,
267+ act_scale_suffix = self .act_scale_suffix ,
268+ )
269+ self .cc_v_b_proj_ = ROWMMWeight (
270+ f"model.layers.{ self .layer_num_ } .self_attn.cc_v_b_proj.weight" ,
271+ self .data_type_ ,
272+ split_n_embed = self .tp_q_head_num_ * self .qk_nope_head_dim ,
273+ weight_scale_suffix = self .weight_scale_suffix ,
274+ act_scale_suffix = self .act_scale_suffix ,
275+ )
239276
240277 self .o_weight_ = COLMMWeight (
241278 f"model.layers.{ self .layer_num_ } .self_attn.o_proj.weight" ,
0 commit comments