Skip to content

Commit dbfadb6

Browse files
committed
feat: support GLM 4.5 family of models
1 parent 3d15c4a commit dbfadb6

File tree

5 files changed

+36
-8
lines changed

5 files changed

+36
-8
lines changed

convert_hf_to_gguf.py

Lines changed: 22 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6605,9 +6605,9 @@ def set_vocab(self):
66056605
self.gguf_writer.add_token_types(toktypes)
66066606

66076607
# Special tokens
6608-
# BOS should be [gMASK] (151331), EOS should be <|endoftext|> (151329) as per official config
6608+
# BOS should be [gMASK] (151331), EOS should be <|endoftext|> (151329) as per tokenizer analysis
66096609
special_vocab._set_special_token(
6610-
"eos", tokenizer.get_added_vocab()["<|endoftext|>"] # 151329 - official EOS token
6610+
"eos", tokenizer.get_added_vocab()["<|endoftext|>"] # 151329 - correct EOS token
66116611
)
66126612
special_vocab._set_special_token(
66136613
"eot", tokenizer.get_added_vocab()["<|endoftext|>"] # 151329 - same as EOS
@@ -6620,9 +6620,25 @@ def set_vocab(self):
66206620
)
66216621
special_vocab._set_special_token("eom", tokenizer.get_added_vocab()["<|observation|>"]) # 151338
66226622

6623-
if "/nothink" in tokenizer.get_added_vocab():
6624-
special_vocab._set_special_token("nothink", tokenizer.get_added_vocab()["/nothink"]) # 151360
6623+
if "<sop>" in tokenizer.get_added_vocab():
6624+
special_vocab._set_special_token("sop", tokenizer.get_added_vocab()["<sop>"]) # 151333
6625+
if "<eop>" in tokenizer.get_added_vocab():
6626+
special_vocab._set_special_token("eop", tokenizer.get_added_vocab()["<eop>"]) # 151334
6627+
if "[sMASK]" in tokenizer.get_added_vocab():
6628+
special_vocab._set_special_token("smask", tokenizer.get_added_vocab()["[sMASK]"]) # 151332
6629+
6630+
# TODO: clean up once decided on an approach to think and /nothink
6631+
#
6632+
# Previously:
6633+
# if "/nothink" in tokenizer.get_added_vocab():
6634+
# special_vocab._set_special_token("nothink", tokenizer.get_added_vocab()["/nothink"]) # 151360
66256635
# Note: <think> and </think> are regular tokens (special=false in official config), not special tokens
6636+
#
6637+
# Latest thinking is:
6638+
# NOTE: /nothink token exists but causes generation issues as mentioned in
6639+
# https://huggingface.co/zai-org/GLM-4.5/discussions/9
6640+
# "it is a very special token. Even as input, it will be encoded into a special token, causing generation issues."
6641+
# Therefore we do NOT add it to avoid generation problems
66266642

66276643
special_vocab.add_to_gguf(self.gguf_writer)
66286644

@@ -6639,6 +6655,8 @@ def set_gguf_parameters(self):
66396655
# MoE parameters - Use only routed expert count (shared experts handled separately)
66406656
if (n_routed_experts := self.hparams.get("n_routed_experts")) is not None:
66416657
self.gguf_writer.add_expert_count(n_routed_experts)
6658+
if (num_experts_per_tok := self.hparams.get("num_experts_per_tok")) is not None:
6659+
self.gguf_writer.add_expert_used_count(num_experts_per_tok)
66426660
if (moe_intermediate_size := self.hparams.get("moe_intermediate_size")) is not None:
66436661
self.gguf_writer.add_expert_feed_forward_length(moe_intermediate_size)
66446662
if (n_shared_experts := self.hparams.get("n_shared_experts")) is not None:

models/templates/README.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,4 +21,5 @@ These templates can be updated with the following commands:
2121
./scripts/get_chat_template.py Qwen/Qwen2.5-7B-Instruct > models/templates/Qwen-Qwen2.5-7B-Instruct.jinja
2222
./scripts/get_chat_template.py Qwen/QwQ-32B > models/templates/Qwen-QwQ-32B.jinja
2323
./scripts/get_chat_template.py Qwen/Qwen3-0.6B > models/templates/Qwen-Qwen3-0.6B.jinja
24-
```
24+
./scripts/get_chat_template.py zai-org/GLM-4.5 > models/templates/zai-org-GLM-4.5.jinja
25+
```
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
numpy~=1.26.4
22
sentencepiece~=0.2.0
3-
transformers>=4.45.1,<5.0.0
3+
transformers>=4.54.1,<5.0.0
44
gguf>=0.1.0
55
protobuf>=4.21.0,<5.0.0

src/llama-kv-cache-unified.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,10 @@ llama_kv_cache_unified::llama_kv_cache_unified(
3939
if (model.arch == LLM_ARCH_GEMMA3N) {
4040
n_layer_cache = 20;
4141
}
42+
if (model.arch == LLM_ARCH_GLM4_MOE) {
43+
// GLM4_MOE: Only process first 46 transformer layers, skip NextN layer
44+
n_layer_cache = hparams.n_layer - 1;
45+
}
4246

4347
// create a context for each buffer type
4448
std::map<ggml_backend_buffer_type_t, ggml_context *> ctx_map;

src/llama-model.cpp

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4397,6 +4397,8 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
43974397
create_tensor(tn(LLM_TENSOR_NEXTN_SHARED_HEAD_HEAD, final_layer), { n_embd, n_vocab }, TENSOR_NOT_REQUIRED);
43984398
create_tensor(tn(LLM_TENSOR_NEXTN_SHARED_HEAD_NORM, final_layer), { n_embd }, TENSOR_NOT_REQUIRED);
43994399

4400+
// Load ALL tensors including NextN layer to satisfy tensor count (803)
4401+
// but only PROCESS first 46 transformer layers in forward pass
44004402
for (int i = 0; i < n_layer; ++i) {
44014403
auto & layer = layers[i];
44024404

@@ -13492,7 +13494,10 @@ struct llm_build_glm4_moe : public llm_graph_context {
1349213494

1349313495
ggml_tensor * inp_out_ids = build_inp_out_ids();
1349413496

13495-
for (int il = 0; il < n_layer; ++il) {
13497+
// Only process first 46 transformer layers (skip NextN layer 46)
13498+
// Layer 46 tensors are loaded but not processed in forward pass
13499+
const int n_transformer_layers = n_layer - 1;
13500+
for (int il = 0; il < n_transformer_layers; ++il) {
1349613501
ggml_tensor * inpSA = inpL;
1349713502

1349813503
// Pre-attention norm
@@ -13554,7 +13559,7 @@ struct llm_build_glm4_moe : public llm_graph_context {
1355413559
Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
1355513560
}
1355613561

13557-
if (il == n_layer - 1 && inp_out_ids) {
13562+
if (il == n_transformer_layers - 1 && inp_out_ids) {
1355813563
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
1355913564
inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
1356013565
}

0 commit comments

Comments
 (0)