@@ -618,6 +618,15 @@ def get_vocab_base_pre(self, tokenizer) -> str:
618618 if chkhsh == "b6e8e1518dc4305be2fe39c313ed643381c4da5db34a98f6a04c093f8afbe99b" :
619619 # ref: https://huggingface.co/THUDM/glm-4-9b-chat
620620 res = "chatglm-bpe"
621+ if chkhsh == "81d72c7348a9f0ebe86f23298d37debe0a5e71149e29bd283904c02262b27516" :
622+ # ref: https://huggingface.co/THUDM/glm-4-9b-chat
623+ res = "chatglm-bpe"
624+ if chkhsh == "a1336059768a55c99a734006ffb02203cd450fed003e9a71886c88acf24fdbc2" :
625+ # ref: https://huggingface.co/THUDM/glm-4-9b-hf
626+ res = "glm4"
627+ if chkhsh == "9ca2dd618e8afaf09731a7cf6e2105b373ba6a1821559f258b272fe83e6eb902" :
628+ # ref: https://huggingface.co/zai-org/GLM-4.5-Air, https://huggingface.co/zai-org/GLM-4.5
629+ res = "glm4"
621630 if chkhsh == "7fc505bd3104ca1083b150b17d088b59534ede9bde81f0dd2090967d7fe52cee" :
622631 # ref: https://huggingface.co/LumiOpen/Viking-7B
623632 res = "viking"
@@ -3948,6 +3957,137 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None):
39483957 return [(self .map_tensor_name (name ), data_torch )]
39493958 return super ().modify_tensors (data_torch , name , bid )
39503959
3960+ @Model .register ("Glm4MoeForCausalLM" )
3961+ class Glm4MoeModel (Model ):
3962+ model_arch = gguf .MODEL_ARCH .GLM4_MOE
3963+
3964+ def __init__ (self , * args , ** kwargs ):
3965+ super ().__init__ (* args , ** kwargs )
3966+ # GLM4_MOE has num_hidden_layers + 1 actual layers (including NextN layer)
3967+ self .block_count = self .hparams ["num_hidden_layers" ] + self .hparams .get ("num_nextn_predict_layers" , 0 )
3968+ self .tensor_map = gguf .get_tensor_name_map (self .model_arch , self .block_count )
3969+
3970+ def set_vocab (self ):
3971+ from transformers import AutoTokenizer
3972+
3973+ tokenizer = AutoTokenizer .from_pretrained (self .dir_model )
3974+ special_vocab = gguf .SpecialVocab (self .dir_model , load_merges = True )
3975+ tokens , toktypes , tokpre = self .get_vocab_base ()
3976+ self .gguf_writer .add_tokenizer_model ("gpt2" )
3977+ self .gguf_writer .add_tokenizer_pre (tokpre )
3978+ self .gguf_writer .add_token_list (tokens )
3979+ self .gguf_writer .add_token_types (toktypes )
3980+
3981+ # Special tokens
3982+ # Note: Using <|endoftext|> (151329) for eot causes endless generation
3983+ special_vocab ._set_special_token ("bos" , tokenizer .get_added_vocab ()["[gMASK]" ]) # 151331
3984+ special_vocab ._set_special_token ("eot" , tokenizer .get_added_vocab ()["<|user|>" ]) # 151336
3985+ special_vocab ._set_special_token ("unk" , tokenizer .get_added_vocab ()["<|endoftext|>" ]) # 151329
3986+ special_vocab ._set_special_token ("eom" , tokenizer .get_added_vocab ()["<|observation|>" ]) # 151338
3987+
3988+ # Patch broken chat template
3989+ if isinstance (special_vocab .chat_template , str ) and "visible_text(m.content).endswith" in special_vocab .chat_template :
3990+ special_vocab .chat_template = special_vocab .chat_template .replace (
3991+ """{{ visible_text(m.content) }}\n {{- '/nothink' if (enable_thinking is defined and not enable_thinking and not visible_text(m.content).endswith("/nothink")) else '' -}}""" ,
3992+ """{% set content = visible_text(m.content) %}{{ content }}\n {{- '/nothink' if (enable_thinking is defined and not enable_thinking and not content.endswith("/nothink")) else '' -}}""" )
3993+
3994+ special_vocab .add_to_gguf (self .gguf_writer )
3995+
3996+ def set_gguf_parameters (self ):
3997+ super ().set_gguf_parameters ()
3998+ if (rope_dim := self .hparams .get ("head_dim" )) is None :
3999+ rope_dim = (
4000+ self .hparams ["hidden_size" ] // self .hparams ["num_attention_heads" ]
4001+ )
4002+ self .gguf_writer .add_rope_dimension_count (
4003+ int (rope_dim * self .hparams .get ("partial_rotary_factor" , 0.5 ))
4004+ )
4005+
4006+ # MoE parameters - Use only routed expert count (shared experts handled separately)
4007+ if (n_routed_experts := self .hparams .get ("n_routed_experts" )) is not None :
4008+ self .gguf_writer .add_expert_count (n_routed_experts )
4009+ if (moe_intermediate_size := self .hparams .get ("moe_intermediate_size" )) is not None :
4010+ self .gguf_writer .add_expert_feed_forward_length (moe_intermediate_size )
4011+ if (n_shared_experts := self .hparams .get ("n_shared_experts" )) is not None :
4012+ self .gguf_writer .add_expert_shared_count (n_shared_experts )
4013+ if (first_k_dense_replace := self .hparams .get ("first_k_dense_replace" )) is not None :
4014+ self .gguf_writer .add_leading_dense_block_count (first_k_dense_replace )
4015+
4016+ # Expert gating function (sigmoid for GLM4_MOE)
4017+ self .gguf_writer .add_expert_gating_func (gguf .ExpertGatingFuncType .SIGMOID )
4018+
4019+ # Routed scaling factor
4020+ if (routed_scaling_factor := self .hparams .get ("routed_scaling_factor" )) is not None :
4021+ self .gguf_writer .add_expert_weights_scale (routed_scaling_factor )
4022+
4023+ # Normalise topk probabilities
4024+ if (norm_topk_prob := self .hparams .get ("norm_topk_prob" )) is not None :
4025+ self .gguf_writer .add_expert_weights_norm (norm_topk_prob )
4026+
4027+ # NextN/MTP prediction layers
4028+ if (num_nextn_predict_layers := self .hparams .get ("num_nextn_predict_layers" )) is not None :
4029+ self .gguf_writer .add_nextn_predict_layers (num_nextn_predict_layers )
4030+
4031+ _experts : list [dict [str , Tensor ]] | None = None
4032+
4033+ def modify_tensors (
4034+ self , data_torch : Tensor , name : str , bid : int | None
4035+ ) -> Iterable [tuple [str , Tensor ]]:
4036+ if name .startswith ("model.visual." ): # ignore visual part
4037+ return []
4038+ elif name .startswith ("model.language_model." ):
4039+ name = name .replace ("language_model." , "" ) # for multimodal variants
4040+
4041+ # Handle main token embedding (but not layer-specific NextN embeddings)
4042+ if name == "model.embed_tokens.weight" and ".layers." not in name :
4043+ return [(self .map_tensor_name ("token_embd.weight" ), data_torch )]
4044+
4045+ # Handle routed experts
4046+ if name .find ("mlp.experts" ) != - 1 :
4047+ n_experts = self .hparams ["n_routed_experts" ]
4048+ assert bid is not None
4049+
4050+ if self ._experts is None :
4051+ self ._experts = [{} for _ in range (self .block_count )]
4052+
4053+ self ._experts [bid ][name ] = data_torch
4054+
4055+ if len (self ._experts [bid ]) >= n_experts * 3 :
4056+ tensors : list [tuple [str , Tensor ]] = []
4057+
4058+ # merge the experts into a single 3d tensor
4059+ for w_name in ["down_proj" , "gate_proj" , "up_proj" ]:
4060+ datas : list [Tensor ] = []
4061+
4062+ for xid in range (n_experts ):
4063+ ename = f"model.layers.{ bid } .mlp.experts.{ xid } .{ w_name } .weight"
4064+ datas .append (self ._experts [bid ][ename ])
4065+ del self ._experts [bid ][ename ]
4066+
4067+ data_torch = torch .stack (datas , dim = 0 )
4068+
4069+ merged_name = f"model.layers.{ bid } .mlp.experts.{ w_name } .weight"
4070+
4071+ new_name = self .map_tensor_name (merged_name )
4072+ tensors .append ((new_name , data_torch ))
4073+ return tensors
4074+ else :
4075+ return []
4076+
4077+ if name .endswith ("e_score_correction_bias" ):
4078+ name = name .replace ("e_score_correction_bias" , "e_score_correction.bias" )
4079+
4080+ new_name = self .map_tensor_name (name )
4081+
4082+ return [(new_name , data_torch )]
4083+
4084+ def prepare_tensors (self ):
4085+ super ().prepare_tensors ()
4086+ if self ._experts is not None :
4087+ # flatten `list[dict[str, Tensor]]` into `list[str]`
4088+ experts = [k for d in self ._experts for k in d .keys ()]
4089+ if len (experts ) > 0 :
4090+ raise ValueError (f"Unprocessed experts: { experts } " )
39514091
39524092@Model .register ("ChatGLMModel" , "ChatGLMForConditionalGeneration" )
39534093class ChatGLMModel (Model ):
0 commit comments