From c74d1619120075acfe3678bec2e828483987e004 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Thireus=20=E2=98=A0?= Date: Fri, 1 Aug 2025 21:06:21 +0100 Subject: [PATCH 01/13] GLM-4.5 --- convert_hf_to_gguf.py | 223 ++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 223 insertions(+) diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index a4de237e0..c27488220 100644 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -618,6 +618,9 @@ def get_vocab_base_pre(self, tokenizer) -> str: if chkhsh == "b6e8e1518dc4305be2fe39c313ed643381c4da5db34a98f6a04c093f8afbe99b": # ref: https://huggingface.co/THUDM/glm-4-9b-chat res = "chatglm-bpe" + if chkhsh == "9ca2dd618e8afaf09731a7cf6e2105b373ba6a1821559f258b272fe83e6eb902": + # ref: https://huggingface.co/zai-org/GLM-4.5-Air, https://huggingface.co/zai-org/GLM-4.5 + res = "gpt-2" if chkhsh == "7fc505bd3104ca1083b150b17d088b59534ede9bde81f0dd2090967d7fe52cee": # ref: https://huggingface.co/LumiOpen/Viking-7B res = "viking" @@ -3948,6 +3951,226 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None): return [(self.map_tensor_name(name), data_torch)] return super().modify_tensors(data_torch, name, bid) +@ModelBase.register("Glm4MoeForCausalLM") +class Glm4MoeModel(TextModel): + model_arch = gguf.MODEL_ARCH.GLM4_MOE + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + # GLM4_MOE has num_hidden_layers + 1 actual layers (including NextN layer) + self.block_count = self.hparams["num_hidden_layers"] + 1 + self.tensor_map = gguf.get_tensor_name_map(self.model_arch, self.block_count) + + def set_vocab(self): + from transformers import AutoTokenizer + + tokenizer = AutoTokenizer.from_pretrained( + self.dir_model, trust_remote_code=True + ) + special_vocab = gguf.SpecialVocab(self.dir_model, load_merges=True) + tokens, toktypes, tokpre = self.get_vocab_base() + self.gguf_writer.add_tokenizer_model("gpt2") + self.gguf_writer.add_tokenizer_pre(tokpre) + self.gguf_writer.add_token_list(tokens) + self.gguf_writer.add_token_types(toktypes) + + # Set special tokens + special_vocab._set_special_token( + "eos", tokenizer.get_added_vocab()["<|endoftext|>"] + ) + special_vocab._set_special_token("eot", tokenizer.get_added_vocab()["<|user|>"]) + special_vocab._set_special_token( + "unk", tokenizer.get_added_vocab()["<|endoftext|>"] + ) + special_vocab._set_special_token( + "bos", tokenizer.get_added_vocab()["<|endoftext|>"] + ) + special_vocab._set_special_token("eom", tokenizer.get_added_vocab()["<|observation|>"]) # 151338 + + # Fix chat template syntax error in GLM-4.5 models + if special_vocab.chat_template and isinstance(special_vocab.chat_template, str): + # Fix multiple syntax issues in GLM-4.5 chat template + template = special_vocab.chat_template + # Fix nested double quotes issue + template = template.replace('endswith("/nothink")', "endswith('/nothink')") + # Fix any other potential parentheses/tuple issues + template = template.replace( + "not visible_text(m.content).endswith('/nothink'))", + "not visible_text(m.content).endswith('/nothink')" + ) + special_vocab.chat_template = template + special_vocab.add_to_gguf(self.gguf_writer) + + def set_gguf_parameters(self): + super().set_gguf_parameters() + if (rope_dim := self.hparams.get("head_dim")) is None: + rope_dim = ( + self.hparams["hidden_size"] // self.hparams["num_attention_heads"] + ) + self.gguf_writer.add_rope_dimension_count( + int(rope_dim * self.hparams.get("partial_rotary_factor", 0.5)) + ) + + # MoE parameters + if (n_experts := self.hparams.get("n_routed_experts")) is not None: + self.gguf_writer.add_expert_count(n_experts) + # Note: expert_used_count is already set by parent class using num_experts_per_tok + if (moe_intermediate_size := self.hparams.get("moe_intermediate_size")) is not None: + self.gguf_writer.add_expert_feed_forward_length(moe_intermediate_size) + if (n_shared_experts := self.hparams.get("n_shared_experts")) is not None: + self.gguf_writer.add_expert_shared_count(n_shared_experts) + if (first_k_dense_replace := self.hparams.get("first_k_dense_replace")) is not None: + self.gguf_writer.add_leading_dense_block_count(first_k_dense_replace) + + # Expert gating function (sigmoid for GLM4_MOE) + self.gguf_writer.add_expert_gating_func(gguf.ExpertGatingFuncType.SIGMOID) + + # Routed scaling factor + if (routed_scaling_factor := self.hparams.get("routed_scaling_factor")) is not None: + self.gguf_writer.add_expert_weights_scale(routed_scaling_factor) + + # Normalise topk probabilities + if (norm_topk_prob := self.hparams.get("norm_topk_prob")) is not None: + self.gguf_writer.add_expert_weights_norm(norm_topk_prob) + + _experts: list[dict[str, Tensor]] | None = None + _shared_experts: list[dict[str, Tensor]] | None = None + + def modify_tensors( + self, data_torch: Tensor, name: str, bid: int | None + ) -> Iterable[tuple[str, Tensor]]: + if name.startswith("model.visual."): # ignore visual part + return [] + elif name.startswith("model.language_model."): + name = name.replace("language_model.", "") # for multimodal variants + + # Handle main token embedding (but not layer-specific NextN embeddings) + if name == "model.embed_tokens.weight" and ".layers." not in name: + return [(self.map_tensor_name("token_embd.weight"), data_torch)] + + # Handle routed experts + if name.find("mlp.experts") != -1 and "shared_experts" not in name: + n_experts = self.hparams["n_routed_experts"] + assert bid is not None + + if self._experts is None: + self._experts = [{} for _ in range(self.block_count)] + + # Extend experts array if needed (for models where actual layers > num_hidden_layers) + while len(self._experts) <= bid: + self._experts.append({}) + + self._experts[bid][name] = data_torch + + if len(self._experts[bid]) >= n_experts * 3: + tensors: list[tuple[str, Tensor]] = [] + + # merge the experts into a single 3d tensor + for w_name in ["down_proj", "gate_proj", "up_proj"]: + datas: list[Tensor] = [] + + for xid in range(n_experts): + ename = f"model.layers.{bid}.mlp.experts.{xid}.{w_name}.weight" + datas.append(self._experts[bid][ename]) + del self._experts[bid][ename] + + data_torch = torch.stack(datas, dim=0) + # Generate GGUF tensor names for merged experts + if w_name == "down_proj": + new_name = f"blk.{bid}.ffn_down_exps.weight" + elif w_name == "gate_proj": + new_name = f"blk.{bid}.ffn_gate_exps.weight" + elif w_name == "up_proj": + new_name = f"blk.{bid}.ffn_up_exps.weight" + else: + merged_name = f"model.layers.{bid}.mlp.experts.{w_name}.weight" + new_name = self.map_tensor_name(merged_name) + tensors.append((new_name, data_torch)) + return tensors + else: + return [] + + # Handle expert gating input (routing gate) + if ".mlp.gate.e_score_correction_bias" in name: + new_name = name.replace("model.layers.", "blk.").replace( + ".mlp.gate.e_score_correction_bias", ".ffn_gate_inp.bias" + ) + return [(new_name, data_torch)] + elif ".mlp.gate.weight" in name: + new_name = name.replace("model.layers.", "blk.").replace( + ".mlp.gate.weight", ".ffn_gate_inp.weight" + ) + return [(new_name, data_torch)] + + # Handle shared expert tensors + if ".mlp.shared_experts." in name: + new_name = name.replace("model.layers.", "blk.").replace(".mlp.shared_experts.", ".ffn_") + if "gate_proj" in new_name: + new_name = new_name.replace("gate_proj", "gate_shexp") + elif "down_proj" in new_name: + new_name = new_name.replace("down_proj", "down_shexp") + elif "up_proj" in new_name: + new_name = new_name.replace("up_proj", "up_shexp") + return [(new_name, data_torch)] + + # Handle regular dense FFN layers (for hybrid dense/MoE architecture) + if ".mlp." in name and "experts" not in name and "_shexp" not in name: + if "gate_proj" in name: + new_name = name.replace("model.layers.", "blk.").replace( + ".mlp.gate_proj.weight", ".ffn_gate.weight" + ) + elif "up_proj" in name: + new_name = name.replace("model.layers.", "blk.").replace( + ".mlp.up_proj.weight", ".ffn_up.weight" + ) + elif "down_proj" in name: + new_name = name.replace("model.layers.", "blk.").replace( + ".mlp.down_proj.weight", ".ffn_down.weight" + ) + else: + new_name = name + return [(self.map_tensor_name(new_name), data_torch)] + + # Handle special NextN tensors - preserve for future MTP support + if ( + ".embed_tokens." in name + or ".shared_head." in name + or ".eh_proj." in name + or ".enorm." in name + or ".hnorm." in name + ): + new_name = name.replace("model.layers.", "blk.").replace("model.", "").replace(".weight", "") + return [(new_name, data_torch)] + + # GLM tensor mapping - handle directly without map_tensor_name + if ".input_layernorm." in name: + new_name = name.replace("model.layers.", "blk.").replace(".input_layernorm.", ".attn_norm.") + return [(new_name, data_torch)] + elif ".post_attention_layernorm." in name: + new_name = name.replace("model.layers.", "blk.").replace(".post_attention_layernorm.", ".ffn_norm.") + return [(new_name, data_torch)] + elif ".self_attn." in name: + # Map GLM self_attn to standard attention naming + new_name = name.replace("model.layers.", "blk.").replace(".self_attn.", ".attn_") + if "q_proj" in new_name: + new_name = new_name.replace("q_proj", "q") + elif "k_proj" in new_name: + new_name = new_name.replace("k_proj", "k") + elif "v_proj" in new_name: + new_name = new_name.replace("v_proj", "v") + elif "o_proj" in new_name: + new_name = new_name.replace("o_proj", "output") + return [(new_name, data_torch)] + + return super().modify_tensors(data_torch, name, bid) + + def prepare_tensors(self): + super().prepare_tensors() + if self._experts is not None: + # flatten `list[dict[str, Tensor]]` into `list[str]` + experts = [k for d in self._experts for k in d.keys()] + if len(experts) > 0: + raise ValueError(f"Unprocessed experts: {experts}") @Model.register("ChatGLMModel", "ChatGLMForConditionalGeneration") class ChatGLMModel(Model): From 4e93d346de4a3a0ce8264641179380ab8a8a4ca7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Thireus=20=E2=98=A0?= Date: Fri, 1 Aug 2025 21:06:47 +0100 Subject: [PATCH 02/13] GLM-4.5 --- gguf-py/gguf/constants.py | 51 +++++++++++++++++++++++++++++++++++++++ 1 file changed, 51 insertions(+) diff --git a/gguf-py/gguf/constants.py b/gguf-py/gguf/constants.py index 32a667e26..92722dc32 100644 --- a/gguf-py/gguf/constants.py +++ b/gguf-py/gguf/constants.py @@ -220,6 +220,7 @@ class MODEL_ARCH(IntEnum): OPENELM = auto() ARCTIC = auto() DEEPSEEK2 = auto() + GLM4_MOE = auto() CHATGLM = auto() BITNET = auto() BITNET_25 = auto() @@ -262,6 +263,9 @@ class MODEL_TENSOR(IntEnum): FFN_GATE_EXP = auto() FFN_DOWN_EXP = auto() FFN_UP_EXP = auto() + FFN_GATE_EXPS = auto() # merged experts + FFN_DOWN_EXPS = auto() # merged experts + FFN_UP_EXPS = auto() # merged experts FFN_GATE_SHEXP = auto() FFN_DOWN_SHEXP = auto() FFN_UP_SHEXP = auto() @@ -314,6 +318,12 @@ class MODEL_TENSOR(IntEnum): ENC_FFN_DOWN = auto() ENC_FFN_UP = auto() ENC_OUTPUT_NORM = auto() + NEXTN_EH_PROJ = auto() # nextn tensors (glm4moe) + NEXTN_EMBED_TOKENS = auto() # nextn tensors (glm4moe) + NEXTN_ENORM = auto() # nextn tensors (glm4moe) + NEXTN_HNORM = auto() # nextn tensors (glm4moe) + NEXTN_SHARED_HEAD_HEAD = auto() # nextn tensors (glm4moe) + NEXTN_SHARED_HEAD_NORM = auto() # nextn tensors (glm4moe) MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = { @@ -358,6 +368,7 @@ class MODEL_TENSOR(IntEnum): MODEL_ARCH.ARCTIC: "arctic", MODEL_ARCH.DEEPSEEK2: "deepseek2", MODEL_ARCH.CHATGLM: "chatglm", + MODEL_ARCH.GLM4_MOE: "glm4moe", MODEL_ARCH.BITNET: "bitnet", MODEL_ARCH.BITNET_25: "bitnet-25", MODEL_ARCH.T5: "t5", @@ -404,6 +415,9 @@ class MODEL_TENSOR(IntEnum): MODEL_TENSOR.FFN_GATE_EXP: "blk.{bid}.ffn_gate_exps", MODEL_TENSOR.FFN_DOWN_EXP: "blk.{bid}.ffn_down_exps", MODEL_TENSOR.FFN_UP_EXP: "blk.{bid}.ffn_up_exps", + MODEL_TENSOR.FFN_GATE_EXPS: "blk.{bid}.ffn_gate_exps", # merged experts + MODEL_TENSOR.FFN_DOWN_EXPS: "blk.{bid}.ffn_down_exps", # merged experts + MODEL_TENSOR.FFN_UP_EXPS: "blk.{bid}.ffn_up_exps", # merged experts MODEL_TENSOR.FFN_EXP_PROBS_B: "blk.{bid}.exp_probs_b", MODEL_TENSOR.LAYER_OUT_NORM: "blk.{bid}.layer_output_norm", MODEL_TENSOR.SSM_IN: "blk.{bid}.ssm_in", @@ -451,6 +465,13 @@ class MODEL_TENSOR(IntEnum): MODEL_TENSOR.ENC_FFN_DOWN: "enc.blk.{bid}.ffn_down", MODEL_TENSOR.ENC_FFN_UP: "enc.blk.{bid}.ffn_up", MODEL_TENSOR.ENC_OUTPUT_NORM: "enc.output_norm", + # NextN/MTP tensors (GLM4_MOE) + MODEL_TENSOR.NEXTN_EH_PROJ: "blk.{bid}.eh_proj", + MODEL_TENSOR.NEXTN_EMBED_TOKENS: "blk.{bid}.embed_tokens", + MODEL_TENSOR.NEXTN_ENORM: "blk.{bid}.enorm", + MODEL_TENSOR.NEXTN_HNORM: "blk.{bid}.hnorm", + MODEL_TENSOR.NEXTN_SHARED_HEAD_HEAD: "blk.{bid}.shared_head.head", + MODEL_TENSOR.NEXTN_SHARED_HEAD_NORM: "blk.{bid}.shared_head.norm", } MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = { @@ -1070,6 +1091,36 @@ class MODEL_TENSOR(IntEnum): MODEL_TENSOR.FFN_DOWN, MODEL_TENSOR.FFN_UP, ], + MODEL_ARCH.GLM4_MOE: [ + MODEL_TENSOR.TOKEN_EMBD, + MODEL_TENSOR.OUTPUT_NORM, + MODEL_TENSOR.OUTPUT, + MODEL_TENSOR.ATTN_NORM, + MODEL_TENSOR.ATTN_Q, + MODEL_TENSOR.ATTN_K, + MODEL_TENSOR.ATTN_V, + MODEL_TENSOR.ATTN_OUT, + MODEL_TENSOR.ATTN_Q_NORM, + MODEL_TENSOR.ATTN_K_NORM, + MODEL_TENSOR.FFN_NORM, + MODEL_TENSOR.FFN_GATE, # dense layers + MODEL_TENSOR.FFN_DOWN, # dense layers + MODEL_TENSOR.FFN_UP, # dense layers + MODEL_TENSOR.FFN_GATE_INP, + MODEL_TENSOR.FFN_GATE_EXPS, + MODEL_TENSOR.FFN_DOWN_EXPS, + MODEL_TENSOR.FFN_UP_EXPS, + MODEL_TENSOR.FFN_GATE_SHEXP, + MODEL_TENSOR.FFN_DOWN_SHEXP, + MODEL_TENSOR.FFN_UP_SHEXP, + # NextN/MTP tensors - preserved but unused + MODEL_TENSOR.NEXTN_EH_PROJ, + MODEL_TENSOR.NEXTN_EMBED_TOKENS, + MODEL_TENSOR.NEXTN_ENORM, + MODEL_TENSOR.NEXTN_HNORM, + MODEL_TENSOR.NEXTN_SHARED_HEAD_HEAD, + MODEL_TENSOR.NEXTN_SHARED_HEAD_NORM, + ], MODEL_ARCH.BITNET: [ MODEL_TENSOR.ATTN_Q, MODEL_TENSOR.ATTN_K, From a253f2fec1d6320db691b334a5cb8342126a90ad Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Thireus=20=E2=98=A0?= Date: Fri, 1 Aug 2025 21:07:11 +0100 Subject: [PATCH 03/13] GLM-4.5 --- src/llama.cpp | 419 +++++++++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 413 insertions(+), 6 deletions(-) diff --git a/src/llama.cpp b/src/llama.cpp index 27647c9d2..5fe20a8fe 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -226,6 +226,7 @@ enum llm_arch { LLM_ARCH_DEEPSEEK2, LLM_ARCH_CHATGLM, LLM_ARCH_GLM4, + LLM_ARCH_GLM4_MOE, LLM_ARCH_BITNET, LLM_ARCH_BITNET_25, LLM_ARCH_BITNET_B158, @@ -284,6 +285,7 @@ static const std::map LLM_ARCH_NAMES = { { LLM_ARCH_DEEPSEEK2, "deepseek2" }, { LLM_ARCH_CHATGLM, "chatglm" }, { LLM_ARCH_GLM4, "glm4" }, + { LLM_ARCH_GLM4_MOE, "glm4moe" }, { LLM_ARCH_BITNET, "bitnet" }, { LLM_ARCH_BITNET_25, "bitnet-25" }, { LLM_ARCH_BITNET_B158, "bitnet-b1.58" }, @@ -609,6 +611,12 @@ enum llm_tensor { LLM_TENSOR_ENC_FFN_DOWN, LLM_TENSOR_ENC_FFN_UP, LLM_TENSOR_ENC_OUTPUT_NORM, + LLM_TENSOR_NEXTN_EH_PROJ, + LLM_TENSOR_NEXTN_EMBED_TOKENS, + LLM_TENSOR_NEXTN_ENORM, + LLM_TENSOR_NEXTN_HNORM, + LLM_TENSOR_NEXTN_SHARED_HEAD_HEAD, + LLM_TENSOR_NEXTN_SHARED_HEAD_NORM, }; static const std::map> LLM_TENSOR_NAMES = { @@ -1407,6 +1415,39 @@ static const std::map> LLM_TENSOR_NA { LLM_TENSOR_FFN_POST_NORM, "blk.%d.post_ffw_norm" }, }, }, + { + LLM_ARCH_GLM4_MOE, + { + { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, + { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, + { LLM_TENSOR_OUTPUT, "output" }, + { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, + { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" }, + { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" }, + { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" }, + { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, + { LLM_TENSOR_ATTN_Q_NORM, "blk.%d.attn_q_norm" }, + { LLM_TENSOR_ATTN_K_NORM, "blk.%d.attn_k_norm" }, + { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" }, + { LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" }, // dense layers + { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, // dense layers + { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, // dense layers + { LLM_TENSOR_FFN_GATE_INP, "blk.%d.ffn_gate_inp" }, + { LLM_TENSOR_FFN_GATE_EXPS, "blk.%d.ffn_gate_exps" }, + { LLM_TENSOR_FFN_DOWN_EXPS, "blk.%d.ffn_down_exps" }, + { LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" }, + { LLM_TENSOR_FFN_GATE_SHEXP, "blk.%d.ffn_gate_shexp" }, + { LLM_TENSOR_FFN_DOWN_SHEXP, "blk.%d.ffn_down_shexp" }, + { LLM_TENSOR_FFN_UP_SHEXP, "blk.%d.ffn_up_shexp" }, + // NextN/MTP tensors - preserved but unused (in final layer, dynamic layer number) + { LLM_TENSOR_NEXTN_EH_PROJ, "blk.%d.eh_proj" }, + { LLM_TENSOR_NEXTN_EMBED_TOKENS, "blk.%d.embed_tokens" }, + { LLM_TENSOR_NEXTN_ENORM, "blk.%d.enorm" }, + { LLM_TENSOR_NEXTN_HNORM, "blk.%d.hnorm" }, + { LLM_TENSOR_NEXTN_SHARED_HEAD_HEAD, "blk.%d.shared_head.head" }, + { LLM_TENSOR_NEXTN_SHARED_HEAD_NORM, "blk.%d.shared_head.norm" }, + }, + }, { LLM_ARCH_BITNET, { @@ -2615,6 +2656,8 @@ enum e_model { MODEL_70B, MODEL_142B, MODEL_236B, + MODEL_106B_A12B, + MODEL_355B_A32B, MODEL_314B, MODEL_405B, MODEL_671B, @@ -3511,6 +3554,26 @@ static bool llama_kv_cache_init( buft_layer_count[llama_default_buffer_type_cpu(true)] = n_layer; } + //if (cparams.fused_moe_up_gate) { + // int nbad = 0; + // for (int i = 0; i < (int) n_layer; i++) { + // auto& layer = model.layers[i]; + // if (layer.ffn_gate_exps && layer.ffn_up_exps && layer.ffn_gate_exps->type != layer.ffn_up_exps->type) { + // ++nbad; + // } + // } + // if (nbad > 0) { + // if (nbad == (int)n_layer) { + // LLAMA_LOG_WARN("=============== ffn_up and ffn_gate are of different type => disabling fmoe\n"); + // const_cast(cparams).fused_moe_up_gate = false; + // } + // else { + // LLAMA_LOG_WARN("=============== ffn_up and ffn_gate are of different in %d out of %d layers, where fmoe will be disabled\n", + // nbad, (int)n_layer); + // } + // } + //} + // create a context for each buffer type std::map ctx_map; for (auto & it : buft_layer_count) { @@ -5272,6 +5335,8 @@ static const char * llama_model_type_name(e_model type) { case MODEL_70B: return "70B"; case MODEL_142B: return "142B"; case MODEL_236B: return "236B"; + case MODEL_106B_A12B: return "106B.A12B"; + case MODEL_355B_A32B: return "355B.A32B"; case MODEL_314B: return "314B"; case MODEL_405B: return "405B"; case MODEL_671B: return "671B"; @@ -6027,6 +6092,31 @@ static void llm_load_hparams( default: model.type = e_model::MODEL_UNKNOWN; } } break; + case LLM_ARCH_GLM4_MOE: + { + ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp, false); + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + + // MoE parameters + ml.get_key(LLM_KV_EXPERT_COUNT, hparams.n_expert, 0); + ml.get_key(LLM_KV_EXPERT_USED_COUNT, hparams.n_expert_used, 0); + ml.get_key(LLM_KV_EXPERT_SHARED_COUNT, hparams.n_expert_shared, 0); + ml.get_key(LLM_KV_LEADING_DENSE_BLOCK_COUNT, hparams.n_layer_dense_lead, 0); + ml.get_key(LLM_KV_EXPERT_WEIGHTS_SCALE, hparams.expert_weights_scale); + ml.get_key(LLM_KV_EXPERT_WEIGHTS_NORM, hparams.expert_weights_norm, false); + + // Expert gating function (GLM4_MOE uses sigmoid) + ml.get_key(LLM_KV_EXPERT_GATING_FUNC, hparams.expert_gating_func, false); + if (hparams.expert_gating_func == 0) { + hparams.expert_gating_func = LLM_EXPERT_GATING_FUNC_SIGMOID; + } + + switch (hparams.n_layer) { + case 47: model.type = e_model::MODEL_106B_A12B; break; // GLM-4.5-Air (46 layers + 1 NextN layer) + case 93: model.type = e_model::MODEL_355B_A32B; break; // GLM-4.5 (92 layers + 1 NextN layer) + default: model.type = e_model::MODEL_UNKNOWN; + } + } break; case LLM_ARCH_BITNET: { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); @@ -8927,6 +9017,135 @@ static bool llm_load_tensors( } } } break; + case LLM_ARCH_GLM4_MOE: + { + const int64_t n_expert = hparams.n_expert; + const int64_t n_expert_used = hparams.n_expert_used; + const int64_t n_expert_shared = hparams.n_expert_shared; + + model.tok_embd = create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + // output + { + model.output_norm = create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + model.output = create_tensor(ctx_output_split, tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_NOT_REQUIRED); + } + // if output is NULL, init from the input tok embed + if (model.output == NULL) { + model.output = create_tensor(ctx_output, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_DUPLICATED); + } + + // --- NextN / MTP tensors (preserved but unused), on the final layer --- + { + const int final_layer = n_layer - 1; + // EH_PROJ: [2*embd, embd] + create_tensor(ctx_for_layer(final_layer), + tn(LLM_TENSOR_NEXTN_EH_PROJ, final_layer), + { 2*n_embd, n_embd }, + llama_model_loader::TENSOR_NOT_REQUIRED); + // EMBED_TOKENS: [embd, vocab] + create_tensor(ctx_for_layer(final_layer), + tn(LLM_TENSOR_NEXTN_EMBED_TOKENS, final_layer), + { n_embd, n_vocab }, + llama_model_loader::TENSOR_NOT_REQUIRED); + // ENORM, HNORM: [embd] + create_tensor(ctx_for_layer(final_layer), + tn(LLM_TENSOR_NEXTN_ENORM, final_layer), + { n_embd }, + llama_model_loader::TENSOR_NOT_REQUIRED); + create_tensor(ctx_for_layer(final_layer), + tn(LLM_TENSOR_NEXTN_HNORM, final_layer), + { n_embd }, + llama_model_loader::TENSOR_NOT_REQUIRED); + // SHARED_HEAD_HEAD: [embd, vocab] + create_tensor(ctx_for_layer(final_layer), + tn(LLM_TENSOR_NEXTN_SHARED_HEAD_HEAD, final_layer), + { n_embd, n_vocab }, + llama_model_loader::TENSOR_NOT_REQUIRED); + // SHARED_HEAD_NORM: [embd] + create_tensor(ctx_for_layer(final_layer), + tn(LLM_TENSOR_NEXTN_SHARED_HEAD_NORM, final_layer), + { n_embd }, + llama_model_loader::TENSOR_NOT_REQUIRED); + } + + for (int i = 0; i < n_layer; ++i) { + ggml_context * ctx_layer = ctx_for_layer(i); + ggml_context * ctx_split = ctx_for_layer_split(i); + + auto & layer = model.layers[i]; + + layer.attn_norm = create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + + // GLM-style attention with bias terms + layer.wq = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_Q, "weight", i), { n_embd, n_embd_head_k * n_head }, 0); + layer.wk = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_K, "weight", i), { n_embd, n_embd_k_gqa }, 0); + layer.wv = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_V, "weight", i), { n_embd, n_embd_v_gqa }, 0); + layer.bq = create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_Q, "bias", i), { n_embd_head_k * n_head }, llama_model_loader::TENSOR_NOT_REQUIRED); + layer.bk = create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_K, "bias", i), { n_embd_k_gqa }, llama_model_loader::TENSOR_NOT_REQUIRED); + layer.bv = create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_V, "bias", i), { n_embd_v_gqa }, llama_model_loader::TENSOR_NOT_REQUIRED); + + layer.wo = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), { n_embd_head_k * n_head, n_embd }, 0); + + // K/Q norm tensors (optional for GLM-4.5 355B variant) + layer.attn_q_norm = create_tensor(ctx_layer, + tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), { n_embd_head_k }, llama_model_loader::TENSOR_NOT_REQUIRED); + layer.attn_k_norm = create_tensor(ctx_layer, + tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), { n_embd_head_k }, llama_model_loader::TENSOR_NOT_REQUIRED); + + layer.ffn_norm = create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "weight", i), { n_embd }, 0); + + // Check if this layer uses MoE or dense FFN based on n_layer_dense_lead + // GLM 4.5 uses hybrid architecture: layer 0 is dense, layers 1+ are MoE + const bool use_moe = + (hparams.n_expert > 0) && (static_cast(i) >= hparams.n_layer_dense_lead); + + if (use_moe) { + // MoE layers + layer.ffn_gate_inp = + create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), { n_embd, n_expert }, 0); + // gate bias + layer.ffn_exp_probs_b = + create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_GATE_INP, "bias", i), { n_expert }, + llama_model_loader::TENSOR_NOT_REQUIRED); + + if (n_expert == 0) { + GGML_ASSERT(hparams.n_expert > 0 && "n_expert must be > 0 for GLM4_MOE MoE layers"); + } + if (n_expert_used == 0) { + GGML_ASSERT(hparams.n_expert_used > 0 && + "n_expert_used must be > 0 for GLM4_MOE MoE layers"); + } + + // MoE branch + const int64_t n_ff_exp = hparams.n_ff_exp ? hparams.n_ff_exp : n_ff / n_expert_used; + + layer.ffn_gate_exps = create_tensor(ctx_split, + tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert }, 0); + layer.ffn_down_exps = create_tensor(ctx_split, + tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), { n_ff_exp, n_embd, n_expert }, 0); + layer.ffn_up_exps = create_tensor(ctx_split, + tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert }, 0); + + // Shared expert + if (n_expert_shared > 0) { + const int64_t n_ff_shexp = n_ff_exp * n_expert_shared; + layer.ffn_gate_shexp = create_tensor(ctx_split, + tn(LLM_TENSOR_FFN_GATE_SHEXP, "weight", i), { n_embd, n_ff_shexp }, 0); + layer.ffn_down_shexp = create_tensor(ctx_split, + tn(LLM_TENSOR_FFN_DOWN_SHEXP, "weight", i), { n_ff_shexp, n_embd }, 0); + layer.ffn_up_shexp = + create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP_SHEXP, "weight", i), { n_embd, n_ff_shexp }, 0); + } + } else { + // Dense layers (first k layers) - GLM uses separate gate/up projections + layer.ffn_gate = create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE, "weight", i), { n_embd, n_ff }, 0); + layer.ffn_down = create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd }, 0); + layer.ffn_up = create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP, "weight", i), { n_embd, n_ff }, 0); + } + } + } + break; case LLM_ARCH_BITNET: { model.tok_embd = create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}); @@ -10065,7 +10284,7 @@ llm_expert_gating_func_type gating_op, } ggml_tensor * par; - if (lctx.cparams.fused_moe_up_gate) { + if (lctx.cparams.fused_moe_up_gate && up_exps->type == gate_exps->type) { par = ggml_moe_up_gate(ctx, up_exps, gate_exps, cur, selected_experts, type_op == LLM_FFN_SILU ? GGML_UNARY_OP_SILU : GGML_UNARY_OP_GELU); } else { ggml_tensor * up = llm_build_lora_mm_id(lctx, ctx, up_exps, cur, selected_experts); // [n_ff, n_expert_used, n_tokens] @@ -10186,7 +10405,7 @@ static struct ggml_tensor * llm_build_kqv( // For DeepSeek-2, it is perfectly fine with fp16 for PP, but I get gibberish when uding fp16 for TG. // Not sure if it is really a matter of insufficient precision, or I have made a mistake in the fattn-vec-f16 kernel. if (use_f32_precision || model.arch == LLM_ARCH_PHI2 || model.arch == LLM_ARCH_PHI3 || model.arch == LLM_ARCH_GPTNEOX || - (model.arch == LLM_ARCH_DEEPSEEK2 && q->ne[1] <= 8) || model.arch == LLM_ARCH_COHERE2 || model.arch == LLM_ARCH_GLM4) { + (model.arch == LLM_ARCH_DEEPSEEK2 && q->ne[1] <= 8) || model.arch == LLM_ARCH_COHERE2 || model.arch == LLM_ARCH_GLM4 || model.arch == LLM_ARCH_GLM4_MOE) { ggml_flash_attn_ext_set_prec(cur, GGML_PREC_F32); } //ggml_flash_attn_ext_set_prec(cur, GGML_PREC_F32); @@ -10211,7 +10430,7 @@ static struct ggml_tensor * llm_build_kqv( //ggml_mul_mat_set_prec(kq, GGML_PREC_F32); if (use_f32_precision || model.arch == LLM_ARCH_PHI2 || model.arch == LLM_ARCH_PHI3 || model.arch == LLM_ARCH_GPTNEOX || model.arch == LLM_ARCH_QWEN2 || - model.arch == LLM_ARCH_COHERE2 || model.arch == LLM_ARCH_GLM4) { + model.arch == LLM_ARCH_COHERE2 || model.arch == LLM_ARCH_GLM4 || model.arch == LLM_ARCH_GLM4_MOE) { // for this arch, we need to perform the KQ multiplication with F32 precision, otherwise we get NaNs // ref: https://github.com/ggerganov/llama.cpp/pull/4490#issuecomment-1859055847 ggml_mul_mat_set_prec(kq, GGML_PREC_F32); @@ -10271,7 +10490,7 @@ static struct ggml_tensor * llm_build_kqv( auto q_i = ggml_view_3d(ctx, q, q->ne[0], q->ne[1], this_ne12, q->nb[1], q->nb[2], q->nb[2]*i12); auto kq_i = ggml_mul_mat(ctx, k_i, q_i); if (model.arch == LLM_ARCH_PHI2 || model.arch == LLM_ARCH_PHI3 || model.arch == LLM_ARCH_GPTNEOX || model.arch == LLM_ARCH_QWEN2 || - model.arch == LLM_ARCH_COHERE2 || model.arch == LLM_ARCH_GLM4) { + model.arch == LLM_ARCH_COHERE2 || model.arch == LLM_ARCH_GLM4 || model.arch == LLM_ARCH_GLM4_MOE) { ggml_mul_mat_set_prec(kq_i, GGML_PREC_F32); } if (model.arch == LLM_ARCH_GROK) { @@ -15978,6 +16197,179 @@ struct llm_build_context { return gf; } + struct ggml_cgraph * build_glm4_moe() { + // create a new graph + struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false); + + const int64_t n_embd_head = hparams.n_embd_head_v; + GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); + + struct ggml_tensor * cur; + struct ggml_tensor * inpL; + + // input embeddings + inpL = llm_build_inp_embd(ctx0, lctx, hparams, batch, model.tok_embd, cb); + + // position embeddings + struct ggml_tensor * inp_pos = build_inp_pos(); + + // attention KV cache input + //auto * inp_attn = build_attn_inp_kv_unified(); + + struct ggml_tensor * KQ_mask = build_inp_KQ_mask(); + + // output token IDs (for last layer cropping) + struct ggml_tensor * inp_out_ids = build_inp_out_ids(); + + for (int il = 0; il < n_layer; ++il) { + struct ggml_tensor * inpSA = inpL; + + // Pre-attention norm + cur = llm_build_norm(ctx0, inpL, hparams, + model.layers[il].attn_norm, NULL, + LLM_NORM_RMS, cb, il); + cb(cur, "attn_norm", il); + + // self-attention + { + // Q, K, V projections + struct ggml_tensor * Qcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wq, cur); + if (model.layers[il].bq) { + Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq); + } + cb(Qcur, "Qcur", il); + + struct ggml_tensor * Kcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wk, cur); + if (model.layers[il].bk) { + Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk); + } + cb(Kcur, "Kcur", il); + + struct ggml_tensor * Vcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wv, cur); + if (model.layers[il].bv) { + Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv); + } + cb(Vcur, "Vcur", il); + + // reshape for multi-head + Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); + Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); + // Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + + // Apply Q/K norm if available (GLM-4.5 355B variant) + if (model.layers[il].attn_q_norm) { + Qcur = llm_build_norm(ctx0, Qcur, hparams, + model.layers[il].attn_q_norm, NULL, + LLM_NORM_RMS, cb, il); + cb(Qcur, "Qcur_normed", il); + } + if (model.layers[il].attn_k_norm) { + Kcur = llm_build_norm(ctx0, Kcur, hparams, + model.layers[il].attn_k_norm, NULL, + LLM_NORM_RMS, cb, il); + cb(Kcur, "Kcur_normed", il); + } + + // apply RoPE + Qcur = ggml_rope_ext(ctx0, Qcur, inp_pos, nullptr, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow); + Kcur = ggml_rope_ext(ctx0, Kcur, inp_pos, nullptr, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow); + cb(Qcur, "Qcur", il); + cb(Kcur, "Kcur", il); + cb(Vcur, "Vcur", il); + + // build attention KV (no unified cache) + cur = llm_build_kv(ctx0, lctx, kv_self, gf, + model.layers[il].wo, model.layers[il].bo, + Kcur, Vcur, Qcur, KQ_mask, + n_tokens, kv_head, n_kv, + 1.0f/sqrtf(float(n_embd_head)), cb, il); + } + + // crop output on last layer + if (il == n_layer - 1) { + // skip computing output for unused tokens + ggml_tensor * inp_out_ids = build_inp_out_ids(); + cur = ggml_get_rows(ctx0, cur, inp_out_ids); + inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids); + } + + // residual connection for attention output + struct ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA); + cb(ffn_inp, "ffn_inp", il); + + // FFN / MoE + cur = llm_build_norm(ctx0, ffn_inp, hparams, + model.layers[il].ffn_norm, NULL, + LLM_NORM_RMS, cb, il); + cb(cur, "ffn_norm", il); + + if ((uint32_t) il < hparams.n_layer_dense_lead) { + // dense FFN + cur = llm_build_ffn(ctx0, lctx, cur, + model.layers[il].ffn_up, NULL, NULL, + model.layers[il].ffn_gate, NULL, NULL, + model.layers[il].ffn_down, NULL, NULL, + NULL, + LLM_FFN_SILU, LLM_FFN_PAR, cb, il); + cb(cur, "ffn_out", il); + } else { + // MoE FFN + struct ggml_tensor * moe_out = llm_build_moe_ffn(ctx0, lctx, cur, + model.layers[il].ffn_gate_inp, + model.layers[il].ffn_up_exps, + model.layers[il].ffn_gate_exps, + model.layers[il].ffn_down_exps, + model.layers[il].ffn_exp_probs_b, + n_expert, n_expert_used, + LLM_FFN_SILU, hparams.expert_weights_norm, + true, hparams.expert_weights_scale, + (enum llm_expert_gating_func_type) hparams.expert_gating_func, + cb, il); + cb(moe_out, "ffn_moe_out", il); + + { + struct ggml_tensor * shexp_out = llm_build_ffn(ctx0, lctx, cur, + model.layers[il].ffn_up_shexp, NULL, NULL, + model.layers[il].ffn_gate_shexp, NULL, NULL, + model.layers[il].ffn_down_shexp, NULL, NULL, + NULL, + LLM_FFN_SILU, LLM_FFN_PAR, cb, il); + cb(shexp_out, "ffn_shexp_out", il); + + cur = ggml_add(ctx0, moe_out, shexp_out); + cb(cur, "ffn_out", il); + } + } + + // residual and context vector + cur = ggml_add(ctx0, cur, ffn_inp); + cur = lctx.cvec.apply_to(ctx0, cur, il); + cb(cur, "l_out", il); + + // prepare next layer input + inpL = cur; + } + + cur = inpL; + + // final norm + cur = llm_build_norm(ctx0, cur, hparams, + model.output_norm, NULL, + LLM_NORM_RMS, cb, -1); + cb(cur, "result_norm", -1); + + // lm head + cur = llm_build_lora_mm(lctx, ctx0, model.output, cur); + cb(cur, "result_output", -1); + + ggml_build_forward_expand(gf, cur); + return gf; + } + struct ggml_cgraph * build_bitnet() { struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false); @@ -17655,6 +18047,10 @@ static struct ggml_cgraph * llama_build_graph( { result = llm.build_glm4(); } break; + case LLM_ARCH_GLM4_MOE: + { + result = llm.build_glm4_moe(); + } break; case LLM_ARCH_BITNET: { result = llm.build_bitnet(); @@ -20134,8 +20530,18 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s // - qs.n_attention_wv == 3 * model.hparams.n_layer for Encoder-Decoder models // - model.arch == LLM_ARCH_DECI for Deci-Nemotron models // - GGML_ASSERT((qs.n_attention_wv == 0 || qs.n_attention_wv == (int)model.hparams.n_layer || qs.n_attention_wv == 3 * (int)model.hparams.n_layer || model.arch == LLM_ARCH_DECI) && "n_attention_wv is unexpected"); - + //GGML_ASSERT((qs.n_attention_wv == 0 || qs.n_attention_wv == (int)model.hparams.n_layer || qs.n_attention_wv == 3 * (int)model.hparams.n_layer || model.arch == LLM_ARCH_DECI) && "n_attention_wv is unexpected"); + // allow any count for GLM4-MoE, but still enforce for all others + if (model.arch != LLM_ARCH_GLM4_MOE) { + GGML_ASSERT( + qs.n_attention_wv == 0 + || qs.n_attention_wv == (int)model.hparams.n_layer + || qs.n_attention_wv == 3 * (int)model.hparams.n_layer + || model.arch == LLM_ARCH_DECI + && "n_attention_wv is unexpected" + ); + } + size_t total_size_org = 0; size_t total_size_new = 0; @@ -21459,6 +21865,7 @@ enum llama_rope_type llama_rope_type(const struct llama_model * model) { case LLM_ARCH_BERT: case LLM_ARCH_NOMIC_BERT: case LLM_ARCH_STABLELM: + case LLM_ARCH_GLM4_MOE: case LLM_ARCH_BITNET: case LLM_ARCH_BITNET_25: case LLM_ARCH_BITNET_B158: From 3c068900e09c43e8a93112e181fd9534115050a9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Thireus=20=E2=98=A0?= Date: Sat, 2 Aug 2025 06:42:12 +0100 Subject: [PATCH 04/13] convert_hf_to_gguf.py compatibility bugfix with GLM-4.5 From @ubergarm - https://github.com/ikawrakow/ik_llama.cpp/pull/668#issuecomment-3145913701 --- convert_hf_to_gguf.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index c27488220..3e4685072 100644 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -618,6 +618,9 @@ def get_vocab_base_pre(self, tokenizer) -> str: if chkhsh == "b6e8e1518dc4305be2fe39c313ed643381c4da5db34a98f6a04c093f8afbe99b": # ref: https://huggingface.co/THUDM/glm-4-9b-chat res = "chatglm-bpe" + if chkhsh == "a1336059768a55c99a734006ffb02203cd450fed003e9a71886c88acf24fdbc2": + # ref: https://huggingface.co/THUDM/glm-4-9b-hf + res = "glm4" if chkhsh == "9ca2dd618e8afaf09731a7cf6e2105b373ba6a1821559f258b272fe83e6eb902": # ref: https://huggingface.co/zai-org/GLM-4.5-Air, https://huggingface.co/zai-org/GLM-4.5 res = "gpt-2" @@ -3951,8 +3954,8 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None): return [(self.map_tensor_name(name), data_torch)] return super().modify_tensors(data_torch, name, bid) -@ModelBase.register("Glm4MoeForCausalLM") -class Glm4MoeModel(TextModel): +@Model.register("Glm4MoeForCausalLM") +class Glm4MoeModel(Model): model_arch = gguf.MODEL_ARCH.GLM4_MOE def __init__(self, *args, **kwargs): From d3d3fe626e6216f2d9d7cee629e14ea0146e8cc0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Thireus=20=E2=98=A0?= Date: Sat, 2 Aug 2025 06:57:02 +0100 Subject: [PATCH 05/13] Add ubergarm comments + my own --- convert_hf_to_gguf.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index 3e4685072..be2eb1702 100644 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -4096,7 +4096,7 @@ def modify_tensors( # Handle expert gating input (routing gate) if ".mlp.gate.e_score_correction_bias" in name: new_name = name.replace("model.layers.", "blk.").replace( - ".mlp.gate.e_score_correction_bias", ".ffn_gate_inp.bias" + ".mlp.gate.e_score_correction_bias", ".ffn_gate_inp.bias" # *NOTE* this is ".exp_probs_b" in mainline PR ) return [(new_name, data_torch)] elif ".mlp.gate.weight" in name: @@ -4134,7 +4134,7 @@ def modify_tensors( new_name = name return [(self.map_tensor_name(new_name), data_torch)] - # Handle special NextN tensors - preserve for future MTP support + # Handle special NextN tensors - preserve for future MTP support - See https://github.com/ggml-org/llama.cpp/pull/13236 if ( ".embed_tokens." in name or ".shared_head." in name @@ -4143,6 +4143,7 @@ def modify_tensors( or ".hnorm." in name ): new_name = name.replace("model.layers.", "blk.").replace("model.", "").replace(".weight", "") + # logger.debug(f"Skipping MTP tensor: {new_name}") return [(new_name, data_torch)] # GLM tensor mapping - handle directly without map_tensor_name From 0a4cb10d538350b072f6b653c11f496ecb89407f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Thireus=20=E2=98=A0?= Date: Sun, 3 Aug 2025 08:47:21 +0100 Subject: [PATCH 06/13] Revert to llama.cpp script version that produced good BF16 See: https://github.com/ikawrakow/ik_llama.cpp/pull/668#issuecomment-3147374559 --- convert_hf_to_gguf.py | 17 ++--------------- 1 file changed, 2 insertions(+), 15 deletions(-) diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index be2eb1702..464973cc8 100644 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -3988,20 +3988,7 @@ def set_vocab(self): special_vocab._set_special_token( "bos", tokenizer.get_added_vocab()["<|endoftext|>"] ) - special_vocab._set_special_token("eom", tokenizer.get_added_vocab()["<|observation|>"]) # 151338 - - # Fix chat template syntax error in GLM-4.5 models - if special_vocab.chat_template and isinstance(special_vocab.chat_template, str): - # Fix multiple syntax issues in GLM-4.5 chat template - template = special_vocab.chat_template - # Fix nested double quotes issue - template = template.replace('endswith("/nothink")', "endswith('/nothink')") - # Fix any other potential parentheses/tuple issues - template = template.replace( - "not visible_text(m.content).endswith('/nothink'))", - "not visible_text(m.content).endswith('/nothink')" - ) - special_vocab.chat_template = template + special_vocab.add_to_gguf(self.gguf_writer) def set_gguf_parameters(self): @@ -4048,7 +4035,7 @@ def modify_tensors( name = name.replace("language_model.", "") # for multimodal variants # Handle main token embedding (but not layer-specific NextN embeddings) - if name == "model.embed_tokens.weight" and ".layers." not in name: + if name == "model.embed_tokens.weight": return [(self.map_tensor_name("token_embd.weight"), data_torch)] # Handle routed experts From 292300d251eb08a751a650ded7b4245d36abbbdf Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Thireus=20=E2=98=A0?= Date: Sun, 3 Aug 2025 08:48:00 +0100 Subject: [PATCH 07/13] Support for jinja chat templates See https://github.com/ikawrakow/ik_llama.cpp/pull/668#issuecomment-3148109962 --- gguf-py/gguf/vocab.py | 395 +++++++++++++++++++++++++++++++++++++++++- 1 file changed, 388 insertions(+), 7 deletions(-) diff --git a/gguf-py/gguf/vocab.py b/gguf-py/gguf/vocab.py index cca097986..e1d5aaf47 100644 --- a/gguf-py/gguf/vocab.py +++ b/gguf-py/gguf/vocab.py @@ -1,5 +1,6 @@ from __future__ import annotations +from enum import Enum import re import logging import json @@ -7,7 +8,29 @@ from pathlib import Path from typing import Any, Callable, Sequence, Mapping, Iterable, Protocol, ClassVar, runtime_checkable -from sentencepiece import SentencePieceProcessor +try: + from sentencepiece import SentencePieceProcessor +except ImportError: + SentencePieceProcessor = None + +try: + from mistral_common.tokens.tokenizers.mistral import MistralTokenizer + from mistral_common.tokens.tokenizers.tekken import Tekkenizer + from mistral_common.tokens.tokenizers.utils import ( + _filter_valid_tokenizer_files, + ) + from mistral_common.tokens.tokenizers.sentencepiece import ( + SentencePieceTokenizer, + ) +except ImportError: + _mistral_common_installed = False + MistralTokenizer = None + Tekkenizer = None + SentencePieceTokenizer = None + _filter_valid_tokenizer_files = None +else: + _mistral_common_installed = True + import gguf @@ -116,6 +139,7 @@ def _set_special_token(self, typ: str, tid: Any) -> None: logger.warning(f'Special token type {typ}, id {tid} out of range, must be under {self.n_vocab} - skipping') def _try_load_from_tokenizer_json(self, path: Path) -> bool: + tokenizer = None tokenizer_file = path / 'tokenizer.json' if tokenizer_file.is_file(): with open(tokenizer_file, encoding = 'utf-8') as f: @@ -149,15 +173,110 @@ def _try_load_from_tokenizer_json(self, path: Path) -> bool: added_tokens = tokenizer.get('added_tokens', {}) else: added_tokens = {} + tokenizer_config = None tokenizer_config_file = path / 'tokenizer_config.json' - if not tokenizer_config_file.is_file(): + if tokenizer_config_file.is_file(): + with open(tokenizer_config_file, encoding = 'utf-8') as f: + tokenizer_config = json.load(f) + if tokenizer: + special_bos = (tokenizer_config or {}).get('bos_token') + special_cls = (tokenizer_config or {}).get('cls_token') + special_eos = (tokenizer_config or {}).get('eos_token') + special_sep = (tokenizer_config or {}).get('sep_token') + if not special_bos and special_cls and tokenizer_config: + tokenizer_config['bos_token'] = special_bos = special_cls + if not special_eos and special_sep and tokenizer_config: + tokenizer_config['eos_token'] = special_eos = special_sep + if post_processor := tokenizer.get('post_processor'): + for processor in post_processor.get('processors', [post_processor]): + if processor.get('type') == 'RobertaProcessing': + self.add_special_token['bos'] = True + self.add_special_token['eos'] = True + self.add_special_token['sep'] = True + if not special_cls and tokenizer_config: + special_cls = processor.get('cls', [special_bos])[0] + tokenizer_config['cls_token'] = special_cls + if not special_sep and tokenizer_config: + special_sep = processor.get('sep', [special_eos])[0] + tokenizer_config['sep_token'] = special_sep + continue + # Crude parsing of TemplateProcessing to determine if BOS/SEP/EOS should be added + # Only works with simple templates, **will** get it wrong on unusual sequences + if processor.get('type') == 'TemplateProcessing': + tmpl_single = processor.get('single', []) + tmpl_pair = processor.get('pair', []) + special_first = None + special_last = None + if len(tmpl_single) > 1: + if special_first := tmpl_single[0].get('SpecialToken', {}).get('id'): + if not tokenizer_config: + special_bos = special_first + self.add_special_token['bos'] = True if special_first in (special_bos, special_cls) else False + if special_first not in (special_bos, special_cls): + logger.warning(f'Unknown leading special token {special_first!r} in TemplateProcessing') + if special_last := tmpl_single[-1].get('SpecialToken', {}).get('id'): + if not tokenizer_config: + special_eos = special_last + elif special_last != special_eos: + if 'eot' not in self.special_token_types: + self.special_token_types = tuple(self.special_token_types) + ('eot', ) + tokenizer_config['eot_token'] = special_eos + elif 'eom' not in self.special_token_types: + self.special_token_types = tuple(self.special_token_types) + ('eom', ) + tokenizer_config['eom_token'] = special_eos + else: + logger.warning(f'Overriding EOS token {special_eos!r} with {special_last!r} without EOT/EOM fallback!') + tokenizer_config['eos_token'] = special_eos = special_last + self.add_special_token['eos'] = True if special_last == special_eos else False + if special_last != special_eos: + logger.warning(f'Unknown trailing special token {special_last!r} in TemplateProcessing') + if tmpl_pair: + seq_start = 1 if special_first and tmpl_pair[0].get('SpecialToken', {}).get('id') == special_first else 0 + seq_stop = -1 if special_last and tmpl_pair[-1].get('SpecialToken', {}).get('id') == special_last else None + if (special_first and seq_start == 0) or (special_last and seq_stop is None): + logger.warning('TemplateProcessing leading/trailing special tokens do not match TemplateProcessing') + if tmpl_pair := tmpl_pair[slice(seq_start, seq_stop)]: + tmpl_a = tmpl_pair[0].get('Sequence', {}).get('id') + tmpl_b = tmpl_pair[-1].get('Sequence', {}).get('id') + if tmpl_a != 'A' or tmpl_b != 'B': + logger.warning(f'Unknown sequence {tmpl_a}...{tmpl_b} in TemplateProcessing') + # A [sep] [eos] B + if tmpl_a == 'A' and tmpl_b == 'B' and (tmpl_pair := tmpl_pair[1:-1]): + add_sep = False + if special_entry := tmpl_pair[0].get('SpecialToken', {}).get('id'): + if special_entry in (special_sep, special_eos) and not special_last: + add_sep = True + if special_entry not in (special_sep, special_eos): + logger.warning(f'Unknown separator token {special_entry!r} in TemplateProcessing') + else: + logger.warning(f'Unknown middle sequence {tmpl_pair[0]!r} in TemplateProcessing') + if len(tmpl_pair) == 2: + if special_entry := tmpl_pair[1].get('SpecialToken', {}).get('id'): + if special_entry in (special_sep, special_eos): + add_sep = True + if special_entry not in (special_sep, special_eos): + logger.warning(f'Unknown second separator token {special_entry!r} in TemplateProcessing') + else: + logger.warning(f'Unknown second middle sequence {tmpl_pair[1]!r} in TemplateProcessing') + self.add_special_token['sep'] = add_sep + if add_sep and not special_sep and tokenizer_config: + tokenizer_config['sep_token'] = special_eos + continue + if not tokenizer_config: return True - with open(tokenizer_config_file, encoding = 'utf-8') as f: - tokenizer_config = json.load(f) chat_template_alt = None - chat_template_file = path / 'chat_template.json' - if chat_template_file.is_file(): - with open(chat_template_file, encoding = 'utf-8') as f: + chat_template_json = path / 'chat_template.json' + chat_template_jinja = path / 'chat_template.jinja' + if chat_template_jinja.is_file(): + with open(chat_template_jinja, encoding = 'utf-8') as f: + chat_template_alt = f.read() + if additional_templates := list((path / 'additional_chat_templates').glob('*.jinja')): + chat_template_alt = [{'name': 'default', 'template': chat_template_alt}] + for template_path in additional_templates: + with open(template_path, encoding = 'utf-8') as fp: + chat_template_alt.append({'name': template_path.stem, 'template': fp.read()}) + elif chat_template_json.is_file(): + with open(chat_template_json, encoding = 'utf-8') as f: chat_template_alt = json.load(f).get('chat_template') chat_template = tokenizer_config.get('chat_template', chat_template_alt) if chat_template is None or isinstance(chat_template, (str, list)): @@ -302,6 +421,9 @@ class SentencePieceVocab(Vocab): name = "spm" def __init__(self, base_path: Path): + if SentencePieceProcessor is None: + raise RuntimeError("sentencepiece is not installed") + added_tokens: dict[str, int] = {} if (fname_tokenizer := base_path / 'tokenizer.model').exists(): # normal location @@ -490,3 +612,262 @@ def all_tokens(self) -> Iterable[tuple[bytes, float, gguf.TokenType]]: def __repr__(self) -> str: return f"" + + +class MistralTokenizerType(str, Enum): + spm = "spm" + tekken = "tekken" + + +# Copied from Transformers (Apache 2.0) +# https://github.com/huggingface/transformers/blob/main/src/transformers/convert_slow_tokenizer.py#L1544 + +def bytes_to_unicode() -> dict[int, str]: + """ + Returns list of utf-8 byte and a mapping to unicode strings. We specifically avoids mapping to whitespace/control + characters the bpe code barfs on. + + The reversible bpe codes work on unicode strings. This means you need a large # of unicode characters in your vocab + if you want to avoid UNKs. When you're at something like a 10B token dataset you end up needing around 5K for + decent coverage. This is a significant percentage of your normal, say, 32K bpe vocab. To avoid that, we want lookup + tables between utf-8 bytes and unicode strings. + """ + bs = ( + list(range(ord("!"), ord("~") + 1)) + + list(range(ord("¡"), ord("¬") + 1)) + + list(range(ord("®"), ord("ÿ") + 1)) + ) + cs = bs[:] + n = 0 + for b in range(2**8): + if b not in bs: + bs.append(b) + cs.append(2**8 + n) + n += 1 + cs_str = [chr(n) for n in cs] + return dict(zip(bs, cs_str)) + + +class MistralVocab(Vocab): + tokenizer_model = "mistral" + name = "mistral" + + added_tokens_dict: dict[str, int] = {} + added_tokens_list: list[str] = [] + + def __init__(self, base_path: Path): + if not _mistral_common_installed: + raise ImportError( + "To use MistralVocab, please install the `mistral-common` package. " + "You can install it with `pip install mistral-common`." + ) + assert _filter_valid_tokenizer_files is not None, "mistral_common is not installed" + assert MistralTokenizer is not None, "mistral_common is not installed" + assert Tekkenizer is not None, "mistral_common is not installed" + + logger.info(f"Loading Mistral tokenizer from {base_path}") + + # Find the tokenizer files + all_files = [f.as_posix() for f in base_path.glob("**/*") if f.is_file()] + valid_tokenizer_files = _filter_valid_tokenizer_files(all_files) + + if len(valid_tokenizer_files) == 0: + raise ValueError(f"No tokenizer file found in the directory: {base_path}") + # If there are multiple tokenizer files, we use tekken.json if it exists, otherwise the versioned one. + if len(valid_tokenizer_files) > 1: + if "tekken.json" in valid_tokenizer_files: + tokenizer_file = "tekken.json" + else: + tokenizer_file = sorted(valid_tokenizer_files)[-1] + logger.warning( + f"Multiple tokenizer files found in {base_path}. Using {tokenizer_file}" + ) + else: + tokenizer_file = valid_tokenizer_files[0] + + self.tokenizer = MistralTokenizer.from_file( + base_path / tokenizer_file + ).instruct_tokenizer.tokenizer + self.tokenizer_type = ( + MistralTokenizerType.tekken + if isinstance(self.tokenizer, Tekkenizer) + else MistralTokenizerType.spm + ) + self.vocab_size = self.tokenizer.n_words + self.fname_tokenizer = base_path / tokenizer_file + self._name = ( + "mistral-" + self.tokenizer_type.value + "-" + self.tokenizer.version + ) + + @property + def tokenizer_name(self) -> str: + return self._name + + @property + def gguf_tokenizer_model(self) -> str: + return "llama" if self.tokenizer_type == MistralTokenizerType.spm else "gpt2" + + def _sentencepiece_tokens(self) -> Iterable[tuple[bytes, float, gguf.TokenType]]: + assert SentencePieceTokenizer is not None, "mistral_common is not installed" + assert isinstance(self.tokenizer, SentencePieceTokenizer), ( + f"Expected SentencePieceTokenizer, got {type(self.tokenizer)}" + ) + + for i in range(self.tokenizer._model.vocab_size()): + piece = self.tokenizer._model.IdToPiece(i) + text = piece.encode("utf-8") + score: float = self.tokenizer._model.GetScore(i) + + toktype = gguf.TokenType.NORMAL + if self.tokenizer._model.IsUnknown(i): + toktype = gguf.TokenType.UNKNOWN + if self.tokenizer._model.IsControl(i): + toktype = gguf.TokenType.CONTROL + + if self.tokenizer._model.IsUnused(i): + toktype = gguf.TokenType.UNUSED + if self.tokenizer._model.IsByte(i): + toktype = gguf.TokenType.BYTE + + yield text, score, toktype + + def _tekken_tokens(self) -> Iterable[tuple[bytes, float, gguf.TokenType]]: + assert Tekkenizer is not None, "mistral_common is not installed" + assert isinstance(self.tokenizer, Tekkenizer), ( + f"Expected Tekkenizer, got {type(self.tokenizer)}" + ) + + byte_encoder = bytes_to_unicode() + for token_id in range(self.tokenizer.num_special_tokens): + yield ( + self.tokenizer.id_to_piece(token_id).encode("utf-8"), + 0, + gguf.TokenType.CONTROL + ) + for token in self.tokenizer._tekken_token2id_nospecial: + yield ( + self.token_bytes_to_string(token, byte_encoder).encode("utf-8"), + 0, + gguf.TokenType.NORMAL, + ) + + def get_token_id(self, token: str) -> int: + assert SentencePieceTokenizer is not None and Tekkenizer is not None, "mistral_common is not installed" + if self.tokenizer_type == MistralTokenizerType.spm: + assert isinstance(self.tokenizer, SentencePieceTokenizer) + return self.tokenizer._vocab.index(token) + elif self.tokenizer_type == MistralTokenizerType.tekken: + assert isinstance(self.tokenizer, Tekkenizer) + return ( + self.tokenizer._vocab.index(token) + self.tokenizer.num_special_tokens + ) + else: + raise ValueError(f"Unknown tokenizer type: {self.tokenizer_type}") + + @property + def bos_id(self) -> int: + return self.tokenizer.bos_id + + @property + def eos_id(self) -> int: + return self.tokenizer.eos_id + + @property + def pad_id(self) -> int: + if self.tokenizer.pad_id == -1: + return self.eos_id + return self.tokenizer.pad_id + + @property + def unk_id(self) -> int: + return self.tokenizer.unk_id + + @property + def bos_token(self) -> str: + return self.tokenizer.id_to_piece(self.tokenizer.bos_id) + + @property + def eos_token(self) -> str: + return self.tokenizer.id_to_piece(self.tokenizer.eos_id) + + @property + def pad_token(self) -> str: + return self.tokenizer.id_to_piece(self.tokenizer.pad_id) + + @property + def unk_token(self) -> str: + return self.tokenizer.id_to_piece(self.tokenizer.unk_id) + + def all_tokens(self) -> Iterable[tuple[bytes, float, gguf.TokenType]]: + if self.tokenizer_type == MistralTokenizerType.spm: + yield from self._sentencepiece_tokens() + + elif self.tokenizer_type == MistralTokenizerType.tekken: + yield from self._tekken_tokens() + + else: + raise ValueError(f"Unknown tokenizer type: {self.tokenizer_type}") + + @staticmethod + def token_bytes_to_string(b, byte_encoder): + return "".join([byte_encoder[ord(char)] for char in b.decode("latin-1")]) + + def extract_vocab_merges_from_model(self): + # Adapted from Transformers (Apache 2.0) + # https://github.com/huggingface/transformers/blob/main/src/transformers/convert_slow_tokenizer.py + assert Tekkenizer is not None and isinstance(self.tokenizer, Tekkenizer), ( + f"Expected Tekkenizer, got {type(self.tokenizer)}" + ) + mergeable_ranks = self.tokenizer._model._mergeable_ranks + token_bytes_map = { + rank: token_bytes for token_bytes, rank in mergeable_ranks.items() + } + merge_pairs = [] + + # Sort vocab by rank to ensure correct merge order + for i in range(256, self.vocab_size - self.tokenizer.num_special_tokens): + merged_token = token_bytes_map[i] + local = [] + for j in range(1, len(merged_token)): + left = merged_token[:j] + right = merged_token[j:] + if ( + left in mergeable_ranks + and right in mergeable_ranks + and (left + right) in mergeable_ranks + ): + local.append((left, right, i)) + if not local: + raise ValueError( + f"Could not find valid merge for token at rank {i}: {merged_token.decode('latin-1')}" + ) + local = sorted( + local, + key=lambda x: (mergeable_ranks[x[0]], mergeable_ranks[x[1]]), + reverse=False, + ) + merge_pairs.extend(local) + merge_pairs = sorted(merge_pairs, key=lambda val: val[2], reverse=False) + + byte_encoder = bytes_to_unicode() + + decoded_merge_pairs = [ + [ + self.token_bytes_to_string(val[0], byte_encoder), + self.token_bytes_to_string(val[1], byte_encoder), + ] + for val in merge_pairs + ] + + merges = [ + " ".join( + [ + # ensure the spaces are properly encoded + "".join(chr(ord(c) + 256) if c == " " else c for c in part) + for part in pair + ] + ) + for pair in decoded_merge_pairs + ] + + return merges From a90aec1f9fd650410f74de3a649c3a9959f3507a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Thireus=20=E2=98=A0?= Date: Mon, 4 Aug 2025 22:59:54 +0100 Subject: [PATCH 08/13] GLM-4.5 llama.cpp final port --- convert_hf_to_gguf.py | 148 ++++---------- convert_hf_to_gguf_update.py | 2 + gguf-py/gguf/constants.py | 51 +++-- gguf-py/gguf/gguf_writer.py | 3 + gguf-py/gguf/tensor_mapping.py | 25 +++ src/llama-vocab.cpp | 24 +++ src/llama-vocab.h | 16 ++ src/llama.cpp | 363 +++++++++++++++++++++++++-------- 8 files changed, 415 insertions(+), 217 deletions(-) diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index 464973cc8..cfabd3f97 100644 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -618,12 +618,15 @@ def get_vocab_base_pre(self, tokenizer) -> str: if chkhsh == "b6e8e1518dc4305be2fe39c313ed643381c4da5db34a98f6a04c093f8afbe99b": # ref: https://huggingface.co/THUDM/glm-4-9b-chat res = "chatglm-bpe" + if chkhsh == "81d72c7348a9f0ebe86f23298d37debe0a5e71149e29bd283904c02262b27516": + # ref: https://huggingface.co/THUDM/glm-4-9b-chat + res = "chatglm-bpe" if chkhsh == "a1336059768a55c99a734006ffb02203cd450fed003e9a71886c88acf24fdbc2": # ref: https://huggingface.co/THUDM/glm-4-9b-hf res = "glm4" if chkhsh == "9ca2dd618e8afaf09731a7cf6e2105b373ba6a1821559f258b272fe83e6eb902": # ref: https://huggingface.co/zai-org/GLM-4.5-Air, https://huggingface.co/zai-org/GLM-4.5 - res = "gpt-2" + res = "glm4" if chkhsh == "7fc505bd3104ca1083b150b17d088b59534ede9bde81f0dd2090967d7fe52cee": # ref: https://huggingface.co/LumiOpen/Viking-7B res = "viking" @@ -3961,15 +3964,13 @@ class Glm4MoeModel(Model): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) # GLM4_MOE has num_hidden_layers + 1 actual layers (including NextN layer) - self.block_count = self.hparams["num_hidden_layers"] + 1 + self.block_count = self.hparams["num_hidden_layers"] + self.hparams.get("num_nextn_predict_layers", 0) self.tensor_map = gguf.get_tensor_name_map(self.model_arch, self.block_count) - + def set_vocab(self): from transformers import AutoTokenizer - tokenizer = AutoTokenizer.from_pretrained( - self.dir_model, trust_remote_code=True - ) + tokenizer = AutoTokenizer.from_pretrained(self.dir_model) special_vocab = gguf.SpecialVocab(self.dir_model, load_merges=True) tokens, toktypes, tokpre = self.get_vocab_base() self.gguf_writer.add_tokenizer_model("gpt2") @@ -3977,17 +3978,18 @@ def set_vocab(self): self.gguf_writer.add_token_list(tokens) self.gguf_writer.add_token_types(toktypes) - # Set special tokens - special_vocab._set_special_token( - "eos", tokenizer.get_added_vocab()["<|endoftext|>"] - ) - special_vocab._set_special_token("eot", tokenizer.get_added_vocab()["<|user|>"]) - special_vocab._set_special_token( - "unk", tokenizer.get_added_vocab()["<|endoftext|>"] - ) - special_vocab._set_special_token( - "bos", tokenizer.get_added_vocab()["<|endoftext|>"] - ) + # Special tokens + # Note: Using <|endoftext|> (151329) for eot causes endless generation + special_vocab._set_special_token("bos", tokenizer.get_added_vocab()["[gMASK]"]) # 151331 + special_vocab._set_special_token("eot", tokenizer.get_added_vocab()["<|user|>"]) # 151336 + special_vocab._set_special_token("unk", tokenizer.get_added_vocab()["<|endoftext|>"]) # 151329 + special_vocab._set_special_token("eom", tokenizer.get_added_vocab()["<|observation|>"]) # 151338 + + # Patch broken chat template + if isinstance(special_vocab.chat_template, str) and "visible_text(m.content).endswith" in special_vocab.chat_template: + special_vocab.chat_template = special_vocab.chat_template.replace( + """{{ visible_text(m.content) }}\n{{- '/nothink' if (enable_thinking is defined and not enable_thinking and not visible_text(m.content).endswith("/nothink")) else '' -}}""", + """{% set content = visible_text(m.content) %}{{ content }}\n{{- '/nothink' if (enable_thinking is defined and not enable_thinking and not content.endswith("/nothink")) else '' -}}""") special_vocab.add_to_gguf(self.gguf_writer) @@ -4001,10 +4003,9 @@ def set_gguf_parameters(self): int(rope_dim * self.hparams.get("partial_rotary_factor", 0.5)) ) - # MoE parameters - if (n_experts := self.hparams.get("n_routed_experts")) is not None: - self.gguf_writer.add_expert_count(n_experts) - # Note: expert_used_count is already set by parent class using num_experts_per_tok + # MoE parameters - Use only routed expert count (shared experts handled separately) + if (n_routed_experts := self.hparams.get("n_routed_experts")) is not None: + self.gguf_writer.add_expert_count(n_routed_experts) if (moe_intermediate_size := self.hparams.get("moe_intermediate_size")) is not None: self.gguf_writer.add_expert_feed_forward_length(moe_intermediate_size) if (n_shared_experts := self.hparams.get("n_shared_experts")) is not None: @@ -4023,8 +4024,11 @@ def set_gguf_parameters(self): if (norm_topk_prob := self.hparams.get("norm_topk_prob")) is not None: self.gguf_writer.add_expert_weights_norm(norm_topk_prob) + # NextN/MTP prediction layers + if (num_nextn_predict_layers := self.hparams.get("num_nextn_predict_layers")) is not None: + self.gguf_writer.add_nextn_predict_layers(num_nextn_predict_layers) + _experts: list[dict[str, Tensor]] | None = None - _shared_experts: list[dict[str, Tensor]] | None = None def modify_tensors( self, data_torch: Tensor, name: str, bid: int | None @@ -4035,21 +4039,17 @@ def modify_tensors( name = name.replace("language_model.", "") # for multimodal variants # Handle main token embedding (but not layer-specific NextN embeddings) - if name == "model.embed_tokens.weight": + if name == "model.embed_tokens.weight" and ".layers." not in name: return [(self.map_tensor_name("token_embd.weight"), data_torch)] # Handle routed experts - if name.find("mlp.experts") != -1 and "shared_experts" not in name: + if name.find("mlp.experts") != -1: n_experts = self.hparams["n_routed_experts"] assert bid is not None if self._experts is None: self._experts = [{} for _ in range(self.block_count)] - # Extend experts array if needed (for models where actual layers > num_hidden_layers) - while len(self._experts) <= bid: - self._experts.append({}) - self._experts[bid][name] = data_torch if len(self._experts[bid]) >= n_experts * 3: @@ -4065,95 +4065,21 @@ def modify_tensors( del self._experts[bid][ename] data_torch = torch.stack(datas, dim=0) - # Generate GGUF tensor names for merged experts - if w_name == "down_proj": - new_name = f"blk.{bid}.ffn_down_exps.weight" - elif w_name == "gate_proj": - new_name = f"blk.{bid}.ffn_gate_exps.weight" - elif w_name == "up_proj": - new_name = f"blk.{bid}.ffn_up_exps.weight" - else: - merged_name = f"model.layers.{bid}.mlp.experts.{w_name}.weight" - new_name = self.map_tensor_name(merged_name) + + merged_name = f"model.layers.{bid}.mlp.experts.{w_name}.weight" + + new_name = self.map_tensor_name(merged_name) tensors.append((new_name, data_torch)) return tensors else: return [] - # Handle expert gating input (routing gate) - if ".mlp.gate.e_score_correction_bias" in name: - new_name = name.replace("model.layers.", "blk.").replace( - ".mlp.gate.e_score_correction_bias", ".ffn_gate_inp.bias" # *NOTE* this is ".exp_probs_b" in mainline PR - ) - return [(new_name, data_torch)] - elif ".mlp.gate.weight" in name: - new_name = name.replace("model.layers.", "blk.").replace( - ".mlp.gate.weight", ".ffn_gate_inp.weight" - ) - return [(new_name, data_torch)] - - # Handle shared expert tensors - if ".mlp.shared_experts." in name: - new_name = name.replace("model.layers.", "blk.").replace(".mlp.shared_experts.", ".ffn_") - if "gate_proj" in new_name: - new_name = new_name.replace("gate_proj", "gate_shexp") - elif "down_proj" in new_name: - new_name = new_name.replace("down_proj", "down_shexp") - elif "up_proj" in new_name: - new_name = new_name.replace("up_proj", "up_shexp") - return [(new_name, data_torch)] - - # Handle regular dense FFN layers (for hybrid dense/MoE architecture) - if ".mlp." in name and "experts" not in name and "_shexp" not in name: - if "gate_proj" in name: - new_name = name.replace("model.layers.", "blk.").replace( - ".mlp.gate_proj.weight", ".ffn_gate.weight" - ) - elif "up_proj" in name: - new_name = name.replace("model.layers.", "blk.").replace( - ".mlp.up_proj.weight", ".ffn_up.weight" - ) - elif "down_proj" in name: - new_name = name.replace("model.layers.", "blk.").replace( - ".mlp.down_proj.weight", ".ffn_down.weight" - ) - else: - new_name = name - return [(self.map_tensor_name(new_name), data_torch)] - - # Handle special NextN tensors - preserve for future MTP support - See https://github.com/ggml-org/llama.cpp/pull/13236 - if ( - ".embed_tokens." in name - or ".shared_head." in name - or ".eh_proj." in name - or ".enorm." in name - or ".hnorm." in name - ): - new_name = name.replace("model.layers.", "blk.").replace("model.", "").replace(".weight", "") - # logger.debug(f"Skipping MTP tensor: {new_name}") - return [(new_name, data_torch)] - - # GLM tensor mapping - handle directly without map_tensor_name - if ".input_layernorm." in name: - new_name = name.replace("model.layers.", "blk.").replace(".input_layernorm.", ".attn_norm.") - return [(new_name, data_torch)] - elif ".post_attention_layernorm." in name: - new_name = name.replace("model.layers.", "blk.").replace(".post_attention_layernorm.", ".ffn_norm.") - return [(new_name, data_torch)] - elif ".self_attn." in name: - # Map GLM self_attn to standard attention naming - new_name = name.replace("model.layers.", "blk.").replace(".self_attn.", ".attn_") - if "q_proj" in new_name: - new_name = new_name.replace("q_proj", "q") - elif "k_proj" in new_name: - new_name = new_name.replace("k_proj", "k") - elif "v_proj" in new_name: - new_name = new_name.replace("v_proj", "v") - elif "o_proj" in new_name: - new_name = new_name.replace("o_proj", "output") - return [(new_name, data_torch)] + if name.endswith("e_score_correction_bias"): + name = name.replace("e_score_correction_bias", "e_score_correction.bias") - return super().modify_tensors(data_torch, name, bid) + new_name = self.map_tensor_name(name) + + return [(new_name, data_torch)] def prepare_tensors(self): super().prepare_tensors() diff --git a/convert_hf_to_gguf_update.py b/convert_hf_to_gguf_update.py index d6541987d..6c2c46ab9 100755 --- a/convert_hf_to_gguf_update.py +++ b/convert_hf_to_gguf_update.py @@ -96,6 +96,8 @@ class TOKENIZER_TYPE(IntEnum): {"name": "smollm", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/HuggingFaceTB/SmolLM-135M", }, {"name": "deepseek-v3", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/deepseek-ai/DeepSeek-V3"}, {"name": "seed-coder", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/ByteDance-Seed/Seed-Coder-8B-Base", }, + {"name": "glm4", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/THUDM/glm-4-9b-hf", "chkhsh": "a1336059768a55c99a734006ffb02203cd450fed003e9a71886c88acf24fdbc2", }, + {"name": "glm4", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/zai-org/GLM-4.5-Air", "chkhsh": "9ca2dd618e8afaf09731a7cf6e2105b373ba6a1821559f258b272fe83e6eb902", }, {"name": "kimi-k2", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/moonshotai/Kimi-K2-Base", "chkhsh": "81212dc7cdb7e0c1074ca62c5aeab0d43c9f52b8a737be7b12a777c953027890", }, ] diff --git a/gguf-py/gguf/constants.py b/gguf-py/gguf/constants.py index 92722dc32..f49234300 100644 --- a/gguf-py/gguf/constants.py +++ b/gguf-py/gguf/constants.py @@ -91,6 +91,7 @@ class LLM: EXPERT_WEIGHTS_SCALE = "{arch}.expert_weights_scale" EXPERT_WEIGHTS_NORM = "{arch}.expert_weights_norm" EXPERT_GATING_FUNC = "{arch}.expert_gating_func" + NEXTN_PREDICT_LAYERS = "{arch}.nextn_predict_layers" POOLING_TYPE = "{arch}.pooling_type" LOGIT_SCALE = "{arch}.logit_scale" DECODER_START_TOKEN_ID = "{arch}.decoder_start_token_id" @@ -159,6 +160,13 @@ class Tokenizer: CHAT_TEMPLATE_N = "tokenizer.chat_template.{name}" CHAT_TEMPLATES = "tokenizer.chat_templates" # FIM/Infill special tokens constants + FIM_PRE_ID = "tokenizer.ggml.fim_pre_token_id" + FIM_SUF_ID = "tokenizer.ggml.fim_suf_token_id" + FIM_MID_ID = "tokenizer.ggml.fim_mid_token_id" + FIM_PAD_ID = "tokenizer.ggml.fim_pad_token_id" + FIM_REP_ID = "tokenizer.ggml.fim_rep_token_id" + FIM_SEP_ID = "tokenizer.ggml.fim_sep_token_id" + # FIM/Infill special tokens constants PREFIX_ID = "tokenizer.ggml.prefix_token_id" SUFFIX_ID = "tokenizer.ggml.suffix_token_id" MIDDLE_ID = "tokenizer.ggml.middle_token_id" @@ -263,9 +271,6 @@ class MODEL_TENSOR(IntEnum): FFN_GATE_EXP = auto() FFN_DOWN_EXP = auto() FFN_UP_EXP = auto() - FFN_GATE_EXPS = auto() # merged experts - FFN_DOWN_EXPS = auto() # merged experts - FFN_UP_EXPS = auto() # merged experts FFN_GATE_SHEXP = auto() FFN_DOWN_SHEXP = auto() FFN_UP_SHEXP = auto() @@ -415,9 +420,6 @@ class MODEL_TENSOR(IntEnum): MODEL_TENSOR.FFN_GATE_EXP: "blk.{bid}.ffn_gate_exps", MODEL_TENSOR.FFN_DOWN_EXP: "blk.{bid}.ffn_down_exps", MODEL_TENSOR.FFN_UP_EXP: "blk.{bid}.ffn_up_exps", - MODEL_TENSOR.FFN_GATE_EXPS: "blk.{bid}.ffn_gate_exps", # merged experts - MODEL_TENSOR.FFN_DOWN_EXPS: "blk.{bid}.ffn_down_exps", # merged experts - MODEL_TENSOR.FFN_UP_EXPS: "blk.{bid}.ffn_up_exps", # merged experts MODEL_TENSOR.FFN_EXP_PROBS_B: "blk.{bid}.exp_probs_b", MODEL_TENSOR.LAYER_OUT_NORM: "blk.{bid}.layer_output_norm", MODEL_TENSOR.SSM_IN: "blk.{bid}.ssm_in", @@ -465,13 +467,13 @@ class MODEL_TENSOR(IntEnum): MODEL_TENSOR.ENC_FFN_DOWN: "enc.blk.{bid}.ffn_down", MODEL_TENSOR.ENC_FFN_UP: "enc.blk.{bid}.ffn_up", MODEL_TENSOR.ENC_OUTPUT_NORM: "enc.output_norm", - # NextN/MTP tensors (GLM4_MOE) - MODEL_TENSOR.NEXTN_EH_PROJ: "blk.{bid}.eh_proj", - MODEL_TENSOR.NEXTN_EMBED_TOKENS: "blk.{bid}.embed_tokens", - MODEL_TENSOR.NEXTN_ENORM: "blk.{bid}.enorm", - MODEL_TENSOR.NEXTN_HNORM: "blk.{bid}.hnorm", - MODEL_TENSOR.NEXTN_SHARED_HEAD_HEAD: "blk.{bid}.shared_head.head", - MODEL_TENSOR.NEXTN_SHARED_HEAD_NORM: "blk.{bid}.shared_head.norm", + # NextN/MTP + MODEL_TENSOR.NEXTN_EH_PROJ: "blk.{bid}.nextn.eh_proj", + MODEL_TENSOR.NEXTN_EMBED_TOKENS: "blk.{bid}.nextn.embed_tokens", + MODEL_TENSOR.NEXTN_ENORM: "blk.{bid}.nextn.enorm", + MODEL_TENSOR.NEXTN_HNORM: "blk.{bid}.nextn.hnorm", + MODEL_TENSOR.NEXTN_SHARED_HEAD_HEAD: "blk.{bid}.nextn.shared_head_head", + MODEL_TENSOR.NEXTN_SHARED_HEAD_NORM: "blk.{bid}.nextn.shared_head_norm", } MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = { @@ -1096,23 +1098,24 @@ class MODEL_TENSOR(IntEnum): MODEL_TENSOR.OUTPUT_NORM, MODEL_TENSOR.OUTPUT, MODEL_TENSOR.ATTN_NORM, + MODEL_TENSOR.ATTN_POST_NORM, MODEL_TENSOR.ATTN_Q, MODEL_TENSOR.ATTN_K, MODEL_TENSOR.ATTN_V, MODEL_TENSOR.ATTN_OUT, MODEL_TENSOR.ATTN_Q_NORM, MODEL_TENSOR.ATTN_K_NORM, - MODEL_TENSOR.FFN_NORM, - MODEL_TENSOR.FFN_GATE, # dense layers - MODEL_TENSOR.FFN_DOWN, # dense layers - MODEL_TENSOR.FFN_UP, # dense layers + MODEL_TENSOR.FFN_GATE, + MODEL_TENSOR.FFN_DOWN, + MODEL_TENSOR.FFN_UP, MODEL_TENSOR.FFN_GATE_INP, - MODEL_TENSOR.FFN_GATE_EXPS, - MODEL_TENSOR.FFN_DOWN_EXPS, - MODEL_TENSOR.FFN_UP_EXPS, + MODEL_TENSOR.FFN_GATE_EXP, + MODEL_TENSOR.FFN_DOWN_EXP, + MODEL_TENSOR.FFN_UP_EXP, MODEL_TENSOR.FFN_GATE_SHEXP, MODEL_TENSOR.FFN_DOWN_SHEXP, MODEL_TENSOR.FFN_UP_SHEXP, + MODEL_TENSOR.FFN_EXP_PROBS_B, # NextN/MTP tensors - preserved but unused MODEL_TENSOR.NEXTN_EH_PROJ, MODEL_TENSOR.NEXTN_EMBED_TOKENS, @@ -1684,6 +1687,14 @@ def get_type(val: Any) -> GGUFValueType: KEY_TOKENIZER_MASK_ID = Keys.Tokenizer.MASK_ID KEY_TOKENIZER_HF_JSON = Keys.Tokenizer.HF_JSON KEY_TOKENIZER_RWKV = Keys.Tokenizer.RWKV + +KEY_TOKENIZER_FIM_PRE_ID = Keys.Tokenizer.FIM_PRE_ID +KEY_TOKENIZER_FIM_SUF_ID = Keys.Tokenizer.FIM_SUF_ID +KEY_TOKENIZER_FIM_MID_ID = Keys.Tokenizer.FIM_MID_ID +KEY_TOKENIZER_FIM_PAD_ID = Keys.Tokenizer.FIM_PAD_ID +KEY_TOKENIZER_FIM_REP_ID = Keys.Tokenizer.FIM_REP_ID +KEY_TOKENIZER_FIM_SEP_ID = Keys.Tokenizer.FIM_SEP_ID + KEY_TOKENIZER_PREFIX_ID = Keys.Tokenizer.PREFIX_ID KEY_TOKENIZER_SUFFIX_ID = Keys.Tokenizer.SUFFIX_ID KEY_TOKENIZER_MIDDLE_ID = Keys.Tokenizer.MIDDLE_ID diff --git a/gguf-py/gguf/gguf_writer.py b/gguf-py/gguf/gguf_writer.py index e31bf97b1..8b820d18a 100644 --- a/gguf-py/gguf/gguf_writer.py +++ b/gguf-py/gguf/gguf_writer.py @@ -677,6 +677,9 @@ def add_expert_weights_norm(self, value: bool) -> None: def add_expert_gating_func(self, value: ExpertGatingFuncType) -> None: self.add_uint32(Keys.LLM.EXPERT_GATING_FUNC.format(arch=self.arch), value.value) + def add_nextn_predict_layers(self, count: int) -> None: + self.add_uint32(Keys.LLM.NEXTN_PREDICT_LAYERS.format(arch=self.arch), count) + def add_layer_norm_eps(self, value: float) -> None: self.add_float32(Keys.Attention.LAYERNORM_EPS.format(arch=self.arch), value) diff --git a/gguf-py/gguf/tensor_mapping.py b/gguf-py/gguf/tensor_mapping.py index d507725c4..22b29f141 100644 --- a/gguf-py/gguf/tensor_mapping.py +++ b/gguf-py/gguf/tensor_mapping.py @@ -592,6 +592,31 @@ class TensorNameMap: MODEL_TENSOR.ENC_OUTPUT_NORM: ( "encoder.final_layer_norm", # t5 ), + + # NextN/MTP tensors for GLM4_MOE + MODEL_TENSOR.NEXTN_EH_PROJ: ( + "model.layers.{bid}.eh_proj", + ), + + MODEL_TENSOR.NEXTN_EMBED_TOKENS: ( + "model.layers.{bid}.embed_tokens", + ), + + MODEL_TENSOR.NEXTN_ENORM: ( + "model.layers.{bid}.enorm", + ), + + MODEL_TENSOR.NEXTN_HNORM: ( + "model.layers.{bid}.hnorm", + ), + + MODEL_TENSOR.NEXTN_SHARED_HEAD_HEAD: ( + "model.layers.{bid}.shared_head.head", + ), + + MODEL_TENSOR.NEXTN_SHARED_HEAD_NORM: ( + "model.layers.{bid}.shared_head.norm", + ), } # architecture-specific block mappings diff --git a/src/llama-vocab.cpp b/src/llama-vocab.cpp index 109a66593..a2bc72d9b 100644 --- a/src/llama-vocab.cpp +++ b/src/llama-vocab.cpp @@ -1546,6 +1546,30 @@ llama_token llama_token_suffix_impl(const struct llama_vocab & vocab) { return vocab.special_suffix_id; } +llama_token llama_token_fim_pre_impl(const struct llama_vocab & vocab) { + return vocab.special_fim_pre_id; +} + +llama_token llama_token_fim_suf_impl(const struct llama_vocab & vocab) { + return vocab.special_fim_suf_id; +} + +llama_token llama_token_fim_mid_impl(const struct llama_vocab & vocab) { + return vocab.special_fim_mid_id; +} + +llama_token llama_token_fim_pad_impl(const struct llama_vocab & vocab) { + return vocab.special_fim_pad_id; +} + +llama_token llama_token_fim_rep_impl(const struct llama_vocab & vocab) { + return vocab.special_fim_rep_id; +} + +llama_token llama_token_fim_sep_impl(const struct llama_vocab & vocab) { + return vocab.special_fim_sep_id; +} + llama_token llama_token_eot_impl(const struct llama_vocab & vocab) { return vocab.special_eot_id; } diff --git a/src/llama-vocab.h b/src/llama-vocab.h index a461eca0e..64ff7cc08 100644 --- a/src/llama-vocab.h +++ b/src/llama-vocab.h @@ -43,6 +43,15 @@ struct llama_vocab { id special_mask_id = -1; id linefeed_id = 13; + + // fim tokens + llama_token special_fim_pre_id = -1; + llama_token special_fim_suf_id = -1; + llama_token special_fim_mid_id = -1; + llama_token special_fim_pad_id = -1; + llama_token special_fim_rep_id = -1; // repo + llama_token special_fim_sep_id = -1; // file separator + id special_prefix_id = -1; id special_suffix_id = -1; id special_middle_id = -1; @@ -100,6 +109,13 @@ llama_token llama_token_pad_impl(const struct llama_vocab & vocab); int32_t llama_add_bos_token_impl(const struct llama_vocab & vocab); int32_t llama_add_eos_token_impl(const struct llama_vocab & vocab); +llama_token llama_token_fim_pre_impl(const struct llama_vocab & vocab); +llama_token llama_token_fim_suf_impl(const struct llama_vocab & vocab); +llama_token llama_token_fim_mid_impl(const struct llama_vocab & vocab); +llama_token llama_token_fim_pad_impl(const struct llama_vocab & vocab); +llama_token llama_token_fim_rep_impl(const struct llama_vocab & vocab); +llama_token llama_token_fim_sep_impl(const struct llama_vocab & vocab); + llama_token llama_token_prefix_impl(const struct llama_vocab & vocab); llama_token llama_token_middle_impl(const struct llama_vocab & vocab); llama_token llama_token_suffix_impl(const struct llama_vocab & vocab); diff --git a/src/llama.cpp b/src/llama.cpp index 5fe20a8fe..34bab47e8 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -330,6 +330,7 @@ enum llm_kv { LLM_KV_EXPERT_WEIGHTS_SCALE, LLM_KV_EXPERT_WEIGHTS_NORM, LLM_KV_EXPERT_GATING_FUNC, + LLM_KV_NEXTN_PREDICT_LAYERS, LLM_KV_POOLING_TYPE, LLM_KV_LOGIT_SCALE, LLM_KV_DECODER_START_TOKEN_ID, @@ -399,6 +400,12 @@ enum llm_kv { LLM_KV_TOKENIZER_PRECOMPILED_CHARSMAP, LLM_KV_TOKENIZER_HF_JSON, LLM_KV_TOKENIZER_RWKV, + LLM_KV_TOKENIZER_FIM_PRE_ID, + LLM_KV_TOKENIZER_FIM_SUF_ID, + LLM_KV_TOKENIZER_FIM_MID_ID, + LLM_KV_TOKENIZER_FIM_PAD_ID, + LLM_KV_TOKENIZER_FIM_REP_ID, + LLM_KV_TOKENIZER_FIM_SEP_ID, LLM_KV_TOKENIZER_PREFIX_ID, LLM_KV_TOKENIZER_SUFFIX_ID, LLM_KV_TOKENIZER_MIDDLE_ID, @@ -439,6 +446,7 @@ static const std::map LLM_KV_NAMES = { { LLM_KV_EXPERT_WEIGHTS_SCALE, "%s.expert_weights_scale" }, { LLM_KV_EXPERT_WEIGHTS_NORM, "%s.expert_weights_norm" }, { LLM_KV_EXPERT_GATING_FUNC, "%s.expert_gating_func" }, + { LLM_KV_NEXTN_PREDICT_LAYERS, "%s.nextn_predict_layers" }, { LLM_KV_POOLING_TYPE , "%s.pooling_type" }, { LLM_KV_LOGIT_SCALE, "%s.logit_scale" }, { LLM_KV_DECODER_START_TOKEN_ID, "%s.decoder_start_token_id" }, @@ -504,6 +512,13 @@ static const std::map LLM_KV_NAMES = { { LLM_KV_TOKENIZER_PRECOMPILED_CHARSMAP, "tokenizer.ggml.precompiled_charsmap" }, { LLM_KV_TOKENIZER_HF_JSON, "tokenizer.huggingface.json" }, { LLM_KV_TOKENIZER_RWKV, "tokenizer.rwkv.world" }, + { LLM_KV_TOKENIZER_FIM_PRE_ID, "tokenizer.ggml.fim_pre_token_id" }, + { LLM_KV_TOKENIZER_FIM_SUF_ID, "tokenizer.ggml.fim_suf_token_id" }, + { LLM_KV_TOKENIZER_FIM_MID_ID, "tokenizer.ggml.fim_mid_token_id" }, + { LLM_KV_TOKENIZER_FIM_PAD_ID, "tokenizer.ggml.fim_pad_token_id" }, + { LLM_KV_TOKENIZER_FIM_REP_ID, "tokenizer.ggml.fim_rep_token_id" }, + { LLM_KV_TOKENIZER_FIM_SEP_ID, "tokenizer.ggml.fim_sep_token_id" }, + { LLM_KV_TOKENIZER_PREFIX_ID, "tokenizer.ggml.prefix_token_id" }, { LLM_KV_TOKENIZER_SUFFIX_ID, "tokenizer.ggml.suffix_token_id" }, { LLM_KV_TOKENIZER_MIDDLE_ID, "tokenizer.ggml.middle_token_id" }, @@ -1422,16 +1437,16 @@ static const std::map> LLM_TENSOR_NA { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, { LLM_TENSOR_OUTPUT, "output" }, { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, + { LLM_TENSOR_ATTN_POST_NORM, "blk.%d.post_attention_norm" }, { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" }, { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" }, { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" }, { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, { LLM_TENSOR_ATTN_Q_NORM, "blk.%d.attn_q_norm" }, { LLM_TENSOR_ATTN_K_NORM, "blk.%d.attn_k_norm" }, - { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" }, - { LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" }, // dense layers - { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, // dense layers - { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, // dense layers + { LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" }, + { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, + { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, { LLM_TENSOR_FFN_GATE_INP, "blk.%d.ffn_gate_inp" }, { LLM_TENSOR_FFN_GATE_EXPS, "blk.%d.ffn_gate_exps" }, { LLM_TENSOR_FFN_DOWN_EXPS, "blk.%d.ffn_down_exps" }, @@ -1439,13 +1454,14 @@ static const std::map> LLM_TENSOR_NA { LLM_TENSOR_FFN_GATE_SHEXP, "blk.%d.ffn_gate_shexp" }, { LLM_TENSOR_FFN_DOWN_SHEXP, "blk.%d.ffn_down_shexp" }, { LLM_TENSOR_FFN_UP_SHEXP, "blk.%d.ffn_up_shexp" }, + { LLM_TENSOR_FFN_EXP_PROBS_B, "blk.%d.exp_probs_b" }, // NextN/MTP tensors - preserved but unused (in final layer, dynamic layer number) - { LLM_TENSOR_NEXTN_EH_PROJ, "blk.%d.eh_proj" }, - { LLM_TENSOR_NEXTN_EMBED_TOKENS, "blk.%d.embed_tokens" }, - { LLM_TENSOR_NEXTN_ENORM, "blk.%d.enorm" }, - { LLM_TENSOR_NEXTN_HNORM, "blk.%d.hnorm" }, - { LLM_TENSOR_NEXTN_SHARED_HEAD_HEAD, "blk.%d.shared_head.head" }, - { LLM_TENSOR_NEXTN_SHARED_HEAD_NORM, "blk.%d.shared_head.norm" }, + { LLM_TENSOR_NEXTN_EH_PROJ, "blk.%d.nextn.eh_proj" }, + { LLM_TENSOR_NEXTN_EMBED_TOKENS, "blk.%d.nextn.embed_tokens" }, + { LLM_TENSOR_NEXTN_ENORM, "blk.%d.nextn.enorm" }, + { LLM_TENSOR_NEXTN_HNORM, "blk.%d.nextn.hnorm" }, + { LLM_TENSOR_NEXTN_SHARED_HEAD_HEAD, "blk.%d.nextn.shared_head_head" }, + { LLM_TENSOR_NEXTN_SHARED_HEAD_NORM, "blk.%d.nextn.shared_head_norm" }, }, }, { @@ -2654,9 +2670,9 @@ enum e_model { MODEL_40B, MODEL_65B, MODEL_70B, + MODEL_106B_A12B, MODEL_142B, MODEL_236B, - MODEL_106B_A12B, MODEL_355B_A32B, MODEL_314B, MODEL_405B, @@ -2728,6 +2744,7 @@ struct llama_hparams { float expert_weights_scale = 0.0; bool expert_weights_norm = false; uint32_t expert_gating_func = LLM_EXPERT_GATING_FUNC_SOFTMAX; + uint32_t nextn_predict_layers = 0; float f_norm_eps; float f_norm_rms_eps; @@ -2928,6 +2945,15 @@ struct llama_cparams { void * cb_eval_user_data; }; +struct llama_layer_nextn { + struct ggml_tensor * eh_proj = nullptr; + struct ggml_tensor * embed_tokens = nullptr; + struct ggml_tensor * enorm = nullptr; + struct ggml_tensor * hnorm = nullptr; + struct ggml_tensor * shared_head_head = nullptr; + struct ggml_tensor * shared_head_norm = nullptr; +}; + // TODO: separate into "llama_layer_enc" and "llama_layer_dec" struct llama_layer { // normalization @@ -3047,6 +3073,8 @@ struct llama_layer { struct ggml_tensor * ffn_up_scale; struct ggml_tensor * ffn_down_scale; + struct llama_layer_nextn nextn; + std::unique_ptr computed_wk_b; std::unique_ptr computed_wv_b; std::unique_ptr computed_wkv_b; @@ -5333,9 +5361,9 @@ static const char * llama_model_type_name(e_model type) { case MODEL_40B: return "40B"; case MODEL_65B: return "65B"; case MODEL_70B: return "70B"; + case MODEL_106B_A12B: return "106B.A12B"; case MODEL_142B: return "142B"; case MODEL_236B: return "236B"; - case MODEL_106B_A12B: return "106B.A12B"; case MODEL_355B_A32B: return "355B.A32B"; case MODEL_314B: return "314B"; case MODEL_405B: return "405B"; @@ -6094,14 +6122,14 @@ static void llm_load_hparams( } break; case LLM_ARCH_GLM4_MOE: { - ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp, false); + ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp); ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); // MoE parameters - ml.get_key(LLM_KV_EXPERT_COUNT, hparams.n_expert, 0); - ml.get_key(LLM_KV_EXPERT_USED_COUNT, hparams.n_expert_used, 0); - ml.get_key(LLM_KV_EXPERT_SHARED_COUNT, hparams.n_expert_shared, 0); - ml.get_key(LLM_KV_LEADING_DENSE_BLOCK_COUNT, hparams.n_layer_dense_lead, 0); + ml.get_key(LLM_KV_EXPERT_COUNT, hparams.n_expert); + ml.get_key(LLM_KV_EXPERT_USED_COUNT, hparams.n_expert_used); + ml.get_key(LLM_KV_EXPERT_SHARED_COUNT, hparams.n_expert_shared); + ml.get_key(LLM_KV_LEADING_DENSE_BLOCK_COUNT, hparams.n_layer_dense_lead, false); ml.get_key(LLM_KV_EXPERT_WEIGHTS_SCALE, hparams.expert_weights_scale); ml.get_key(LLM_KV_EXPERT_WEIGHTS_NORM, hparams.expert_weights_norm, false); @@ -6111,6 +6139,9 @@ static void llm_load_hparams( hparams.expert_gating_func = LLM_EXPERT_GATING_FUNC_SIGMOID; } + // NextN/MTP parameters + ml.get_key(LLM_KV_NEXTN_PREDICT_LAYERS, hparams.nextn_predict_layers, false); + switch (hparams.n_layer) { case 47: model.type = e_model::MODEL_106B_A12B; break; // GLM-4.5-Air (46 layers + 1 NextN layer) case 93: model.type = e_model::MODEL_355B_A32B; break; // GLM-4.5 (92 layers + 1 NextN layer) @@ -6654,16 +6685,24 @@ static void llm_load_vocab( const std::vector> special_token_types = { { LLM_KV_TOKENIZER_BOS_ID, vocab.special_bos_id }, { LLM_KV_TOKENIZER_EOS_ID, vocab.special_eos_id }, + { LLM_KV_TOKENIZER_EOT_ID, vocab.special_eot_id }, + { LLM_KV_TOKENIZER_EOM_ID, vocab.special_eom_id }, { LLM_KV_TOKENIZER_UNK_ID, vocab.special_unk_id }, { LLM_KV_TOKENIZER_SEP_ID, vocab.special_sep_id }, { LLM_KV_TOKENIZER_PAD_ID, vocab.special_pad_id }, { LLM_KV_TOKENIZER_CLS_ID, vocab.special_cls_id }, { LLM_KV_TOKENIZER_MASK_ID, vocab.special_mask_id }, + + { LLM_KV_TOKENIZER_FIM_PRE_ID, vocab.special_fim_pre_id }, + { LLM_KV_TOKENIZER_FIM_SUF_ID, vocab.special_fim_suf_id }, + { LLM_KV_TOKENIZER_FIM_MID_ID, vocab.special_fim_mid_id }, + { LLM_KV_TOKENIZER_FIM_PAD_ID, vocab.special_fim_pad_id }, + { LLM_KV_TOKENIZER_FIM_REP_ID, vocab.special_fim_rep_id }, + { LLM_KV_TOKENIZER_FIM_SEP_ID, vocab.special_fim_sep_id }, + { LLM_KV_TOKENIZER_PREFIX_ID, vocab.special_prefix_id }, { LLM_KV_TOKENIZER_SUFFIX_ID, vocab.special_suffix_id }, { LLM_KV_TOKENIZER_MIDDLE_ID, vocab.special_middle_id }, - { LLM_KV_TOKENIZER_EOT_ID, vocab.special_eot_id }, - { LLM_KV_TOKENIZER_EOM_ID, vocab.special_eom_id }, }; for (const auto & it : special_token_types) { @@ -6727,6 +6766,118 @@ static void llm_load_vocab( vocab.special_eom_id = t->second; } } + + for (const auto & t : vocab.token_to_id) { + // find FIM_PRE token: "<|fim_prefix|>", "", "
", etc.
+            if (vocab.special_fim_pre_id == -1) {
+                if (false
+                        || t.first == "<|fim_prefix|>"  // Qwen
+                        || t.first == ""
+                        || t.first == ""    // Granite
+                        || t.first == "<|fim▁begin|>" // DeepSeek
+                        || t.first == "
"
+                        || t.first == "▁
"          // CodeLlama
+                        || t.first == "<|code_prefix|>" // GLM-4.5
+                        ) {
+                    vocab.special_fim_pre_id = t.second;
+                    if ((vocab.id_to_token[t.second].attr & LLAMA_TOKEN_ATTR_CONTROL) == 0) {
+                        LLAMA_LOG_WARN("%s: control-looking token: %6d '%s' was not control-type; this is probably a bug in the model. its type will be overridden\n",
+                                __func__, t.second, t.first.c_str());
+                                vocab.id_to_token[t.second].attr = LLAMA_TOKEN_ATTR_CONTROL;
+                    }
+                }
+            }
+
+            // find FIM_SUF token: "<|fim_suffix|>", "", "", etc.
+            if (vocab.special_fim_suf_id == -1) {
+                if (false
+                        || t.first == "<|fim_suffix|>" // Qwen
+                        || t.first == ""
+                        || t.first == ""   // Granite
+                        || t.first == "<|fim▁hole|>" // DeepSeek
+                        || t.first == ""
+                        || t.first == "▁"         // CodeLlama
+                        || t.first == "<|code_suffix|>" // GLM-4.5
+                        ) {
+                    vocab.special_fim_suf_id = t.second;
+                    if ((vocab.id_to_token[t.second].attr & LLAMA_TOKEN_ATTR_CONTROL) == 0) {
+                        LLAMA_LOG_WARN("%s: control-looking token: %6d '%s' was not control-type; this is probably a bug in the model. its type will be overridden\n",
+                                __func__, t.second, t.first.c_str());
+                        vocab.id_to_token[t.second].attr = LLAMA_TOKEN_ATTR_CONTROL;
+                    }
+                }
+            }
+
+            // find FIM_MID token: "<|fim_middle|>", "", "", etc.
+            if (vocab.special_fim_mid_id == -1) {
+                if (false
+                        || t.first == "<|fim_middle|>" // Qwen
+                        || t.first == ""
+                        || t.first == ""   // Granite
+                        || t.first == "<|fim▁end|>"  // DeepSeek
+                        || t.first == ""
+                        || t.first == "▁"         // CodeLlama
+                        || t.first == "<|code_middle|>" // GLM-4.5
+                        ) {
+                    vocab.special_fim_mid_id = t.second;
+                    if ((vocab.id_to_token[t.second].attr & LLAMA_TOKEN_ATTR_CONTROL) == 0) {
+                        LLAMA_LOG_WARN("%s: control-looking token: %6d '%s' was not control-type; this is probably a bug in the model. its type will be overridden\n",
+                                __func__, t.second, t.first.c_str());
+                        vocab.id_to_token[t.second].attr = LLAMA_TOKEN_ATTR_CONTROL;
+                    }
+                }
+            }
+
+            // find FIM_PAD token: "<|fim_pad|>", "", "", etc.
+            if (vocab.special_fim_pad_id == -1) {
+                if (false
+                        || t.first == "<|fim_pad|>" // Qwen
+                        || t.first == ""
+                        || t.first == ""   // Granite
+                        || t.first == ""
+                        ) {
+                    vocab.special_fim_pad_id = t.second;
+                    if ((vocab.id_to_token[t.second].attr & LLAMA_TOKEN_ATTR_CONTROL) == 0) {
+                        LLAMA_LOG_WARN("%s: control-looking token: %6d '%s' was not control-type; this is probably a bug in the model. its type will be overridden\n",
+                                __func__, t.second, t.first.c_str());
+                        vocab.id_to_token[t.second].attr = LLAMA_TOKEN_ATTR_CONTROL;
+                    }
+                }
+            }
+
+            // find FIM_REP token: "<|fim_repo|>", "", "", etc.
+            if (vocab.special_fim_rep_id == -1) {
+                if (false
+                        || t.first == "<|fim_repo|>"  // Qwen
+                        || t.first == "<|repo_name|>"
+                        || t.first == ""
+                        || t.first == ""
+                        || t.first == ""    // Granite
+                        ) {
+                    vocab.special_fim_rep_id = t.second;
+                    if ((vocab.id_to_token[t.second].attr & LLAMA_TOKEN_ATTR_CONTROL) == 0) {
+                        LLAMA_LOG_WARN("%s: control-looking token: %6d '%s' was not control-type; this is probably a bug in the model. its type will be overridden\n",
+                                __func__, t.second, t.first.c_str());
+                        vocab.id_to_token[t.second].attr = LLAMA_TOKEN_ATTR_CONTROL;
+                    }
+                }
+            }
+
+            // find FIM_SEP token: "<|file_sep|>"
+            if (vocab.special_fim_sep_id == -1) {
+                if (false
+                        || t.first == "<|file_sep|>" // Qwen
+                        ) {
+                    vocab.special_fim_sep_id = t.second;
+                    if ((vocab.id_to_token[t.second].attr & LLAMA_TOKEN_ATTR_CONTROL) == 0) {
+                        LLAMA_LOG_WARN("%s: control-looking token: %6d '%s' was not control-type; this is probably a bug in the model. its type will be overridden\n",
+                                __func__, t.second, t.first.c_str());
+                        vocab.id_to_token[t.second].attr = LLAMA_TOKEN_ATTR_CONTROL;
+                    }
+                }
+            }
+        }
+
     }
 
     // build special tokens cache
@@ -6948,6 +7099,14 @@ static void llm_load_print_meta(llama_model_loader & ml, llama_model & model) {
     if (vocab.special_mask_id   != -1) { LLAMA_LOG_INFO( "%s: MASK token       = %d '%s'\n", __func__, vocab.special_mask_id, vocab.id_to_token[vocab.special_mask_id].text.c_str() ); }
 
     if (vocab.linefeed_id       != -1) { LLAMA_LOG_INFO( "%s: LF token         = %d '%s'\n", __func__, vocab.linefeed_id,       vocab.id_to_token[vocab.linefeed_id].text.c_str() );       }
+ 
+    if (vocab.special_fim_pre_id != -1) { LLAMA_LOG_INFO( "%s: FIM PRE token    = %d '%s'\n", __func__, vocab.special_fim_pre_id, vocab.id_to_token.at(vocab.special_fim_pre_id).text.c_str() ); }
+    if (vocab.special_fim_suf_id != -1) { LLAMA_LOG_INFO( "%s: FIM SUF token    = %d '%s'\n", __func__, vocab.special_fim_suf_id, vocab.id_to_token.at(vocab.special_fim_suf_id).text.c_str() ); }
+    if (vocab.special_fim_mid_id != -1) { LLAMA_LOG_INFO( "%s: FIM MID token    = %d '%s'\n", __func__, vocab.special_fim_mid_id, vocab.id_to_token.at(vocab.special_fim_mid_id).text.c_str() ); }
+    if (vocab.special_fim_pad_id != -1) { LLAMA_LOG_INFO( "%s: FIM PAD token    = %d '%s'\n", __func__, vocab.special_fim_pad_id, vocab.id_to_token.at(vocab.special_fim_pad_id).text.c_str() ); }
+    if (vocab.special_fim_rep_id != -1) { LLAMA_LOG_INFO( "%s: FIM REP token    = %d '%s'\n", __func__, vocab.special_fim_rep_id, vocab.id_to_token.at(vocab.special_fim_rep_id).text.c_str() ); }
+    if (vocab.special_fim_sep_id != -1) { LLAMA_LOG_INFO( "%s: FIM SEP token    = %d '%s'\n", __func__, vocab.special_fim_sep_id, vocab.id_to_token.at(vocab.special_fim_sep_id).text.c_str() ); }
+
     if (vocab.special_prefix_id != -1) { LLAMA_LOG_INFO( "%s: PRE token        = %d '%s'\n", __func__, vocab.special_prefix_id, vocab.id_to_token[vocab.special_prefix_id].text.c_str() ); }
     if (vocab.special_suffix_id != -1) { LLAMA_LOG_INFO( "%s: SUF token        = %d '%s'\n", __func__, vocab.special_suffix_id, vocab.id_to_token[vocab.special_suffix_id].text.c_str() ); }
     if (vocab.special_middle_id != -1) { LLAMA_LOG_INFO( "%s: MID token        = %d '%s'\n", __func__, vocab.special_middle_id, vocab.id_to_token[vocab.special_middle_id].text.c_str() ); }
@@ -9023,6 +9182,9 @@ static bool llm_load_tensors(
                     const int64_t n_expert_used   = hparams.n_expert_used;
                     const int64_t n_expert_shared = hparams.n_expert_shared;
 
+                    GGML_ASSERT(hparams.n_expert > 0 && "n_expert must be > 0 for GLM4_MOE MoE layers");
+                    GGML_ASSERT(hparams.n_expert_used > 0 && "n_expert_used must be > 0 for GLM4_MOE MoE layers");
+
                     model.tok_embd = create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
 
                     // output
@@ -9035,40 +9197,6 @@ static bool llm_load_tensors(
                         model.output = create_tensor(ctx_output, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_DUPLICATED);
                     }
                     
-                    // --- NextN / MTP tensors (preserved but unused), on the final layer ---
-                    {
-                        const int final_layer = n_layer - 1;
-                        // EH_PROJ: [2*embd, embd]
-                        create_tensor(ctx_for_layer(final_layer),
-                                      tn(LLM_TENSOR_NEXTN_EH_PROJ, final_layer),
-                                      { 2*n_embd, n_embd },
-                                      llama_model_loader::TENSOR_NOT_REQUIRED);
-                        // EMBED_TOKENS: [embd, vocab]
-                        create_tensor(ctx_for_layer(final_layer),
-                                      tn(LLM_TENSOR_NEXTN_EMBED_TOKENS, final_layer),
-                                      { n_embd, n_vocab },
-                                      llama_model_loader::TENSOR_NOT_REQUIRED);
-                        // ENORM, HNORM: [embd]
-                        create_tensor(ctx_for_layer(final_layer),
-                                      tn(LLM_TENSOR_NEXTN_ENORM, final_layer),
-                                      { n_embd },
-                                      llama_model_loader::TENSOR_NOT_REQUIRED);
-                        create_tensor(ctx_for_layer(final_layer),
-                                      tn(LLM_TENSOR_NEXTN_HNORM, final_layer),
-                                      { n_embd },
-                                      llama_model_loader::TENSOR_NOT_REQUIRED);
-                        // SHARED_HEAD_HEAD: [embd, vocab]
-                        create_tensor(ctx_for_layer(final_layer),
-                                      tn(LLM_TENSOR_NEXTN_SHARED_HEAD_HEAD, final_layer),
-                                      { n_embd, n_vocab },
-                                      llama_model_loader::TENSOR_NOT_REQUIRED);
-                        // SHARED_HEAD_NORM: [embd]
-                        create_tensor(ctx_for_layer(final_layer),
-                                      tn(LLM_TENSOR_NEXTN_SHARED_HEAD_NORM, final_layer),
-                                      { n_embd },
-                                      llama_model_loader::TENSOR_NOT_REQUIRED);
-                    }
-
                     for (int i = 0; i < n_layer; ++i) {
                         ggml_context * ctx_layer = ctx_for_layer(i);
                         ggml_context * ctx_split = ctx_for_layer_split(i);
@@ -9081,9 +9209,9 @@ static bool llm_load_tensors(
                         layer.wq = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_Q, "weight", i), { n_embd, n_embd_head_k * n_head }, 0);
                         layer.wk = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_K, "weight", i), { n_embd, n_embd_k_gqa }, 0);
                         layer.wv = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_V, "weight", i), { n_embd, n_embd_v_gqa }, 0);
-                        layer.bq = create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_Q, "bias", i), { n_embd_head_k * n_head }, llama_model_loader::TENSOR_NOT_REQUIRED);
-                        layer.bk = create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_K, "bias", i), { n_embd_k_gqa }, llama_model_loader::TENSOR_NOT_REQUIRED);
-                        layer.bv = create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_V, "bias", i), { n_embd_v_gqa }, llama_model_loader::TENSOR_NOT_REQUIRED);
+                        layer.bq = create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_Q, "bias", i), { n_embd_head_k * n_head }, 0);
+                        layer.bk = create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_K, "bias", i), { n_embd_k_gqa }, 0);
+                        layer.bv = create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_V, "bias", i), { n_embd_v_gqa }, 0);
 
                         layer.wo = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), { n_embd_head_k * n_head, n_embd }, 0);
 
@@ -9093,29 +9221,17 @@ static bool llm_load_tensors(
                         layer.attn_k_norm = create_tensor(ctx_layer, 
                             tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), { n_embd_head_k }, llama_model_loader::TENSOR_NOT_REQUIRED);
 
-                        layer.ffn_norm = create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "weight", i), { n_embd }, 0);
+                        layer.attn_post_norm = create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_POST_NORM, "weight", i), { n_embd }, 0);
 
                         // Check if this layer uses MoE or dense FFN based on n_layer_dense_lead
                         // GLM 4.5 uses hybrid architecture: layer 0 is dense, layers 1+ are MoE
-                        const bool use_moe =
-                            (hparams.n_expert > 0) && (static_cast(i) >= hparams.n_layer_dense_lead);
+                        const bool use_moe = (static_cast(i) >= hparams.n_layer_dense_lead);
 
                         if (use_moe) {
                             // MoE layers
-                            layer.ffn_gate_inp =
-                                create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), { n_embd, n_expert }, 0);
+                            layer.ffn_gate_inp = create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), { n_embd, n_expert }, 0);
                             // gate bias
-                            layer.ffn_exp_probs_b =
-                                create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_GATE_INP, "bias", i), { n_expert },
-                                              llama_model_loader::TENSOR_NOT_REQUIRED);
-
-                            if (n_expert == 0) {
-                                GGML_ASSERT(hparams.n_expert > 0 && "n_expert must be > 0 for GLM4_MOE MoE layers");
-                            }
-                            if (n_expert_used == 0) {
-                                GGML_ASSERT(hparams.n_expert_used > 0 &&
-                                            "n_expert_used must be > 0 for GLM4_MOE MoE layers");
-                            }
+                            layer.ffn_exp_probs_b = create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_EXP_PROBS_B, "bias", i), { n_expert }, 0);
 
                             // MoE branch
                             const int64_t n_ff_exp = hparams.n_ff_exp ? hparams.n_ff_exp : n_ff / n_expert_used;
@@ -9134,8 +9250,8 @@ static bool llm_load_tensors(
                                     tn(LLM_TENSOR_FFN_GATE_SHEXP, "weight", i), { n_embd, n_ff_shexp }, 0);
                                 layer.ffn_down_shexp = create_tensor(ctx_split, 
                                     tn(LLM_TENSOR_FFN_DOWN_SHEXP, "weight", i), { n_ff_shexp, n_embd }, 0);
-                                layer.ffn_up_shexp =
-                                    create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP_SHEXP, "weight", i), { n_embd, n_ff_shexp }, 0);
+                                layer.ffn_up_shexp = create_tensor(ctx_split, 
+                                    tn(LLM_TENSOR_FFN_UP_SHEXP, "weight", i), { n_embd, n_ff_shexp }, 0);
                             }
                         } else {
                             // Dense layers (first k layers) - GLM uses separate gate/up projections
@@ -9143,6 +9259,40 @@ static bool llm_load_tensors(
                             layer.ffn_down = create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd }, 0);
                             layer.ffn_up   = create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP, "weight", i), { n_embd, n_ff }, 0);
                         }
+                        // --- NextN / MTP tensors (preserved but unused), on the final layer ---
+                        if (hparams.nextn_predict_layers > 0 && static_cast(i) >= n_layer - hparams.nextn_predict_layers) {
+                            const int final_layer = n_layer - 1;
+                            // EH_PROJ: [2*embd, embd]
+                            layer.nextn.eh_proj          = create_tensor(ctx_for_layer(final_layer),
+                                        tn(LLM_TENSOR_NEXTN_EH_PROJ, "weight", final_layer),
+                                        { 2*n_embd, n_embd },
+                                        llama_model_loader::TENSOR_NOT_REQUIRED);
+                            // EMBED_TOKENS: [embd, vocab]
+                            layer.nextn.embed_tokens     = create_tensor(ctx_for_layer(final_layer),
+                                        tn(LLM_TENSOR_NEXTN_EMBED_TOKENS, "weight", final_layer),
+                                        { n_embd, n_vocab },
+                                        llama_model_loader::TENSOR_NOT_REQUIRED);
+                            // ENORM, HNORM: [embd]
+                            layer.nextn.enorm            = create_tensor(ctx_for_layer(final_layer),
+                                        tn(LLM_TENSOR_NEXTN_ENORM, "weight", final_layer),
+                                        { n_embd },
+                                        llama_model_loader::TENSOR_NOT_REQUIRED);
+                            layer.nextn.hnorm            = create_tensor(ctx_for_layer(final_layer),
+                                        tn(LLM_TENSOR_NEXTN_HNORM, "weight", final_layer),
+                                        { n_embd },
+                                        llama_model_loader::TENSOR_NOT_REQUIRED);
+                            // SHARED_HEAD_HEAD: [embd, vocab]
+                            layer.nextn.shared_head_head = create_tensor(ctx_for_layer(final_layer),
+                                        tn(LLM_TENSOR_NEXTN_SHARED_HEAD_HEAD, "weight", final_layer),
+                                        { n_embd, n_vocab },
+                                        llama_model_loader::TENSOR_NOT_REQUIRED);
+                            // SHARED_HEAD_NORM: [embd]
+                            layer.nextn.shared_head_norm = create_tensor(ctx_for_layer(final_layer),
+                                        tn(LLM_TENSOR_NEXTN_SHARED_HEAD_NORM, "weight", final_layer),
+                                        { n_embd },
+                                        llama_model_loader::TENSOR_NOT_REQUIRED);
+                        }
+
                     }
                 }
                 break;
@@ -10174,6 +10324,10 @@ static struct ggml_tensor * llm_build_ffn(
 
     if (down) {
         cur = llm_build_lora_mm(lctx, ctx, down, cur);
+        if (lctx.model.arch == LLM_ARCH_GLM4 || lctx.model.arch == LLM_ARCH_GLM4_MOE) {
+            // GLM4 and GLM4_MOE seem to have numerical issues with half-precision accumulators
+            ggml_mul_mat_set_prec(cur, GGML_PREC_F32);
+        }
     }
 
     if (down_b) {
@@ -10522,6 +10676,10 @@ static struct ggml_tensor * llm_build_kqv(
 
     if (wo) {
         cur = llm_build_lora_mm(lctx, ctx, wo, cur);
+        if (lctx.model.arch == LLM_ARCH_GLM4 || lctx.model.arch == LLM_ARCH_GLM4_MOE) {
+            // GLM4 and GLM4_MOE seem to have numerical issues with half-precision accumulators
+            ggml_mul_mat_set_prec(cur, GGML_PREC_F32);
+        }
     }
 
     if (wo_b) {
@@ -16220,8 +16378,11 @@ struct llm_build_context {
     
         // output token IDs (for last layer cropping)
         struct ggml_tensor * inp_out_ids = build_inp_out_ids();
-    
-        for (int il = 0; il < n_layer; ++il) {
+
+        // Only process up to last layer (skip final NextN layer)
+        // Final layer tensors are loaded but not processed in forward pass
+        const int n_transformer_layers = n_layer - hparams.nextn_predict_layers;
+        for (int il = 0; il < n_transformer_layers; ++il) {
             struct ggml_tensor * inpSA = inpL;
     
             // Pre-attention norm
@@ -16283,14 +16444,14 @@ struct llm_build_context {
     
                 // build attention KV (no unified cache)
                 cur = llm_build_kv(ctx0, lctx, kv_self, gf,
-                                   model.layers[il].wo, model.layers[il].bo,
+                                   model.layers[il].wo, NULL,
                                    Kcur, Vcur, Qcur, KQ_mask,
                                    n_tokens, kv_head, n_kv,
                                    1.0f/sqrtf(float(n_embd_head)), cb, il);
             }
     
             // crop output on last layer
-            if (il == n_layer - 1) {
+            if (il == n_transformer_layers - 1 && inp_out_ids) {
                 // skip computing output for unused tokens
                 ggml_tensor * inp_out_ids = build_inp_out_ids();
                 cur   = ggml_get_rows(ctx0,   cur, inp_out_ids);
@@ -16301,11 +16462,11 @@ struct llm_build_context {
             struct ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
             cb(ffn_inp, "ffn_inp", il);
     
-            // FFN / MoE
+            // Post-attention norm
             cur = llm_build_norm(ctx0, ffn_inp, hparams,
-                                 model.layers[il].ffn_norm, NULL,
+                                 model.layers[il].attn_post_norm, NULL,
                                  LLM_NORM_RMS, cb, il);
-            cb(cur, "ffn_norm", il);
+            cb(cur, "post_attn_norm", il);
     
             if ((uint32_t) il < hparams.n_layer_dense_lead) {
                 // dense FFN
@@ -16318,7 +16479,7 @@ struct llm_build_context {
                 cb(cur, "ffn_out", il);
             } else {
                 // MoE FFN
-                struct ggml_tensor * moe_out = llm_build_moe_ffn(ctx0, lctx, cur,
+                struct ggml_tensor * routed_out = llm_build_moe_ffn(ctx0, lctx, cur,
                                             model.layers[il].ffn_gate_inp,
                                             model.layers[il].ffn_up_exps,
                                             model.layers[il].ffn_gate_exps,
@@ -16329,18 +16490,18 @@ struct llm_build_context {
                                             true, hparams.expert_weights_scale,
                                             (enum llm_expert_gating_func_type) hparams.expert_gating_func,
                                             cb, il);
-                cb(moe_out, "ffn_moe_out", il);
+                cb(routed_out, "routed_out", il);
 
                 {
-                    struct ggml_tensor * shexp_out = llm_build_ffn(ctx0, lctx, cur,
+                    struct ggml_tensor * shared_out = llm_build_ffn(ctx0, lctx, cur,
                                                 model.layers[il].ffn_up_shexp, NULL, NULL,
                                                 model.layers[il].ffn_gate_shexp, NULL, NULL,
                                                 model.layers[il].ffn_down_shexp, NULL, NULL,
                                                 NULL,
                                                 LLM_FFN_SILU, LLM_FFN_PAR, cb, il);
-                    cb(shexp_out, "ffn_shexp_out", il);
+                    cb(shared_out, "ffn_shexp_out", il);
         
-                    cur = ggml_add(ctx0, moe_out, shexp_out);
+                    cur = ggml_add(ctx0, routed_out, shared_out);
                     cb(cur, "ffn_out", il);
                 }
             }
@@ -23555,6 +23716,36 @@ llama_token llama_token_eot(const struct llama_model * model) {
     return llama_token_eot_impl(model->vocab);
 }
 
+// deprecated
+llama_token llama_token_fim_pre(const struct llama_model * model) {
+    return llama_token_fim_pre_impl(model->vocab);
+}
+
+// deprecated
+llama_token llama_token_fim_suf(const struct llama_model * model) {
+    return llama_token_fim_suf_impl(model->vocab);
+}
+
+// deprecated
+llama_token llama_token_fim_mid(const struct llama_model * model) {
+    return llama_token_fim_mid_impl(model->vocab);
+}
+
+// deprecated
+llama_token llama_token_fim_pad(const struct llama_model * model) {
+    return llama_token_fim_pad_impl(model->vocab);
+}
+
+// deprecated
+llama_token llama_token_fim_rep(const struct llama_model * model) {
+    return llama_token_fim_rep_impl(model->vocab);
+}
+
+// deprecated
+llama_token llama_token_fim_sep(const struct llama_model * model) {
+    return llama_token_fim_sep_impl(model->vocab);
+}
+
 //
 // tokenization
 //

From 3f3e384fabe25ac916e12aba8c0f7462ec3c6e84 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Thireus=20=E2=98=A0?= 
Date: Tue, 5 Aug 2025 13:56:08 +0100
Subject: [PATCH 09/13] Handle TENSOR_SKIP

Ported the hanges from:

https://github.com/sammcj/llama.cpp/commit/f129567dc0232272358ea71c5017486554b2abd3
https://github.com/sammcj/llama.cpp/commit/dcbbd2cb057a6c6e907e0195395a74201ef19e1b

Except op info since ik_llama.cpp doesn't support this operation.
---
 src/llama.cpp | 72 +++++++++++++++++++++++++++++----------------------
 1 file changed, 41 insertions(+), 31 deletions(-)

diff --git a/src/llama.cpp b/src/llama.cpp
index 34bab47e8..7c27d3cfc 100644
--- a/src/llama.cpp
+++ b/src/llama.cpp
@@ -4885,8 +4885,9 @@ struct llama_model_loader {
         return cur;
     }
 
-    static const int TENSOR_NOT_REQUIRED = 1;
-    static const int TENSOR_DUPLICATED   = 2;
+    static const int TENSOR_NOT_REQUIRED = 1 << 0;
+    static const int TENSOR_DUPLICATED   = 1 << 1;
+    static const int TENSOR_SKIP         = 1 << 2;
 
     struct ggml_tensor * create_tensor(struct ggml_context * ctx, const std::string & name, const std::vector & ne, int flags = 0) {
         const struct ggml_tensor * cur = check_tensor_dims(name, ne, !(flags & TENSOR_NOT_REQUIRED));
@@ -7581,6 +7582,10 @@ static bool llm_load_tensors(
 
     LLAMA_LOG_INFO("%s: ggml ctx size = %7.2f MiB\n", __func__, model.ctxs.size()*ctx_size/1024.0/1024.0);
 
+    const auto TENSOR_DUPLICATED   = llama_model_loader::TENSOR_DUPLICATED;
+    const auto TENSOR_NOT_REQUIRED = llama_model_loader::TENSOR_NOT_REQUIRED;
+    const auto TENSOR_SKIP         = llama_model_loader::TENSOR_SKIP;
+
     // create tensors for the weights
     {
         // note: cast to int64_t since we will use these for the tensor dimensions
@@ -9201,27 +9206,33 @@ static bool llm_load_tensors(
                         ggml_context * ctx_layer = ctx_for_layer(i);
                         ggml_context * ctx_split = ctx_for_layer_split(i);
 
+                        int flags = 0;
+                        if (hparams.nextn_predict_layers > 0 && static_cast(i) >= n_layer - hparams.nextn_predict_layers) {
+                            // skip all tensors in the NextN layers
+                            flags |= TENSOR_SKIP;
+                        }
+
                         auto & layer = model.layers[i];
 
-                        layer.attn_norm = create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0);
+                        layer.attn_norm = create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, flags);
 
                         // GLM-style attention with bias terms
-                        layer.wq = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_Q, "weight", i), { n_embd, n_embd_head_k * n_head }, 0);
-                        layer.wk = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_K, "weight", i), { n_embd, n_embd_k_gqa }, 0);
-                        layer.wv = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_V, "weight", i), { n_embd, n_embd_v_gqa }, 0);
-                        layer.bq = create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_Q, "bias", i), { n_embd_head_k * n_head }, 0);
-                        layer.bk = create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_K, "bias", i), { n_embd_k_gqa }, 0);
-                        layer.bv = create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_V, "bias", i), { n_embd_v_gqa }, 0);
+                        layer.wq = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_Q, "weight", i), { n_embd, n_embd_head_k * n_head }, flags);
+                        layer.wk = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_K, "weight", i), { n_embd, n_embd_k_gqa }, flags);
+                        layer.wv = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_V, "weight", i), { n_embd, n_embd_v_gqa }, flags);
+                        layer.bq = create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_Q, "bias", i), { n_embd_head_k * n_head }, flags);
+                        layer.bk = create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_K, "bias", i), { n_embd_k_gqa }, flags);
+                        layer.bv = create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_V, "bias", i), { n_embd_v_gqa }, flags);
 
-                        layer.wo = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), { n_embd_head_k * n_head, n_embd }, 0);
+                        layer.wo = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), { n_embd_head_k * n_head, n_embd }, flags);
 
                         // K/Q norm tensors (optional for GLM-4.5 355B variant)
                         layer.attn_q_norm = create_tensor(ctx_layer, 
-                            tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), { n_embd_head_k }, llama_model_loader::TENSOR_NOT_REQUIRED);
+                            tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), { n_embd_head_k }, llama_model_loader::TENSOR_NOT_REQUIRED | flags);
                         layer.attn_k_norm = create_tensor(ctx_layer, 
-                            tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), { n_embd_head_k }, llama_model_loader::TENSOR_NOT_REQUIRED);
+                            tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), { n_embd_head_k }, llama_model_loader::TENSOR_NOT_REQUIRED | flags);
 
-                        layer.attn_post_norm = create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_POST_NORM, "weight", i), { n_embd }, 0);
+                        layer.attn_post_norm = create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_POST_NORM, "weight", i), { n_embd }, flags);
 
                         // Check if this layer uses MoE or dense FFN based on n_layer_dense_lead
                         // GLM 4.5 uses hybrid architecture: layer 0 is dense, layers 1+ are MoE
@@ -9229,35 +9240,35 @@ static bool llm_load_tensors(
 
                         if (use_moe) {
                             // MoE layers
-                            layer.ffn_gate_inp = create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), { n_embd, n_expert }, 0);
+                            layer.ffn_gate_inp = create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), { n_embd, n_expert }, flags);
                             // gate bias
-                            layer.ffn_exp_probs_b = create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_EXP_PROBS_B, "bias", i), { n_expert }, 0);
+                            layer.ffn_exp_probs_b = create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_EXP_PROBS_B, "bias", i), { n_expert }, flags);
 
                             // MoE branch
                             const int64_t n_ff_exp = hparams.n_ff_exp ? hparams.n_ff_exp : n_ff / n_expert_used;
 
                             layer.ffn_gate_exps = create_tensor(ctx_split, 
-                                tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert }, 0);
+                                tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert }, flags);
                             layer.ffn_down_exps = create_tensor(ctx_split, 
-                                tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), { n_ff_exp, n_embd, n_expert }, 0);
+                                tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), { n_ff_exp, n_embd, n_expert }, flags);
                             layer.ffn_up_exps = create_tensor(ctx_split, 
-                                tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert }, 0);
+                                tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert }, flags);
 
                             // Shared expert
                             if (n_expert_shared > 0) {
                                 const int64_t n_ff_shexp = n_ff_exp * n_expert_shared;
                                 layer.ffn_gate_shexp     = create_tensor(ctx_split, 
-                                    tn(LLM_TENSOR_FFN_GATE_SHEXP, "weight", i), { n_embd, n_ff_shexp }, 0);
+                                    tn(LLM_TENSOR_FFN_GATE_SHEXP, "weight", i), { n_embd, n_ff_shexp }, flags);
                                 layer.ffn_down_shexp = create_tensor(ctx_split, 
-                                    tn(LLM_TENSOR_FFN_DOWN_SHEXP, "weight", i), { n_ff_shexp, n_embd }, 0);
+                                    tn(LLM_TENSOR_FFN_DOWN_SHEXP, "weight", i), { n_ff_shexp, n_embd }, flags);
                                 layer.ffn_up_shexp = create_tensor(ctx_split, 
-                                    tn(LLM_TENSOR_FFN_UP_SHEXP, "weight", i), { n_embd, n_ff_shexp }, 0);
+                                    tn(LLM_TENSOR_FFN_UP_SHEXP, "weight", i), { n_embd, n_ff_shexp }, flags);
                             }
                         } else {
                             // Dense layers (first k layers) - GLM uses separate gate/up projections
-                            layer.ffn_gate = create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE, "weight", i), { n_embd, n_ff }, 0);
-                            layer.ffn_down = create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd }, 0);
-                            layer.ffn_up   = create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP, "weight", i), { n_embd, n_ff }, 0);
+                            layer.ffn_gate = create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE, "weight", i), { n_embd, n_ff }, flags);
+                            layer.ffn_down = create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd }, flags);
+                            layer.ffn_up   = create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP, "weight", i), { n_embd, n_ff }, flags);
                         }
                         // --- NextN / MTP tensors (preserved but unused), on the final layer ---
                         if (hparams.nextn_predict_layers > 0 && static_cast(i) >= n_layer - hparams.nextn_predict_layers) {
@@ -9266,33 +9277,32 @@ static bool llm_load_tensors(
                             layer.nextn.eh_proj          = create_tensor(ctx_for_layer(final_layer),
                                         tn(LLM_TENSOR_NEXTN_EH_PROJ, "weight", final_layer),
                                         { 2*n_embd, n_embd },
-                                        llama_model_loader::TENSOR_NOT_REQUIRED);
+                                        flags);
                             // EMBED_TOKENS: [embd, vocab]
                             layer.nextn.embed_tokens     = create_tensor(ctx_for_layer(final_layer),
                                         tn(LLM_TENSOR_NEXTN_EMBED_TOKENS, "weight", final_layer),
                                         { n_embd, n_vocab },
-                                        llama_model_loader::TENSOR_NOT_REQUIRED);
+                                        flags);
                             // ENORM, HNORM: [embd]
                             layer.nextn.enorm            = create_tensor(ctx_for_layer(final_layer),
                                         tn(LLM_TENSOR_NEXTN_ENORM, "weight", final_layer),
                                         { n_embd },
-                                        llama_model_loader::TENSOR_NOT_REQUIRED);
+                                        flags);
                             layer.nextn.hnorm            = create_tensor(ctx_for_layer(final_layer),
                                         tn(LLM_TENSOR_NEXTN_HNORM, "weight", final_layer),
                                         { n_embd },
-                                        llama_model_loader::TENSOR_NOT_REQUIRED);
+                                        flags);
                             // SHARED_HEAD_HEAD: [embd, vocab]
                             layer.nextn.shared_head_head = create_tensor(ctx_for_layer(final_layer),
                                         tn(LLM_TENSOR_NEXTN_SHARED_HEAD_HEAD, "weight", final_layer),
                                         { n_embd, n_vocab },
-                                        llama_model_loader::TENSOR_NOT_REQUIRED);
+                                        flags);
                             // SHARED_HEAD_NORM: [embd]
                             layer.nextn.shared_head_norm = create_tensor(ctx_for_layer(final_layer),
                                         tn(LLM_TENSOR_NEXTN_SHARED_HEAD_NORM, "weight", final_layer),
                                         { n_embd },
-                                        llama_model_loader::TENSOR_NOT_REQUIRED);
+                                        flags);
                         }
-
                     }
                 }
                 break;

From a3641e6646d77d3f0b36798ca313b50e41386f72 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Thireus=20=E2=98=A0?= 
Date: Tue, 5 Aug 2025 14:54:36 +0100
Subject: [PATCH 10/13] Bugfix for TENSOR_SKIP

skip loading if a tensor has the TENSOR_SKIP flag - @ubergarm via https://github.com/ikawrakow/ik_llama.cpp/pull/668#issuecomment-3155297198
---
 src/llama.cpp | 11 +++++++++++
 1 file changed, 11 insertions(+)

diff --git a/src/llama.cpp b/src/llama.cpp
index 7c27d3cfc..ca68be504 100644
--- a/src/llama.cpp
+++ b/src/llama.cpp
@@ -4896,6 +4896,17 @@ struct llama_model_loader {
             return NULL;
         }
 
+        // skip unused tensors
+        if (flags & TENSOR_SKIP) {
+            const size_t nbytes = ggml_nbytes(cur);
+            LLAMA_LOG_WARN("model has unused tensor %s (size = %zu bytes) -- ignoring\n", name.c_str(), nbytes);
+
+            size_data -= nbytes;
+            n_created++;
+
+            return nullptr;
+        }
+
         return create_tensor_for(ctx, cur, flags & TENSOR_DUPLICATED);
     }
 

From d97ebefd691dfda9adc977721959e77069647052 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Thireus=20=E2=98=A0?= 
Date: Tue, 5 Aug 2025 17:35:37 +0100
Subject: [PATCH 11/13] Update llama.cpp

Restore original GGLM_ASSERT
---
 src/llama.cpp | 12 +-----------
 1 file changed, 1 insertion(+), 11 deletions(-)

diff --git a/src/llama.cpp b/src/llama.cpp
index ca68be504..343fa2d3e 100644
--- a/src/llama.cpp
+++ b/src/llama.cpp
@@ -20712,17 +20712,7 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s
     //  - qs.n_attention_wv == 3 * model.hparams.n_layer for Encoder-Decoder models
     //  - model.arch == LLM_ARCH_DECI                    for Deci-Nemotron   models
     //
-    //GGML_ASSERT((qs.n_attention_wv == 0 || qs.n_attention_wv == (int)model.hparams.n_layer || qs.n_attention_wv == 3 * (int)model.hparams.n_layer || model.arch == LLM_ARCH_DECI) && "n_attention_wv is unexpected");
-    // allow any count for GLM4-MoE, but still enforce for all others
-    if (model.arch != LLM_ARCH_GLM4_MOE) {
-        GGML_ASSERT(
-             qs.n_attention_wv == 0
-          || qs.n_attention_wv == (int)model.hparams.n_layer
-          || qs.n_attention_wv == 3 * (int)model.hparams.n_layer
-          || model.arch == LLM_ARCH_DECI
-          && "n_attention_wv is unexpected"
-        );
-    }
+    GGML_ASSERT((qs.n_attention_wv == 0 || qs.n_attention_wv == (int)model.hparams.n_layer || qs.n_attention_wv == 3 * (int)model.hparams.n_layer || model.arch == LLM_ARCH_DECI) && "n_attention_wv is unexpected");
     
     size_t total_size_org = 0;
     size_t total_size_new = 0;

From 41a235bbcad55e3fec7c07755793c52f2f2c61b5 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Thireus=20=E2=98=A0?= 
Date: Tue, 5 Aug 2025 18:16:08 +0100
Subject: [PATCH 12/13] Fix chat template detection

Changes suggested by @ubergarm - https://github.com/ikawrakow/ik_llama.cpp/pull/668#issuecomment-3155927840
---
 src/llama.cpp | 34 ++++++++++++++++++++++------------
 1 file changed, 22 insertions(+), 12 deletions(-)

diff --git a/src/llama.cpp b/src/llama.cpp
index 343fa2d3e..ff4d23b5d 100644
--- a/src/llama.cpp
+++ b/src/llama.cpp
@@ -1740,8 +1740,8 @@ enum llm_chat_template {
     LLM_CHAT_TEMPLATE_DEEPSEEK_3,
     LLM_CHAT_TEMPLATE_COMMAND_R,
     LLM_CHAT_TEMPLATE_LLAMA_3,
-    LLM_CHAT_TEMPLATE_CHATGML_3,
-    LLM_CHAT_TEMPLATE_CHATGML_4,
+    LLM_CHAT_TEMPLATE_CHATGLM_3,
+    LLM_CHAT_TEMPLATE_CHATGLM_4,
     LLM_CHAT_TEMPLATE_MINICPM,
     LLM_CHAT_TEMPLATE_EXAONE_3,
     LLM_CHAT_TEMPLATE_RWKV_WORLD,
@@ -1781,8 +1781,8 @@ static const std::map LLM_CHAT_TEMPLATES = {
     { "deepseek3",         LLM_CHAT_TEMPLATE_DEEPSEEK_3        },
     { "command-r",         LLM_CHAT_TEMPLATE_COMMAND_R         },
     { "llama3",            LLM_CHAT_TEMPLATE_LLAMA_3           },
-    { "chatglm3",          LLM_CHAT_TEMPLATE_CHATGML_3         },
-    { "chatglm4",          LLM_CHAT_TEMPLATE_CHATGML_4         },
+    { "chatglm3",          LLM_CHAT_TEMPLATE_CHATGLM_3         },
+    { "chatglm4",          LLM_CHAT_TEMPLATE_CHATGLM_4         },
     { "minicpm",           LLM_CHAT_TEMPLATE_MINICPM           },
     { "exaone3",           LLM_CHAT_TEMPLATE_EXAONE_3          },
     { "rwkv-world",        LLM_CHAT_TEMPLATE_RWKV_WORLD        },
@@ -20712,7 +20712,17 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s
     //  - qs.n_attention_wv == 3 * model.hparams.n_layer for Encoder-Decoder models
     //  - model.arch == LLM_ARCH_DECI                    for Deci-Nemotron   models
     //
-    GGML_ASSERT((qs.n_attention_wv == 0 || qs.n_attention_wv == (int)model.hparams.n_layer || qs.n_attention_wv == 3 * (int)model.hparams.n_layer || model.arch == LLM_ARCH_DECI) && "n_attention_wv is unexpected");
+    //GGML_ASSERT((qs.n_attention_wv == 0 || qs.n_attention_wv == (int)model.hparams.n_layer || qs.n_attention_wv == 3 * (int)model.hparams.n_layer || model.arch == LLM_ARCH_DECI) && "n_attention_wv is unexpected");
+    // allow any count for GLM4-MoE, but still enforce for all others
+    if (model.arch != LLM_ARCH_GLM4_MOE) {
+        GGML_ASSERT(
+             qs.n_attention_wv == 0
+          || qs.n_attention_wv == (int)model.hparams.n_layer
+          || qs.n_attention_wv == 3 * (int)model.hparams.n_layer
+          || model.arch == LLM_ARCH_DECI
+          && "n_attention_wv is unexpected"
+        );
+    }
     
     size_t total_size_org = 0;
     size_t total_size_new = 0;
@@ -23841,6 +23851,11 @@ static llm_chat_template llama_chat_detect_template(const std::string & tmpl) {
                 return LLM_CHAT_TEMPLATE_LLAMA_2;
             }
         }
+    } else if (tmpl_contains("[gMASK]sop")) {
+        // chatglm3-6b
+        return LLM_CHAT_TEMPLATE_CHATGLM_3;
+    } else if (tmpl_contains("[gMASK]")) {
+        return LLM_CHAT_TEMPLATE_CHATGLM_4;
     } else if (tmpl_contains("<|assistant|>") && tmpl_contains("<|end|>")) {
         return LLM_CHAT_TEMPLATE_PHI_3;
     } else if (tmpl_contains("<|assistant|>") && tmpl_contains("<|user|>")) {
@@ -23873,11 +23888,6 @@ static llm_chat_template llama_chat_detect_template(const std::string & tmpl) {
         return LLM_CHAT_TEMPLATE_COMMAND_R;
     } else if (tmpl_contains("<|start_header_id|>") && tmpl_contains("<|end_header_id|>")) {
         return LLM_CHAT_TEMPLATE_LLAMA_3;
-    } else if (tmpl_contains("[gMASK]sop")) {
-        // chatglm3-6b
-        return LLM_CHAT_TEMPLATE_CHATGML_3;
-    } else if (tmpl_contains("[gMASK]")) {
-        return LLM_CHAT_TEMPLATE_CHATGML_4;
     } else if (tmpl_contains(LU8("<用户>"))) {
         // MiniCPM-3B-OpenHermes-2.5-v2-GGUF
         return LLM_CHAT_TEMPLATE_MINICPM;
@@ -24160,7 +24170,7 @@ static int32_t llama_chat_apply_template_internal(
         if (add_ass) {
             ss << "<|start_header_id|>assistant<|end_header_id|>\n\n";
         }
-    } else if (tmpl == LLM_CHAT_TEMPLATE_CHATGML_3) {
+    } else if (tmpl == LLM_CHAT_TEMPLATE_CHATGLM_3) {
         // chatglm3-6b
         ss << "[gMASK]" << "sop";
         for (auto message : chat) {
@@ -24170,7 +24180,7 @@ static int32_t llama_chat_apply_template_internal(
         if (add_ass) {
             ss << "<|assistant|>";
         }
-    } else if (tmpl == LLM_CHAT_TEMPLATE_CHATGML_4) {
+    } else if (tmpl == LLM_CHAT_TEMPLATE_CHATGLM_4) {
         ss << "[gMASK]" << "";
         for (auto message : chat) {
             std::string role(message->role);

From 323e7f3dd647d18940998fb152a8852d3faf0ca6 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Thireus=20=E2=98=A0?= 
Date: Tue, 5 Aug 2025 18:45:30 +0100
Subject: [PATCH 13/13] Revert to original GGML_ASSERT

---
 src/llama.cpp | 14 ++------------
 1 file changed, 2 insertions(+), 12 deletions(-)

diff --git a/src/llama.cpp b/src/llama.cpp
index ff4d23b5d..47e26a83a 100644
--- a/src/llama.cpp
+++ b/src/llama.cpp
@@ -20712,18 +20712,8 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s
     //  - qs.n_attention_wv == 3 * model.hparams.n_layer for Encoder-Decoder models
     //  - model.arch == LLM_ARCH_DECI                    for Deci-Nemotron   models
     //
-    //GGML_ASSERT((qs.n_attention_wv == 0 || qs.n_attention_wv == (int)model.hparams.n_layer || qs.n_attention_wv == 3 * (int)model.hparams.n_layer || model.arch == LLM_ARCH_DECI) && "n_attention_wv is unexpected");
-    // allow any count for GLM4-MoE, but still enforce for all others
-    if (model.arch != LLM_ARCH_GLM4_MOE) {
-        GGML_ASSERT(
-             qs.n_attention_wv == 0
-          || qs.n_attention_wv == (int)model.hparams.n_layer
-          || qs.n_attention_wv == 3 * (int)model.hparams.n_layer
-          || model.arch == LLM_ARCH_DECI
-          && "n_attention_wv is unexpected"
-        );
-    }
-    
+    GGML_ASSERT((qs.n_attention_wv == 0 || qs.n_attention_wv == (int)model.hparams.n_layer || qs.n_attention_wv == 3 * (int)model.hparams.n_layer || model.arch == LLM_ARCH_DECI) && "n_attention_wv is unexpected");
+
     size_t total_size_org = 0;
     size_t total_size_new = 0;