@@ -66,6 +66,7 @@ def fuse_vb_o(self, layer_weight):
6666class Deepseek2TransformerLayerWeight (TransformerLayerWeight ):
6767 def __init__ (self , layer_num , tp_rank , world_size , data_type , network_config , mode = [], quant_cfg = None ):
6868 self .enable_dp = os .getenv ("ENABLE_DP" , "0" ).upper () in ["ON" , "TRUE" , "1" ]
69+ self .enable_cc_method = os .getenv ("ENABLE_CC_METHOD" , "False" ).upper () in ["ON" , "TRUE" , "1" ]
6970 super ().__init__ (layer_num , tp_rank , world_size , data_type , network_config , mode , quant_cfg )
7071 return
7172
@@ -159,13 +160,17 @@ def load_hf_weights(self, weights):
159160 kv_b_proj_ = weights [f"model.layers.{ self .layer_num_ } .self_attn.kv_b_proj.weight" ]
160161 weights [f"model.layers.{ self .layer_num_ } .self_attn.k_b_proj.weight" ] = self ._load_kb (kv_b_proj_ )
161162 weights [f"model.layers.{ self .layer_num_ } .self_attn.v_b_proj.weight" ] = self ._load_vb (kv_b_proj_ )
162-
163- weights [f"model.layers.{ self .layer_num_ } .self_attn.cc_k_b_proj.weight" ] = self ._load_kb (kv_b_proj_ ).reshape (
164- - 1 , self .kv_lora_rank
165- )
166- weights [f"model.layers.{ self .layer_num_ } .self_attn.cc_v_b_proj.weight" ] = (
167- self ._load_vb (kv_b_proj_ ).transpose (0 , 1 ).reshape (self .kv_lora_rank , - 1 ).transpose (0 , 1 ).contiguous ()
168- )
163+ if self .enable_cc_method :
164+ weights [f"model.layers.{ self .layer_num_ } .self_attn.cc_k_b_proj.weight" ] = self ._load_kb (
165+ kv_b_proj_
166+ ).reshape (- 1 , self .kv_lora_rank )
167+ weights [f"model.layers.{ self .layer_num_ } .self_attn.cc_v_b_proj.weight" ] = (
168+ self ._load_vb (kv_b_proj_ )
169+ .transpose (0 , 1 )
170+ .reshape (self .kv_lora_rank , - 1 )
171+ .transpose (0 , 1 )
172+ .contiguous ()
173+ )
169174 if (
170175 self .quant_cfg .quantized_weight
171176 and f"model.layers.{ self .layer_num_ } .self_attn.kv_b_proj." + self .weight_scale_suffix in weights
@@ -183,17 +188,17 @@ def load_hf_weights(self, weights):
183188 weights [
184189 f"model.layers.{ self .layer_num_ } .self_attn.v_b_proj." + self .weight_scale_suffix
185190 ] = self ._load_vb_scale (kv_b_proj_scale_ , block_size )
186-
187- weights [
188- f"model.layers.{ self .layer_num_ } .self_attn.cc_k_b_proj." + self .weight_scale_suffix
189- ] = self ._load_kb_scale (kv_b_proj_scale_ , block_size ).reshape (- 1 , self .kv_lora_rank // block_size )
190- weights [f"model.layers.{ self .layer_num_ } .self_attn.cc_v_b_proj." + self .weight_scale_suffix ] = (
191- self ._load_vb_scale (kv_b_proj_scale_ , block_size )
192- .transpose (0 , 1 )
193- .reshape (self .kv_lora_rank // block_size , - 1 )
194- .transpose (0 , 1 )
195- .contiguous ()
196- )
191+ if self . enable_cc_method :
192+ weights [
193+ f"model.layers.{ self .layer_num_ } .self_attn.cc_k_b_proj." + self .weight_scale_suffix
194+ ] = self ._load_kb_scale (kv_b_proj_scale_ , block_size ).reshape (- 1 , self .kv_lora_rank // block_size )
195+ weights [f"model.layers.{ self .layer_num_ } .self_attn.cc_v_b_proj." + self .weight_scale_suffix ] = (
196+ self ._load_vb_scale (kv_b_proj_scale_ , block_size )
197+ .transpose (0 , 1 )
198+ .reshape (self .kv_lora_rank // block_size , - 1 )
199+ .transpose (0 , 1 )
200+ .contiguous ()
201+ )
197202
198203 return super ().load_hf_weights (weights )
199204
@@ -253,21 +258,21 @@ def _init_qkvo(self):
253258 weight_scale_suffix = self .weight_scale_suffix ,
254259 act_scale_suffix = self .act_scale_suffix ,
255260 )
256-
257- self .cc_k_b_proj_ = ROWMMWeight (
258- f"model.layers.{ self .layer_num_ } .self_attn.cc_k_b_proj.weight" ,
259- self .data_type_ ,
260- split_n_embed = self .tp_q_head_num_ * self .qk_nope_head_dim ,
261- weight_scale_suffix = self .weight_scale_suffix ,
262- act_scale_suffix = self .act_scale_suffix ,
263- )
264- self .cc_v_b_proj_ = ROWMMWeight (
265- f"model.layers.{ self .layer_num_ } .self_attn.cc_v_b_proj.weight" ,
266- self .data_type_ ,
267- split_n_embed = self .tp_q_head_num_ * self .qk_nope_head_dim ,
268- weight_scale_suffix = self .weight_scale_suffix ,
269- act_scale_suffix = self .act_scale_suffix ,
270- )
261+ if self . enable_cc_method :
262+ self .cc_k_b_proj_ = ROWMMWeight (
263+ f"model.layers.{ self .layer_num_ } .self_attn.cc_k_b_proj.weight" ,
264+ self .data_type_ ,
265+ split_n_embed = self .tp_q_head_num_ * self .qk_nope_head_dim ,
266+ weight_scale_suffix = self .weight_scale_suffix ,
267+ act_scale_suffix = self .act_scale_suffix ,
268+ )
269+ self .cc_v_b_proj_ = ROWMMWeight (
270+ f"model.layers.{ self .layer_num_ } .self_attn.cc_v_b_proj.weight" ,
271+ self .data_type_ ,
272+ split_n_embed = self .tp_q_head_num_ * self .qk_nope_head_dim ,
273+ weight_scale_suffix = self .weight_scale_suffix ,
274+ act_scale_suffix = self .act_scale_suffix ,
275+ )
271276
272277 self .o_weight_ = COLMMWeight (
273278 f"model.layers.{ self .layer_num_ } .self_attn.o_proj.weight" ,
0 commit comments