Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
79 changes: 79 additions & 0 deletions convert_hf_to_gguf.py
Original file line number Diff line number Diff line change
Expand Up @@ -894,6 +894,9 @@ def get_vocab_base_pre(self, tokenizer) -> str:
if chkhsh == "9b1be57e70d20d9501b2b3186e792d81181ae36ada3903c26f9fea418cf87206":
# ref: https://huggingface.co/inclusionAI/LLaDA-MoE-7B-A1B-Base
res = "llada-moe"
if chkhsh == "49fc0303c9e0d2c2c565c510f64b2d9b271276acdcdadff733249eda9f7d59df":
# ref: https://huggingface.co/arcee-ai/Trinity-Tokenizer
res = "afmoe"
if chkhsh == "53e325976a6e142379c19b09afcae354f2f496f147afa8f9e189a33fe4e3024e":
# ref: https://huggingface.co/ibm-granite/granite-docling-258M
res = "granite-docling"
Expand Down Expand Up @@ -2283,6 +2286,82 @@ def set_gguf_parameters(self):
self.gguf_writer.add_rope_scaling_orig_ctx_len(rope_scaling["original_max_position_embeddings"])


@ModelBase.register("AfmoeForCausalLM")
class AfmoeModel(LlamaModel):
model_arch = gguf.MODEL_ARCH.AFMOE

def set_gguf_parameters(self):
super().set_gguf_parameters()

# MoE parameters
if (n_experts := self.hparams.get("num_experts")) is not None:
self.gguf_writer.add_expert_count(n_experts)
if (n_shared_experts := self.hparams.get("num_shared_experts")) is not None:
self.gguf_writer.add_expert_shared_count(n_shared_experts)
if (moe_intermediate_size := self.hparams.get("moe_intermediate_size")) is not None:
self.gguf_writer.add_expert_feed_forward_length(moe_intermediate_size)
if (n_dense_layers := self.hparams.get("num_dense_layers")) is not None:
self.gguf_writer.add_leading_dense_block_count(n_dense_layers)

# Gating function (sigmoid)
if (score_func := self.hparams.get("score_func")) is not None and score_func == "sigmoid":
self.gguf_writer.add_expert_gating_func(gguf.ExpertGatingFuncType.SIGMOID)

# Route normalization and scaling
if (route_norm := self.hparams.get("route_norm")) is not None:
self.gguf_writer.add_expert_weights_norm(route_norm)
if (route_scale := self.hparams.get("route_scale")) is not None:
self.gguf_writer.add_expert_weights_scale(route_scale)

# Sliding window attention
if (sliding_window := self.hparams.get("sliding_window")) is not None:
self.gguf_writer.add_sliding_window(sliding_window)

def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
# Handle expert weights - they're already merged in the HF format
if ".block_sparse_moe.experts.w1" in name:
assert bid is not None
return [(self.format_tensor_name(gguf.MODEL_TENSOR.FFN_GATE_EXP, bid), data_torch)]
elif ".block_sparse_moe.experts.w2" in name:
assert bid is not None
return [(self.format_tensor_name(gguf.MODEL_TENSOR.FFN_DOWN_EXP, bid), data_torch)]
elif ".block_sparse_moe.experts.w3" in name:
assert bid is not None
return [(self.format_tensor_name(gguf.MODEL_TENSOR.FFN_UP_EXP, bid), data_torch)]

# Map dual normalization layers
if ".attn_norm_a." in name and bid is not None:
return [(self.format_tensor_name(gguf.MODEL_TENSOR.ATTN_NORM, bid), data_torch)]
elif ".attn_norm_b." in name and bid is not None:
return [(self.format_tensor_name(gguf.MODEL_TENSOR.ATTN_NORM_2, bid), data_torch)]
elif ".ffn_norm_a." in name and bid is not None:
return [(self.format_tensor_name(gguf.MODEL_TENSOR.FFN_NORM, bid), data_torch)]
elif ".ffn_norm_b." in name and bid is not None:
return [(self.format_tensor_name(gguf.MODEL_TENSOR.FFN_POST_NORM, bid), data_torch)]

# Map attention gate
elif ".self_attn.gate_proj." in name and bid is not None:
return [(self.format_tensor_name(gguf.MODEL_TENSOR.ATTN_GATE, bid), data_torch)]

# Map shared experts
elif ".block_sparse_moe.shared_experts.gate_proj." in name and bid is not None:
return [(self.format_tensor_name(gguf.MODEL_TENSOR.FFN_GATE_SHEXP, bid), data_torch)]
elif ".block_sparse_moe.shared_experts.up_proj." in name and bid is not None:
return [(self.format_tensor_name(gguf.MODEL_TENSOR.FFN_UP_SHEXP, bid), data_torch)]
elif ".block_sparse_moe.shared_experts.down_proj." in name and bid is not None:
return [(self.format_tensor_name(gguf.MODEL_TENSOR.FFN_DOWN_SHEXP, bid), data_torch)]

# Map router
elif ".block_sparse_moe.router.gate." in name and bid is not None:
return [(self.format_tensor_name(gguf.MODEL_TENSOR.FFN_GATE_INP, bid), data_torch)]

# Skip expert_bias
elif "expert_bias" in name:
return []

return [(self.map_tensor_name(name), data_torch)]


@ModelBase.register(
"LlavaForConditionalGeneration", # pixtral
"Mistral3ForConditionalGeneration", # mistral small 3.1
Expand Down
1 change: 1 addition & 0 deletions convert_hf_to_gguf_update.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,7 @@ class TOKENIZER_TYPE(IntEnum):
{"name": "exaone4", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/LGAI-EXAONE/EXAONE-4.0-32B", },
{"name": "mellum", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/JetBrains/Mellum-4b-base", },
{"name": "llada-moe", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/inclusionAI/LLaDA-MoE-7B-A1B-Base", },
{"name": "afmoe", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/arcee-ai/Trinity-Tokenizer", },
{"name": "granite-docling", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/ibm-granite/granite-docling-258M", },
]

Expand Down
30 changes: 30 additions & 0 deletions gguf-py/gguf/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -402,6 +402,7 @@ class MODEL_ARCH(IntEnum):
BAILINGMOE = auto()
DOTS1 = auto()
ARCEE = auto()
AFMOE = auto()
ERNIE4_5 = auto()
ERNIE4_5_MOE = auto()
HUNYUAN_MOE = auto()
Expand Down Expand Up @@ -452,6 +453,7 @@ class MODEL_TENSOR(IntEnum):
ATTN_POST_NORM = auto()
ATTN_ROT_EMBD = auto()
ATTN_SINKS = auto()
ATTN_GATE = auto()
FFN_GATE_INP = auto()
FFN_GATE_INP_SHEXP = auto()
FFN_NORM = auto()
Expand Down Expand Up @@ -746,6 +748,7 @@ class MODEL_TENSOR(IntEnum):
MODEL_ARCH.BAILINGMOE: "bailingmoe",
MODEL_ARCH.DOTS1: "dots1",
MODEL_ARCH.ARCEE: "arcee",
MODEL_ARCH.AFMOE: "afmoe",
MODEL_ARCH.ERNIE4_5: "ernie4_5",
MODEL_ARCH.ERNIE4_5_MOE: "ernie4_5-moe",
MODEL_ARCH.FALCON_H1: "falcon-h1",
Expand Down Expand Up @@ -795,6 +798,7 @@ class MODEL_TENSOR(IntEnum):
MODEL_TENSOR.ATTN_OUT: "blk.{bid}.attn_output",
MODEL_TENSOR.ATTN_ROT_EMBD: "blk.{bid}.attn_rot_embd",
MODEL_TENSOR.ATTN_SINKS: "blk.{bid}.attn_sinks",
MODEL_TENSOR.ATTN_GATE: "blk.{bid}.attn_gate",
MODEL_TENSOR.ATTN_Q_NORM: "blk.{bid}.attn_q_norm",
MODEL_TENSOR.ATTN_K_NORM: "blk.{bid}.attn_k_norm",
MODEL_TENSOR.ATTN_OUT_NORM: "blk.{bid}.attn_output_norm",
Expand Down Expand Up @@ -2572,6 +2576,32 @@ class MODEL_TENSOR(IntEnum):
MODEL_TENSOR.FFN_DOWN,
MODEL_TENSOR.FFN_UP,
],
MODEL_ARCH.AFMOE: [
MODEL_TENSOR.TOKEN_EMBD,
MODEL_TENSOR.OUTPUT_NORM,
MODEL_TENSOR.OUTPUT,
MODEL_TENSOR.ATTN_NORM,
MODEL_TENSOR.ATTN_NORM_2,
MODEL_TENSOR.ATTN_Q,
MODEL_TENSOR.ATTN_K,
MODEL_TENSOR.ATTN_V,
MODEL_TENSOR.ATTN_OUT,
MODEL_TENSOR.ATTN_Q_NORM,
MODEL_TENSOR.ATTN_K_NORM,
MODEL_TENSOR.ATTN_GATE,
MODEL_TENSOR.FFN_NORM,
MODEL_TENSOR.FFN_POST_NORM,
MODEL_TENSOR.FFN_GATE,
MODEL_TENSOR.FFN_DOWN,
MODEL_TENSOR.FFN_UP,
MODEL_TENSOR.FFN_GATE_INP,
MODEL_TENSOR.FFN_GATE_EXP,
MODEL_TENSOR.FFN_DOWN_EXP,
MODEL_TENSOR.FFN_UP_EXP,
MODEL_TENSOR.FFN_GATE_SHEXP,
MODEL_TENSOR.FFN_UP_SHEXP,
MODEL_TENSOR.FFN_DOWN_SHEXP,
],
MODEL_ARCH.ERNIE4_5: [
MODEL_TENSOR.TOKEN_EMBD,
MODEL_TENSOR.OUTPUT_NORM,
Expand Down
Binary file added models/ggml-vocab-afmoe.gguf
Binary file not shown.
112 changes: 112 additions & 0 deletions models/ggml-vocab-afmoe.gguf.inp
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
ied 4 ½ months
__ggml_vocab_test__
Äpfel
__ggml_vocab_test__

__ggml_vocab_test__

__ggml_vocab_test__

__ggml_vocab_test__

__ggml_vocab_test__

__ggml_vocab_test__


__ggml_vocab_test__



__ggml_vocab_test__




__ggml_vocab_test__


__ggml_vocab_test__
Hello world
__ggml_vocab_test__
Hello world
__ggml_vocab_test__
Hello World
__ggml_vocab_test__
Hello World
__ggml_vocab_test__
Hello World!
__ggml_vocab_test__
Hello, world!
__ggml_vocab_test__
Hello, world!
__ggml_vocab_test__
this is 🦙.cpp
__ggml_vocab_test__
w048 7tuijk dsdfhu
__ggml_vocab_test__
нещо на Български
__ggml_vocab_test__
កាន់តែពិសេសអាចខលចេញ
__ggml_vocab_test__
🚀 (normal) 😶‍🌫️ (multiple emojis concatenated) ✅ (only emoji that has its own token)
__ggml_vocab_test__
Hello
__ggml_vocab_test__
Hello
__ggml_vocab_test__
Hello
__ggml_vocab_test__
Hello
__ggml_vocab_test__
Hello
__ggml_vocab_test__
Hello
Hello
__ggml_vocab_test__
(
__ggml_vocab_test__

=
__ggml_vocab_test__
' era
__ggml_vocab_test__
Hello, y'all! How are you 😁 ?我想在apple工作1314151天~
__ggml_vocab_test__
!!!!!!
__ggml_vocab_test__
3
__ggml_vocab_test__
33
__ggml_vocab_test__
333
__ggml_vocab_test__
3333
__ggml_vocab_test__
33333
__ggml_vocab_test__
333333
__ggml_vocab_test__
3333333
__ggml_vocab_test__
33333333
__ggml_vocab_test__
333333333
__ggml_vocab_test__
Cửa Việt
__ggml_vocab_test__
discards
__ggml_vocab_test__











🚀 (normal) 😶‍🌫️ (multiple emojis concatenated) ✅ 🦙🦙 3 33 333 3333 33333 333333 3333333 33333333 3.3 3..3 3...3 កាន់តែពិសេសអាច😁 ?我想在apple工作1314151天~ ------======= нещо на Български ''''''```````""""......!!!!!!?????? I've been 'told he's there, 'RE you sure? 'M not sure I'll make it, 'D you like some tea? We'Ve a'lL
__ggml_vocab_test__
46 changes: 46 additions & 0 deletions models/ggml-vocab-afmoe.gguf.out
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
1129 252 51 252 20861 3621
49116 25524 343

252
288
344
229
230
327
1866
4402
14795 1117
30197 1117
14795 3295
30197 3295
30197 3295 32
14795 43 1117 32
30197 43 1117 32
483 351 69865 279 45 11112
118 18799 252 54 115 4546 30869 25372 4191 13934
23835 183893 7432 30515 125974 185839 20324
124940 92255 273 160060 191869 44968 256 188211 21207 147 142156 195704 142156 21207 127 92255 259 21207 255 190792 21207 259 195704 21207 263
12479 387 10171 40 22860 146 18932 15540 136 10094 387 49707 77415 91293 40 70574 387 9266 56494 384 651 692 1204 9776 40
14795
30197
252 30197
288 30197
344 30197
344 30197 230 344 30197
387
230 399
38 6260
14795 43 366 76896 32 822 429 383 22860 255 2972 111778 3712 27304 19409 48 23988 18044 13814 73996
183574
50
2158
11805
50 11805
2158 11805
11805 11805
50 11805 11805
2158 11805 11805
11805 11805 11805
66 70789 96 140747
104867
144635 20623 120822 22300 4402 71947 2759 24373 12479 387 10171 40 22860 146 18932 15540 136 10094 387 49707 77415 91293 40 70574 69865 279 63816 279 252 50 252 2158 252 11805 252 50 11805 252 2158 11805 252 11805 11805 252 50 11805 11805 252 2158 11805 11805 252 50 45 50 252 50 634 50 252 50 1472 50 252 124940 92255 273 160060 191869 44968 256 188211 21207 147 142156 195704 142156 21207 127 92255 259 45614 255 2972 111778 3712 27304 19409 48 23988 18044 13814 73996 79520 1235 23427 13373 183893 7432 30515 125974 185839 20324 27123 36632 25121 3124 36057 36678 183574 31148 10446 365 1908 874 578 63490 438 414 765 43 578 1954 383 2259 62 578 76 487 2259 365 2130 960 394 43 578 67 383 679 766 8748 62 1155 38 35185 290 66450 75
31 changes: 31 additions & 0 deletions src/llama-arch.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,7 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
{ LLM_ARCH_BAILINGMOE, "bailingmoe" },
{ LLM_ARCH_DOTS1, "dots1" },
{ LLM_ARCH_ARCEE, "arcee" },
{ LLM_ARCH_AFMOE, "afmoe" },
{ LLM_ARCH_ERNIE4_5, "ernie4_5" },
{ LLM_ARCH_ERNIE4_5_MOE, "ernie4_5-moe" },
{ LLM_ARCH_HUNYUAN_MOE, "hunyuan-moe" },
Expand Down Expand Up @@ -319,6 +320,35 @@ static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_N
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
},
},
{
LLM_ARCH_AFMOE,
{
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
{ LLM_TENSOR_OUTPUT_NORM, "output_norm" },
{ LLM_TENSOR_OUTPUT, "output" },
{ LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
{ LLM_TENSOR_ATTN_NORM_2, "blk.%d.attn_norm_2" },
{ LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
{ LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
{ LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
{ LLM_TENSOR_ATTN_Q_NORM, "blk.%d.attn_q_norm" },
{ LLM_TENSOR_ATTN_K_NORM, "blk.%d.attn_k_norm" },
{ LLM_TENSOR_ATTN_GATE, "blk.%d.attn_gate" },
{ LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
{ LLM_TENSOR_FFN_POST_NORM, "blk.%d.post_ffw_norm" },
{ LLM_TENSOR_FFN_GATE_INP, "blk.%d.ffn_gate_inp" },
{ LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" },
{ LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
{ LLM_TENSOR_FFN_GATE_EXPS, "blk.%d.ffn_gate_exps" },
{ LLM_TENSOR_FFN_DOWN_EXPS, "blk.%d.ffn_down_exps" },
{ LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" },
{ LLM_TENSOR_FFN_GATE_SHEXP, "blk.%d.ffn_gate_shexp" },
{ LLM_TENSOR_FFN_UP_SHEXP, "blk.%d.ffn_up_shexp" },
{ LLM_TENSOR_FFN_DOWN_SHEXP, "blk.%d.ffn_down_shexp" },
},
},
{
LLM_ARCH_LLAMA4,
{
Expand Down Expand Up @@ -2301,6 +2331,7 @@ static const std::map<llm_tensor, llm_tensor_info> LLM_TENSOR_INFOS = {
{LLM_TENSOR_ATTN_V, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_ATTN_QKV, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_ATTN_OUT, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_ATTN_GATE, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_FFN_GATE, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_FFN_DOWN, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_FFN_UP, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
Expand Down
2 changes: 2 additions & 0 deletions src/llama-arch.h
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,7 @@ enum llm_arch {
LLM_ARCH_BAILINGMOE,
LLM_ARCH_DOTS1,
LLM_ARCH_ARCEE,
LLM_ARCH_AFMOE,
LLM_ARCH_ERNIE4_5,
LLM_ARCH_ERNIE4_5_MOE,
LLM_ARCH_HUNYUAN_MOE,
Expand Down Expand Up @@ -302,6 +303,7 @@ enum llm_tensor {
LLM_TENSOR_ATTN_POST_NORM,
LLM_TENSOR_ATTN_ROT_EMBD,
LLM_TENSOR_ATTN_SINKS,
LLM_TENSOR_ATTN_GATE,
LLM_TENSOR_FFN_GATE_INP,
LLM_TENSOR_FFN_GATE_INP_SHEXP,
LLM_TENSOR_FFN_NORM,
Expand Down
Loading