@@ -346,6 +346,8 @@ def prepare_tensors(self):
346346 data_qtype = gguf .GGMLQuantizationType .BF16
347347 elif self .ftype == gguf .LlamaFileType .MOSTLY_Q8_0 :
348348 data_qtype = gguf .GGMLQuantizationType .Q8_0
349+ elif self .ftype == gguf .LlamaFileType .MOSTLY_Q4_0 :
350+ data_qtype = gguf .GGMLQuantizationType .Q4_0
349351 elif self .ftype == gguf .LlamaFileType .MOSTLY_TQ1_0 :
350352 data_qtype = gguf .GGMLQuantizationType .TQ1_0
351353 elif self .ftype == gguf .LlamaFileType .MOSTLY_TQ2_0 :
@@ -6394,24 +6396,22 @@ def set_gguf_parameters(self):
63946396
63956397
63966398@ModelBase .register ("HunYuanMoEV1ForCausalLM" )
6397- class HunYuanMoEModel (LlamaModel ):
6399+ class HunYuanMoEModel (TextModel ):
63986400 model_arch = gguf .MODEL_ARCH .HUNYUAN_MOE
6399- undo_permute = False
64006401
64016402 def __init__ (self , * args , ** kwargs ):
64026403 super ().__init__ (* args , ** kwargs )
6404+ # For handling tied embeddings
6405+ self ._tok_embd = None
64036406
64046407 def set_vocab (self ):
6405- self ._set_vocab_gpt2 ()
6406-
6407- def get_vocab_base (self ) -> tuple [list [str ], list [int ], str ]:
6408- tokens : list [str ] = []
6409- toktypes : list [int ] = []
6410-
64116408 from transformers import AutoTokenizer
64126409 tokenizer = AutoTokenizer .from_pretrained (self .dir_model , trust_remote_code = True )
64136410
6414- # merge logic is copied from QwenModel, maybe incorrect
6411+ # 1. Get the pre-tokenizer identifier hash
6412+ tokpre = self .get_vocab_base_pre (tokenizer )
6413+
6414+ # 2. Reverse-engineer the merges list from mergeable_ranks
64156415 merges = []
64166416 vocab = {}
64176417 mergeable_ranks = tokenizer .mergeable_ranks
@@ -6420,75 +6420,103 @@ def get_vocab_base(self) -> tuple[list[str], list[int], str]:
64206420 if len (token ) == 1 :
64216421 continue
64226422 merged = QwenModel .bpe (mergeable_ranks , token , max_rank = rank )
6423- if len (merged ) == 2 :
6423+ if len (merged ) == 2 : #todo this is an assert in Qwen, why?
64246424 merges .append (' ' .join (map (QwenModel .token_bytes_to_string , merged )))
6425- self .gguf_writer .add_token_merges (merges )
64266425
6427- reverse_vocab = tokenizer .decoder
6428- assert max (reverse_vocab .keys ()) < tokenizer .vocab_size
6429-
6430- tokpre = self .get_vocab_base_pre (tokenizer )
6431- added_vocab = tokenizer .get_added_vocab ()
6432-
6433- added_tokens_decoder = tokenizer .added_tokens_decoder
6434-
6435- for i in range (tokenizer .vocab_size ):
6426+ # 3. Generate the tokens and toktypes lists
6427+ vocab_size = self .hparams ["vocab_size" ]
6428+ assert tokenizer .vocab_size == vocab_size
6429+ special_tokens = tokenizer .special_tokens
6430+ reverse_vocab = {id_ : encoded_tok for encoded_tok , id_ in {** vocab , ** special_tokens }.items ()}
6431+ tokens : list [str ] = []
6432+ toktypes : list [int ] = []
6433+ for i in range (vocab_size ):
64366434 if i not in reverse_vocab :
64376435 tokens .append (f"[PAD{ i } ]" )
64386436 toktypes .append (gguf .TokenType .UNUSED )
64396437 else :
6440- token : str = reverse_vocab [i ]
6441- if token in added_vocab :
6442- # The tokenizer in llama.cpp assumes the CONTROL and USER_DEFINED tokens are pre-normalized.
6443- # To avoid unexpected issues - we make sure to normalize non-normalized tokens
6444- if not added_tokens_decoder [i ].normalized :
6445- previous_token = token
6446- token = tokenizer .decode (tokenizer .encode (token , add_special_tokens = False ))
6447- if previous_token != token :
6448- logger .info (f"{ repr (previous_token )} is encoded and decoded back to { repr (token )} using AutoTokenizer" )
6449-
6450- if added_tokens_decoder [i ].special or self .does_token_look_special (token ):
6451- toktypes .append (gguf .TokenType .CONTROL )
6452- else :
6453- # NOTE: this was added for Gemma.
6454- # Encoding and decoding the tokens above isn't sufficient for this case.
6455- token = token .replace (b"\xe2 \x96 \x81 " .decode ("utf-8" ), " " ) # pre-normalize user-defined spaces
6456- toktypes .append (gguf .TokenType .USER_DEFINED )
6438+ token = reverse_vocab [i ]
6439+ tokens .append (token )
6440+ if i in special_tokens .values ():
6441+ toktypes .append (gguf .TokenType .CONTROL )
64576442 else :
64586443 toktypes .append (gguf .TokenType .NORMAL )
6459- tokens .append (token )
64606444
6461- return tokens , toktypes , tokpre
6445+ # 4. Write all vocab-related fields to the GGUF writer
6446+ self .gguf_writer .add_tokenizer_model ("gpt2" )
6447+ self .gguf_writer .add_tokenizer_pre (tokpre )
6448+ self .gguf_writer .add_token_list (tokens )
6449+ self .gguf_writer .add_token_types (toktypes )
6450+ self .gguf_writer .add_token_merges (merges )
6451+
6452+ # 5. Add special tokens and chat templates
6453+ special_vocab = gguf .SpecialVocab (self .dir_model , load_merges = False )
6454+ special_vocab .add_to_gguf (self .gguf_writer )
6455+ # FIX for BOS token: Overwrite incorrect id read from config.json
6456+ self .gguf_writer .add_bos_token_id (127959 ) # <|bos|>
64626457
64636458 def set_gguf_parameters (self ):
64646459 super ().set_gguf_parameters ()
6460+ hparams = self .hparams
64656461
6466- self .gguf_writer .add_expert_count (self . hparams ["num_experts" ])
6467- self .gguf_writer .add_expert_shared_feed_forward_length (self . hparams ["intermediate_size" ])
6462+ self .gguf_writer .add_expert_count (hparams ["num_experts" ])
6463+ self .gguf_writer .add_expert_shared_feed_forward_length (hparams ["intermediate_size" ])
64686464
6469- moe_intermediate_size = self . hparams ["moe_intermediate_size" ]
6465+ moe_intermediate_size = hparams ["moe_intermediate_size" ]
64706466 assert all (n == moe_intermediate_size [0 ] for n in moe_intermediate_size )
64716467 self .gguf_writer .add_expert_feed_forward_length (moe_intermediate_size [0 ])
64726468
6473- moe_topk = self . hparams ["moe_topk" ]
6469+ moe_topk = hparams ["moe_topk" ]
64746470 assert all (topk == moe_topk [0 ] for topk in moe_topk )
64756471 self .gguf_writer .add_expert_used_count (moe_topk [0 ])
64766472
6473+ moe_shared_expert = hparams ["num_shared_expert" ]
6474+ assert all (n == moe_shared_expert [0 ] for n in moe_shared_expert )
6475+ self .gguf_writer .add_expert_shared_count (moe_shared_expert [0 ])
6476+
6477+ # Rope
6478+ rope_scaling = hparams .get ("rope_scaling" , {})
6479+ if rope_scaling .get ("type" ) == "dynamic" :
6480+ # HunYuan uses NTK Aware Alpha based scaling. Original implementation: https://www.reddit.com/r/LocalLLaMA/comments/14lz7j5/ntkaware_scaled_rope_allows_llama_models_to_have/
6481+ # 1000 corresponds to a usable context length of 256k (https://github.com/Tencent-Hunyuan/Hunyuan-A13B/blob/main/report/Hunyuan_A13B_Technical_Report.pdf)
6482+ alpha = rope_scaling .get ("alpha" , 1000 )
6483+ base = hparams .get ("rope_theta" , 10000.0 )
6484+ dim = (hparams ["hidden_size" ] // hparams ["num_attention_heads" ]) # 128
6485+ scaled_base = base * (alpha ** (dim / (dim - 2 ))) # 10000 * (1000 ** (128 / 126)) = 11158839.9251
6486+ self .gguf_writer .add_rope_freq_base (scaled_base )
6487+ self .gguf_writer .add_rope_scaling_type (gguf .RopeScalingType .NONE )
6488+ self .gguf_writer .add_rope_scaling_factor (1 )
6489+ #There is no consistent way to calculate ctx from alpha, and the config is incorrectly set to 32k
6490+ self .gguf_writer .add_rope_scaling_orig_ctx_len (256 * 1024 ) # 256k context length
6491+ self .gguf_writer .add_context_length (256 * 1024 ) # 256k context length
6492+
6493+ # if any of our assumptions about the values are wrong, something has changed and this may need to be updated
6494+ assert alpha == 1000 and base == 10000.0 and dim == 128 and self .hparams ["max_position_embeddings" ] in [32 * 1024 , 256 * 1024 ] , \
6495+ "HunYuan dynamic RoPE scaling assumptions changed, please update the logic or context length manually"
6496+
6497+ _experts : list [dict [str , Tensor ]] | None = None
6498+
64776499 def modify_tensors (self , data_torch : Tensor , name : str , bid : int | None ) -> Iterable [tuple [str , Tensor ]]:
6478- # process the experts separately
6500+ if name == "model.embed_tokens.weight" :
6501+ self ._tok_embd = data_torch .clone ()
6502+
6503+ if name == "lm_head.weight" :
6504+ if self .hparams .get ("tie_word_embeddings" , False ):
6505+ logger .info ("Skipping tied output layer 'lm_head.weight'" )
6506+ return []
6507+
64796508 if name .find ("mlp.experts" ) != - 1 :
64806509 n_experts = self .hparams ["num_experts" ]
64816510 assert bid is not None
64826511
6483- tensors : list [tuple [str , Tensor ]] = []
6484-
64856512 if self ._experts is None :
64866513 self ._experts = [{} for _ in range (self .block_count )]
64876514
64886515 self ._experts [bid ][name ] = data_torch
64896516
64906517 if len (self ._experts [bid ]) >= n_experts * 3 :
64916518 # merge the experts into a single 3d tensor
6519+ tensors : list [tuple [str , Tensor ]] = []
64926520 for w_name in ["down_proj" , "gate_proj" , "up_proj" ]:
64936521 datas : list [Tensor ] = []
64946522
@@ -6498,11 +6526,8 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter
64986526 del self ._experts [bid ][ename ]
64996527
65006528 data_torch = torch .stack (datas , dim = 0 )
6501-
65026529 merged_name = f"model.layers.{ bid } .mlp.experts.{ w_name } .weight"
6503-
65046530 new_name = self .map_tensor_name (merged_name )
6505-
65066531 tensors .append ((new_name , data_torch ))
65076532
65086533 return tensors
@@ -6511,6 +6536,13 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter
65116536
65126537 return [(self .map_tensor_name (name ), data_torch )]
65136538
6539+ def prepare_tensors (self ):
6540+ super ().prepare_tensors ()
6541+ if self ._experts is not None :
6542+ experts = [k for d in self ._experts for k in d .keys ()]
6543+ if len (experts ) > 0 :
6544+ raise ValueError (f"Unprocessed experts: { experts } " )
6545+
65146546###### CONVERSION LOGIC ######
65156547
65166548
@@ -6600,7 +6632,7 @@ def parse_args() -> argparse.Namespace:
66006632 help = "path to write to; default: based on input. {ftype} will be replaced by the outtype." ,
66016633 )
66026634 parser .add_argument (
6603- "--outtype" , type = str , choices = ["f32" , "f16" , "bf16" , "q8_0" , "tq1_0" , "tq2_0" , "auto" ], default = "f16" ,
6635+ "--outtype" , type = str , choices = ["f32" , "f16" , "bf16" , "q4_0" , " q8_0" , "tq1_0" , "tq2_0" , "auto" ], default = "f16" ,
66046636 help = "output format - use f32 for float32, f16 for float16, bf16 for bfloat16, q8_0 for Q8_0, tq1_0 or tq2_0 for ternary, and auto for the highest-fidelity 16-bit float type depending on the first loaded tensor type" ,
66056637 )
66066638 parser .add_argument (
@@ -6732,6 +6764,7 @@ def main() -> None:
67326764 "f32" : gguf .LlamaFileType .ALL_F32 ,
67336765 "f16" : gguf .LlamaFileType .MOSTLY_F16 ,
67346766 "bf16" : gguf .LlamaFileType .MOSTLY_BF16 ,
6767+ "q4_0" : gguf .LlamaFileType .MOSTLY_Q4_0 ,
67356768 "q8_0" : gguf .LlamaFileType .MOSTLY_Q8_0 ,
67366769 "tq1_0" : gguf .LlamaFileType .MOSTLY_TQ1_0 ,
67376770 "tq2_0" : gguf .LlamaFileType .MOSTLY_TQ2_0 ,
0 commit comments