@@ -815,6 +815,9 @@ def get_vocab_base_pre(self, tokenizer) -> str:
815815 if chkhsh == "1431a23e583c97432bc230bff598d103ddb5a1f89960c8f1d1051aaa944d0b35" :
816816 # ref: https://huggingface.co/sapienzanlp/Minerva-7B-base-v1.0
817817 res = "minerva-7b"
818+ if chkhsh == "7e57df22b1fe23a7b1e1c7f3dc4e3f96d43a4eb0836d0c6bdc3436d7b2f1c664" :
819+ # ref: https://huggingface.co/tencent/Hunyuan-A13B-Instruct
820+ res = "hunyuan"
818821
819822 if res is None :
820823 logger .warning ("\n " )
@@ -6666,6 +6669,156 @@ def set_gguf_parameters(self):
66666669 # Add any other Falcon Mamba2 specific configuration
66676670 self .gguf_writer .add_rope_freq_base (self .find_hparam (["rope_theta" ]))
66686671
6672+
6673+ @ModelBase .register ("HunYuanMoEV1ForCausalLM" )
6674+ class HunYuanMoEModel (TextModel ):
6675+ model_arch = gguf .MODEL_ARCH .HUNYUAN_MOE
6676+
6677+ def __init__ (self , * args , ** kwargs ):
6678+ super ().__init__ (* args , ** kwargs )
6679+ # For handling tied embeddings
6680+ self ._tok_embd = None
6681+
6682+ def set_vocab (self ):
6683+ from transformers import AutoTokenizer
6684+ tokenizer = AutoTokenizer .from_pretrained (self .dir_model , trust_remote_code = True )
6685+
6686+ # 1. Get the pre-tokenizer identifier hash
6687+ tokpre = self .get_vocab_base_pre (tokenizer )
6688+
6689+ # 2. Reverse-engineer the merges list from mergeable_ranks
6690+ merges = []
6691+ vocab = {}
6692+ mergeable_ranks = tokenizer .mergeable_ranks
6693+ for token , rank in mergeable_ranks .items ():
6694+ vocab [QwenModel .token_bytes_to_string (token )] = rank
6695+ if len (token ) == 1 :
6696+ continue
6697+ merged = QwenModel .bpe (mergeable_ranks , token , max_rank = rank )
6698+ if len (merged ) == 2 : # todo this is an assert in Qwen, why?
6699+ merges .append (' ' .join (map (QwenModel .token_bytes_to_string , merged )))
6700+
6701+ # 3. Generate the tokens and toktypes lists
6702+ vocab_size = self .hparams ["vocab_size" ]
6703+ assert tokenizer .vocab_size == vocab_size
6704+ special_tokens = tokenizer .special_tokens
6705+ reverse_vocab = {id_ : encoded_tok for encoded_tok , id_ in {** vocab , ** special_tokens }.items ()}
6706+ tokens : list [str ] = []
6707+ toktypes : list [int ] = []
6708+ for i in range (vocab_size ):
6709+ if i not in reverse_vocab :
6710+ tokens .append (f"[PAD{ i } ]" )
6711+ toktypes .append (gguf .TokenType .UNUSED )
6712+ else :
6713+ token = reverse_vocab [i ]
6714+ tokens .append (token )
6715+ if i in special_tokens .values ():
6716+ toktypes .append (gguf .TokenType .CONTROL )
6717+ else :
6718+ toktypes .append (gguf .TokenType .NORMAL )
6719+
6720+ # 4. Write all vocab-related fields to the GGUF writer
6721+ self .gguf_writer .add_tokenizer_model ("gpt2" )
6722+ self .gguf_writer .add_tokenizer_pre (tokpre )
6723+ self .gguf_writer .add_token_list (tokens )
6724+ self .gguf_writer .add_token_types (toktypes )
6725+ self .gguf_writer .add_token_merges (merges )
6726+
6727+ # 5. Add special tokens and chat templates
6728+ special_vocab = gguf .SpecialVocab (self .dir_model , load_merges = False )
6729+ special_vocab .add_to_gguf (self .gguf_writer )
6730+ # FIX for BOS token: Overwrite incorrect id read from config.json
6731+ self .gguf_writer .add_bos_token_id (127959 ) # <|bos|>
6732+
6733+ def set_gguf_parameters (self ):
6734+ super ().set_gguf_parameters ()
6735+ hparams = self .hparams
6736+
6737+ self .gguf_writer .add_expert_count (hparams ["num_experts" ])
6738+ self .gguf_writer .add_expert_shared_feed_forward_length (hparams ["intermediate_size" ])
6739+
6740+ moe_intermediate_size = hparams ["moe_intermediate_size" ]
6741+ assert all (n == moe_intermediate_size [0 ] for n in moe_intermediate_size )
6742+ self .gguf_writer .add_expert_feed_forward_length (moe_intermediate_size [0 ])
6743+
6744+ moe_topk = hparams ["moe_topk" ]
6745+ assert all (topk == moe_topk [0 ] for topk in moe_topk )
6746+ self .gguf_writer .add_expert_used_count (moe_topk [0 ])
6747+
6748+ moe_shared_expert = hparams ["num_shared_expert" ]
6749+ assert all (n == moe_shared_expert [0 ] for n in moe_shared_expert )
6750+ self .gguf_writer .add_expert_shared_count (moe_shared_expert [0 ])
6751+
6752+ # Rope
6753+ rope_scaling = hparams .get ("rope_scaling" , {})
6754+ if rope_scaling .get ("type" ) == "dynamic" :
6755+ # HunYuan uses NTK Aware Alpha based scaling. Original implementation: https://www.reddit.com/r/LocalLLaMA/comments/14lz7j5/ntkaware_scaled_rope_allows_llama_models_to_have/
6756+ # 1000 corresponds to a usable context length of 256k (https://github.com/Tencent-Hunyuan/Hunyuan-A13B/blob/main/report/Hunyuan_A13B_Technical_Report.pdf)
6757+ alpha = rope_scaling .get ("alpha" , 1000 )
6758+ base = hparams .get ("rope_theta" , 10000.0 )
6759+ dim = (hparams ["hidden_size" ] // hparams ["num_attention_heads" ]) # 128
6760+ scaled_base = base * (alpha ** (dim / (dim - 2 ))) # 10000 * (1000 ** (128 / 126)) = 11158839.9251
6761+ self .gguf_writer .add_rope_freq_base (scaled_base )
6762+ self .gguf_writer .add_rope_scaling_type (gguf .RopeScalingType .NONE )
6763+ self .gguf_writer .add_rope_scaling_factor (1 )
6764+ # There is no consistent way to calculate ctx from alpha, and the config is incorrectly set to 32k
6765+ self .gguf_writer .add_rope_scaling_orig_ctx_len (256 * 1024 ) # 256k context length
6766+ self .gguf_writer .add_context_length (256 * 1024 ) # 256k context length
6767+
6768+ # if any of our assumptions about the values are wrong, something has changed and this may need to be updated
6769+ assert alpha == 1000 and base == 10000.0 and dim == 128 and self .hparams ["max_position_embeddings" ] in [32 * 1024 , 256 * 1024 ] , \
6770+ "HunYuan dynamic RoPE scaling assumptions changed, please update the logic or context length manually"
6771+
6772+ _experts : list [dict [str , Tensor ]] | None = None
6773+
6774+ def modify_tensors (self , data_torch : Tensor , name : str , bid : int | None ) -> Iterable [tuple [str , Tensor ]]:
6775+ if name == "model.embed_tokens.weight" :
6776+ self ._tok_embd = data_torch .clone ()
6777+
6778+ if name == "lm_head.weight" :
6779+ if self .hparams .get ("tie_word_embeddings" , False ):
6780+ logger .info ("Skipping tied output layer 'lm_head.weight'" )
6781+ return []
6782+
6783+ if name .find ("mlp.experts" ) != - 1 :
6784+ n_experts = self .hparams ["num_experts" ]
6785+ assert bid is not None
6786+
6787+ if self ._experts is None :
6788+ self ._experts = [{} for _ in range (self .block_count )]
6789+
6790+ self ._experts [bid ][name ] = data_torch
6791+
6792+ if len (self ._experts [bid ]) >= n_experts * 3 :
6793+ # merge the experts into a single 3d tensor
6794+ tensors : list [tuple [str , Tensor ]] = []
6795+ for w_name in ["down_proj" , "gate_proj" , "up_proj" ]:
6796+ datas : list [Tensor ] = []
6797+
6798+ for xid in range (n_experts ):
6799+ ename = f"model.layers.{ bid } .mlp.experts.{ xid } .{ w_name } .weight"
6800+ datas .append (self ._experts [bid ][ename ])
6801+ del self ._experts [bid ][ename ]
6802+
6803+ data_torch = torch .stack (datas , dim = 0 )
6804+ merged_name = f"model.layers.{ bid } .mlp.experts.{ w_name } .weight"
6805+ new_name = self .map_tensor_name (merged_name )
6806+ tensors .append ((new_name , data_torch ))
6807+
6808+ return tensors
6809+ else :
6810+ return []
6811+
6812+ return [(self .map_tensor_name (name ), data_torch )]
6813+
6814+ def prepare_tensors (self ):
6815+ super ().prepare_tensors ()
6816+ if self ._experts is not None :
6817+ experts = [k for d in self ._experts for k in d .keys ()]
6818+ if len (experts ) > 0 :
6819+ raise ValueError (f"Unprocessed experts: { experts } " )
6820+
6821+
66696822###### CONVERSION LOGIC ######
66706823
66716824
0 commit comments