Skip to content

Commit 8865bdf

Browse files
GLM-4-0414
1 parent d3bd719 commit 8865bdf

File tree

10 files changed

+325
-41
lines changed

10 files changed

+325
-41
lines changed

README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,7 @@ Instructions for adding support for new models: [HOWTO-add-model.md](docs/develo
9797
- [x] [Flan T5](https://huggingface.co/models?search=flan-t5)
9898
- [x] [Open Elm models](https://huggingface.co/collections/apple/openelm-instruct-models-6619ad295d7ae9f868b759ca)
9999
- [x] [ChatGLM3-6b](https://huggingface.co/THUDM/chatglm3-6b) + [ChatGLM4-9b](https://huggingface.co/THUDM/glm-4-9b) + [GLMEdge-1.5b](https://huggingface.co/THUDM/glm-edge-1.5b-chat) + [GLMEdge-4b](https://huggingface.co/THUDM/glm-edge-4b-chat)
100+
- [x] [GLM-4-0414](https://huggingface.co/collections/THUDM/glm-4-0414-67f3cbcb34dd9d252707cb2e)
100101
- [x] [SmolLM](https://huggingface.co/collections/HuggingFaceTB/smollm-6695016cad7167254ce15966)
101102
- [x] [EXAONE-3.0-7.8B-Instruct](https://huggingface.co/LGAI-EXAONE/EXAONE-3.0-7.8B-Instruct)
102103
- [x] [FalconMamba Models](https://huggingface.co/collections/tiiuae/falconmamba-7b-66b9a580324dd1598b0f6d4a)

convert_hf_to_gguf.py

Lines changed: 38 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -717,6 +717,9 @@ def get_vocab_base_pre(self, tokenizer) -> str:
717717
if chkhsh == "d353350c764d8c3b39c763113960e4fb4919bea5fbf208a0e3b22e8469dc7406":
718718
# ref: https://huggingface.co/meta-llama/Llama-4-Scout-17B-16E-Instruct
719719
res = "llama4"
720+
if chkhsh == "a1336059768a55c99a734006ffb02203cd450fed003e9a71886c88acf24fdbc2":
721+
# ref: https://huggingface.co/THUDM/glm-4-9b-hf
722+
res = "glm4"
720723

721724
if res is None:
722725
logger.warning("\n")
@@ -4882,6 +4885,41 @@ def prepare_tensors(self):
48824885
super().prepare_tensors()
48834886
self.gguf_writer.add_max_alibi_bias(self.max_alibi_bias)
48844887

4888+
@Model.register("Glm4ForCausalLM")
4889+
class Glm4Model(Model):
4890+
model_arch = gguf.MODEL_ARCH.GLM4
4891+
4892+
def set_vocab(self):
4893+
self._set_vocab_gpt2()
4894+
4895+
def set_gguf_parameters(self):
4896+
super().set_gguf_parameters()
4897+
if self.hparams.get("rope_scaling") is not None and "factor" in self.hparams["rope_scaling"]:
4898+
if self.hparams["rope_scaling"].get("type") == "yarn":
4899+
self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.YARN)
4900+
self.gguf_writer.add_rope_scaling_factor(self.hparams["rope_scaling"]["factor"])
4901+
self.gguf_writer.add_rope_scaling_orig_ctx_len(self.hparams["rope_scaling"]["original_max_position_embeddings"])
4902+
4903+
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[Tuple[str, Tensor]]:
4904+
if "gate_up_proj" in name:
4905+
match = re.match(r"model\.layers\.(\d+)\.gate_up_proj\.weight", name)
4906+
if match:
4907+
bid = int(match.group(1))
4908+
return [(f"blk.{bid}.ffn_up.weight", data_torch)]
4909+
4910+
if "post_self_attn_layernorm" in name:
4911+
match = re.match(r"model\.layers\.(\d+)\.post_self_attn_layernorm\.weight", name)
4912+
if match:
4913+
bid = int(match.group(1))
4914+
return [(f"blk.{bid}.post_attn_norm.weight", data_torch)]
4915+
4916+
if "post_mlp_layernorm" in name:
4917+
match = re.match(r"model\.layers\.(\d+)\.post_mlp_layernorm\.weight", name)
4918+
if match:
4919+
bid = int(match.group(1))
4920+
return [(f"blk.{bid}.post_mlp_norm.weight", data_torch)]
4921+
4922+
return super().modify_tensors(data_torch, name, bid)
48854923

48864924
@Model.register("GlmForCausalLM", "ChatGLMModel", "ChatGLMForConditionalGeneration")
48874925
class ChatGLMModel(Model):
@@ -5551,7 +5589,6 @@ def main() -> None:
55515589
with torch.inference_mode():
55525590
output_type = ftype_map[args.outtype]
55535591
model_architecture = hparams["architectures"][0]
5554-
55555592
try:
55565593
model_class = Model.from_model_architecture(model_architecture)
55575594
except NotImplementedError:

convert_hf_to_gguf_update.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,7 @@ class TOKENIZER_TYPE(IntEnum):
8181
{"name": "refact", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/smallcloudai/Refact-1_6-base", },
8282
{"name": "command-r", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/CohereForAI/c4ai-command-r-v01", },
8383
{"name": "qwen2", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/Qwen/Qwen1.5-7B", },
84+
{"name": "glm4", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/THUDM/glm-4-9b-hf", },
8485
{"name": "olmo", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/allenai/OLMo-1.7-7B-hf", },
8586
{"name": "dbrx", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/databricks/dbrx-base", },
8687
{"name": "jina-v1-en", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/jinaai/jina-reranker-v1-tiny-en", },

examples/server/README.md

Lines changed: 38 additions & 38 deletions
Large diffs are not rendered by default.

gguf-py/gguf/constants.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -280,6 +280,7 @@ class MODEL_ARCH(IntEnum):
280280
DEEPSEEK = auto()
281281
DEEPSEEK2 = auto()
282282
CHATGLM = auto()
283+
GLM4 = auto()
283284
BITNET = auto()
284285
T5 = auto()
285286
T5ENCODER = auto()
@@ -487,6 +488,7 @@ class MODEL_TENSOR(IntEnum):
487488
MODEL_ARCH.DEEPSEEK: "deepseek",
488489
MODEL_ARCH.DEEPSEEK2: "deepseek2",
489490
MODEL_ARCH.CHATGLM: "chatglm",
491+
MODEL_ARCH.GLM4: "glm4",
490492
MODEL_ARCH.BITNET: "bitnet",
491493
MODEL_ARCH.T5: "t5",
492494
MODEL_ARCH.T5ENCODER: "t5encoder",
@@ -1561,6 +1563,21 @@ class MODEL_TENSOR(IntEnum):
15611563
MODEL_TENSOR.FFN_DOWN,
15621564
MODEL_TENSOR.FFN_UP,
15631565
],
1566+
MODEL_ARCH.GLM4 : [
1567+
MODEL_TENSOR.TOKEN_EMBD,
1568+
MODEL_TENSOR.ROPE_FREQS,
1569+
MODEL_TENSOR.OUTPUT_NORM,
1570+
MODEL_TENSOR.OUTPUT,
1571+
MODEL_TENSOR.ATTN_NORM,
1572+
MODEL_TENSOR.ATTN_QKV,
1573+
MODEL_TENSOR.ATTN_Q,
1574+
MODEL_TENSOR.ATTN_K,
1575+
MODEL_TENSOR.ATTN_V,
1576+
MODEL_TENSOR.ATTN_OUT,
1577+
MODEL_TENSOR.FFN_NORM,
1578+
MODEL_TENSOR.FFN_DOWN,
1579+
MODEL_TENSOR.FFN_UP,
1580+
],
15641581
MODEL_ARCH.BITNET: [
15651582
MODEL_TENSOR.ATTN_Q,
15661583
MODEL_TENSOR.ATTN_K,

gguf-py/gguf/tensor_mapping.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ class TensorNameMap:
1313
"transformer.wte", # gpt2 gpt-j mpt refact qwen dbrx jais exaone
1414
"transformer.word_embeddings", # falcon
1515
"word_embeddings", # bloom
16-
"model.embed_tokens", # llama-hf nemotron olmoe olmo2 rwkv6qwen2
16+
"model.embed_tokens", # llama-hf nemotron olmoe olmo2 rwkv6qwen2 glm4-0414
1717
"tok_embeddings", # llama-pth
1818
"embeddings.word_embeddings", # bert nomic-bert
1919
"language_model.embedding.word_embeddings", # persimmon
@@ -306,7 +306,7 @@ class TensorNameMap:
306306
"h.{bid}.mlp.c_fc", # gpt2
307307
"transformer.h.{bid}.mlp.fc1", # phi2
308308
"model.layers.{bid}.mlp.fc1", # phi2
309-
"model.layers.{bid}.mlp.gate_up_proj", # phi3
309+
"model.layers.{bid}.mlp.gate_up_proj", # phi3 glm-4-0414
310310
"model.layers.layers.{bid}.mlp.up_proj", # plamo
311311
"model.layers.{bid}.feed_forward.w3", # internlm2
312312
"encoder.layers.{bid}.mlp.fc11", # nomic-bert

src/llama-arch.cpp

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
5454
{ LLM_ARCH_DEEPSEEK, "deepseek" },
5555
{ LLM_ARCH_DEEPSEEK2, "deepseek2" },
5656
{ LLM_ARCH_CHATGLM, "chatglm" },
57+
{ LLM_ARCH_GLM4, "glm4" },
5758
{ LLM_ARCH_BITNET, "bitnet" },
5859
{ LLM_ARCH_T5, "t5" },
5960
{ LLM_ARCH_T5ENCODER, "t5encoder" },
@@ -1152,6 +1153,25 @@ static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_N
11521153
{ LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
11531154
},
11541155
},
1156+
{
1157+
LLM_ARCH_GLM4,
1158+
{
1159+
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
1160+
{ LLM_TENSOR_ROPE_FREQS, "rope_freqs" },
1161+
{ LLM_TENSOR_OUTPUT_NORM, "output_norm" },
1162+
{ LLM_TENSOR_OUTPUT, "output" },
1163+
{ LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
1164+
{ LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
1165+
{ LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
1166+
{ LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },
1167+
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
1168+
{ LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
1169+
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
1170+
{ LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
1171+
{ LLM_TENSOR_ATTN_POST_NORM, "blk.%d.post_attn_norm" },
1172+
{ LLM_TENSOR_FFN_POST_NORM, "blk.%d.post_mlp_norm" },
1173+
},
1174+
},
11551175
{
11561176
LLM_ARCH_BITNET,
11571177
{

src/llama-arch.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,7 @@ enum llm_arch {
5858
LLM_ARCH_DEEPSEEK,
5959
LLM_ARCH_DEEPSEEK2,
6060
LLM_ARCH_CHATGLM,
61+
LLM_ARCH_GLM4,
6162
LLM_ARCH_BITNET,
6263
LLM_ARCH_T5,
6364
LLM_ARCH_T5ENCODER,
@@ -256,6 +257,8 @@ enum llm_tensor {
256257
LLM_TENSOR_ATTN_Q_NORM,
257258
LLM_TENSOR_ATTN_K_NORM,
258259
LLM_TENSOR_LAYER_OUT_NORM,
260+
LLM_TENSOR_POST_ATTN_NORM,
261+
LLM_TENSOR_POST_MLP_NORM,
259262
LLM_TENSOR_SSM_IN,
260263
LLM_TENSOR_SSM_CONV1D,
261264
LLM_TENSOR_SSM_X,

0 commit comments

Comments
 (0)