@@ -1014,6 +1014,9 @@ def get_vocab_base_pre(self, tokenizer) -> str:
10141014 if chkhsh == "7e57df22b1fe23a7b1e1c7f3dc4e3f96d43a4eb0836d0c6bdc3436d7b2f1c664" :
10151015 # ref: https://huggingface.co/tencent/Hunyuan-A13B-Instruct
10161016 res = "hunyuan"
1017+ if chkhsh == "bba3b3366b646dbdded5dbc42d59598b849371afc42f7beafa914afaa5b70aa6" :
1018+ # ref: https://huggingface.co/tencent/Hunyuan-4B-Instruct
1019+ res = "hunyuan-dense"
10171020 if chkhsh == "a6b57017d60e6edb4d88ecc2845188e0eb333a70357e45dcc9b53964a73bbae6" :
10181021 # ref: https://huggingface.co/tiiuae/Falcon-H1-0.5B-Base
10191022 res = "falcon-h1"
@@ -7883,11 +7886,6 @@ def set_gguf_parameters(self):
78837886class HunYuanMoEModel (TextModel ):
78847887 model_arch = gguf .MODEL_ARCH .HUNYUAN_MOE
78857888
7886- def __init__ (self , * args , ** kwargs ):
7887- super ().__init__ (* args , ** kwargs )
7888- # For handling tied embeddings
7889- self ._tok_embd = None
7890-
78917889 def set_vocab (self ):
78927890 from transformers import AutoTokenizer
78937891 tokenizer = AutoTokenizer .from_pretrained (self .dir_model , trust_remote_code = True )
@@ -7981,9 +7979,6 @@ def set_gguf_parameters(self):
79817979 _experts : list [dict [str , Tensor ]] | None = None
79827980
79837981 def modify_tensors (self , data_torch : Tensor , name : str , bid : int | None ) -> Iterable [tuple [str , Tensor ]]:
7984- if name == "model.embed_tokens.weight" :
7985- self ._tok_embd = data_torch .clone ()
7986-
79877982 if name == "lm_head.weight" :
79887983 if self .hparams .get ("tie_word_embeddings" , False ):
79897984 logger .info ("Skipping tied output layer 'lm_head.weight'" )
@@ -8028,6 +8023,98 @@ def prepare_tensors(self):
80288023 raise ValueError (f"Unprocessed experts: { experts } " )
80298024
80308025
8026+ @ModelBase .register ("HunYuanDenseV1ForCausalLM" )
8027+ class HunYuanModel (TextModel ):
8028+ model_arch = gguf .MODEL_ARCH .HUNYUAN_DENSE
8029+
8030+ def set_vocab (self ):
8031+ if (self .dir_model / "tokenizer.json" ).is_file ():
8032+ self ._set_vocab_gpt2 ()
8033+ else :
8034+ from transformers import AutoTokenizer
8035+ tokenizer = AutoTokenizer .from_pretrained (self .dir_model , trust_remote_code = True )
8036+
8037+ # 1. Get the pre-tokenizer identifier hash
8038+ tokpre = self .get_vocab_base_pre (tokenizer )
8039+
8040+ # 2. Reverse-engineer the merges list from mergeable_ranks
8041+ merges = []
8042+ vocab = {}
8043+ mergeable_ranks = tokenizer .mergeable_ranks
8044+ for token , rank in mergeable_ranks .items ():
8045+ vocab [QwenModel .token_bytes_to_string (token )] = rank
8046+ if len (token ) == 1 :
8047+ continue
8048+ merged = QwenModel .bpe (mergeable_ranks , token , max_rank = rank )
8049+ if len (merged ) == 2 :
8050+ merges .append (' ' .join (map (QwenModel .token_bytes_to_string , merged )))
8051+
8052+ # 3. Generate the tokens and toktypes lists
8053+ vocab_size = self .hparams ["vocab_size" ]
8054+ assert tokenizer .vocab_size == vocab_size
8055+ special_tokens = tokenizer .special_tokens
8056+ reverse_vocab = {id_ : encoded_tok for encoded_tok , id_ in {** vocab , ** special_tokens }.items ()}
8057+ tokens : list [str ] = []
8058+ toktypes : list [int ] = []
8059+ for i in range (vocab_size ):
8060+ if i not in reverse_vocab :
8061+ tokens .append (f"[PAD{ i } ]" )
8062+ toktypes .append (gguf .TokenType .UNUSED )
8063+ else :
8064+ token = reverse_vocab [i ]
8065+ tokens .append (token )
8066+ if i in special_tokens .values ():
8067+ toktypes .append (gguf .TokenType .CONTROL )
8068+ else :
8069+ toktypes .append (gguf .TokenType .NORMAL )
8070+
8071+ # 4. Write all vocab-related fields to the GGUF writer
8072+ self .gguf_writer .add_tokenizer_model ("gpt2" )
8073+ self .gguf_writer .add_tokenizer_pre (tokpre )
8074+ self .gguf_writer .add_token_list (tokens )
8075+ self .gguf_writer .add_token_types (toktypes )
8076+ self .gguf_writer .add_token_merges (merges )
8077+
8078+ # 5. Add special tokens and chat templates
8079+ special_vocab = gguf .SpecialVocab (self .dir_model , load_merges = False )
8080+ special_vocab .add_to_gguf (self .gguf_writer )
8081+ # FIX for BOS token: Overwrite incorrect id read from config.json
8082+ if self .hparams ['hidden_size' ] == 4096 :
8083+ self .gguf_writer .add_bos_token_id (127958 ) # only for 7b dense, fix <|bos|> token
8084+
8085+ def set_gguf_parameters (self ):
8086+ super ().set_gguf_parameters ()
8087+ hparams = self .hparams
8088+
8089+ # Rope
8090+ rope_scaling = hparams .get ("rope_scaling" , {})
8091+ if rope_scaling .get ("type" ) == "dynamic" :
8092+ # HunYuan uses NTK Aware Alpha based scaling. Original implementation: https://www.reddit.com/r/LocalLLaMA/comments/14lz7j5/ntkaware_scaled_rope_allows_llama_models_to_have/
8093+ # 1000 corresponds to a usable context length of 256k (https://github.com/Tencent-Hunyuan/Hunyuan-A13B/blob/main/report/Hunyuan_A13B_Technical_Report.pdf)
8094+ alpha = rope_scaling .get ("alpha" , 50 )
8095+ base = hparams .get ("rope_theta" , 10000.0 )
8096+ dim = hparams ["head_dim" ]
8097+ scaled_base = base * (alpha ** (dim / (dim - 2 )))
8098+ self .gguf_writer .add_rope_freq_base (scaled_base )
8099+ self .gguf_writer .add_rope_scaling_type (gguf .RopeScalingType .NONE )
8100+ self .gguf_writer .add_rope_scaling_factor (1 )
8101+ # There is no consistent way to calculate ctx from alpha, and the config is incorrectly set to 32k
8102+ self .gguf_writer .add_rope_scaling_orig_ctx_len (256 * 1024 ) # 256k context length
8103+ self .gguf_writer .add_context_length (256 * 1024 ) # 256k context length
8104+
8105+ # if any of our assumptions about the values are wrong, something has changed and this may need to be updated
8106+ assert base == 10000.0 and self .hparams ["max_position_embeddings" ] in [32 * 1024 , 256 * 1024 ] , \
8107+ "HunYuan dynamic RoPE scaling assumptions changed, please update the logic or context length manually"
8108+
8109+ def modify_tensors (self , data_torch : Tensor , name : str , bid : int | None ) -> Iterable [tuple [str , Tensor ]]:
8110+ if name == "lm_head.weight" :
8111+ if self .hparams .get ("tie_word_embeddings" , False ):
8112+ logger .info ("Skipping tied output layer 'lm_head.weight'" )
8113+ return []
8114+
8115+ return [(self .map_tensor_name (name ), data_torch )]
8116+
8117+
80318118@ModelBase .register ("SmolLM3ForCausalLM" )
80328119class SmolLM3Model (LlamaModel ):
80338120 model_arch = gguf .MODEL_ARCH .SMOLLM3
0 commit comments