Skip to content

Commit acd76d0

Browse files
tamarPaltamarPal
authored andcommitted
feat: Add Megrez-MoE architecture support
Implements complete support for Megrez-MoE (Mixture of Experts) models: - Add LLM_ARCH_MEGREZ_MOE architecture enum and mappings - Implement build_mergez_moe_ffn() with sigmoid+bias gating - Add llm_build_megrez_moe class for full model graph construction - Support 31-layer architecture (layer 0: dense FFN, layers 1-30: MoE) - Implement expert sharing pattern with 64 experts, 6 used per token, 4 shared - Load all model hyperparameters and 372 tensors correctly - Configure NEOX RoPE type for proper positional encoding Tested with Megrez2-3x7B-A3B_Q4_K_M.gguf model. All 39 llama.cpp tests pass successfully. Output verified to match infinigence/llama.cpp reference implementation. Note: Use --no-warmup flag to avoid warmup memory allocation issue.
1 parent 7f09a68 commit acd76d0

File tree

5 files changed

+205
-33
lines changed

5 files changed

+205
-33
lines changed

src/llama-arch.cpp

Lines changed: 23 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,7 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
108108
{ LLM_ARCH_MINIMAX_M2, "minimax-m2" },
109109
{ LLM_ARCH_COGVLM, "cogvlm" },
110110
{ LLM_ARCH_PANGU_EMBED, "pangu-embedded" },
111+
{ LLM_ARCH_MEGREZ_MOE, "megrez-moe" },
111112
{ LLM_ARCH_UNKNOWN, "(unknown)" },
112113
};
113114

@@ -2379,40 +2380,29 @@ static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_N
23792380
},
23802381
},
23812382
{
2382-
LLM_ARCH_PANGU_EMBED,
2383+
LLM_ARCH_MEGREZ_MOE,
23832384
{
2384-
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
2385-
{ LLM_TENSOR_OUTPUT_NORM, "output_norm" },
2386-
{ LLM_TENSOR_OUTPUT, "output" },
2387-
{ LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
2388-
{ LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
2389-
{ LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
2390-
{ LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },
2391-
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
2392-
{ LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
2393-
{ LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" },
2394-
{ LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
2395-
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
2396-
},
2397-
},
2398-
{
2399-
LLM_ARCH_COGVLM,
2400-
{
2401-
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
2402-
{ LLM_TENSOR_OUTPUT_NORM, "output_norm" },
2403-
{ LLM_TENSOR_OUTPUT, "output" },
2404-
{ LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
2405-
{ LLM_TENSOR_ATTN_QKV, "blk.%d.attn_qkv" },
2406-
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
2407-
{ LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
2408-
{ LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" },
2409-
{ LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
2410-
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
2411-
{ LLM_TENSOR_VISEXP_ATTN_QKV, "blk.%d.vis_attn_qkv" },
2412-
{ LLM_TENSOR_VISEXP_ATTN_OUT, "blk.%d.vis_attn_output" },
2413-
{ LLM_TENSOR_VISEXP_FFN_GATE, "blk.%d.vis_gate" },
2414-
{ LLM_TENSOR_VISEXP_FFN_DOWN, "blk.%d.vis_down" },
2415-
{ LLM_TENSOR_VISEXP_FFN_UP, "blk.%d.vis_up" },
2385+
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
2386+
{ LLM_TENSOR_OUTPUT_NORM, "output_norm" },
2387+
{ LLM_TENSOR_OUTPUT, "output" },
2388+
{ LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
2389+
{ LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
2390+
{ LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
2391+
{ LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },
2392+
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
2393+
{ LLM_TENSOR_FFN_GATE_INP, "blk.%d.ffn_gate_inp" },
2394+
{ LLM_TENSOR_FFN_EXP_PROBS_B, "blk.%d.exp_probs_b" },
2395+
{ LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
2396+
{ LLM_TENSOR_FFN_GATE_EXPS, "blk.%d.ffn_gate_exps" },
2397+
{ LLM_TENSOR_FFN_DOWN_EXPS, "blk.%d.ffn_down_exps" },
2398+
{ LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" },
2399+
{ LLM_TENSOR_FFN_GATE_SHEXP, "blk.%d.ffn_gate_shexp" },
2400+
{ LLM_TENSOR_FFN_DOWN_SHEXP, "blk.%d.ffn_down_shexp" },
2401+
{ LLM_TENSOR_FFN_UP_SHEXP, "blk.%d.ffn_up_shexp" },
2402+
{ LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" },
2403+
{ LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
2404+
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
2405+
>>>>>>> 256414a18 (feat: Add Megrez-MoE architecture support)
24162406
},
24172407
},
24182408
{

src/llama-arch.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -112,6 +112,7 @@ enum llm_arch {
112112
LLM_ARCH_MINIMAX_M2,
113113
LLM_ARCH_COGVLM,
114114
LLM_ARCH_PANGU_EMBED,
115+
LLM_ARCH_MEGREZ_MOE,
115116
LLM_ARCH_UNKNOWN,
116117
};
117118

src/llama-graph.cpp

Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1140,6 +1140,100 @@ ggml_tensor * llm_graph_context::build_moe_ffn(
11401140
return moe_out;
11411141
}
11421142

1143+
ggml_tensor * llm_graph_context::build_mergez_moe_ffn(
1144+
ggml_tensor * cur,
1145+
ggml_tensor * hidden_state,
1146+
ggml_tensor * gate_inp,
1147+
ggml_tensor * exp_probs_b,
1148+
ggml_tensor * up_exps,
1149+
ggml_tensor * gate_exps,
1150+
ggml_tensor * down_exps,
1151+
int64_t n_expert,
1152+
int64_t n_expert_used,
1153+
int il) const {
1154+
const int64_t n_embd = cur->ne[0];
1155+
const int64_t n_tokens = cur->ne[1];
1156+
1157+
ggml_tensor * logits = build_lora_mm(gate_inp, hidden_state); // [n_expert, n_tokens]
1158+
cb(logits, "ffn_moe_logits", il);
1159+
1160+
ggml_tensor * normalized_logits = nullptr;
1161+
ggml_tensor * probs = nullptr;
1162+
if (exp_probs_b) {
1163+
// For Megrez: sigmoid THEN add bias (not the other way around!)
1164+
normalized_logits = ggml_sigmoid(ctx0, logits); // [n_expert, n_tokens]
1165+
cb(normalized_logits, "ffn_moe_logits_normalize", il);
1166+
probs = ggml_add(ctx0, normalized_logits, exp_probs_b); // Add bias AFTER sigmoid
1167+
cb(probs, "ffn_moe_probs", il);
1168+
} else {
1169+
probs = ggml_soft_max(ctx0, logits); // [n_expert, n_tokens]
1170+
}
1171+
1172+
// select experts
1173+
ggml_tensor * selected_experts = ggml_top_k(ctx0, probs, n_expert_used); // [n_expert_used, n_tokens]
1174+
cb(selected_experts->src[0], "ffn_moe_argsort", il);
1175+
cb(selected_experts, "ffn_moe_topk", il);
1176+
1177+
ggml_tensor * weights = nullptr;
1178+
if (exp_probs_b) {
1179+
ggml_tensor * weight0s = ggml_get_rows(ctx0,
1180+
ggml_reshape_3d(ctx0, normalized_logits, 1, n_expert, n_tokens), selected_experts); // [1, n_expert_used, n_tokens]
1181+
cb(weight0s, "ffn_moe_weights0", il);
1182+
weight0s = ggml_reshape_2d(ctx0, weight0s, n_expert_used, n_tokens);
1183+
ggml_tensor * weights_sum = ggml_sum_rows(ctx0, weight0s); // [1, n_tokens]
1184+
cb(weights_sum, "ffn_moe_weights0_sum", il);
1185+
weights = ggml_div(ctx0, weight0s, weights_sum); // [n_expert_used, n_tokens]
1186+
cb(weights, "ffn_moe_weights_norm", il);
1187+
weights = ggml_reshape_3d(ctx0, weights, 1, n_expert_used, n_tokens);
1188+
} else {
1189+
weights = ggml_get_rows(ctx0,
1190+
ggml_reshape_3d(ctx0, probs, 1, n_expert, n_tokens), selected_experts); // [1, n_expert_used, n_tokens]
1191+
cb(weights, "ffn_moe_weights", il);
1192+
}
1193+
1194+
cur = ggml_reshape_3d(ctx0, cur, n_embd, 1, n_tokens);
1195+
1196+
ggml_tensor * up = build_lora_mm_id(up_exps, cur, selected_experts); // [n_ff, n_expert_used, n_tokens]
1197+
cb(up, "ffn_moe_up", il);
1198+
1199+
ggml_tensor * gate = build_lora_mm_id(gate_exps, cur, selected_experts); // [n_ff, n_expert_used, n_tokens]
1200+
cb(gate, "ffn_moe_gate", il);
1201+
1202+
gate = ggml_silu(ctx0, gate);
1203+
cb(gate, "ffn_moe_silu", il);
1204+
1205+
ggml_tensor * par = ggml_mul(ctx0, up, gate); // [n_ff, n_expert_used, n_tokens]
1206+
cb(par, "ffn_moe_gate_par", il);
1207+
1208+
ggml_tensor * experts = build_lora_mm_id(down_exps, par, selected_experts); // [n_embd, n_expert_used, n_tokens]
1209+
cb(experts, "ffn_moe_down", il);
1210+
1211+
experts = ggml_mul(ctx0, experts, weights);
1212+
cb(experts, "ffn_moe_weighted", il);
1213+
1214+
// aggregate experts
1215+
ggml_tensor * moe_out = nullptr;
1216+
for (int i = 0; i < n_expert_used; ++i) {
1217+
ggml_tensor * cur_expert = ggml_view_2d(ctx0, experts, n_embd, n_tokens,
1218+
experts->nb[2], i*experts->nb[1]);
1219+
1220+
if (i == 0) {
1221+
moe_out = cur_expert;
1222+
} else {
1223+
moe_out = ggml_add(ctx0, moe_out, cur_expert);
1224+
}
1225+
}
1226+
1227+
if (n_expert_used == 1) {
1228+
// avoid returning a non-contiguous tensor
1229+
moe_out = ggml_cont(ctx0, moe_out);
1230+
}
1231+
1232+
cb(moe_out, "ffn_moe_out", il);
1233+
1234+
return moe_out;
1235+
}
1236+
11431237
// input embeddings with optional lora
11441238
ggml_tensor * llm_graph_context::build_inp_embd(ggml_tensor * tok_embd) const {
11451239
const int64_t n_embd = hparams.n_embd;

src/llama-graph.h

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -672,6 +672,19 @@ struct llm_graph_context {
672672
int il,
673673
ggml_tensor * probs_in = nullptr) const;
674674

675+
// build Megrez MoE FFN (special gating with sigmoid + bias)
676+
ggml_tensor * build_mergez_moe_ffn(
677+
ggml_tensor * cur,
678+
ggml_tensor * hidden_state,
679+
ggml_tensor * gate_inp,
680+
ggml_tensor * exp_probs_b,
681+
ggml_tensor * up_exps,
682+
ggml_tensor * gate_exps,
683+
ggml_tensor * down_exps,
684+
int64_t n_expert,
685+
int64_t n_expert_used,
686+
int il) const;
687+
675688
//
676689
// inputs
677690
//

src/llama-model.cpp

Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2183,6 +2183,16 @@ void llama_model::load_hparams(llama_model_loader & ml) {
21832183
switch (hparams.n_layer) {
21842184
case 26: type = LLM_TYPE_1B; break; // openPangu-Embedded-1B-V1.1
21852185
case 34: type = LLM_TYPE_7B; break; // openPangu-Embedded-7B-V1.1
2186+
case LLM_ARCH_MEGREZ_MOE:
2187+
{
2188+
ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp);
2189+
ml.get_key(LLM_KV_EXPERT_SHARED_COUNT, hparams.n_expert_shared);
2190+
ml.get_key(LLM_KV_EXPERT_SHARED_FEED_FORWARD_LENGTH, hparams.n_ff_shexp);
2191+
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
2192+
ml.get_key(LLM_KV_EXPERT_GATING_FUNC, hparams.expert_gating_func, false);
2193+
2194+
switch (hparams.n_layer) {
2195+
case 31: type = LLM_TYPE_7B; break;
21862196
default: type = LLM_TYPE_UNKNOWN;
21872197
}
21882198
} break;
@@ -3338,6 +3348,65 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
33383348
layer.ffn_up_shexp = create_tensor(tn(LLM_TENSOR_FFN_UP_SHEXP, "weight", i), { n_embd, n_ff_shexp}, 0);
33393349
}
33403350
} break;
3351+
case LLM_ARCH_MEGREZ_MOE:
3352+
{
3353+
tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
3354+
3355+
// output
3356+
output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
3357+
output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, 0);
3358+
3359+
for (int i = 0; i < n_layer; ++i) {
3360+
auto & layer = layers[i];
3361+
3362+
layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0);
3363+
3364+
layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd}, 0);
3365+
layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_gqa}, 0);
3366+
layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa}, 0);
3367+
layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0);
3368+
3369+
layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0);
3370+
3371+
// Layer 0 is dense, layers 1-30 are MoE
3372+
if (i == 0) {
3373+
// Dense layer
3374+
layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0);
3375+
layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}, 0);
3376+
layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0);
3377+
} else {
3378+
// All MoE layers (1-30) have these
3379+
layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}, 0);
3380+
layer.ffn_exp_probs_b = create_tensor(tn(LLM_TENSOR_FFN_EXP_PROBS_B, i), {n_expert}, 0);
3381+
3382+
if (n_expert == 0) {
3383+
throw std::runtime_error("n_expert must be > 0 for MEGREZ_MOE");
3384+
}
3385+
if (n_expert_used == 0) {
3386+
throw std::runtime_error("n_expert_used must be > 0 for MEGREZ_MOE");
3387+
}
3388+
3389+
// All MoE layers have shared expert
3390+
const int64_t n_ff_shexp = hparams.n_ff_shexp;
3391+
layer.ffn_gate_shexp = create_tensor(tn(LLM_TENSOR_FFN_GATE_SHEXP, "weight", i), {n_embd, n_ff_shexp}, 0);
3392+
layer.ffn_down_shexp = create_tensor(tn(LLM_TENSOR_FFN_DOWN_SHEXP, "weight", i), {n_ff_shexp, n_embd}, 0);
3393+
layer.ffn_up_shexp = create_tensor(tn(LLM_TENSOR_FFN_UP_SHEXP, "weight", i), {n_embd, n_ff_shexp}, 0);
3394+
3395+
// Only layers 1, 4, 7, 10, 13, 16, 19, 22, 25, 28 have actual expert tensors
3396+
// Pattern: (i-1) % 3 == 0 for i > 0
3397+
if ((i - 1) % 3 == 0) {
3398+
// MoE branch - use the expert-specific FF size from hparams
3399+
const int64_t n_ff_exp = hparams.n_ff_exp;
3400+
3401+
layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert}, 0);
3402+
layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), {n_ff_exp, n_embd, n_expert}, 0);
3403+
layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert}, 0);
3404+
}
3405+
// Note: layers that share experts (2, 3, 5, 6, etc.) only have gate_inp and shared expert
3406+
// They will reference the regular experts from their corresponding "full" layer during inference
3407+
}
3408+
}
3409+
} break;
33413410
case LLM_ARCH_QWEN3:
33423411
case LLM_ARCH_QWEN3VL:
33433412
{
@@ -7178,6 +7247,10 @@ ggml_cgraph * llama_model::build_graph(const llm_graph_params & params) const {
71787247
{
71797248
llm = std::make_unique<llm_build_jais>(*this, params);
71807249
} break;
7250+
case LLM_ARCH_MEGREZ_MOE:
7251+
{
7252+
llm = std::make_unique<llm_build_megrez_moe>(*this, params);
7253+
} break;
71817254
case LLM_ARCH_NEMOTRON:
71827255
{
71837256
llm = std::make_unique<llm_build_nemotron>(*this, params);
@@ -7518,6 +7591,7 @@ llama_rope_type llama_model_rope_type(const llama_model * model) {
75187591
case LLM_ARCH_GPTNEOX:
75197592
case LLM_ARCH_CODESHELL:
75207593
case LLM_ARCH_ORION:
7594+
case LLM_ARCH_MEGREZ_MOE:
75217595
case LLM_ARCH_NEMOTRON:
75227596
case LLM_ARCH_EXAONE:
75237597
case LLM_ARCH_EXAONE4:

0 commit comments

Comments
 (0)