@@ -2264,13 +2264,6 @@ def set_vocab(self):
22642264
22652265        special_vocab .add_to_gguf (self .gguf_writer )
22662266
2267-     def  _hf_permute_qk (self , weights , n_head : int , n_head_kv : int ):
2268-         if  n_head_kv  is  not None  and  n_head  !=  n_head_kv :
2269-             n_head  =  n_head_kv 
2270-         return  (weights .reshape (n_head , 2 , weights .shape [0 ] //  n_head  //  2 , * weights .shape [1 :])
2271-                 .swapaxes (1 , 2 )
2272-                 .reshape (weights .shape ))
2273- 
22742267    def  set_gguf_parameters (self ):
22752268        self .gguf_writer .add_name ("InternLM2" )
22762269        self .gguf_writer .add_context_length (self .hparams ["max_position_embeddings" ])
@@ -2290,26 +2283,22 @@ def set_gguf_parameters(self):
22902283    def  modify_tensors (self , data_torch : Tensor , name : str , bid : int  |  None ) ->  Iterable [tuple [str , Tensor ]]:
22912284        num_heads  =  self .hparams ["num_attention_heads" ]
22922285        num_kv_heads  =  self .hparams ["num_key_value_heads" ]
2293-         hidden_size  =  self .hparams ["hidden_size" ]
2286+         n_embd  =  self .hparams ["hidden_size" ]
22942287        q_per_kv  =  num_heads  //  num_kv_heads 
2295-         head_dim  =  hidden_size  //  num_heads 
2288+         head_dim  =  n_embd  //  num_heads 
22962289        num_groups  =  num_heads  //  q_per_kv 
22972290
2298-         qkv_pattern  =  r"model\.layers\.(\d+)\.attention\.wqkv" 
2299- 
2300-         if  re .match (qkv_pattern , name ):
2301-             bid  =  re .findall (qkv_pattern , name )[0 ]
2291+         if  bid  is  not None  and  f"model.layers.{ bid }   in  name :
23022292            qkv  =  data_torch 
2303-             # qkv = rearrange(qkv.T, " o (g n i) ->o g n i", g=num_groups, n=q_per_kv + 2, i=head_dim) 
2304-             qkv  =  qkv .T .reshape ((- 1 , num_groups , q_per_kv  +  2 , head_dim ))
2305-             q , k , v  =  qkv [..., : q_per_kv , :], qkv [..., q_per_kv : q_per_kv  +  1 , :], qkv [..., q_per_kv  +  1 : q_per_kv  +  2 , :]
2293+ 
2294+             qkv  =  qkv .reshape ((num_groups , q_per_kv  +  2 , head_dim , n_embd ))
2295+             q , k , v  =  qkv [:, : q_per_kv ], qkv [:, - 2 ], qkv [:, - 1 ]
2296+ 
23062297            # The model weights of q and k equire additional reshape. 
2307-             # q = self._hf_permute_qk(rearrange(q, " o g n i ->  o (g n i)").T, num_heads, num_heads) 
2308-             q  =  self ._hf_permute_qk (q .reshape ((q .shape [0 ], - 1 )).T , num_heads , num_heads )
2309-             # k = self._hf_permute_qk(rearrange(k, " o g n i ->  o (g n i)").T, num_heads, num_kv_heads) 
2310-             k  =  self ._hf_permute_qk (k .reshape ((k .shape [0 ], - 1 )).T , num_heads , num_kv_heads )
2311-             # v = rearrange(v, " o g n i ->  o (g n i)").T 
2312-             v  =  v .reshape ((v .shape [0 ], - 1 )).T 
2298+             q  =  LlamaModel .permute (q .reshape ((- 1 , q .shape [- 1 ])), num_heads , num_heads )
2299+             k  =  LlamaModel .permute (k .reshape ((- 1 , k .shape [- 1 ])), num_heads , num_kv_heads )
2300+             v  =  v .reshape ((- 1 , v .shape [- 1 ]))
2301+ 
23132302            return  [
23142303                (self .format_tensor_name (gguf .MODEL_TENSOR .ATTN_Q , bid ), q ),
23152304                (self .format_tensor_name (gguf .MODEL_TENSOR .ATTN_K , bid ), k ),
@@ -3585,6 +3574,7 @@ def main() -> None:
35853574                                     small_first_shard = args .no_tensor_first_split )
35863575
35873576        logger .info ("Set model parameters" )
3577+         model_instance .gguf_writer .add_type (gguf .GGUFType .MODEL )
35883578        model_instance .set_gguf_parameters ()
35893579
35903580        logger .info ("Set model tokenizer" )
0 commit comments