Skip to content

Commit 256414a

Browse files
tamarPaltamarPal
authored andcommitted
feat: Add Megrez-MoE architecture support
Implements complete support for Megrez-MoE (Mixture of Experts) models: - Add LLM_ARCH_MEGREZ_MOE architecture enum and mappings - Implement build_mergez_moe_ffn() with sigmoid+bias gating - Add llm_build_megrez_moe class for full model graph construction - Support 31-layer architecture (layer 0: dense FFN, layers 1-30: MoE) - Implement expert sharing pattern with 64 experts, 6 used per token, 4 shared - Load all model hyperparameters and 372 tensors correctly - Configure NEOX RoPE type for proper positional encoding Tested with Megrez2-3x7B-A3B_Q4_K_M.gguf model. All 39 llama.cpp tests pass successfully. Output verified to match infinigence/llama.cpp reference implementation. Note: Use --no-warmup flag to avoid warmup memory allocation issue.
1 parent 1eeb523 commit 256414a

File tree

5 files changed

+426
-0
lines changed

5 files changed

+426
-0
lines changed

src/llama-arch.cpp

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,7 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
9898
{ LLM_ARCH_LLADA, "llada" },
9999
{ LLM_ARCH_LLADA_MOE, "llada-moe" },
100100
{ LLM_ARCH_SEED_OSS, "seed_oss" },
101+
{ LLM_ARCH_MEGREZ_MOE, "megrez-moe" },
101102
{ LLM_ARCH_UNKNOWN, "(unknown)" },
102103
};
103104

@@ -2185,6 +2186,31 @@ static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_N
21852186
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
21862187
},
21872188
},
2189+
{
2190+
LLM_ARCH_MEGREZ_MOE,
2191+
{
2192+
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
2193+
{ LLM_TENSOR_OUTPUT_NORM, "output_norm" },
2194+
{ LLM_TENSOR_OUTPUT, "output" },
2195+
{ LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
2196+
{ LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
2197+
{ LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
2198+
{ LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },
2199+
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
2200+
{ LLM_TENSOR_FFN_GATE_INP, "blk.%d.ffn_gate_inp" },
2201+
{ LLM_TENSOR_FFN_EXP_PROBS_B, "blk.%d.exp_probs_b" },
2202+
{ LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
2203+
{ LLM_TENSOR_FFN_GATE_EXPS, "blk.%d.ffn_gate_exps" },
2204+
{ LLM_TENSOR_FFN_DOWN_EXPS, "blk.%d.ffn_down_exps" },
2205+
{ LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" },
2206+
{ LLM_TENSOR_FFN_GATE_SHEXP, "blk.%d.ffn_gate_shexp" },
2207+
{ LLM_TENSOR_FFN_DOWN_SHEXP, "blk.%d.ffn_down_shexp" },
2208+
{ LLM_TENSOR_FFN_UP_SHEXP, "blk.%d.ffn_up_shexp" },
2209+
{ LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" },
2210+
{ LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
2211+
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
2212+
},
2213+
},
21882214
{
21892215
LLM_ARCH_UNKNOWN,
21902216
{

src/llama-arch.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,7 @@ enum llm_arch {
102102
LLM_ARCH_LLADA,
103103
LLM_ARCH_LLADA_MOE,
104104
LLM_ARCH_SEED_OSS,
105+
LLM_ARCH_MEGREZ_MOE,
105106
LLM_ARCH_UNKNOWN,
106107
};
107108

src/llama-graph.cpp

Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1063,6 +1063,100 @@ ggml_tensor * llm_graph_context::build_moe_ffn(
10631063
return moe_out;
10641064
}
10651065

1066+
ggml_tensor * llm_graph_context::build_mergez_moe_ffn(
1067+
ggml_tensor * cur,
1068+
ggml_tensor * hidden_state,
1069+
ggml_tensor * gate_inp,
1070+
ggml_tensor * exp_probs_b,
1071+
ggml_tensor * up_exps,
1072+
ggml_tensor * gate_exps,
1073+
ggml_tensor * down_exps,
1074+
int64_t n_expert,
1075+
int64_t n_expert_used,
1076+
int il) const {
1077+
const int64_t n_embd = cur->ne[0];
1078+
const int64_t n_tokens = cur->ne[1];
1079+
1080+
ggml_tensor * logits = build_lora_mm(gate_inp, hidden_state); // [n_expert, n_tokens]
1081+
cb(logits, "ffn_moe_logits", il);
1082+
1083+
ggml_tensor * normalized_logits = nullptr;
1084+
ggml_tensor * probs = nullptr;
1085+
if (exp_probs_b) {
1086+
// For Megrez: sigmoid THEN add bias (not the other way around!)
1087+
normalized_logits = ggml_sigmoid(ctx0, logits); // [n_expert, n_tokens]
1088+
cb(normalized_logits, "ffn_moe_logits_normalize", il);
1089+
probs = ggml_add(ctx0, normalized_logits, exp_probs_b); // Add bias AFTER sigmoid
1090+
cb(probs, "ffn_moe_probs", il);
1091+
} else {
1092+
probs = ggml_soft_max(ctx0, logits); // [n_expert, n_tokens]
1093+
}
1094+
1095+
// select experts
1096+
ggml_tensor * selected_experts = ggml_top_k(ctx0, probs, n_expert_used); // [n_expert_used, n_tokens]
1097+
cb(selected_experts->src[0], "ffn_moe_argsort", il);
1098+
cb(selected_experts, "ffn_moe_topk", il);
1099+
1100+
ggml_tensor * weights = nullptr;
1101+
if (exp_probs_b) {
1102+
ggml_tensor * weight0s = ggml_get_rows(ctx0,
1103+
ggml_reshape_3d(ctx0, normalized_logits, 1, n_expert, n_tokens), selected_experts); // [1, n_expert_used, n_tokens]
1104+
cb(weight0s, "ffn_moe_weights0", il);
1105+
weight0s = ggml_reshape_2d(ctx0, weight0s, n_expert_used, n_tokens);
1106+
ggml_tensor * weights_sum = ggml_sum_rows(ctx0, weight0s); // [1, n_tokens]
1107+
cb(weights_sum, "ffn_moe_weights0_sum", il);
1108+
weights = ggml_div(ctx0, weight0s, weights_sum); // [n_expert_used, n_tokens]
1109+
cb(weights, "ffn_moe_weights_norm", il);
1110+
weights = ggml_reshape_3d(ctx0, weights, 1, n_expert_used, n_tokens);
1111+
} else {
1112+
weights = ggml_get_rows(ctx0,
1113+
ggml_reshape_3d(ctx0, probs, 1, n_expert, n_tokens), selected_experts); // [1, n_expert_used, n_tokens]
1114+
cb(weights, "ffn_moe_weights", il);
1115+
}
1116+
1117+
cur = ggml_reshape_3d(ctx0, cur, n_embd, 1, n_tokens);
1118+
1119+
ggml_tensor * up = build_lora_mm_id(up_exps, cur, selected_experts); // [n_ff, n_expert_used, n_tokens]
1120+
cb(up, "ffn_moe_up", il);
1121+
1122+
ggml_tensor * gate = build_lora_mm_id(gate_exps, cur, selected_experts); // [n_ff, n_expert_used, n_tokens]
1123+
cb(gate, "ffn_moe_gate", il);
1124+
1125+
gate = ggml_silu(ctx0, gate);
1126+
cb(gate, "ffn_moe_silu", il);
1127+
1128+
ggml_tensor * par = ggml_mul(ctx0, up, gate); // [n_ff, n_expert_used, n_tokens]
1129+
cb(par, "ffn_moe_gate_par", il);
1130+
1131+
ggml_tensor * experts = build_lora_mm_id(down_exps, par, selected_experts); // [n_embd, n_expert_used, n_tokens]
1132+
cb(experts, "ffn_moe_down", il);
1133+
1134+
experts = ggml_mul(ctx0, experts, weights);
1135+
cb(experts, "ffn_moe_weighted", il);
1136+
1137+
// aggregate experts
1138+
ggml_tensor * moe_out = nullptr;
1139+
for (int i = 0; i < n_expert_used; ++i) {
1140+
ggml_tensor * cur_expert = ggml_view_2d(ctx0, experts, n_embd, n_tokens,
1141+
experts->nb[2], i*experts->nb[1]);
1142+
1143+
if (i == 0) {
1144+
moe_out = cur_expert;
1145+
} else {
1146+
moe_out = ggml_add(ctx0, moe_out, cur_expert);
1147+
}
1148+
}
1149+
1150+
if (n_expert_used == 1) {
1151+
// avoid returning a non-contiguous tensor
1152+
moe_out = ggml_cont(ctx0, moe_out);
1153+
}
1154+
1155+
cb(moe_out, "ffn_moe_out", il);
1156+
1157+
return moe_out;
1158+
}
1159+
10661160
// input embeddings with optional lora
10671161
ggml_tensor * llm_graph_context::build_inp_embd(ggml_tensor * tok_embd) const {
10681162
const int64_t n_embd = hparams.n_embd;

src/llama-graph.h

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -667,6 +667,19 @@ struct llm_graph_context {
667667
int il,
668668
ggml_tensor * probs_in = nullptr) const;
669669

670+
// build Megrez MoE FFN (special gating with sigmoid + bias)
671+
ggml_tensor * build_mergez_moe_ffn(
672+
ggml_tensor * cur,
673+
ggml_tensor * hidden_state,
674+
ggml_tensor * gate_inp,
675+
ggml_tensor * exp_probs_b,
676+
ggml_tensor * up_exps,
677+
ggml_tensor * gate_exps,
678+
ggml_tensor * down_exps,
679+
int64_t n_expert,
680+
int64_t n_expert_used,
681+
int il) const;
682+
670683
//
671684
// inputs
672685
//

0 commit comments

Comments
 (0)