@@ -132,27 +132,27 @@ def _load_kb(self, kv_b_proj_):
132132 k_b_proj_ = kv_b_proj_ .view (self .num_attention_heads , self .qk_nope_head_dim * 2 , self .kv_lora_rank )[
133133 :, : self .qk_nope_head_dim , :
134134 ]
135- return k_b_proj_ .contiguous ().to (self . data_type_ )
135+ return k_b_proj_ .contiguous ().to (kv_b_proj_ . dtype )
136136
137137 def _load_kb_scale (self , kv_b_proj_ , block_size ):
138138 k_b_proj_scale_ = kv_b_proj_ .view (
139139 self .num_attention_heads , self .qk_nope_head_dim * 2 // block_size , self .kv_lora_rank // block_size
140140 )[:, : self .qk_nope_head_dim // block_size , :]
141- return k_b_proj_scale_ .contiguous ().to (self . data_type_ )
141+ return k_b_proj_scale_ .contiguous ().to (kv_b_proj_ . dtype )
142142
143143 def _load_vb (self , kv_b_proj_ ):
144144 v_b_proj_ = kv_b_proj_ .T .view (self .kv_lora_rank , self .num_attention_heads , self .qk_nope_head_dim * 2 ,)[
145145 :, :, self .qk_nope_head_dim :
146146 ].transpose (0 , 1 )
147- return v_b_proj_ .contiguous ().to (self . data_type_ )
147+ return v_b_proj_ .contiguous ().to (kv_b_proj_ . dtype )
148148
149149 def _load_vb_scale (self , kv_b_proj_scale_ , block_size ):
150150 v_b_proj_scale_ = kv_b_proj_scale_ .T .view (
151151 self .kv_lora_rank // block_size ,
152152 self .num_attention_heads ,
153153 self .qk_nope_head_dim * 2 // block_size ,
154154 )[:, :, self .qk_nope_head_dim // block_size :].transpose (0 , 1 )
155- return v_b_proj_scale_ .contiguous ().to (self . data_type_ )
155+ return v_b_proj_scale_ .contiguous ().to (kv_b_proj_scale_ . dtype )
156156
157157 def load_hf_weights (self , weights ):
158158 if f"model.layers.{ self .layer_num_ } .self_attn.kv_b_proj.weight" in weights :
@@ -166,6 +166,8 @@ def load_hf_weights(self, weights):
166166 weights [f"model.layers.{ self .layer_num_ } .self_attn.cc_v_b_proj.weight" ] = (
167167 self ._load_vb (kv_b_proj_ ).transpose (0 , 1 ).reshape (self .kv_lora_rank , - 1 ).transpose (0 , 1 ).contiguous ()
168168 )
169+ # print( weights[f"model.layers.{self.layer_num_}.self_attn.cc_v_b_proj.weight"].dtype)
170+ # print( weights[f"model.layers.{self.layer_num_}.self_attn.v_b_proj.weight"].dtype)
169171
170172 if (
171173 self .quant_cfg .quantized_weight
0 commit comments