@@ -618,6 +618,12 @@ def get_vocab_base_pre(self, tokenizer) -> str:
618618 if chkhsh == "b6e8e1518dc4305be2fe39c313ed643381c4da5db34a98f6a04c093f8afbe99b" :
619619 # ref: https://huggingface.co/THUDM/glm-4-9b-chat
620620 res = "chatglm-bpe"
621+ if chkhsh == "a1336059768a55c99a734006ffb02203cd450fed003e9a71886c88acf24fdbc2" :
622+ # ref: https://huggingface.co/THUDM/glm-4-9b-hf
623+ res = "glm4"
624+ if chkhsh == "9ca2dd618e8afaf09731a7cf6e2105b373ba6a1821559f258b272fe83e6eb902" :
625+ # ref: https://huggingface.co/zai-org/GLM-4.5-Air, https://huggingface.co/zai-org/GLM-4.5
626+ res = "gpt-2"
621627 if chkhsh == "7fc505bd3104ca1083b150b17d088b59534ede9bde81f0dd2090967d7fe52cee" :
622628 # ref: https://huggingface.co/LumiOpen/Viking-7B
623629 res = "viking"
@@ -3948,6 +3954,214 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None):
39483954 return [(self .map_tensor_name (name ), data_torch )]
39493955 return super ().modify_tensors (data_torch , name , bid )
39503956
3957+ @Model .register ("Glm4MoeForCausalLM" )
3958+ class Glm4MoeModel (Model ):
3959+ model_arch = gguf .MODEL_ARCH .GLM4_MOE
3960+
3961+ def __init__ (self , * args , ** kwargs ):
3962+ super ().__init__ (* args , ** kwargs )
3963+ # GLM4_MOE has num_hidden_layers + 1 actual layers (including NextN layer)
3964+ self .block_count = self .hparams ["num_hidden_layers" ] + 1
3965+ self .tensor_map = gguf .get_tensor_name_map (self .model_arch , self .block_count )
3966+
3967+ def set_vocab (self ):
3968+ from transformers import AutoTokenizer
3969+
3970+ tokenizer = AutoTokenizer .from_pretrained (
3971+ self .dir_model , trust_remote_code = True
3972+ )
3973+ special_vocab = gguf .SpecialVocab (self .dir_model , load_merges = True )
3974+ tokens , toktypes , tokpre = self .get_vocab_base ()
3975+ self .gguf_writer .add_tokenizer_model ("gpt2" )
3976+ self .gguf_writer .add_tokenizer_pre (tokpre )
3977+ self .gguf_writer .add_token_list (tokens )
3978+ self .gguf_writer .add_token_types (toktypes )
3979+
3980+ # Set special tokens
3981+ special_vocab ._set_special_token (
3982+ "eos" , tokenizer .get_added_vocab ()["<|endoftext|>" ]
3983+ )
3984+ special_vocab ._set_special_token ("eot" , tokenizer .get_added_vocab ()["<|user|>" ])
3985+ special_vocab ._set_special_token (
3986+ "unk" , tokenizer .get_added_vocab ()["<|endoftext|>" ]
3987+ )
3988+ special_vocab ._set_special_token (
3989+ "bos" , tokenizer .get_added_vocab ()["<|endoftext|>" ]
3990+ )
3991+
3992+ special_vocab .add_to_gguf (self .gguf_writer )
3993+
3994+ def set_gguf_parameters (self ):
3995+ super ().set_gguf_parameters ()
3996+ if (rope_dim := self .hparams .get ("head_dim" )) is None :
3997+ rope_dim = (
3998+ self .hparams ["hidden_size" ] // self .hparams ["num_attention_heads" ]
3999+ )
4000+ self .gguf_writer .add_rope_dimension_count (
4001+ int (rope_dim * self .hparams .get ("partial_rotary_factor" , 0.5 ))
4002+ )
4003+
4004+ # MoE parameters
4005+ if (n_experts := self .hparams .get ("n_routed_experts" )) is not None :
4006+ self .gguf_writer .add_expert_count (n_experts )
4007+ # Note: expert_used_count is already set by parent class using num_experts_per_tok
4008+ if (moe_intermediate_size := self .hparams .get ("moe_intermediate_size" )) is not None :
4009+ self .gguf_writer .add_expert_feed_forward_length (moe_intermediate_size )
4010+ if (n_shared_experts := self .hparams .get ("n_shared_experts" )) is not None :
4011+ self .gguf_writer .add_expert_shared_count (n_shared_experts )
4012+ if (first_k_dense_replace := self .hparams .get ("first_k_dense_replace" )) is not None :
4013+ self .gguf_writer .add_leading_dense_block_count (first_k_dense_replace )
4014+
4015+ # Expert gating function (sigmoid for GLM4_MOE)
4016+ self .gguf_writer .add_expert_gating_func (gguf .ExpertGatingFuncType .SIGMOID )
4017+
4018+ # Routed scaling factor
4019+ if (routed_scaling_factor := self .hparams .get ("routed_scaling_factor" )) is not None :
4020+ self .gguf_writer .add_expert_weights_scale (routed_scaling_factor )
4021+
4022+ # Normalise topk probabilities
4023+ if (norm_topk_prob := self .hparams .get ("norm_topk_prob" )) is not None :
4024+ self .gguf_writer .add_expert_weights_norm (norm_topk_prob )
4025+
4026+ _experts : list [dict [str , Tensor ]] | None = None
4027+ _shared_experts : list [dict [str , Tensor ]] | None = None
4028+
4029+ def modify_tensors (
4030+ self , data_torch : Tensor , name : str , bid : int | None
4031+ ) -> Iterable [tuple [str , Tensor ]]:
4032+ if name .startswith ("model.visual." ): # ignore visual part
4033+ return []
4034+ elif name .startswith ("model.language_model." ):
4035+ name = name .replace ("language_model." , "" ) # for multimodal variants
4036+
4037+ # Handle main token embedding (but not layer-specific NextN embeddings)
4038+ if name == "model.embed_tokens.weight" :
4039+ return [(self .map_tensor_name ("token_embd.weight" ), data_torch )]
4040+
4041+ # Handle routed experts
4042+ if name .find ("mlp.experts" ) != - 1 and "shared_experts" not in name :
4043+ n_experts = self .hparams ["n_routed_experts" ]
4044+ assert bid is not None
4045+
4046+ if self ._experts is None :
4047+ self ._experts = [{} for _ in range (self .block_count )]
4048+
4049+ # Extend experts array if needed (for models where actual layers > num_hidden_layers)
4050+ while len (self ._experts ) <= bid :
4051+ self ._experts .append ({})
4052+
4053+ self ._experts [bid ][name ] = data_torch
4054+
4055+ if len (self ._experts [bid ]) >= n_experts * 3 :
4056+ tensors : list [tuple [str , Tensor ]] = []
4057+
4058+ # merge the experts into a single 3d tensor
4059+ for w_name in ["down_proj" , "gate_proj" , "up_proj" ]:
4060+ datas : list [Tensor ] = []
4061+
4062+ for xid in range (n_experts ):
4063+ ename = f"model.layers.{ bid } .mlp.experts.{ xid } .{ w_name } .weight"
4064+ datas .append (self ._experts [bid ][ename ])
4065+ del self ._experts [bid ][ename ]
4066+
4067+ data_torch = torch .stack (datas , dim = 0 )
4068+ # Generate GGUF tensor names for merged experts
4069+ if w_name == "down_proj" :
4070+ new_name = f"blk.{ bid } .ffn_down_exps.weight"
4071+ elif w_name == "gate_proj" :
4072+ new_name = f"blk.{ bid } .ffn_gate_exps.weight"
4073+ elif w_name == "up_proj" :
4074+ new_name = f"blk.{ bid } .ffn_up_exps.weight"
4075+ else :
4076+ merged_name = f"model.layers.{ bid } .mlp.experts.{ w_name } .weight"
4077+ new_name = self .map_tensor_name (merged_name )
4078+ tensors .append ((new_name , data_torch ))
4079+ return tensors
4080+ else :
4081+ return []
4082+
4083+ # Handle expert gating input (routing gate)
4084+ if ".mlp.gate.e_score_correction_bias" in name :
4085+ new_name = name .replace ("model.layers." , "blk." ).replace (
4086+ ".mlp.gate.e_score_correction_bias" , ".ffn_gate_inp.bias" # *NOTE* this is ".exp_probs_b" in mainline PR
4087+ )
4088+ return [(new_name , data_torch )]
4089+ elif ".mlp.gate.weight" in name :
4090+ new_name = name .replace ("model.layers." , "blk." ).replace (
4091+ ".mlp.gate.weight" , ".ffn_gate_inp.weight"
4092+ )
4093+ return [(new_name , data_torch )]
4094+
4095+ # Handle shared expert tensors
4096+ if ".mlp.shared_experts." in name :
4097+ new_name = name .replace ("model.layers." , "blk." ).replace (".mlp.shared_experts." , ".ffn_" )
4098+ if "gate_proj" in new_name :
4099+ new_name = new_name .replace ("gate_proj" , "gate_shexp" )
4100+ elif "down_proj" in new_name :
4101+ new_name = new_name .replace ("down_proj" , "down_shexp" )
4102+ elif "up_proj" in new_name :
4103+ new_name = new_name .replace ("up_proj" , "up_shexp" )
4104+ return [(new_name , data_torch )]
4105+
4106+ # Handle regular dense FFN layers (for hybrid dense/MoE architecture)
4107+ if ".mlp." in name and "experts" not in name and "_shexp" not in name :
4108+ if "gate_proj" in name :
4109+ new_name = name .replace ("model.layers." , "blk." ).replace (
4110+ ".mlp.gate_proj.weight" , ".ffn_gate.weight"
4111+ )
4112+ elif "up_proj" in name :
4113+ new_name = name .replace ("model.layers." , "blk." ).replace (
4114+ ".mlp.up_proj.weight" , ".ffn_up.weight"
4115+ )
4116+ elif "down_proj" in name :
4117+ new_name = name .replace ("model.layers." , "blk." ).replace (
4118+ ".mlp.down_proj.weight" , ".ffn_down.weight"
4119+ )
4120+ else :
4121+ new_name = name
4122+ return [(self .map_tensor_name (new_name ), data_torch )]
4123+
4124+ # Handle special NextN tensors - preserve for future MTP support - See https://github.com/ggml-org/llama.cpp/pull/13236
4125+ if (
4126+ ".embed_tokens." in name
4127+ or ".shared_head." in name
4128+ or ".eh_proj." in name
4129+ or ".enorm." in name
4130+ or ".hnorm." in name
4131+ ):
4132+ new_name = name .replace ("model.layers." , "blk." ).replace ("model." , "" ).replace (".weight" , "" )
4133+ # logger.debug(f"Skipping MTP tensor: {new_name}")
4134+ return [(new_name , data_torch )]
4135+
4136+ # GLM tensor mapping - handle directly without map_tensor_name
4137+ if ".input_layernorm." in name :
4138+ new_name = name .replace ("model.layers." , "blk." ).replace (".input_layernorm." , ".attn_norm." )
4139+ return [(new_name , data_torch )]
4140+ elif ".post_attention_layernorm." in name :
4141+ new_name = name .replace ("model.layers." , "blk." ).replace (".post_attention_layernorm." , ".ffn_norm." )
4142+ return [(new_name , data_torch )]
4143+ elif ".self_attn." in name :
4144+ # Map GLM self_attn to standard attention naming
4145+ new_name = name .replace ("model.layers." , "blk." ).replace (".self_attn." , ".attn_" )
4146+ if "q_proj" in new_name :
4147+ new_name = new_name .replace ("q_proj" , "q" )
4148+ elif "k_proj" in new_name :
4149+ new_name = new_name .replace ("k_proj" , "k" )
4150+ elif "v_proj" in new_name :
4151+ new_name = new_name .replace ("v_proj" , "v" )
4152+ elif "o_proj" in new_name :
4153+ new_name = new_name .replace ("o_proj" , "output" )
4154+ return [(new_name , data_torch )]
4155+
4156+ return super ().modify_tensors (data_torch , name , bid )
4157+
4158+ def prepare_tensors (self ):
4159+ super ().prepare_tensors ()
4160+ if self ._experts is not None :
4161+ # flatten `list[dict[str, Tensor]]` into `list[str]`
4162+ experts = [k for d in self ._experts for k in d .keys ()]
4163+ if len (experts ) > 0 :
4164+ raise ValueError (f"Unprocessed experts: { experts } " )
39514165
39524166@Model .register ("ChatGLMModel" , "ChatGLMForConditionalGeneration" )
39534167class ChatGLMModel (Model ):
0 commit comments