@@ -2819,6 +2819,66 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter
28192819 return [(self .map_tensor_name (name ), data_torch )]
28202820
28212821
2822+ @Model .register ("InternLM3ForCausalLM" )
2823+ class InternLM3Model (Model ):
2824+ model_arch = gguf .MODEL_ARCH .LLAMA
2825+
2826+ def set_vocab (self ):
2827+ tokens , scores , toktypes = self ._create_vocab_sentencepiece ()
2828+
2829+ self .gguf_writer .add_tokenizer_model ("llama" )
2830+ self .gguf_writer .add_tokenizer_pre ("default" )
2831+ self .gguf_writer .add_token_list (tokens )
2832+ self .gguf_writer .add_token_scores (scores )
2833+ self .gguf_writer .add_token_types (toktypes )
2834+
2835+ special_vocab = gguf .SpecialVocab (self .dir_model , n_vocab = len (tokens ))
2836+
2837+ tokenizer_config_file = self .dir_model / 'tokenizer_config.json'
2838+ if tokenizer_config_file .is_file ():
2839+ with open (tokenizer_config_file , "r" , encoding = "utf-8" ) as f :
2840+ tokenizer_config_json = json .load (f )
2841+ if "add_prefix_space" in tokenizer_config_json :
2842+ self .gguf_writer .add_add_space_prefix (tokenizer_config_json ["add_prefix_space" ])
2843+
2844+ if "added_tokens_decoder" in tokenizer_config_json :
2845+ for token_id , token_data in tokenizer_config_json ["added_tokens_decoder" ].items ():
2846+ if token_data .get ("special" ):
2847+ token_id = int (token_id )
2848+ token = token_data ["content" ]
2849+ special_vocab ._set_special_token (token , token_id )
2850+ # update eos token
2851+ if token == '<|im_end|>' and "eos" in special_vocab .special_token_ids :
2852+ special_vocab .special_token_ids ["eos" ] = token_id
2853+
2854+ special_vocab .add_to_gguf (self .gguf_writer )
2855+
2856+ def set_gguf_parameters (self ):
2857+ super ().set_gguf_parameters ()
2858+ hparams = self .hparams
2859+ self .gguf_writer .add_vocab_size (hparams ["vocab_size" ])
2860+
2861+ if "head_dim" in hparams :
2862+ rope_dim = hparams ["head_dim" ]
2863+ else :
2864+ rope_dim = hparams ["hidden_size" ] // hparams ["num_attention_heads" ]
2865+ self .gguf_writer .add_rope_dimension_count (rope_dim )
2866+
2867+ if self .hparams .get ("rope_scaling" ) is not None and "factor" in self .hparams ["rope_scaling" ]:
2868+ if self .hparams ["rope_scaling" ].get ("type" ) == "linear" or self .hparams ["rope_scaling" ].get ("rope_type" ) == "linear" :
2869+ self .gguf_writer .add_rope_scaling_type (gguf .RopeScalingType .LINEAR )
2870+ self .gguf_writer .add_rope_scaling_factor (self .hparams ["rope_scaling" ]["factor" ])
2871+
2872+ def modify_tensors (self , data_torch : Tensor , name : str , bid : int | None ) -> Iterable [tuple [str , Tensor ]]:
2873+ n_head = self .hparams ["num_attention_heads" ]
2874+ n_kv_head = self .hparams .get ("num_key_value_heads" )
2875+ if name .endswith (("q_proj.weight" , "q_proj.bias" )):
2876+ data_torch = LlamaModel .permute (data_torch , n_head , n_head )
2877+ if name .endswith (("k_proj.weight" , "k_proj.bias" )):
2878+ data_torch = LlamaModel .permute (data_torch , n_head , n_kv_head )
2879+ return [(self .map_tensor_name (name ), data_torch )]
2880+
2881+
28222882@Model .register ("BertModel" , "BertForMaskedLM" , "CamembertModel" )
28232883class BertModel (Model ):
28242884 model_arch = gguf .MODEL_ARCH .BERT
0 commit comments