Skip to content

Commit e1fcf8b

Browse files
bartowski1182CISC
andauthored
model : add AfmoeForCausalLM support (ggml-org#16477)
* Add AFMOE model support * Update to vocab * Add model sizing * Undo Rope change for ARCEE model * Address review comments * Update modeling code is_sliding -> use_rope, replace hard-coded logic * Fix AFMOE tokenizer * Update convert_hf_to_gguf.py Co-authored-by: Sigbjørn Skjæret <[email protected]> * Update convert_hf_to_gguf.py Co-authored-by: Sigbjørn Skjæret <[email protected]> * Update AFMoE tokenizer class identification to be more unique --------- Co-authored-by: Sigbjørn Skjæret <[email protected]>
1 parent 6cd0cf7 commit e1fcf8b

File tree

14 files changed

+541
-1
lines changed

14 files changed

+541
-1
lines changed

convert_hf_to_gguf.py

Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1124,6 +1124,9 @@ def get_vocab_base_pre(self, tokenizer) -> str:
11241124
if chkhsh == "a1e163ecab2e718a4c829d1148b6e86824ec36163bb71941c3dca9cd5ac25756":
11251125
# ref: https://huggingface.co/JetBrains/Mellum-4b-base
11261126
res = "mellum"
1127+
if chkhsh == "49fc0303c9e0d2c2c565c510f64b2d9b271276acdcdadff733249eda9f7d59df":
1128+
# ref: https://huggingface.co/arcee-ai/Trinity-Tokenizer
1129+
res = "afmoe"
11271130
if chkhsh == "9b1be57e70d20d9501b2b3186e792d81181ae36ada3903c26f9fea418cf87206":
11281131
# ref: https://huggingface.co/inclusionAI/Ling-mini-base-2.0
11291132
res = "bailingmoe2"
@@ -2533,6 +2536,81 @@ def set_gguf_parameters(self):
25332536
self.gguf_writer.add_rope_scaling_orig_ctx_len(rope_scaling["original_max_position_embeddings"])
25342537

25352538

2539+
@ModelBase.register("AfmoeForCausalLM")
2540+
class AfmoeModel(LlamaModel):
2541+
model_arch = gguf.MODEL_ARCH.AFMOE
2542+
2543+
def set_gguf_parameters(self):
2544+
super().set_gguf_parameters()
2545+
2546+
# MoE parameters
2547+
if (n_experts := self.hparams.get("num_experts")) is not None:
2548+
self.gguf_writer.add_expert_count(n_experts)
2549+
if (n_shared_experts := self.hparams.get("num_shared_experts")) is not None:
2550+
self.gguf_writer.add_expert_shared_count(n_shared_experts)
2551+
if (moe_intermediate_size := self.hparams.get("moe_intermediate_size")) is not None:
2552+
self.gguf_writer.add_expert_feed_forward_length(moe_intermediate_size)
2553+
if (n_dense_layers := self.hparams.get("num_dense_layers")) is not None:
2554+
self.gguf_writer.add_leading_dense_block_count(n_dense_layers)
2555+
2556+
# Expert Gating Function
2557+
score_func = self.hparams.get("score_func")
2558+
if score_func == "sigmoid":
2559+
self.gguf_writer.add_expert_gating_func(gguf.ExpertGatingFuncType.SIGMOID)
2560+
elif score_func == "softmax":
2561+
self.gguf_writer.add_expert_gating_func(gguf.ExpertGatingFuncType.SOFTMAX)
2562+
elif score_func is not None:
2563+
raise ValueError(f"Unsupported score_function value: {score_func}")
2564+
2565+
# Route normalization and scaling
2566+
if (route_norm := self.hparams.get("route_norm")) is not None:
2567+
self.gguf_writer.add_expert_weights_norm(route_norm)
2568+
if (route_scale := self.hparams.get("route_scale")) is not None:
2569+
self.gguf_writer.add_expert_weights_scale(route_scale)
2570+
2571+
# Sliding window attention
2572+
if (sliding_window := self.hparams.get("sliding_window")) is not None:
2573+
self.gguf_writer.add_sliding_window(sliding_window)
2574+
2575+
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
2576+
# Handle expert weights - they're already merged in the HF format
2577+
# process the experts separately
2578+
if name.find("mlp.experts") != -1:
2579+
n_experts = self.hparams["num_experts"]
2580+
assert bid is not None
2581+
2582+
if self._experts is None:
2583+
self._experts = [{} for _ in range(self.block_count)]
2584+
2585+
self._experts[bid][name] = data_torch
2586+
2587+
if len(self._experts[bid]) >= n_experts * 3:
2588+
tensors: list[tuple[str, Tensor]] = []
2589+
2590+
# merge the experts into a single 3d tensor
2591+
for w_name in ["gate_proj", "up_proj", "down_proj"]:
2592+
datas: list[Tensor] = []
2593+
2594+
for xid in range(n_experts):
2595+
ename_to_retrieve = f"model.layers.{bid}.mlp.experts.{xid}.{w_name}.weight"
2596+
datas.append(self._experts[bid][ename_to_retrieve])
2597+
del self._experts[bid][ename_to_retrieve]
2598+
2599+
data_torch = torch.stack(datas, dim=0)
2600+
merged_name = f"model.layers.{bid}.mlp.experts.{w_name}.weight"
2601+
new_name = self.map_tensor_name(merged_name)
2602+
tensors.append((new_name, data_torch))
2603+
2604+
return tensors
2605+
else:
2606+
return []
2607+
2608+
if name.endswith(".expert_bias"):
2609+
name = name.replace(".expert_bias", ".expert_bias.bias")
2610+
2611+
return [(self.map_tensor_name(name), data_torch)]
2612+
2613+
25362614
@ModelBase.register(
25372615
"LlavaForConditionalGeneration", # pixtral
25382616
"Mistral3ForConditionalGeneration", # mistral small 3.1

convert_hf_to_gguf_update.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -139,6 +139,7 @@ class TOKENIZER_TYPE(IntEnum):
139139
{"name": "lfm2", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/LiquidAI/LFM2-Tokenizer"},
140140
{"name": "exaone4", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/LGAI-EXAONE/EXAONE-4.0-32B", },
141141
{"name": "mellum", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/JetBrains/Mellum-4b-base", },
142+
{"name": "afmoe", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/arcee-ai/Trinity-Tokenizer", },
142143
{"name": "bailingmoe2", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/inclusionAI/Ling-mini-base-2.0", },
143144
{"name": "granite-docling", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/ibm-granite/granite-docling-258M", },
144145
{"name": "minimax-m2", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/MiniMaxAI/MiniMax-M2", },

gguf-py/gguf/constants.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -409,6 +409,7 @@ class MODEL_ARCH(IntEnum):
409409
BAILINGMOE2 = auto()
410410
DOTS1 = auto()
411411
ARCEE = auto()
412+
AFMOE = auto()
412413
ERNIE4_5 = auto()
413414
ERNIE4_5_MOE = auto()
414415
HUNYUAN_MOE = auto()
@@ -464,6 +465,7 @@ class MODEL_TENSOR(IntEnum):
464465
ATTN_POST_NORM = auto()
465466
ATTN_ROT_EMBD = auto()
466467
ATTN_SINKS = auto()
468+
ATTN_GATE = auto()
467469
FFN_GATE_INP = auto()
468470
FFN_GATE_INP_SHEXP = auto()
469471
FFN_NORM = auto()
@@ -776,6 +778,7 @@ class MODEL_TENSOR(IntEnum):
776778
MODEL_ARCH.BAILINGMOE2: "bailingmoe2",
777779
MODEL_ARCH.DOTS1: "dots1",
778780
MODEL_ARCH.ARCEE: "arcee",
781+
MODEL_ARCH.AFMOE: "afmoe",
779782
MODEL_ARCH.ERNIE4_5: "ernie4_5",
780783
MODEL_ARCH.ERNIE4_5_MOE: "ernie4_5-moe",
781784
MODEL_ARCH.FALCON_H1: "falcon-h1",
@@ -828,6 +831,7 @@ class MODEL_TENSOR(IntEnum):
828831
MODEL_TENSOR.ATTN_OUT: "blk.{bid}.attn_output",
829832
MODEL_TENSOR.ATTN_ROT_EMBD: "blk.{bid}.attn_rot_embd",
830833
MODEL_TENSOR.ATTN_SINKS: "blk.{bid}.attn_sinks",
834+
MODEL_TENSOR.ATTN_GATE: "blk.{bid}.attn_gate",
831835
MODEL_TENSOR.ATTN_Q_NORM: "blk.{bid}.attn_q_norm",
832836
MODEL_TENSOR.ATTN_K_NORM: "blk.{bid}.attn_k_norm",
833837
MODEL_TENSOR.ATTN_OUT_NORM: "blk.{bid}.attn_output_norm",
@@ -2693,6 +2697,33 @@ class MODEL_TENSOR(IntEnum):
26932697
MODEL_TENSOR.FFN_DOWN,
26942698
MODEL_TENSOR.FFN_UP,
26952699
],
2700+
MODEL_ARCH.AFMOE: [
2701+
MODEL_TENSOR.TOKEN_EMBD,
2702+
MODEL_TENSOR.OUTPUT_NORM,
2703+
MODEL_TENSOR.OUTPUT,
2704+
MODEL_TENSOR.ATTN_NORM,
2705+
MODEL_TENSOR.ATTN_POST_NORM,
2706+
MODEL_TENSOR.ATTN_Q,
2707+
MODEL_TENSOR.ATTN_K,
2708+
MODEL_TENSOR.ATTN_V,
2709+
MODEL_TENSOR.ATTN_OUT,
2710+
MODEL_TENSOR.ATTN_Q_NORM,
2711+
MODEL_TENSOR.ATTN_K_NORM,
2712+
MODEL_TENSOR.ATTN_GATE,
2713+
MODEL_TENSOR.FFN_GATE,
2714+
MODEL_TENSOR.FFN_DOWN,
2715+
MODEL_TENSOR.FFN_UP,
2716+
MODEL_TENSOR.FFN_GATE_INP,
2717+
MODEL_TENSOR.FFN_GATE_EXP,
2718+
MODEL_TENSOR.FFN_DOWN_EXP,
2719+
MODEL_TENSOR.FFN_UP_EXP,
2720+
MODEL_TENSOR.FFN_GATE_SHEXP,
2721+
MODEL_TENSOR.FFN_UP_SHEXP,
2722+
MODEL_TENSOR.FFN_DOWN_SHEXP,
2723+
MODEL_TENSOR.FFN_PRE_NORM,
2724+
MODEL_TENSOR.FFN_POST_NORM,
2725+
MODEL_TENSOR.FFN_EXP_PROBS_B,
2726+
],
26962727
MODEL_ARCH.ERNIE4_5: [
26972728
MODEL_TENSOR.TOKEN_EMBD,
26982729
MODEL_TENSOR.OUTPUT_NORM,

gguf-py/gguf/tensor_mapping.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -314,6 +314,10 @@ class TensorNameMap:
314314
"model.layers.{bid}.self_attn.sinks", # openai-moe
315315
),
316316

317+
MODEL_TENSOR.ATTN_GATE: (
318+
"model.layers.{bid}.self_attn.gate_proj", # afmoe
319+
),
320+
317321
# Feed-forward norm
318322
MODEL_TENSOR.FFN_NORM: (
319323
"gpt_neox.layers.{bid}.post_attention_layernorm", # gptneox
@@ -340,11 +344,12 @@ class TensorNameMap:
340344
"model.layers.{bid}.feedforward_layernorm", # apertus
341345
),
342346

343-
# Post feed-forward norm
347+
# Pre feed-forward norm
344348
MODEL_TENSOR.FFN_PRE_NORM: (
345349
"model.layers.{bid}.pre_feedforward_layernorm", # gemma2
346350
"layers.{bid}.pre_feedforward_layernorm", # embeddinggemma
347351
"model.layers.{bid}.pre_ff_layernorm.weight",
352+
"model.layers.{bid}.pre_mlp_layernorm", # afmoe
348353
),
349354

350355
# Post feed-forward norm
@@ -370,6 +375,7 @@ class TensorNameMap:
370375
"model.layers.{bid}.mlp.gate.wg", # hunyuan
371376
"model.layers.{bid}.block_sparse_moe.primary_router", # smallthinker
372377
"model.layers.{bid}.feed_forward.gate", # lfm2moe
378+
"model.layers.{bid}.mlp.router.gate", # afmoe
373379
),
374380

375381
MODEL_TENSOR.FFN_GATE_INP_SHEXP: (
@@ -380,6 +386,7 @@ class TensorNameMap:
380386
"model.layers.{bid}.mlp.gate.e_score_correction", # deepseek-v3 dots1
381387
"model.layers.{bid}.mlp.moe_statics.e_score_correction", # ernie4.5-moe
382388
"model.layers.{bid}.mlp.gate.expert_bias", # bailingmoe2
389+
"model.layers.{bid}.mlp.expert_bias", # afmoe
383390
"model.layers.{bid}.feed_forward.expert_bias", # lfm2moe
384391
"model.layers.{bid}.block_sparse_moe.e_score_correction", # minimax-m2
385392
),

src/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ add_library(llama
3535
unicode-data.cpp
3636
unicode.cpp
3737
unicode.h
38+
models/afmoe.cpp
3839
models/apertus.cpp
3940
models/arcee.cpp
4041
models/arctic.cpp

src/llama-arch.cpp

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,7 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
9090
{ LLM_ARCH_BAILINGMOE2, "bailingmoe2" },
9191
{ LLM_ARCH_DOTS1, "dots1" },
9292
{ LLM_ARCH_ARCEE, "arcee" },
93+
{ LLM_ARCH_AFMOE, "afmoe" },
9394
{ LLM_ARCH_ERNIE4_5, "ernie4_5" },
9495
{ LLM_ARCH_ERNIE4_5_MOE, "ernie4_5-moe" },
9596
{ LLM_ARCH_HUNYUAN_MOE, "hunyuan-moe" },
@@ -333,6 +334,36 @@ static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_N
333334
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
334335
},
335336
},
337+
{
338+
LLM_ARCH_AFMOE,
339+
{
340+
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
341+
{ LLM_TENSOR_OUTPUT_NORM, "output_norm" },
342+
{ LLM_TENSOR_OUTPUT, "output" },
343+
{ LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
344+
{ LLM_TENSOR_ATTN_POST_NORM, "blk.%d.post_attention_norm" },
345+
{ LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
346+
{ LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
347+
{ LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },
348+
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
349+
{ LLM_TENSOR_ATTN_Q_NORM, "blk.%d.attn_q_norm" },
350+
{ LLM_TENSOR_ATTN_K_NORM, "blk.%d.attn_k_norm" },
351+
{ LLM_TENSOR_ATTN_GATE, "blk.%d.attn_gate" },
352+
{ LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
353+
{ LLM_TENSOR_FFN_POST_NORM, "blk.%d.post_ffw_norm" },
354+
{ LLM_TENSOR_FFN_GATE_INP, "blk.%d.ffn_gate_inp" },
355+
{ LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" },
356+
{ LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
357+
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
358+
{ LLM_TENSOR_FFN_GATE_EXPS, "blk.%d.ffn_gate_exps" },
359+
{ LLM_TENSOR_FFN_DOWN_EXPS, "blk.%d.ffn_down_exps" },
360+
{ LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" },
361+
{ LLM_TENSOR_FFN_GATE_SHEXP, "blk.%d.ffn_gate_shexp" },
362+
{ LLM_TENSOR_FFN_UP_SHEXP, "blk.%d.ffn_up_shexp" },
363+
{ LLM_TENSOR_FFN_DOWN_SHEXP, "blk.%d.ffn_down_shexp" },
364+
{ LLM_TENSOR_FFN_EXP_PROBS_B, "blk.%d.exp_probs_b" },
365+
},
366+
},
336367
{
337368
LLM_ARCH_LLAMA4,
338369
{
@@ -2444,6 +2475,7 @@ static const std::map<llm_tensor, llm_tensor_info> LLM_TENSOR_INFOS = {
24442475
{LLM_TENSOR_ATTN_V, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
24452476
{LLM_TENSOR_ATTN_QKV, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
24462477
{LLM_TENSOR_ATTN_OUT, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
2478+
{LLM_TENSOR_ATTN_GATE, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
24472479
{LLM_TENSOR_FFN_GATE, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
24482480
{LLM_TENSOR_FFN_DOWN, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
24492481
{LLM_TENSOR_FFN_UP, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},

src/llama-arch.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,7 @@ enum llm_arch {
9494
LLM_ARCH_BAILINGMOE2,
9595
LLM_ARCH_DOTS1,
9696
LLM_ARCH_ARCEE,
97+
LLM_ARCH_AFMOE,
9798
LLM_ARCH_ERNIE4_5,
9899
LLM_ARCH_ERNIE4_5_MOE,
99100
LLM_ARCH_HUNYUAN_MOE,
@@ -312,6 +313,7 @@ enum llm_tensor {
312313
LLM_TENSOR_ATTN_POST_NORM,
313314
LLM_TENSOR_ATTN_ROT_EMBD,
314315
LLM_TENSOR_ATTN_SINKS,
316+
LLM_TENSOR_ATTN_GATE,
315317
LLM_TENSOR_FFN_GATE_INP,
316318
LLM_TENSOR_FFN_GATE_INP_SHEXP,
317319
LLM_TENSOR_FFN_NORM,

0 commit comments

Comments
 (0)