Skip to content

Commit ef0144c

Browse files
sammcjCISCslaren
authored
model: support GLM 4.5 family of models (#14939)
* model: Add GLM 4.5 (#14921) Co-authored-by: Sigbjørn Skjæret <[email protected]> * Merge in PR suggestions Co-authored-by: Sigbjørn Skjæret <[email protected]> * model: Add GLM 4.5 family of models (#14921) 1. Updated tensor_mapping.py with NextN tensor mappings - Added proper tensor mappings for all NextN/MTP tensors in /Users/samm/git/llama.cpp/gguf-py/gguf/tensor_mapping.py - Added mappings for: eh_proj, embed_tokens, enorm, hnorm, shared_head.head, shared_head.norm 2. Added num_nextn_predict_layers configuration - Added LLM_KV_NUM_NEXTN_PREDICT_LAYERS constant to llama-arch.h and llama-arch.cpp - Added num_nextn_predict_layers field to llama_hparams struct - Updated GLM4_MOE parameter loading in llama-model.cpp to read this parameter - Modified tensor loading logic to conditionally load NextN tensors based on num_nextn_predict_layers - Added GGUF writer support in gguf_writer.py with add_num_nextn_predict_layers() method - Updated conversion script to extract and write this parameter from HuggingFace config 3. Added FIM tokens for GLM4_MOE - Added GLM-4.5's FIM tokens to llama-vocab.cpp: - <|code_prefix|> for FIM_PRE - <|code_suffix|> for FIM_SUF - <|code_middle|> for FIM_MID 4. Removed manual NextN tensor handling - Removed the special-case handling in convert_hf_to_gguf.py that manually mapped NextN tensors - NextN tensors are now handled automatically through the proper tensor mapping system * glm 4.5 update tensors names * model: glm 4.5 apply suggestions from code review Co-authored-by: Sigbjørn Skjæret <[email protected]> * Update src/llama-model.cpp Co-authored-by: Sigbjørn Skjæret <[email protected]> * model: glm 4.5 apply suggestions from code review Co-authored-by: Sigbjørn Skjæret <[email protected]> * model: glm 4.5 apply suggestions from code review * Apply suggestions from code review * patch broken chat template * typings fix * add TENSOR_SKIP flag Co-authored-by: Diego Devesa <[email protected]> * Update src/llama-model-loader.h Co-authored-by: Sigbjørn Skjæret <[email protected]> --------- Co-authored-by: Sigbjørn Skjæret <[email protected]> Co-authored-by: Diego Devesa <[email protected]>
1 parent 2721257 commit ef0144c

15 files changed

+594
-8
lines changed

convert_hf_to_gguf.py

Lines changed: 136 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -678,6 +678,9 @@ def get_vocab_base_pre(self, tokenizer) -> str:
678678
if chkhsh == "a1336059768a55c99a734006ffb02203cd450fed003e9a71886c88acf24fdbc2":
679679
# ref: https://huggingface.co/THUDM/glm-4-9b-hf
680680
res = "glm4"
681+
if chkhsh == "9ca2dd618e8afaf09731a7cf6e2105b373ba6a1821559f258b272fe83e6eb902":
682+
# ref: https://huggingface.co/zai-org/GLM-4.5-Air
683+
res = "glm4"
681684
if chkhsh == "1431a23e583c97432bc230bff598d103ddb5a1f89960c8f1d1051aaa944d0b35":
682685
# ref: https://huggingface.co/sapienzanlp/Minerva-7B-base-v1.0
683686
res = "minerva-7b"
@@ -6696,6 +6699,139 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter
66966699
return super().modify_tensors(data_torch, name, bid)
66976700

66986701

6702+
@ModelBase.register("Glm4MoeForCausalLM")
6703+
class Glm4MoeModel(TextModel):
6704+
model_arch = gguf.MODEL_ARCH.GLM4_MOE
6705+
6706+
def __init__(self, *args, **kwargs):
6707+
super().__init__(*args, **kwargs)
6708+
# GLM4_MOE has num_hidden_layers + 1 actual layers (including NextN layer)
6709+
self.block_count = self.hparams["num_hidden_layers"] + self.hparams.get("num_nextn_predict_layers", 0)
6710+
self.tensor_map = gguf.get_tensor_name_map(self.model_arch, self.block_count)
6711+
6712+
def set_vocab(self):
6713+
from transformers import AutoTokenizer
6714+
6715+
tokenizer = AutoTokenizer.from_pretrained(self.dir_model)
6716+
special_vocab = gguf.SpecialVocab(self.dir_model, load_merges=True)
6717+
tokens, toktypes, tokpre = self.get_vocab_base()
6718+
self.gguf_writer.add_tokenizer_model("gpt2")
6719+
self.gguf_writer.add_tokenizer_pre(tokpre)
6720+
self.gguf_writer.add_token_list(tokens)
6721+
self.gguf_writer.add_token_types(toktypes)
6722+
6723+
# Special tokens
6724+
# Note: Using <|endoftext|> (151329) for eot causes endless generation
6725+
special_vocab._set_special_token("bos", tokenizer.get_added_vocab()["[gMASK]"]) # 151331
6726+
special_vocab._set_special_token("eot", tokenizer.get_added_vocab()["<|user|>"]) # 151336
6727+
special_vocab._set_special_token("unk", tokenizer.get_added_vocab()["<|endoftext|>"]) # 151329
6728+
special_vocab._set_special_token("eom", tokenizer.get_added_vocab()["<|observation|>"]) # 151338
6729+
6730+
# Patch broken chat template
6731+
if isinstance(special_vocab.chat_template, str) and "visible_text(m.content).endswith" in special_vocab.chat_template:
6732+
special_vocab.chat_template = special_vocab.chat_template.replace(
6733+
"""{{ visible_text(m.content) }}\n{{- '/nothink' if (enable_thinking is defined and not enable_thinking and not visible_text(m.content).endswith("/nothink")) else '' -}}""",
6734+
"""{% set content = visible_text(m.content) %}{{ content }}\n{{- '/nothink' if (enable_thinking is defined and not enable_thinking and not content.endswith("/nothink")) else '' -}}""")
6735+
6736+
special_vocab.add_to_gguf(self.gguf_writer)
6737+
6738+
def set_gguf_parameters(self):
6739+
super().set_gguf_parameters()
6740+
if (rope_dim := self.hparams.get("head_dim")) is None:
6741+
rope_dim = (
6742+
self.hparams["hidden_size"] // self.hparams["num_attention_heads"]
6743+
)
6744+
self.gguf_writer.add_rope_dimension_count(
6745+
int(rope_dim * self.hparams.get("partial_rotary_factor", 0.5))
6746+
)
6747+
6748+
# MoE parameters - Use only routed expert count (shared experts handled separately)
6749+
if (n_routed_experts := self.hparams.get("n_routed_experts")) is not None:
6750+
self.gguf_writer.add_expert_count(n_routed_experts)
6751+
if (moe_intermediate_size := self.hparams.get("moe_intermediate_size")) is not None:
6752+
self.gguf_writer.add_expert_feed_forward_length(moe_intermediate_size)
6753+
if (n_shared_experts := self.hparams.get("n_shared_experts")) is not None:
6754+
self.gguf_writer.add_expert_shared_count(n_shared_experts)
6755+
if (first_k_dense_replace := self.hparams.get("first_k_dense_replace")) is not None:
6756+
self.gguf_writer.add_leading_dense_block_count(first_k_dense_replace)
6757+
6758+
# Expert gating function (sigmoid for GLM4_MOE)
6759+
self.gguf_writer.add_expert_gating_func(gguf.ExpertGatingFuncType.SIGMOID)
6760+
6761+
# Routed scaling factor
6762+
if (routed_scaling_factor := self.hparams.get("routed_scaling_factor")) is not None:
6763+
self.gguf_writer.add_expert_weights_scale(routed_scaling_factor)
6764+
6765+
# Normalise topk probabilities
6766+
if (norm_topk_prob := self.hparams.get("norm_topk_prob")) is not None:
6767+
self.gguf_writer.add_expert_weights_norm(norm_topk_prob)
6768+
6769+
# NextN/MTP prediction layers
6770+
if (num_nextn_predict_layers := self.hparams.get("num_nextn_predict_layers")) is not None:
6771+
self.gguf_writer.add_nextn_predict_layers(num_nextn_predict_layers)
6772+
6773+
_experts: list[dict[str, Tensor]] | None = None
6774+
6775+
def modify_tensors(
6776+
self, data_torch: Tensor, name: str, bid: int | None
6777+
) -> Iterable[tuple[str, Tensor]]:
6778+
if name.startswith("model.visual."): # ignore visual part
6779+
return []
6780+
elif name.startswith("model.language_model."):
6781+
name = name.replace("language_model.", "") # for multimodal variants
6782+
6783+
# Handle main token embedding (but not layer-specific NextN embeddings)
6784+
if name == "model.embed_tokens.weight" and ".layers." not in name:
6785+
return [(self.map_tensor_name("token_embd.weight"), data_torch)]
6786+
6787+
# Handle routed experts
6788+
if name.find("mlp.experts") != -1:
6789+
n_experts = self.hparams["n_routed_experts"]
6790+
assert bid is not None
6791+
6792+
if self._experts is None:
6793+
self._experts = [{} for _ in range(self.block_count)]
6794+
6795+
self._experts[bid][name] = data_torch
6796+
6797+
if len(self._experts[bid]) >= n_experts * 3:
6798+
tensors: list[tuple[str, Tensor]] = []
6799+
6800+
# merge the experts into a single 3d tensor
6801+
for w_name in ["down_proj", "gate_proj", "up_proj"]:
6802+
datas: list[Tensor] = []
6803+
6804+
for xid in range(n_experts):
6805+
ename = f"model.layers.{bid}.mlp.experts.{xid}.{w_name}.weight"
6806+
datas.append(self._experts[bid][ename])
6807+
del self._experts[bid][ename]
6808+
6809+
data_torch = torch.stack(datas, dim=0)
6810+
6811+
merged_name = f"model.layers.{bid}.mlp.experts.{w_name}.weight"
6812+
6813+
new_name = self.map_tensor_name(merged_name)
6814+
tensors.append((new_name, data_torch))
6815+
return tensors
6816+
else:
6817+
return []
6818+
6819+
if name.endswith("e_score_correction_bias"):
6820+
name = name.replace("e_score_correction_bias", "e_score_correction.bias")
6821+
6822+
new_name = self.map_tensor_name(name)
6823+
6824+
return [(new_name, data_torch)]
6825+
6826+
def prepare_tensors(self):
6827+
super().prepare_tensors()
6828+
if self._experts is not None:
6829+
# flatten `list[dict[str, Tensor]]` into `list[str]`
6830+
experts = [k for d in self._experts for k in d.keys()]
6831+
if len(experts) > 0:
6832+
raise ValueError(f"Unprocessed experts: {experts}")
6833+
6834+
66996835
@ModelBase.register("GlmForCausalLM", "ChatGLMModel", "ChatGLMForConditionalGeneration")
67006836
class ChatGLMModel(TextModel):
67016837
model_arch = gguf.MODEL_ARCH.CHATGLM

convert_hf_to_gguf_update.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -147,6 +147,7 @@ class TOKENIZER_TYPE(IntEnum):
147147
{"name": "chatglm-bpe", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/THUDM/glm-4-9b-chat", "chkhsh": "b6e8e1518dc4305be2fe39c313ed643381c4da5db34a98f6a04c093f8afbe99b"},
148148
{"name": "chatglm-bpe", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/THUDM/glm-4-9b-chat", "chkhsh": "81d72c7348a9f0ebe86f23298d37debe0a5e71149e29bd283904c02262b27516"},
149149
{"name": "glm4", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/THUDM/glm-4-9b-hf", "chkhsh": "a1336059768a55c99a734006ffb02203cd450fed003e9a71886c88acf24fdbc2"},
150+
{"name": "glm4", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/zai-org/GLM-4.5-Air", "chkhsh": "9ca2dd618e8afaf09731a7cf6e2105b373ba6a1821559f258b272fe83e6eb902"},
150151
{"name": "minerva-7b", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/sapienzanlp/Minerva-7B-base-v1.0", "chkhsh": "1431a23e583c97432bc230bff598d103ddb5a1f89960c8f1d1051aaa944d0b35"},
151152
{"name": "hunyuan", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/tencent/Hunyuan-A13B-Instruct", "chkhsh": "7e57df22b1fe23a7b1e1c7f3dc4e3f96d43a4eb0836d0c6bdc3436d7b2f1c664"},
152153
{"name": "hunyuan-dense", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/tencent/Hunyuan-4B-Instruct", "chkhsh": "bba3b3366b646dbdded5dbc42d59598b849371afc42f7beafa914afaa5b70aa6"},

gguf-py/gguf/constants.py

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,7 @@ class LLM:
105105
EXPERT_WEIGHTS_NORM = "{arch}.expert_weights_norm"
106106
EXPERT_GATING_FUNC = "{arch}.expert_gating_func"
107107
MOE_EVERY_N_LAYERS = "{arch}.moe_every_n_layers"
108+
NEXTN_PREDICT_LAYERS = "{arch}.nextn_predict_layers"
108109
POOLING_TYPE = "{arch}.pooling_type"
109110
LOGIT_SCALE = "{arch}.logit_scale"
110111
DECODER_START_TOKEN_ID = "{arch}.decoder_start_token_id"
@@ -357,6 +358,7 @@ class MODEL_ARCH(IntEnum):
357358
DEEPSEEK2 = auto()
358359
CHATGLM = auto()
359360
GLM4 = auto()
361+
GLM4_MOE = auto()
360362
BITNET = auto()
361363
T5 = auto()
362364
T5ENCODER = auto()
@@ -614,6 +616,13 @@ class MODEL_TENSOR(IntEnum):
614616
A_MMPROJ_FC = auto()
615617
A_MM_NORM_PRE = auto()
616618
A_MM_NORM_MID = auto()
619+
# nextn/mtp
620+
NEXTN_EH_PROJ = auto()
621+
NEXTN_EMBED_TOKENS = auto()
622+
NEXTN_ENORM = auto()
623+
NEXTN_HNORM = auto()
624+
NEXTN_SHARED_HEAD_HEAD = auto()
625+
NEXTN_SHARED_HEAD_NORM = auto()
617626

618627

619628
MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = {
@@ -678,6 +687,7 @@ class MODEL_TENSOR(IntEnum):
678687
MODEL_ARCH.DEEPSEEK2: "deepseek2",
679688
MODEL_ARCH.CHATGLM: "chatglm",
680689
MODEL_ARCH.GLM4: "glm4",
690+
MODEL_ARCH.GLM4_MOE: "glm4moe",
681691
MODEL_ARCH.BITNET: "bitnet",
682692
MODEL_ARCH.T5: "t5",
683693
MODEL_ARCH.T5ENCODER: "t5encoder",
@@ -936,6 +946,13 @@ class MODEL_TENSOR(IntEnum):
936946
MODEL_TENSOR.A_MMPROJ_FC: "mm.a.fc",
937947
MODEL_TENSOR.A_MM_NORM_PRE: "mm.a.norm_pre",
938948
MODEL_TENSOR.A_MM_NORM_MID: "mm.a.norm_mid",
949+
# NextN/MTP
950+
MODEL_TENSOR.NEXTN_EH_PROJ: "blk.{bid}.nextn.eh_proj",
951+
MODEL_TENSOR.NEXTN_EMBED_TOKENS: "blk.{bid}.nextn.embed_tokens",
952+
MODEL_TENSOR.NEXTN_ENORM: "blk.{bid}.nextn.enorm",
953+
MODEL_TENSOR.NEXTN_HNORM: "blk.{bid}.nextn.hnorm",
954+
MODEL_TENSOR.NEXTN_SHARED_HEAD_HEAD: "blk.{bid}.nextn.shared_head_head",
955+
MODEL_TENSOR.NEXTN_SHARED_HEAD_NORM: "blk.{bid}.nextn.shared_head_norm",
939956
}
940957

941958
MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
@@ -2124,6 +2141,37 @@ class MODEL_TENSOR(IntEnum):
21242141
MODEL_TENSOR.ATTN_POST_NORM,
21252142
MODEL_TENSOR.FFN_POST_NORM,
21262143
],
2144+
MODEL_ARCH.GLM4_MOE: [
2145+
MODEL_TENSOR.TOKEN_EMBD,
2146+
MODEL_TENSOR.OUTPUT_NORM,
2147+
MODEL_TENSOR.OUTPUT,
2148+
MODEL_TENSOR.ATTN_NORM,
2149+
MODEL_TENSOR.ATTN_POST_NORM,
2150+
MODEL_TENSOR.ATTN_Q,
2151+
MODEL_TENSOR.ATTN_K,
2152+
MODEL_TENSOR.ATTN_V,
2153+
MODEL_TENSOR.ATTN_OUT,
2154+
MODEL_TENSOR.ATTN_Q_NORM,
2155+
MODEL_TENSOR.ATTN_K_NORM,
2156+
MODEL_TENSOR.FFN_GATE,
2157+
MODEL_TENSOR.FFN_DOWN,
2158+
MODEL_TENSOR.FFN_UP,
2159+
MODEL_TENSOR.FFN_GATE_INP,
2160+
MODEL_TENSOR.FFN_GATE_EXP,
2161+
MODEL_TENSOR.FFN_DOWN_EXP,
2162+
MODEL_TENSOR.FFN_UP_EXP,
2163+
MODEL_TENSOR.FFN_GATE_SHEXP,
2164+
MODEL_TENSOR.FFN_DOWN_SHEXP,
2165+
MODEL_TENSOR.FFN_UP_SHEXP,
2166+
MODEL_TENSOR.FFN_EXP_PROBS_B,
2167+
# NextN/MTP tensors - preserved but unused
2168+
MODEL_TENSOR.NEXTN_EH_PROJ,
2169+
MODEL_TENSOR.NEXTN_EMBED_TOKENS,
2170+
MODEL_TENSOR.NEXTN_ENORM,
2171+
MODEL_TENSOR.NEXTN_HNORM,
2172+
MODEL_TENSOR.NEXTN_SHARED_HEAD_HEAD,
2173+
MODEL_TENSOR.NEXTN_SHARED_HEAD_NORM,
2174+
],
21272175
MODEL_ARCH.BITNET: [
21282176
MODEL_TENSOR.ATTN_Q,
21292177
MODEL_TENSOR.ATTN_K,

gguf-py/gguf/gguf_writer.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -753,6 +753,9 @@ def add_expert_gating_func(self, value: ExpertGatingFuncType) -> None:
753753
def add_moe_every_n_layers(self, value: int) -> None:
754754
self.add_uint32(Keys.LLM.MOE_EVERY_N_LAYERS.format(arch=self.arch), value)
755755

756+
def add_nextn_predict_layers(self, count: int) -> None:
757+
self.add_uint32(Keys.LLM.NEXTN_PREDICT_LAYERS.format(arch=self.arch), count)
758+
756759
def add_swin_norm(self, value: bool) -> None:
757760
self.add_bool(Keys.LLM.SWIN_NORM.format(arch=self.arch), value)
758761

gguf-py/gguf/tensor_mapping.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1369,6 +1369,31 @@ class TensorNameMap:
13691369
MODEL_TENSOR.A_MM_NORM_MID: (
13701370
"audio.multi_modal_projector.ln_mid", # ultravox
13711371
),
1372+
1373+
# NextN/MTP tensors for GLM4_MOE
1374+
MODEL_TENSOR.NEXTN_EH_PROJ: (
1375+
"model.layers.{bid}.eh_proj",
1376+
),
1377+
1378+
MODEL_TENSOR.NEXTN_EMBED_TOKENS: (
1379+
"model.layers.{bid}.embed_tokens",
1380+
),
1381+
1382+
MODEL_TENSOR.NEXTN_ENORM: (
1383+
"model.layers.{bid}.enorm",
1384+
),
1385+
1386+
MODEL_TENSOR.NEXTN_HNORM: (
1387+
"model.layers.{bid}.hnorm",
1388+
),
1389+
1390+
MODEL_TENSOR.NEXTN_SHARED_HEAD_HEAD: (
1391+
"model.layers.{bid}.shared_head.head",
1392+
),
1393+
1394+
MODEL_TENSOR.NEXTN_SHARED_HEAD_NORM: (
1395+
"model.layers.{bid}.shared_head.norm",
1396+
),
13721397
}
13731398

13741399
# architecture-specific block mappings

models/templates/README.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,4 +21,5 @@ These templates can be updated with the following commands:
2121
./scripts/get_chat_template.py Qwen/Qwen2.5-7B-Instruct > models/templates/Qwen-Qwen2.5-7B-Instruct.jinja
2222
./scripts/get_chat_template.py Qwen/QwQ-32B > models/templates/Qwen-QwQ-32B.jinja
2323
./scripts/get_chat_template.py Qwen/Qwen3-0.6B > models/templates/Qwen-Qwen3-0.6B.jinja
24-
```
24+
./scripts/get_chat_template.py zai-org/GLM-4.5 > models/templates/zai-org-GLM-4.5.jinja
25+
```

src/llama-arch.cpp

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,7 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
6262
{ LLM_ARCH_DEEPSEEK2, "deepseek2" },
6363
{ LLM_ARCH_CHATGLM, "chatglm" },
6464
{ LLM_ARCH_GLM4, "glm4" },
65+
{ LLM_ARCH_GLM4_MOE, "glm4moe" },
6566
{ LLM_ARCH_BITNET, "bitnet" },
6667
{ LLM_ARCH_T5, "t5" },
6768
{ LLM_ARCH_T5ENCODER, "t5encoder" },
@@ -127,6 +128,7 @@ static const std::map<llm_kv, const char *> LLM_KV_NAMES = {
127128
{ LLM_KV_EXPERT_WEIGHTS_NORM, "%s.expert_weights_norm" },
128129
{ LLM_KV_EXPERT_GATING_FUNC, "%s.expert_gating_func" },
129130
{ LLM_KV_MOE_EVERY_N_LAYERS, "%s.moe_every_n_layers" },
131+
{ LLM_KV_NEXTN_PREDICT_LAYERS, "%s.nextn_predict_layers" },
130132
{ LLM_KV_POOLING_TYPE, "%s.pooling_type" },
131133
{ LLM_KV_LOGIT_SCALE, "%s.logit_scale" },
132134
{ LLM_KV_DECODER_START_TOKEN_ID, "%s.decoder_start_token_id" },
@@ -1391,6 +1393,40 @@ static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_N
13911393
{ LLM_TENSOR_FFN_POST_NORM, "blk.%d.post_ffw_norm" },
13921394
},
13931395
},
1396+
{
1397+
LLM_ARCH_GLM4_MOE,
1398+
{
1399+
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
1400+
{ LLM_TENSOR_OUTPUT_NORM, "output_norm" },
1401+
{ LLM_TENSOR_OUTPUT, "output" },
1402+
{ LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
1403+
{ LLM_TENSOR_ATTN_POST_NORM, "blk.%d.post_attention_norm" },
1404+
{ LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
1405+
{ LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
1406+
{ LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },
1407+
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
1408+
{ LLM_TENSOR_ATTN_Q_NORM, "blk.%d.attn_q_norm" },
1409+
{ LLM_TENSOR_ATTN_K_NORM, "blk.%d.attn_k_norm" },
1410+
{ LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" },
1411+
{ LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
1412+
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
1413+
{ LLM_TENSOR_FFN_GATE_INP, "blk.%d.ffn_gate_inp" },
1414+
{ LLM_TENSOR_FFN_GATE_EXPS, "blk.%d.ffn_gate_exps" },
1415+
{ LLM_TENSOR_FFN_DOWN_EXPS, "blk.%d.ffn_down_exps" },
1416+
{ LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" },
1417+
{ LLM_TENSOR_FFN_GATE_SHEXP, "blk.%d.ffn_gate_shexp" },
1418+
{ LLM_TENSOR_FFN_DOWN_SHEXP, "blk.%d.ffn_down_shexp" },
1419+
{ LLM_TENSOR_FFN_UP_SHEXP, "blk.%d.ffn_up_shexp" },
1420+
{ LLM_TENSOR_FFN_EXP_PROBS_B, "blk.%d.exp_probs_b" },
1421+
// NextN/MTP tensors - preserved but unused (in final layer, dynamic layer number)
1422+
{ LLM_TENSOR_NEXTN_EH_PROJ, "blk.%d.nextn.eh_proj" },
1423+
{ LLM_TENSOR_NEXTN_EMBED_TOKENS, "blk.%d.nextn.embed_tokens" },
1424+
{ LLM_TENSOR_NEXTN_ENORM, "blk.%d.nextn.enorm" },
1425+
{ LLM_TENSOR_NEXTN_HNORM, "blk.%d.nextn.hnorm" },
1426+
{ LLM_TENSOR_NEXTN_SHARED_HEAD_HEAD, "blk.%d.nextn.shared_head_head" },
1427+
{ LLM_TENSOR_NEXTN_SHARED_HEAD_NORM, "blk.%d.nextn.shared_head_norm" },
1428+
},
1429+
},
13941430
{
13951431
LLM_ARCH_BITNET,
13961432
{
@@ -2181,6 +2217,14 @@ static const std::map<llm_tensor, llm_tensor_info> LLM_TENSOR_INFOS = {
21812217
{LLM_TENSOR_SHORTCONV_CONV, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_SSM_CONV}},
21822218
{LLM_TENSOR_SHORTCONV_INPROJ, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
21832219
{LLM_TENSOR_SHORTCONV_OUTPROJ, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
2220+
// NextN/MTP tensors are currently ignored (reserved for future MTP support)
2221+
// These tensors only exist in the last layer(s) and are treated as output tensors
2222+
{LLM_TENSOR_NEXTN_EH_PROJ, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL_MAT}},
2223+
{LLM_TENSOR_NEXTN_EMBED_TOKENS, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_GET_ROWS}},
2224+
{LLM_TENSOR_NEXTN_ENORM, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_GET_ROWS}},
2225+
{LLM_TENSOR_NEXTN_HNORM, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL}},
2226+
{LLM_TENSOR_NEXTN_SHARED_HEAD_HEAD, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL_MAT}},
2227+
{LLM_TENSOR_NEXTN_SHARED_HEAD_NORM, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL}},
21842228
};
21852229

21862230
LLM_KV::LLM_KV(llm_arch arch, const char * suffix) : arch(arch), suffix(suffix) {}

src/llama-arch.h

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,7 @@ enum llm_arch {
6666
LLM_ARCH_DEEPSEEK2,
6767
LLM_ARCH_CHATGLM,
6868
LLM_ARCH_GLM4,
69+
LLM_ARCH_GLM4_MOE,
6970
LLM_ARCH_BITNET,
7071
LLM_ARCH_T5,
7172
LLM_ARCH_T5ENCODER,
@@ -131,6 +132,7 @@ enum llm_kv {
131132
LLM_KV_EXPERT_WEIGHTS_NORM,
132133
LLM_KV_EXPERT_GATING_FUNC,
133134
LLM_KV_MOE_EVERY_N_LAYERS,
135+
LLM_KV_NEXTN_PREDICT_LAYERS,
134136
LLM_KV_POOLING_TYPE,
135137
LLM_KV_LOGIT_SCALE,
136138
LLM_KV_DECODER_START_TOKEN_ID,
@@ -409,6 +411,12 @@ enum llm_tensor {
409411
LLM_TENSOR_SHORTCONV_CONV,
410412
LLM_TENSOR_SHORTCONV_INPROJ,
411413
LLM_TENSOR_SHORTCONV_OUTPROJ,
414+
LLM_TENSOR_NEXTN_EH_PROJ,
415+
LLM_TENSOR_NEXTN_EMBED_TOKENS,
416+
LLM_TENSOR_NEXTN_ENORM,
417+
LLM_TENSOR_NEXTN_HNORM,
418+
LLM_TENSOR_NEXTN_SHARED_HEAD_HEAD,
419+
LLM_TENSOR_NEXTN_SHARED_HEAD_NORM,
412420
};
413421

414422
enum llm_tensor_layer {

0 commit comments

Comments
 (0)