Skip to content

Commit c28932a

Browse files
committed
Updates for AFMOE
1 parent 7647992 commit c28932a

File tree

8 files changed

+264
-0
lines changed

8 files changed

+264
-0
lines changed

convert_hf_to_gguf.py

Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -891,6 +891,9 @@ def get_vocab_base_pre(self, tokenizer) -> str:
891891
if chkhsh == "9b1be57e70d20d9501b2b3186e792d81181ae36ada3903c26f9fea418cf87206":
892892
# ref: https://huggingface.co/inclusionAI/LLaDA-MoE-7B-A1B-Base
893893
res = "llada-moe"
894+
if chkhsh == "49fc0303c9e0d2c2c565c510f64b2d9b271276acdcdadff733249eda9f7d59df":
895+
# ref: https://huggingface.co/arcee-ai/AFMoE
896+
res = "afmoe"
894897

895898
if res is None:
896899
logger.warning("\n")
@@ -2275,6 +2278,90 @@ def set_gguf_parameters(self):
22752278
self.gguf_writer.add_rope_scaling_orig_ctx_len(rope_scaling["original_max_position_embeddings"])
22762279

22772280

2281+
@ModelBase.register("AfmoeForCausalLM")
2282+
class AfmoeModel(LlamaModel):
2283+
model_arch = gguf.MODEL_ARCH.AFMOE
2284+
2285+
def set_gguf_parameters(self):
2286+
super().set_gguf_parameters()
2287+
2288+
# MoE parameters
2289+
if (n_experts := self.hparams.get("num_experts")) is not None:
2290+
self.gguf_writer.add_expert_count(n_experts)
2291+
#if (n_experts_used := self.hparams.get("num_experts_per_tok")) is not None:
2292+
# self.gguf_writer.add_expert_used_count(n_experts_used)
2293+
if (n_shared_experts := self.hparams.get("num_shared_experts")) is not None:
2294+
self.gguf_writer.add_expert_shared_count(n_shared_experts)
2295+
if (moe_intermediate_size := self.hparams.get("moe_intermediate_size")) is not None:
2296+
self.gguf_writer.add_expert_feed_forward_length(moe_intermediate_size)
2297+
if (n_dense_layers := self.hparams.get("num_dense_layers")) is not None:
2298+
self.gguf_writer.add_leading_dense_block_count(n_dense_layers)
2299+
2300+
# Gating function (sigmoid)
2301+
if (score_func := self.hparams.get("score_func")) is not None and score_func == "sigmoid":
2302+
self.gguf_writer.add_expert_gating_func(gguf.ExpertGatingFuncType.SIGMOID)
2303+
2304+
# Route normalization and scaling
2305+
if (route_norm := self.hparams.get("route_norm")) is not None:
2306+
self.gguf_writer.add_expert_weights_norm(route_norm)
2307+
if (route_scale := self.hparams.get("route_scale")) is not None:
2308+
self.gguf_writer.add_expert_weights_scale(route_scale)
2309+
2310+
# Sliding window attention
2311+
if (sliding_window := self.hparams.get("sliding_window")) is not None:
2312+
self.gguf_writer.add_sliding_window(sliding_window)
2313+
2314+
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
2315+
# Handle expert weights - they're already merged in the HF format
2316+
if ".block_sparse_moe.experts.w1" in name:
2317+
assert bid is not None
2318+
return [(self.format_tensor_name(gguf.MODEL_TENSOR.FFN_GATE_EXPS, bid), data_torch)]
2319+
elif ".block_sparse_moe.experts.w2" in name:
2320+
assert bid is not None
2321+
return [(self.format_tensor_name(gguf.MODEL_TENSOR.FFN_DOWN_EXPS, bid), data_torch)]
2322+
elif ".block_sparse_moe.experts.w3" in name:
2323+
assert bid is not None
2324+
return [(self.format_tensor_name(gguf.MODEL_TENSOR.FFN_UP_EXPS, bid), data_torch)]
2325+
2326+
# Map dual normalization layers
2327+
if ".attn_norm_a." in name and bid is not None:
2328+
return [(self.format_tensor_name(gguf.MODEL_TENSOR.ATTN_NORM, bid), data_torch)]
2329+
elif ".attn_norm_b." in name and bid is not None:
2330+
return [(self.format_tensor_name(gguf.MODEL_TENSOR.ATTN_NORM_2, bid), data_torch)]
2331+
elif ".ffn_norm_a." in name and bid is not None:
2332+
return [(self.format_tensor_name(gguf.MODEL_TENSOR.FFN_NORM, bid), data_torch)]
2333+
elif ".ffn_norm_b." in name and bid is not None:
2334+
return [(self.format_tensor_name(gguf.MODEL_TENSOR.FFN_POST_NORM, bid), data_torch)]
2335+
2336+
# Map Q/K norms
2337+
elif ".self_attn.q_norm." in name and bid is not None:
2338+
return [(self.format_tensor_name(gguf.MODEL_TENSOR.ATTN_Q_NORM, bid), data_torch)]
2339+
elif ".self_attn.k_norm." in name and bid is not None:
2340+
return [(self.format_tensor_name(gguf.MODEL_TENSOR.ATTN_K_NORM, bid), data_torch)]
2341+
2342+
# Map attention gate
2343+
elif ".self_attn.gate_proj." in name and bid is not None:
2344+
return [(self.format_tensor_name(gguf.MODEL_TENSOR.ATTN_GATE, bid), data_torch)]
2345+
2346+
# Map shared experts
2347+
elif ".block_sparse_moe.shared_experts.gate_proj." in name and bid is not None:
2348+
return [(self.format_tensor_name(gguf.MODEL_TENSOR.FFN_GATE_SHEXP, bid), data_torch)]
2349+
elif ".block_sparse_moe.shared_experts.up_proj." in name and bid is not None:
2350+
return [(self.format_tensor_name(gguf.MODEL_TENSOR.FFN_UP_SHEXP, bid), data_torch)]
2351+
elif ".block_sparse_moe.shared_experts.down_proj." in name and bid is not None:
2352+
return [(self.format_tensor_name(gguf.MODEL_TENSOR.FFN_DOWN_SHEXP, bid), data_torch)]
2353+
2354+
# Map router
2355+
elif ".block_sparse_moe.router.gate." in name and bid is not None:
2356+
return [(self.format_tensor_name(gguf.MODEL_TENSOR.FFN_GATE_INP, bid), data_torch)]
2357+
2358+
# Skip expert_bias
2359+
elif "expert_bias" in name:
2360+
return []
2361+
2362+
return [(self.map_tensor_name(name), data_torch)]
2363+
2364+
22782365
@ModelBase.register(
22792366
"LlavaForConditionalGeneration", # pixtral
22802367
"Mistral3ForConditionalGeneration", # mistral small 3.1

gguf-py/gguf/constants.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -392,6 +392,7 @@ class MODEL_ARCH(IntEnum):
392392
BAILINGMOE = auto()
393393
DOTS1 = auto()
394394
ARCEE = auto()
395+
AFMOE = auto()
395396
ERNIE4_5 = auto()
396397
ERNIE4_5_MOE = auto()
397398
HUNYUAN_MOE = auto()
@@ -438,6 +439,7 @@ class MODEL_TENSOR(IntEnum):
438439
ATTN_POST_NORM = auto()
439440
ATTN_ROT_EMBD = auto()
440441
ATTN_SINKS = auto()
442+
ATTN_GATE = auto()
441443
FFN_GATE_INP = auto()
442444
FFN_GATE_INP_SHEXP = auto()
443445
FFN_NORM = auto()
@@ -451,6 +453,9 @@ class MODEL_TENSOR(IntEnum):
451453
FFN_GATE_EXP = auto()
452454
FFN_DOWN_EXP = auto()
453455
FFN_UP_EXP = auto()
456+
FFN_GATE_EXPS = auto()
457+
FFN_DOWN_EXPS = auto()
458+
FFN_UP_EXPS = auto()
454459
FFN_GATE_SHEXP = auto()
455460
FFN_DOWN_SHEXP = auto()
456461
FFN_UP_SHEXP = auto()
@@ -732,6 +737,7 @@ class MODEL_TENSOR(IntEnum):
732737
MODEL_ARCH.BAILINGMOE: "bailingmoe",
733738
MODEL_ARCH.DOTS1: "dots1",
734739
MODEL_ARCH.ARCEE: "arcee",
740+
MODEL_ARCH.AFMOE: "afmoe",
735741
MODEL_ARCH.ERNIE4_5: "ernie4_5",
736742
MODEL_ARCH.ERNIE4_5_MOE: "ernie4_5-moe",
737743
MODEL_ARCH.FALCON_H1: "falcon-h1",
@@ -777,6 +783,7 @@ class MODEL_TENSOR(IntEnum):
777783
MODEL_TENSOR.ATTN_OUT: "blk.{bid}.attn_output",
778784
MODEL_TENSOR.ATTN_ROT_EMBD: "blk.{bid}.attn_rot_embd",
779785
MODEL_TENSOR.ATTN_SINKS: "blk.{bid}.attn_sinks",
786+
MODEL_TENSOR.ATTN_GATE: "blk.{bid}.attn_gate",
780787
MODEL_TENSOR.ATTN_Q_NORM: "blk.{bid}.attn_q_norm",
781788
MODEL_TENSOR.ATTN_K_NORM: "blk.{bid}.attn_k_norm",
782789
MODEL_TENSOR.ATTN_OUT_NORM: "blk.{bid}.attn_output_norm",
@@ -800,6 +807,9 @@ class MODEL_TENSOR(IntEnum):
800807
MODEL_TENSOR.FFN_GATE_EXP: "blk.{bid}.ffn_gate_exps",
801808
MODEL_TENSOR.FFN_DOWN_EXP: "blk.{bid}.ffn_down_exps",
802809
MODEL_TENSOR.FFN_UP_EXP: "blk.{bid}.ffn_up_exps",
810+
MODEL_TENSOR.FFN_GATE_EXPS: "blk.{bid}.ffn_gate_exps",
811+
MODEL_TENSOR.FFN_DOWN_EXPS: "blk.{bid}.ffn_down_exps",
812+
MODEL_TENSOR.FFN_UP_EXPS: "blk.{bid}.ffn_up_exps",
803813
MODEL_TENSOR.FFN_EXP_PROBS_B: "blk.{bid}.exp_probs_b",
804814
MODEL_TENSOR.LAYER_OUT_NORM: "blk.{bid}.layer_output_norm",
805815
MODEL_TENSOR.PER_LAYER_TOKEN_EMBD: "per_layer_token_embd", # gemma3n
@@ -2552,6 +2562,32 @@ class MODEL_TENSOR(IntEnum):
25522562
MODEL_TENSOR.FFN_DOWN,
25532563
MODEL_TENSOR.FFN_UP,
25542564
],
2565+
MODEL_ARCH.AFMOE: [
2566+
MODEL_TENSOR.TOKEN_EMBD,
2567+
MODEL_TENSOR.OUTPUT_NORM,
2568+
MODEL_TENSOR.OUTPUT,
2569+
MODEL_TENSOR.ATTN_NORM,
2570+
MODEL_TENSOR.ATTN_NORM_2,
2571+
MODEL_TENSOR.ATTN_Q,
2572+
MODEL_TENSOR.ATTN_K,
2573+
MODEL_TENSOR.ATTN_V,
2574+
MODEL_TENSOR.ATTN_OUT,
2575+
MODEL_TENSOR.ATTN_Q_NORM,
2576+
MODEL_TENSOR.ATTN_K_NORM,
2577+
MODEL_TENSOR.ATTN_GATE,
2578+
MODEL_TENSOR.FFN_NORM,
2579+
MODEL_TENSOR.FFN_POST_NORM,
2580+
MODEL_TENSOR.FFN_GATE,
2581+
MODEL_TENSOR.FFN_DOWN,
2582+
MODEL_TENSOR.FFN_UP,
2583+
MODEL_TENSOR.FFN_GATE_INP,
2584+
MODEL_TENSOR.FFN_GATE_EXPS,
2585+
MODEL_TENSOR.FFN_DOWN_EXPS,
2586+
MODEL_TENSOR.FFN_UP_EXPS,
2587+
MODEL_TENSOR.FFN_GATE_SHEXP,
2588+
MODEL_TENSOR.FFN_UP_SHEXP,
2589+
MODEL_TENSOR.FFN_DOWN_SHEXP,
2590+
],
25552591
MODEL_ARCH.ERNIE4_5: [
25562592
MODEL_TENSOR.TOKEN_EMBD,
25572593
MODEL_TENSOR.OUTPUT_NORM,

src/llama-arch.cpp

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,7 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
8686
{ LLM_ARCH_BAILINGMOE, "bailingmoe" },
8787
{ LLM_ARCH_DOTS1, "dots1" },
8888
{ LLM_ARCH_ARCEE, "arcee" },
89+
{ LLM_ARCH_AFMOE, "afmoe" },
8990
{ LLM_ARCH_ERNIE4_5, "ernie4_5" },
9091
{ LLM_ARCH_ERNIE4_5_MOE, "ernie4_5-moe" },
9192
{ LLM_ARCH_HUNYUAN_MOE, "hunyuan-moe" },
@@ -307,6 +308,35 @@ static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_N
307308
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
308309
},
309310
},
311+
{
312+
LLM_ARCH_AFMOE,
313+
{
314+
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
315+
{ LLM_TENSOR_OUTPUT_NORM, "output_norm" },
316+
{ LLM_TENSOR_OUTPUT, "output" },
317+
{ LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
318+
{ LLM_TENSOR_ATTN_NORM_2, "blk.%d.attn_norm_2" },
319+
{ LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
320+
{ LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
321+
{ LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },
322+
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
323+
{ LLM_TENSOR_ATTN_Q_NORM, "blk.%d.attn_q_norm" },
324+
{ LLM_TENSOR_ATTN_K_NORM, "blk.%d.attn_k_norm" },
325+
{ LLM_TENSOR_ATTN_GATE, "blk.%d.attn_gate" },
326+
{ LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
327+
{ LLM_TENSOR_FFN_POST_NORM, "blk.%d.post_ffw_norm" },
328+
{ LLM_TENSOR_FFN_GATE_INP, "blk.%d.ffn_gate_inp" },
329+
{ LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" },
330+
{ LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
331+
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
332+
{ LLM_TENSOR_FFN_GATE_EXPS, "blk.%d.ffn_gate_exps" },
333+
{ LLM_TENSOR_FFN_DOWN_EXPS, "blk.%d.ffn_down_exps" },
334+
{ LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" },
335+
{ LLM_TENSOR_FFN_GATE_SHEXP, "blk.%d.ffn_gate_shexp" },
336+
{ LLM_TENSOR_FFN_UP_SHEXP, "blk.%d.ffn_up_shexp" },
337+
{ LLM_TENSOR_FFN_DOWN_SHEXP, "blk.%d.ffn_down_shexp" },
338+
},
339+
},
310340
{
311341
LLM_ARCH_LLAMA4,
312342
{
@@ -2240,6 +2270,7 @@ static const std::map<llm_tensor, llm_tensor_info> LLM_TENSOR_INFOS = {
22402270
{LLM_TENSOR_ATTN_V, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
22412271
{LLM_TENSOR_ATTN_QKV, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
22422272
{LLM_TENSOR_ATTN_OUT, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
2273+
{LLM_TENSOR_ATTN_GATE, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
22432274
{LLM_TENSOR_FFN_GATE, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
22442275
{LLM_TENSOR_FFN_DOWN, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
22452276
{LLM_TENSOR_FFN_UP, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},

src/llama-arch.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,7 @@ enum llm_arch {
9090
LLM_ARCH_BAILINGMOE,
9191
LLM_ARCH_DOTS1,
9292
LLM_ARCH_ARCEE,
93+
LLM_ARCH_AFMOE,
9394
LLM_ARCH_ERNIE4_5,
9495
LLM_ARCH_ERNIE4_5_MOE,
9596
LLM_ARCH_HUNYUAN_MOE,
@@ -287,6 +288,7 @@ enum llm_tensor {
287288
LLM_TENSOR_ATTN_POST_NORM,
288289
LLM_TENSOR_ATTN_ROT_EMBD,
289290
LLM_TENSOR_ATTN_SINKS,
291+
LLM_TENSOR_ATTN_GATE,
290292
LLM_TENSOR_FFN_GATE_INP,
291293
LLM_TENSOR_FFN_GATE_INP_SHEXP,
292294
LLM_TENSOR_FFN_NORM,

src/llama-model.cpp

Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -663,6 +663,29 @@ void llama_model::load_hparams(llama_model_loader & ml) {
663663
default: type = LLM_TYPE_UNKNOWN;
664664
}
665665
} break;
666+
case LLM_ARCH_AFMOE:
667+
{
668+
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
669+
ml.get_key(LLM_KV_LEADING_DENSE_BLOCK_COUNT, hparams.n_layer_dense_lead);
670+
ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp);
671+
ml.get_key(LLM_KV_EXPERT_COUNT, hparams.n_expert);
672+
ml.get_key(LLM_KV_EXPERT_USED_COUNT, hparams.n_expert_used);
673+
ml.get_key(LLM_KV_EXPERT_SHARED_COUNT, hparams.n_expert_shared);
674+
ml.get_key(LLM_KV_EXPERT_GATING_FUNC, hparams.expert_gating_func, false);
675+
ml.get_key(LLM_KV_EXPERT_WEIGHTS_SCALE, hparams.expert_weights_scale, false);
676+
ml.get_key(LLM_KV_EXPERT_WEIGHTS_NORM, hparams.expert_weights_norm, false);
677+
ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa, false);
678+
679+
// Default to sigmoid if not set
680+
if (hparams.expert_gating_func == LLAMA_EXPERT_GATING_FUNC_TYPE_NONE) {
681+
hparams.expert_gating_func = LLAMA_EXPERT_GATING_FUNC_TYPE_SIGMOID;
682+
}
683+
684+
switch (hparams.n_layer) {
685+
case 56: type = LLM_TYPE_1B; break;
686+
default: type = LLM_TYPE_UNKNOWN;
687+
}
688+
} break;
666689
case LLM_ARCH_DECI:
667690
{
668691
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
@@ -5519,6 +5542,70 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
55195542
layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0);
55205543
}
55215544
} break;
5545+
case LLM_ARCH_AFMOE:
5546+
{
5547+
tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
5548+
5549+
// output
5550+
output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
5551+
output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED);
5552+
5553+
// if output is NULL, init from the input tok embed
5554+
if (output == NULL) {
5555+
output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED);
5556+
}
5557+
5558+
const int64_t n_ff_exp = hparams.n_ff_exp;
5559+
const int64_t n_expert_shared = hparams.n_expert_shared;
5560+
5561+
for (int i = 0; i < n_layer; ++i) {
5562+
auto & layer = layers[i];
5563+
5564+
// dual attention normalization
5565+
layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0);
5566+
layer.attn_norm_2 = create_tensor(tn(LLM_TENSOR_ATTN_NORM_2, "weight", i), {n_embd}, 0);
5567+
5568+
// attention projections
5569+
layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd_head_k * n_head}, 0);
5570+
layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_k_gqa}, 0);
5571+
layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_v_gqa}, 0);
5572+
layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd}, 0);
5573+
5574+
// Q/K normalization
5575+
layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {n_embd_head_k}, 0);
5576+
layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd_head_k}, 0);
5577+
5578+
// attention gating
5579+
layer.wqkv_gate = create_tensor(tn(LLM_TENSOR_ATTN_GATE, "weight", i), {n_embd, n_embd_head_k * n_head}, 0);
5580+
5581+
// dual ffn normalization
5582+
layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0);
5583+
layer.ffn_post_norm = create_tensor(tn(LLM_TENSOR_FFN_POST_NORM, "weight", i), {n_embd}, 0);
5584+
5585+
if (static_cast<uint32_t>(i) >= hparams.n_layer_dense_lead) {
5586+
// MoE layers
5587+
layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}, 0);
5588+
5589+
// grouped expert weights
5590+
layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), {n_embd, n_ff_exp, n_expert}, 0);
5591+
layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), {n_ff_exp, n_embd, n_expert}, 0);
5592+
layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), {n_embd, n_ff_exp, n_expert}, 0);
5593+
5594+
// shared expert
5595+
if (n_expert_shared > 0) {
5596+
const int64_t n_ff_shexp = n_ff_exp * n_expert_shared;
5597+
layer.ffn_gate_shexp = create_tensor(tn(LLM_TENSOR_FFN_GATE_SHEXP, "weight", i), {n_embd, n_ff_shexp}, 0);
5598+
layer.ffn_down_shexp = create_tensor(tn(LLM_TENSOR_FFN_DOWN_SHEXP, "weight", i), {n_ff_shexp, n_embd}, 0);
5599+
layer.ffn_up_shexp = create_tensor(tn(LLM_TENSOR_FFN_UP_SHEXP, "weight", i), {n_embd, n_ff_shexp}, 0);
5600+
}
5601+
} else {
5602+
// Dense layers
5603+
layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0);
5604+
layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}, 0);
5605+
layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0);
5606+
}
5607+
}
5608+
} break;
55225609
case LLM_ARCH_ERNIE4_5:
55235610
case LLM_ARCH_ERNIE4_5_MOE:
55245611
{
@@ -19578,6 +19665,10 @@ ggml_cgraph * llama_model::build_graph(const llm_graph_params & params) const {
1957819665
{
1957919666
llm = std::make_unique<llm_build_arcee>(*this, params);
1958019667
} break;
19668+
case LLM_ARCH_AFMOE:
19669+
{
19670+
llm = std::make_unique<llm_build_arcee>(*this, params);
19671+
} break;
1958119672
case LLM_ARCH_ERNIE4_5:
1958219673
{
1958319674
llm = std::make_unique<llm_build_ernie4_5>(*this, params);
@@ -19776,6 +19867,7 @@ llama_rope_type llama_model_rope_type(const llama_model * model) {
1977619867
case LLM_ARCH_NEO_BERT:
1977719868
case LLM_ARCH_SMOLLM3:
1977819869
case LLM_ARCH_ARCEE:
19870+
case LLM_ARCH_AFMOE:
1977919871
case LLM_ARCH_ERNIE4_5:
1978019872
case LLM_ARCH_ERNIE4_5_MOE:
1978119873
return LLAMA_ROPE_TYPE_NORM;

src/llama-model.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -229,6 +229,7 @@ struct llama_layer {
229229
struct ggml_tensor * wk_enc = nullptr;
230230
struct ggml_tensor * wv_enc = nullptr;
231231
struct ggml_tensor * wo_enc = nullptr;
232+
struct ggml_tensor * wqkv_gate = nullptr;
232233

233234
// attention bias
234235
struct ggml_tensor * bq = nullptr;

0 commit comments

Comments
 (0)