Skip to content

Commit 84ddaa5

Browse files
sammcjCISC
andcommitted
model: Add GLM 4.5 (#14921)
Co-authored-by: Sigbjørn Skjæret <[email protected]>
1 parent 0a5036b commit 84ddaa5

File tree

10 files changed

+544
-6
lines changed

10 files changed

+544
-6
lines changed

convert_hf_to_gguf.py

Lines changed: 147 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -678,6 +678,9 @@ def get_vocab_base_pre(self, tokenizer) -> str:
678678
if chkhsh == "a1336059768a55c99a734006ffb02203cd450fed003e9a71886c88acf24fdbc2":
679679
# ref: https://huggingface.co/THUDM/glm-4-9b-hf
680680
res = "glm4"
681+
if chkhsh == "9ca2dd618e8afaf09731a7cf6e2105b373ba6a1821559f258b272fe83e6eb902":
682+
# ref: https://huggingface.co/zai-org/GLM-4.5-Air
683+
res = "glm4"
681684
if chkhsh == "1431a23e583c97432bc230bff598d103ddb5a1f89960c8f1d1051aaa944d0b35":
682685
# ref: https://huggingface.co/sapienzanlp/Minerva-7B-base-v1.0
683686
res = "minerva-7b"
@@ -6578,6 +6581,149 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter
65786581
return super().modify_tensors(data_torch, name, bid)
65796582

65806583

6584+
@ModelBase.register("Glm4MoeForCausalLM")
6585+
class Glm4MoeModel(TextModel):
6586+
model_arch = gguf.MODEL_ARCH.GLM4_MOE
6587+
6588+
def __init__(self, *args, **kwargs):
6589+
super().__init__(*args, **kwargs)
6590+
# GLM4_MOE has num_hidden_layers + 1 actual layers (including NextN layer)
6591+
self.block_count = self.hparams["num_hidden_layers"] + 1
6592+
self.tensor_map = gguf.get_tensor_name_map(self.model_arch, self.block_count)
6593+
6594+
def set_vocab(self):
6595+
from transformers import AutoTokenizer
6596+
6597+
tokenizer = AutoTokenizer.from_pretrained(self.dir_model)
6598+
special_vocab = gguf.SpecialVocab(self.dir_model, load_merges=True)
6599+
tokens, toktypes, tokpre = self.get_vocab_base()
6600+
self.gguf_writer.add_tokenizer_model("gpt2")
6601+
self.gguf_writer.add_tokenizer_pre(tokpre)
6602+
self.gguf_writer.add_token_list(tokens)
6603+
self.gguf_writer.add_token_types(toktypes)
6604+
6605+
# Special tokens
6606+
# Note: Using <|endoftext|> (151329) for eos and eot causes endless generation
6607+
special_vocab._set_special_token("bos", tokenizer.get_added_vocab()["[gMASK]"]) # 151331
6608+
special_vocab._set_special_token("eos", tokenizer.get_added_vocab()["<|user|>"]) # 151336 - end of
6609+
special_vocab._set_special_token("eot", tokenizer.get_added_vocab()["<|user|>"]) # 151336 - same as EOS
6610+
special_vocab._set_special_token("eog", tokenizer.get_added_vocab()["<|user|>"]) # 151336 - same as EOS
6611+
special_vocab._set_special_token("unk", tokenizer.get_added_vocab()["<|endoftext|>"]) # 151329
6612+
special_vocab._set_special_token("eom", tokenizer.get_added_vocab()["<|observation|>"]) # 151338
6613+
6614+
if "<sop>" in tokenizer.get_added_vocab():
6615+
special_vocab._set_special_token("sop", tokenizer.get_added_vocab()["<sop>"]) # 151333
6616+
if "<eop>" in tokenizer.get_added_vocab():
6617+
special_vocab._set_special_token("eop", tokenizer.get_added_vocab()["<eop>"]) # 151334
6618+
6619+
special_vocab.add_to_gguf(self.gguf_writer)
6620+
6621+
def set_gguf_parameters(self):
6622+
super().set_gguf_parameters()
6623+
if (rope_dim := self.hparams.get("head_dim")) is None:
6624+
rope_dim = (
6625+
self.hparams["hidden_size"] // self.hparams["num_attention_heads"]
6626+
)
6627+
self.gguf_writer.add_rope_dimension_count(
6628+
int(rope_dim * self.hparams.get("partial_rotary_factor", 0.5))
6629+
)
6630+
6631+
# MoE parameters - Use only routed expert count (shared experts handled separately)
6632+
if (n_routed_experts := self.hparams.get("n_routed_experts")) is not None:
6633+
self.gguf_writer.add_expert_count(n_routed_experts)
6634+
if (num_experts_per_tok := self.hparams.get("num_experts_per_tok")) is not None:
6635+
self.gguf_writer.add_expert_used_count(num_experts_per_tok)
6636+
if (moe_intermediate_size := self.hparams.get("moe_intermediate_size")) is not None:
6637+
self.gguf_writer.add_expert_feed_forward_length(moe_intermediate_size)
6638+
if (n_shared_experts := self.hparams.get("n_shared_experts")) is not None:
6639+
self.gguf_writer.add_expert_shared_count(n_shared_experts)
6640+
if (first_k_dense_replace := self.hparams.get("first_k_dense_replace")) is not None:
6641+
self.gguf_writer.add_leading_dense_block_count(first_k_dense_replace)
6642+
6643+
# Expert gating function (sigmoid for GLM4_MOE)
6644+
self.gguf_writer.add_expert_gating_func(gguf.ExpertGatingFuncType.SIGMOID)
6645+
6646+
# Routed scaling factor
6647+
if (routed_scaling_factor := self.hparams.get("routed_scaling_factor")) is not None:
6648+
self.gguf_writer.add_expert_weights_scale(routed_scaling_factor)
6649+
6650+
# Normalise topk probabilities
6651+
if (norm_topk_prob := self.hparams.get("norm_topk_prob")) is not None:
6652+
self.gguf_writer.add_expert_weights_norm(norm_topk_prob)
6653+
6654+
_experts: list[dict[str, Tensor]] | None = None
6655+
6656+
def modify_tensors(
6657+
self, data_torch: Tensor, name: str, bid: int | None
6658+
) -> Iterable[tuple[str, Tensor]]:
6659+
if name.startswith("model.visual."): # ignore visual part
6660+
return []
6661+
elif name.startswith("model.language_model."):
6662+
name = name.replace("language_model.", "") # for multimodal variants
6663+
6664+
# Handle main token embedding (but not layer-specific NextN embeddings)
6665+
if name == "model.embed_tokens.weight" and ".layers." not in name:
6666+
return [(self.map_tensor_name("token_embd.weight"), data_torch)]
6667+
6668+
# Handle routed experts
6669+
if name.find("mlp.experts") != -1:
6670+
n_experts = self.hparams["n_routed_experts"]
6671+
assert bid is not None
6672+
6673+
if self._experts is None:
6674+
self._experts = [{} for _ in range(self.block_count)]
6675+
6676+
self._experts[bid][name] = data_torch
6677+
6678+
if len(self._experts[bid]) >= n_experts * 3:
6679+
tensors: list[tuple[str, Tensor]] = []
6680+
6681+
# merge the experts into a single 3d tensor
6682+
for w_name in ["down_proj", "gate_proj", "up_proj"]:
6683+
datas: list[Tensor] = []
6684+
6685+
for xid in range(n_experts):
6686+
ename = f"model.layers.{bid}.mlp.experts.{xid}.{w_name}.weight"
6687+
datas.append(self._experts[bid][ename])
6688+
del self._experts[bid][ename]
6689+
6690+
data_torch = torch.stack(datas, dim=0)
6691+
6692+
merged_name = f"model.layers.{bid}.mlp.experts.{w_name}.weight"
6693+
6694+
new_name = self.map_tensor_name(merged_name)
6695+
tensors.append((new_name, data_torch))
6696+
return tensors
6697+
else:
6698+
return []
6699+
6700+
if name.endswith("e_score_correction_bias"):
6701+
name = name.replace("e_score_correction_bias", "e_score_correction.bias")
6702+
6703+
# Handle special NextN tensors - preserve for future MTP support
6704+
if (
6705+
".embed_tokens." in name
6706+
or ".shared_head." in name
6707+
or ".eh_proj." in name
6708+
or ".enorm." in name
6709+
or ".hnorm." in name
6710+
):
6711+
new_name = name.replace("model.layers.", "blk.").replace("model.", "").replace(".weight", "")
6712+
return [(new_name, data_torch)]
6713+
6714+
new_name = self.map_tensor_name(name)
6715+
6716+
return [(new_name, data_torch)]
6717+
6718+
def prepare_tensors(self):
6719+
super().prepare_tensors()
6720+
if self._experts is not None:
6721+
# flatten `list[dict[str, Tensor]]` into `list[str]`
6722+
experts = [k for d in self._experts for k in d.keys()]
6723+
if len(experts) > 0:
6724+
raise ValueError(f"Unprocessed experts: {experts}")
6725+
6726+
65816727
@ModelBase.register("GlmForCausalLM", "ChatGLMModel", "ChatGLMForConditionalGeneration")
65826728
class ChatGLMModel(TextModel):
65836729
model_arch = gguf.MODEL_ARCH.CHATGLM
@@ -6594,7 +6740,7 @@ def set_vocab_chatglm3(self):
65946740
vocab_size = hparams.get("padded_vocab_size", len(tokenizer.get_vocab()))
65956741
assert max(tokenizer.get_vocab().values()) < vocab_size
65966742
role_special_tokens = ["<|system|>", "<|user|>", "<|assistant|>", "<|observation|>"]
6597-
special_tokens = ["[MASK]", "[gMASK]", "[sMASK]", "sop", "eop"] + role_special_tokens
6743+
special_tokens = ["[MASK]", "[gMASK]", "sop", "eop"] + role_special_tokens
65986744
for token_id in range(vocab_size):
65996745
piece = tokenizer._convert_id_to_token(token_id)
66006746
if token_id == 0:

convert_hf_to_gguf_update.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -138,6 +138,7 @@ class TOKENIZER_TYPE(IntEnum):
138138
{"name": "chatglm-bpe", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/THUDM/glm-4-9b-chat", "chkhsh": "b6e8e1518dc4305be2fe39c313ed643381c4da5db34a98f6a04c093f8afbe99b"},
139139
{"name": "chatglm-bpe", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/THUDM/glm-4-9b-chat", "chkhsh": "81d72c7348a9f0ebe86f23298d37debe0a5e71149e29bd283904c02262b27516"},
140140
{"name": "glm4", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/THUDM/glm-4-9b-hf", "chkhsh": "a1336059768a55c99a734006ffb02203cd450fed003e9a71886c88acf24fdbc2"},
141+
{"name": "glm4", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/zai-org/GLM-4.5-Air", "chkhsh": "9ca2dd618e8afaf09731a7cf6e2105b373ba6a1821559f258b272fe83e6eb902"},
141142
{"name": "minerva-7b", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/sapienzanlp/Minerva-7B-base-v1.0", "chkhsh": "1431a23e583c97432bc230bff598d103ddb5a1f89960c8f1d1051aaa944d0b35"},
142143
{"name": "hunyuan", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/tencent/Hunyuan-A13B-Instruct", "chkhsh": "7e57df22b1fe23a7b1e1c7f3dc4e3f96d43a4eb0836d0c6bdc3436d7b2f1c664"},
143144
# falcon-h1 series uses 4 different tokenizers across model sizes (0.5b - 34b), hence we need to define 4 different hashes

gguf-py/gguf/constants.py

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -354,6 +354,7 @@ class MODEL_ARCH(IntEnum):
354354
DEEPSEEK2 = auto()
355355
CHATGLM = auto()
356356
GLM4 = auto()
357+
GLM4_MOE = auto()
357358
BITNET = auto()
358359
T5 = auto()
359360
T5ENCODER = auto()
@@ -609,6 +610,12 @@ class MODEL_TENSOR(IntEnum):
609610
A_MMPROJ_FC = auto()
610611
A_MM_NORM_PRE = auto()
611612
A_MM_NORM_MID = auto()
613+
NEXTN_EH_PROJ = auto() # nextn tensors (glm4moe)
614+
NEXTN_EMBED_TOKENS = auto() # nextn tensors (glm4moe)
615+
NEXTN_ENORM = auto() # nextn tensors (glm4moe)
616+
NEXTN_HNORM = auto() # nextn tensors (glm4moe)
617+
NEXTN_SHARED_HEAD_HEAD = auto() # nextn tensors (glm4moe)
618+
NEXTN_SHARED_HEAD_NORM = auto() # nextn tensors (glm4moe)
612619

613620

614621
MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = {
@@ -673,6 +680,7 @@ class MODEL_TENSOR(IntEnum):
673680
MODEL_ARCH.DEEPSEEK2: "deepseek2",
674681
MODEL_ARCH.CHATGLM: "chatglm",
675682
MODEL_ARCH.GLM4: "glm4",
683+
MODEL_ARCH.GLM4_MOE: "glm4moe",
676684
MODEL_ARCH.BITNET: "bitnet",
677685
MODEL_ARCH.T5: "t5",
678686
MODEL_ARCH.T5ENCODER: "t5encoder",
@@ -929,6 +937,13 @@ class MODEL_TENSOR(IntEnum):
929937
MODEL_TENSOR.A_MMPROJ_FC: "mm.a.fc",
930938
MODEL_TENSOR.A_MM_NORM_PRE: "mm.a.norm_pre",
931939
MODEL_TENSOR.A_MM_NORM_MID: "mm.a.norm_mid",
940+
# NextN/MTP tensors (GLM4_MOE)
941+
MODEL_TENSOR.NEXTN_EH_PROJ: "blk.{bid}.eh_proj",
942+
MODEL_TENSOR.NEXTN_EMBED_TOKENS: "blk.{bid}.embed_tokens",
943+
MODEL_TENSOR.NEXTN_ENORM: "blk.{bid}.enorm",
944+
MODEL_TENSOR.NEXTN_HNORM: "blk.{bid}.hnorm",
945+
MODEL_TENSOR.NEXTN_SHARED_HEAD_HEAD: "blk.{bid}.shared_head.head",
946+
MODEL_TENSOR.NEXTN_SHARED_HEAD_NORM: "blk.{bid}.shared_head.norm",
932947
}
933948

934949
MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
@@ -2102,6 +2117,37 @@ class MODEL_TENSOR(IntEnum):
21022117
MODEL_TENSOR.ATTN_POST_NORM,
21032118
MODEL_TENSOR.FFN_POST_NORM,
21042119
],
2120+
MODEL_ARCH.GLM4_MOE: [
2121+
MODEL_TENSOR.TOKEN_EMBD,
2122+
MODEL_TENSOR.OUTPUT_NORM,
2123+
MODEL_TENSOR.OUTPUT,
2124+
MODEL_TENSOR.ATTN_NORM,
2125+
MODEL_TENSOR.ATTN_POST_NORM,
2126+
MODEL_TENSOR.ATTN_Q,
2127+
MODEL_TENSOR.ATTN_K,
2128+
MODEL_TENSOR.ATTN_V,
2129+
MODEL_TENSOR.ATTN_OUT,
2130+
MODEL_TENSOR.ATTN_Q_NORM,
2131+
MODEL_TENSOR.ATTN_K_NORM,
2132+
MODEL_TENSOR.FFN_GATE, # dense layers
2133+
MODEL_TENSOR.FFN_DOWN, # dense layers
2134+
MODEL_TENSOR.FFN_UP, # dense layers
2135+
MODEL_TENSOR.FFN_GATE_INP,
2136+
MODEL_TENSOR.FFN_GATE_EXP,
2137+
MODEL_TENSOR.FFN_DOWN_EXP,
2138+
MODEL_TENSOR.FFN_UP_EXP,
2139+
MODEL_TENSOR.FFN_GATE_SHEXP,
2140+
MODEL_TENSOR.FFN_DOWN_SHEXP,
2141+
MODEL_TENSOR.FFN_UP_SHEXP,
2142+
MODEL_TENSOR.FFN_EXP_PROBS_B,
2143+
# NextN/MTP tensors - preserved but unused
2144+
MODEL_TENSOR.NEXTN_EH_PROJ,
2145+
MODEL_TENSOR.NEXTN_EMBED_TOKENS,
2146+
MODEL_TENSOR.NEXTN_ENORM,
2147+
MODEL_TENSOR.NEXTN_HNORM,
2148+
MODEL_TENSOR.NEXTN_SHARED_HEAD_HEAD,
2149+
MODEL_TENSOR.NEXTN_SHARED_HEAD_NORM,
2150+
],
21052151
MODEL_ARCH.BITNET: [
21062152
MODEL_TENSOR.ATTN_Q,
21072153
MODEL_TENSOR.ATTN_K,

models/templates/README.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,4 +21,5 @@ These templates can be updated with the following commands:
2121
./scripts/get_chat_template.py Qwen/Qwen2.5-7B-Instruct > models/templates/Qwen-Qwen2.5-7B-Instruct.jinja
2222
./scripts/get_chat_template.py Qwen/QwQ-32B > models/templates/Qwen-QwQ-32B.jinja
2323
./scripts/get_chat_template.py Qwen/Qwen3-0.6B > models/templates/Qwen-Qwen3-0.6B.jinja
24-
```
24+
./scripts/get_chat_template.py zai-org/GLM-4.5 > models/templates/zai-org-GLM-4.5.jinja
25+
```

src/llama-arch.cpp

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,7 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
6262
{ LLM_ARCH_DEEPSEEK2, "deepseek2" },
6363
{ LLM_ARCH_CHATGLM, "chatglm" },
6464
{ LLM_ARCH_GLM4, "glm4" },
65+
{ LLM_ARCH_GLM4_MOE, "glm4moe" },
6566
{ LLM_ARCH_BITNET, "bitnet" },
6667
{ LLM_ARCH_T5, "t5" },
6768
{ LLM_ARCH_T5ENCODER, "t5encoder" },
@@ -1389,6 +1390,40 @@ static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_N
13891390
{ LLM_TENSOR_FFN_POST_NORM, "blk.%d.post_ffw_norm" },
13901391
},
13911392
},
1393+
{
1394+
LLM_ARCH_GLM4_MOE,
1395+
{
1396+
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
1397+
{ LLM_TENSOR_OUTPUT_NORM, "output_norm" },
1398+
{ LLM_TENSOR_OUTPUT, "output" },
1399+
{ LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
1400+
{ LLM_TENSOR_ATTN_POST_NORM, "blk.%d.post_attention_norm" },
1401+
{ LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
1402+
{ LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
1403+
{ LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },
1404+
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
1405+
{ LLM_TENSOR_ATTN_Q_NORM, "blk.%d.attn_q_norm" },
1406+
{ LLM_TENSOR_ATTN_K_NORM, "blk.%d.attn_k_norm" },
1407+
{ LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" }, // dense layers
1408+
{ LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, // dense layers
1409+
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, // dense layers
1410+
{ LLM_TENSOR_FFN_GATE_INP, "blk.%d.ffn_gate_inp" },
1411+
{ LLM_TENSOR_FFN_GATE_EXPS, "blk.%d.ffn_gate_exps" },
1412+
{ LLM_TENSOR_FFN_DOWN_EXPS, "blk.%d.ffn_down_exps" },
1413+
{ LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" },
1414+
{ LLM_TENSOR_FFN_GATE_SHEXP, "blk.%d.ffn_gate_shexp" },
1415+
{ LLM_TENSOR_FFN_DOWN_SHEXP, "blk.%d.ffn_down_shexp" },
1416+
{ LLM_TENSOR_FFN_UP_SHEXP, "blk.%d.ffn_up_shexp" },
1417+
{ LLM_TENSOR_FFN_EXP_PROBS_B, "blk.%d.exp_probs_b" },
1418+
// NextN/MTP tensors - preserved but unused (in final layer, dynamic layer number)
1419+
{ LLM_TENSOR_NEXTN_EH_PROJ, "blk.%d.eh_proj" },
1420+
{ LLM_TENSOR_NEXTN_EMBED_TOKENS, "blk.%d.embed_tokens" },
1421+
{ LLM_TENSOR_NEXTN_ENORM, "blk.%d.enorm" },
1422+
{ LLM_TENSOR_NEXTN_HNORM, "blk.%d.hnorm" },
1423+
{ LLM_TENSOR_NEXTN_SHARED_HEAD_HEAD, "blk.%d.shared_head.head" },
1424+
{ LLM_TENSOR_NEXTN_SHARED_HEAD_NORM, "blk.%d.shared_head.norm" },
1425+
},
1426+
},
13921427
{
13931428
LLM_ARCH_BITNET,
13941429
{
@@ -2142,6 +2177,14 @@ static const std::map<llm_tensor, llm_tensor_info> LLM_TENSOR_INFOS = {
21422177
{LLM_TENSOR_SHORTCONV_CONV, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_SSM_CONV}},
21432178
{LLM_TENSOR_SHORTCONV_INPROJ, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
21442179
{LLM_TENSOR_SHORTCONV_OUTPROJ, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
2180+
// NextN/MTP tensors are loaded but never used (reserved for future MTP support)
2181+
// These tensors only exist in the last layer (layer 46 for GLM-4.5-Air) and are treated as output tensors
2182+
{LLM_TENSOR_NEXTN_EH_PROJ, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_NONE}},
2183+
{LLM_TENSOR_NEXTN_EMBED_TOKENS, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_NONE}},
2184+
{LLM_TENSOR_NEXTN_ENORM, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_NONE}},
2185+
{LLM_TENSOR_NEXTN_HNORM, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_NONE}},
2186+
{LLM_TENSOR_NEXTN_SHARED_HEAD_HEAD, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_NONE}},
2187+
{LLM_TENSOR_NEXTN_SHARED_HEAD_NORM, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_NONE}},
21452188
};
21462189

21472190
LLM_KV::LLM_KV(llm_arch arch, const char * suffix) : arch(arch), suffix(suffix) {}

src/llama-arch.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,7 @@ enum llm_arch {
6666
LLM_ARCH_DEEPSEEK2,
6767
LLM_ARCH_CHATGLM,
6868
LLM_ARCH_GLM4,
69+
LLM_ARCH_GLM4_MOE,
6970
LLM_ARCH_BITNET,
7071
LLM_ARCH_T5,
7172
LLM_ARCH_T5ENCODER,
@@ -407,6 +408,12 @@ enum llm_tensor {
407408
LLM_TENSOR_SHORTCONV_CONV,
408409
LLM_TENSOR_SHORTCONV_INPROJ,
409410
LLM_TENSOR_SHORTCONV_OUTPROJ,
411+
LLM_TENSOR_NEXTN_EH_PROJ,
412+
LLM_TENSOR_NEXTN_EMBED_TOKENS,
413+
LLM_TENSOR_NEXTN_ENORM,
414+
LLM_TENSOR_NEXTN_HNORM,
415+
LLM_TENSOR_NEXTN_SHARED_HEAD_HEAD,
416+
LLM_TENSOR_NEXTN_SHARED_HEAD_NORM,
410417
};
411418

412419
enum llm_tensor_layer {

src/llama-graph.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -760,8 +760,8 @@ ggml_tensor * llm_graph_context::build_ffn(
760760

761761
if (down) {
762762
cur = build_lora_mm(down, cur);
763-
if (arch == LLM_ARCH_GLM4) {
764-
// GLM4 seems to have numerical issues with half-precision accumulators
763+
if (arch == LLM_ARCH_GLM4 || arch == LLM_ARCH_GLM4_MOE) {
764+
// GLM4 and GLM4_MOE seem to have numerical issues with half-precision accumulators
765765
ggml_mul_mat_set_prec(cur, GGML_PREC_F32);
766766
}
767767
}
@@ -1481,8 +1481,8 @@ ggml_tensor * llm_graph_context::build_attn(
14811481

14821482
if (wo) {
14831483
cur = build_lora_mm(wo, cur);
1484-
if (arch == LLM_ARCH_GLM4) {
1485-
// GLM4 seems to have numerical issues with half-precision accumulators
1484+
if (arch == LLM_ARCH_GLM4 || arch == LLM_ARCH_GLM4_MOE) {
1485+
// GLM4 and GLM4_MOE seem to have numerical issues with half-precision accumulators
14861486
ggml_mul_mat_set_prec(cur, GGML_PREC_F32);
14871487
}
14881488
}

src/llama-kv-cache-unified.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,10 @@ llama_kv_cache_unified::llama_kv_cache_unified(
3939
if (model.arch == LLM_ARCH_GEMMA3N) {
4040
n_layer_cache = 20;
4141
}
42+
if (model.arch == LLM_ARCH_GLM4_MOE) {
43+
// GLM4_MOE: Only process first 46 transformer layers, skip NextN layer
44+
n_layer_cache = hparams.n_layer - 1;
45+
}
4246

4347
// create a context for each buffer type
4448
std::map<ggml_backend_buffer_type_t, ggml_context *> ctx_map;

0 commit comments

Comments
 (0)