Skip to content
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
79 changes: 79 additions & 0 deletions convert_hf_to_gguf.py
Original file line number Diff line number Diff line change
Expand Up @@ -1048,6 +1048,9 @@ def get_vocab_base_pre(self, tokenizer) -> str:
if chkhsh == "a1e163ecab2e718a4c829d1148b6e86824ec36163bb71941c3dca9cd5ac25756":
# ref: https://huggingface.co/JetBrains/Mellum-4b-base
res = "mellum"
if chkhsh == "49fc0303c9e0d2c2c565c510f64b2d9b271276acdcdadff733249eda9f7d59df":
# ref: https://huggingface.co/arcee-ai/Trinity-Tokenizer
res = "afmoe"
if chkhsh == "9b1be57e70d20d9501b2b3186e792d81181ae36ada3903c26f9fea418cf87206":
# ref: https://huggingface.co/inclusionAI/Ling-mini-base-2.0
res = "bailingmoe2"
Expand Down Expand Up @@ -2457,6 +2460,82 @@ def set_gguf_parameters(self):
self.gguf_writer.add_rope_scaling_orig_ctx_len(rope_scaling["original_max_position_embeddings"])


@ModelBase.register("AfmoeForCausalLM")
class AfmoeModel(LlamaModel):
model_arch = gguf.MODEL_ARCH.AFMOE

def set_gguf_parameters(self):
super().set_gguf_parameters()

# MoE parameters
if (n_experts := self.hparams.get("num_experts")) is not None:
self.gguf_writer.add_expert_count(n_experts)
if (n_shared_experts := self.hparams.get("num_shared_experts")) is not None:
self.gguf_writer.add_expert_shared_count(n_shared_experts)
if (moe_intermediate_size := self.hparams.get("moe_intermediate_size")) is not None:
self.gguf_writer.add_expert_feed_forward_length(moe_intermediate_size)
if (n_dense_layers := self.hparams.get("num_dense_layers")) is not None:
self.gguf_writer.add_leading_dense_block_count(n_dense_layers)

# Expert Gating Function
score_func = self.hparams.get("score_func", "sigmoid")
if score_func == "sigmoid":
self.gguf_writer.add_expert_gating_func(gguf.ExpertGatingFuncType.SIGMOID)
elif score_func == "softmax":
self.gguf_writer.add_expert_gating_func(gguf.ExpertGatingFuncType.SOFTMAX)
else:
raise ValueError(f"Unsupported score_function value: {score_func}")

# Route normalization and scaling
if (route_norm := self.hparams.get("route_norm")) is not None:
self.gguf_writer.add_expert_weights_norm(route_norm)
if (route_scale := self.hparams.get("route_scale")) is not None:
self.gguf_writer.add_expert_weights_scale(route_scale)

# Sliding window attention
if (sliding_window := self.hparams.get("sliding_window")) is not None:
self.gguf_writer.add_sliding_window(sliding_window)

def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
# Handle expert weights - they're already merged in the HF format
# process the experts separately
if name.find("mlp.experts") != -1:
n_experts = self.hparams["num_experts"]
assert bid is not None

if self._experts is None:
self._experts = [{} for _ in range(self.block_count)]

self._experts[bid][name] = data_torch

if len(self._experts[bid]) >= n_experts * 3:
tensors: list[tuple[str, Tensor]] = []

# merge the experts into a single 3d tensor
for w_name in ["gate_proj", "up_proj", "down_proj"]:
datas: list[Tensor] = []

for xid in range(n_experts):
ename_to_retrieve = f"model.layers.{bid}.mlp.experts.{xid}.{w_name}.weight"
datas.append(self._experts[bid][ename_to_retrieve])
del self._experts[bid][ename_to_retrieve]

data_torch = torch.stack(datas, dim=0)
merged_name = f"model.layers.{bid}.mlp.experts.{w_name}.weight"
new_name = self.map_tensor_name(merged_name)
tensors.append((new_name, data_torch))

return tensors
else:
return []


if name.endswith(".expert_bias"):
name = name.replace(".expert_bias", ".expert_bias.bias")

return [(self.map_tensor_name(name), data_torch)]


@ModelBase.register(
"LlavaForConditionalGeneration", # pixtral
"Mistral3ForConditionalGeneration", # mistral small 3.1
Expand Down
1 change: 1 addition & 0 deletions convert_hf_to_gguf_update.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,7 @@ class TOKENIZER_TYPE(IntEnum):
{"name": "lfm2", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/LiquidAI/LFM2-Tokenizer"},
{"name": "exaone4", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/LGAI-EXAONE/EXAONE-4.0-32B", },
{"name": "mellum", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/JetBrains/Mellum-4b-base", },
{"name": "afmoe", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/arcee-ai/Trinity-Tokenizer", },
{"name": "bailingmoe2", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/inclusionAI/Ling-mini-base-2.0", },
{"name": "granite-docling", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/ibm-granite/granite-docling-258M", },
{"name": "minimax-m2", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/MiniMaxAI/MiniMax-M2", },
Expand Down
31 changes: 31 additions & 0 deletions gguf-py/gguf/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -409,6 +409,7 @@ class MODEL_ARCH(IntEnum):
BAILINGMOE2 = auto()
DOTS1 = auto()
ARCEE = auto()
AFMOE = auto()
ERNIE4_5 = auto()
ERNIE4_5_MOE = auto()
HUNYUAN_MOE = auto()
Expand Down Expand Up @@ -464,6 +465,7 @@ class MODEL_TENSOR(IntEnum):
ATTN_POST_NORM = auto()
ATTN_ROT_EMBD = auto()
ATTN_SINKS = auto()
ATTN_GATE = auto()
FFN_GATE_INP = auto()
FFN_GATE_INP_SHEXP = auto()
FFN_NORM = auto()
Expand Down Expand Up @@ -776,6 +778,7 @@ class MODEL_TENSOR(IntEnum):
MODEL_ARCH.BAILINGMOE2: "bailingmoe2",
MODEL_ARCH.DOTS1: "dots1",
MODEL_ARCH.ARCEE: "arcee",
MODEL_ARCH.AFMOE: "afmoe",
MODEL_ARCH.ERNIE4_5: "ernie4_5",
MODEL_ARCH.ERNIE4_5_MOE: "ernie4_5-moe",
MODEL_ARCH.FALCON_H1: "falcon-h1",
Expand Down Expand Up @@ -828,6 +831,7 @@ class MODEL_TENSOR(IntEnum):
MODEL_TENSOR.ATTN_OUT: "blk.{bid}.attn_output",
MODEL_TENSOR.ATTN_ROT_EMBD: "blk.{bid}.attn_rot_embd",
MODEL_TENSOR.ATTN_SINKS: "blk.{bid}.attn_sinks",
MODEL_TENSOR.ATTN_GATE: "blk.{bid}.attn_gate",
MODEL_TENSOR.ATTN_Q_NORM: "blk.{bid}.attn_q_norm",
MODEL_TENSOR.ATTN_K_NORM: "blk.{bid}.attn_k_norm",
MODEL_TENSOR.ATTN_OUT_NORM: "blk.{bid}.attn_output_norm",
Expand Down Expand Up @@ -2693,6 +2697,33 @@ class MODEL_TENSOR(IntEnum):
MODEL_TENSOR.FFN_DOWN,
MODEL_TENSOR.FFN_UP,
],
MODEL_ARCH.AFMOE: [
MODEL_TENSOR.TOKEN_EMBD,
MODEL_TENSOR.OUTPUT_NORM,
MODEL_TENSOR.OUTPUT,
MODEL_TENSOR.ATTN_NORM,
MODEL_TENSOR.ATTN_POST_NORM,
MODEL_TENSOR.ATTN_Q,
MODEL_TENSOR.ATTN_K,
MODEL_TENSOR.ATTN_V,
MODEL_TENSOR.ATTN_OUT,
MODEL_TENSOR.ATTN_Q_NORM,
MODEL_TENSOR.ATTN_K_NORM,
MODEL_TENSOR.ATTN_GATE,
MODEL_TENSOR.FFN_GATE,
MODEL_TENSOR.FFN_DOWN,
MODEL_TENSOR.FFN_UP,
MODEL_TENSOR.FFN_GATE_INP,
MODEL_TENSOR.FFN_GATE_EXP,
MODEL_TENSOR.FFN_DOWN_EXP,
MODEL_TENSOR.FFN_UP_EXP,
MODEL_TENSOR.FFN_GATE_SHEXP,
MODEL_TENSOR.FFN_UP_SHEXP,
MODEL_TENSOR.FFN_DOWN_SHEXP,
MODEL_TENSOR.FFN_PRE_NORM,
MODEL_TENSOR.FFN_POST_NORM,
MODEL_TENSOR.FFN_EXP_PROBS_B,
],
MODEL_ARCH.ERNIE4_5: [
MODEL_TENSOR.TOKEN_EMBD,
MODEL_TENSOR.OUTPUT_NORM,
Expand Down
9 changes: 8 additions & 1 deletion gguf-py/gguf/tensor_mapping.py
Original file line number Diff line number Diff line change
Expand Up @@ -314,6 +314,10 @@ class TensorNameMap:
"model.layers.{bid}.self_attn.sinks", # openai-moe
),

MODEL_TENSOR.ATTN_GATE: (
"model.layers.{bid}.self_attn.gate_proj", # afmoe
),

# Feed-forward norm
MODEL_TENSOR.FFN_NORM: (
"gpt_neox.layers.{bid}.post_attention_layernorm", # gptneox
Expand All @@ -340,11 +344,12 @@ class TensorNameMap:
"model.layers.{bid}.feedforward_layernorm", # apertus
),

# Post feed-forward norm
# Pre feed-forward norm
MODEL_TENSOR.FFN_PRE_NORM: (
"model.layers.{bid}.pre_feedforward_layernorm", # gemma2
"layers.{bid}.pre_feedforward_layernorm", # embeddinggemma
"model.layers.{bid}.pre_ff_layernorm.weight",
"model.layers.{bid}.pre_mlp_layernorm", # afmoe
),

# Post feed-forward norm
Expand All @@ -370,6 +375,7 @@ class TensorNameMap:
"model.layers.{bid}.mlp.gate.wg", # hunyuan
"model.layers.{bid}.block_sparse_moe.primary_router", # smallthinker
"model.layers.{bid}.feed_forward.gate", # lfm2moe
"model.layers.{bid}.mlp.router.gate", # afmoe
),

MODEL_TENSOR.FFN_GATE_INP_SHEXP: (
Expand All @@ -380,6 +386,7 @@ class TensorNameMap:
"model.layers.{bid}.mlp.gate.e_score_correction", # deepseek-v3 dots1
"model.layers.{bid}.mlp.moe_statics.e_score_correction", # ernie4.5-moe
"model.layers.{bid}.mlp.gate.expert_bias", # bailingmoe2
"model.layers.{bid}.mlp.expert_bias", # afmoe
"model.layers.{bid}.feed_forward.expert_bias", # lfm2moe
"model.layers.{bid}.block_sparse_moe.e_score_correction", # minimax-m2
),
Expand Down
1 change: 1 addition & 0 deletions src/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ add_library(llama
unicode-data.cpp
unicode.cpp
unicode.h
models/afmoe.cpp
models/apertus.cpp
models/arcee.cpp
models/arctic.cpp
Expand Down
32 changes: 32 additions & 0 deletions src/llama-arch.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,7 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
{ LLM_ARCH_BAILINGMOE2, "bailingmoe2" },
{ LLM_ARCH_DOTS1, "dots1" },
{ LLM_ARCH_ARCEE, "arcee" },
{ LLM_ARCH_AFMOE, "afmoe" },
{ LLM_ARCH_ERNIE4_5, "ernie4_5" },
{ LLM_ARCH_ERNIE4_5_MOE, "ernie4_5-moe" },
{ LLM_ARCH_HUNYUAN_MOE, "hunyuan-moe" },
Expand Down Expand Up @@ -333,6 +334,36 @@ static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_N
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
},
},
{
LLM_ARCH_AFMOE,
{
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
{ LLM_TENSOR_OUTPUT_NORM, "output_norm" },
{ LLM_TENSOR_OUTPUT, "output" },
{ LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
{ LLM_TENSOR_ATTN_POST_NORM, "blk.%d.post_attention_norm" },
{ LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
{ LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
{ LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
{ LLM_TENSOR_ATTN_Q_NORM, "blk.%d.attn_q_norm" },
{ LLM_TENSOR_ATTN_K_NORM, "blk.%d.attn_k_norm" },
{ LLM_TENSOR_ATTN_GATE, "blk.%d.attn_gate" },
{ LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
{ LLM_TENSOR_FFN_POST_NORM, "blk.%d.post_ffw_norm" },
{ LLM_TENSOR_FFN_GATE_INP, "blk.%d.ffn_gate_inp" },
{ LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" },
{ LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
{ LLM_TENSOR_FFN_GATE_EXPS, "blk.%d.ffn_gate_exps" },
{ LLM_TENSOR_FFN_DOWN_EXPS, "blk.%d.ffn_down_exps" },
{ LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" },
{ LLM_TENSOR_FFN_GATE_SHEXP, "blk.%d.ffn_gate_shexp" },
{ LLM_TENSOR_FFN_UP_SHEXP, "blk.%d.ffn_up_shexp" },
{ LLM_TENSOR_FFN_DOWN_SHEXP, "blk.%d.ffn_down_shexp" },
{ LLM_TENSOR_FFN_EXP_PROBS_B, "blk.%d.exp_probs_b" },
},
},
{
LLM_ARCH_LLAMA4,
{
Expand Down Expand Up @@ -2444,6 +2475,7 @@ static const std::map<llm_tensor, llm_tensor_info> LLM_TENSOR_INFOS = {
{LLM_TENSOR_ATTN_V, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_ATTN_QKV, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_ATTN_OUT, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_ATTN_GATE, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_FFN_GATE, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_FFN_DOWN, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_FFN_UP, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
Expand Down
2 changes: 2 additions & 0 deletions src/llama-arch.h
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,7 @@ enum llm_arch {
LLM_ARCH_BAILINGMOE2,
LLM_ARCH_DOTS1,
LLM_ARCH_ARCEE,
LLM_ARCH_AFMOE,
LLM_ARCH_ERNIE4_5,
LLM_ARCH_ERNIE4_5_MOE,
LLM_ARCH_HUNYUAN_MOE,
Expand Down Expand Up @@ -312,6 +313,7 @@ enum llm_tensor {
LLM_TENSOR_ATTN_POST_NORM,
LLM_TENSOR_ATTN_ROT_EMBD,
LLM_TENSOR_ATTN_SINKS,
LLM_TENSOR_ATTN_GATE,
LLM_TENSOR_FFN_GATE_INP,
LLM_TENSOR_FFN_GATE_INP_SHEXP,
LLM_TENSOR_FFN_NORM,
Expand Down
Loading
Loading