Skip to content

Commit 053234f

Browse files
tamarPaltamarPal
authored andcommitted
refactor: use standard build_moe_ffn instead of custom build_mergez_moe_ffn
- Remove custom build_mergez_moe_ffn implementation (100+ lines) - Use existing build_moe_ffn with LLAMA_EXPERT_GATING_FUNC_TYPE_SIGMOID - Pre-compute gate logits from pre_gate_hidden (Megrez-MoE's unique gating) - Pass pre-computed logits via probs_in parameter - Maintain exact same behavior and output quality This addresses review feedback to reuse existing MoE infrastructure instead of duplicating code. The sigmoid gating + bias after activation is already supported by build_moe_ffn.
1 parent c00b183 commit 053234f

File tree

3 files changed

+17
-111
lines changed

3 files changed

+17
-111
lines changed

src/llama-graph.cpp

Lines changed: 0 additions & 94 deletions
Original file line numberDiff line numberDiff line change
@@ -1140,100 +1140,6 @@ ggml_tensor * llm_graph_context::build_moe_ffn(
11401140
return moe_out;
11411141
}
11421142

1143-
ggml_tensor * llm_graph_context::build_mergez_moe_ffn(
1144-
ggml_tensor * cur,
1145-
ggml_tensor * hidden_state,
1146-
ggml_tensor * gate_inp,
1147-
ggml_tensor * exp_probs_b,
1148-
ggml_tensor * up_exps,
1149-
ggml_tensor * gate_exps,
1150-
ggml_tensor * down_exps,
1151-
int64_t n_expert,
1152-
int64_t n_expert_used,
1153-
int il) const {
1154-
const int64_t n_embd = cur->ne[0];
1155-
const int64_t n_tokens = cur->ne[1];
1156-
1157-
ggml_tensor * logits = build_lora_mm(gate_inp, hidden_state); // [n_expert, n_tokens]
1158-
cb(logits, "ffn_moe_logits", il);
1159-
1160-
ggml_tensor * normalized_logits = nullptr;
1161-
ggml_tensor * probs = nullptr;
1162-
if (exp_probs_b) {
1163-
// For Megrez: sigmoid THEN add bias (not the other way around!)
1164-
normalized_logits = ggml_sigmoid(ctx0, logits); // [n_expert, n_tokens]
1165-
cb(normalized_logits, "ffn_moe_logits_normalize", il);
1166-
probs = ggml_add(ctx0, normalized_logits, exp_probs_b); // Add bias AFTER sigmoid
1167-
cb(probs, "ffn_moe_probs", il);
1168-
} else {
1169-
probs = ggml_soft_max(ctx0, logits); // [n_expert, n_tokens]
1170-
}
1171-
1172-
// select experts
1173-
ggml_tensor * selected_experts = ggml_top_k(ctx0, probs, n_expert_used); // [n_expert_used, n_tokens]
1174-
cb(selected_experts->src[0], "ffn_moe_argsort", il);
1175-
cb(selected_experts, "ffn_moe_topk", il);
1176-
1177-
ggml_tensor * weights = nullptr;
1178-
if (exp_probs_b) {
1179-
ggml_tensor * weight0s = ggml_get_rows(ctx0,
1180-
ggml_reshape_3d(ctx0, normalized_logits, 1, n_expert, n_tokens), selected_experts); // [1, n_expert_used, n_tokens]
1181-
cb(weight0s, "ffn_moe_weights0", il);
1182-
weight0s = ggml_reshape_2d(ctx0, weight0s, n_expert_used, n_tokens);
1183-
ggml_tensor * weights_sum = ggml_sum_rows(ctx0, weight0s); // [1, n_tokens]
1184-
cb(weights_sum, "ffn_moe_weights0_sum", il);
1185-
weights = ggml_div(ctx0, weight0s, weights_sum); // [n_expert_used, n_tokens]
1186-
cb(weights, "ffn_moe_weights_norm", il);
1187-
weights = ggml_reshape_3d(ctx0, weights, 1, n_expert_used, n_tokens);
1188-
} else {
1189-
weights = ggml_get_rows(ctx0,
1190-
ggml_reshape_3d(ctx0, probs, 1, n_expert, n_tokens), selected_experts); // [1, n_expert_used, n_tokens]
1191-
cb(weights, "ffn_moe_weights", il);
1192-
}
1193-
1194-
cur = ggml_reshape_3d(ctx0, cur, n_embd, 1, n_tokens);
1195-
1196-
ggml_tensor * up = build_lora_mm_id(up_exps, cur, selected_experts); // [n_ff, n_expert_used, n_tokens]
1197-
cb(up, "ffn_moe_up", il);
1198-
1199-
ggml_tensor * gate = build_lora_mm_id(gate_exps, cur, selected_experts); // [n_ff, n_expert_used, n_tokens]
1200-
cb(gate, "ffn_moe_gate", il);
1201-
1202-
gate = ggml_silu(ctx0, gate);
1203-
cb(gate, "ffn_moe_silu", il);
1204-
1205-
ggml_tensor * par = ggml_mul(ctx0, up, gate); // [n_ff, n_expert_used, n_tokens]
1206-
cb(par, "ffn_moe_gate_par", il);
1207-
1208-
ggml_tensor * experts = build_lora_mm_id(down_exps, par, selected_experts); // [n_embd, n_expert_used, n_tokens]
1209-
cb(experts, "ffn_moe_down", il);
1210-
1211-
experts = ggml_mul(ctx0, experts, weights);
1212-
cb(experts, "ffn_moe_weighted", il);
1213-
1214-
// aggregate experts
1215-
ggml_tensor * moe_out = nullptr;
1216-
for (int i = 0; i < n_expert_used; ++i) {
1217-
ggml_tensor * cur_expert = ggml_view_2d(ctx0, experts, n_embd, n_tokens,
1218-
experts->nb[2], i*experts->nb[1]);
1219-
1220-
if (i == 0) {
1221-
moe_out = cur_expert;
1222-
} else {
1223-
moe_out = ggml_add(ctx0, moe_out, cur_expert);
1224-
}
1225-
}
1226-
1227-
if (n_expert_used == 1) {
1228-
// avoid returning a non-contiguous tensor
1229-
moe_out = ggml_cont(ctx0, moe_out);
1230-
}
1231-
1232-
cb(moe_out, "ffn_moe_out", il);
1233-
1234-
return moe_out;
1235-
}
1236-
12371143
// input embeddings with optional lora
12381144
ggml_tensor * llm_graph_context::build_inp_embd(ggml_tensor * tok_embd) const {
12391145
const int64_t n_embd = hparams.n_embd;

src/llama-graph.h

Lines changed: 0 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -672,19 +672,6 @@ struct llm_graph_context {
672672
int il,
673673
ggml_tensor * probs_in = nullptr) const;
674674

675-
// build Megrez MoE FFN (special gating with sigmoid + bias)
676-
ggml_tensor * build_mergez_moe_ffn(
677-
ggml_tensor * cur,
678-
ggml_tensor * hidden_state,
679-
ggml_tensor * gate_inp,
680-
ggml_tensor * exp_probs_b,
681-
ggml_tensor * up_exps,
682-
ggml_tensor * gate_exps,
683-
ggml_tensor * down_exps,
684-
int64_t n_expert,
685-
int64_t n_expert_used,
686-
int il) const;
687-
688675
//
689676
// inputs
690677
//

src/models/megrez-moe.cpp

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -163,14 +163,27 @@ llm_build_megrez_moe::llm_build_megrez_moe(const llama_model & model, const llm_
163163
cb(cur, "ffn_out", il);
164164
} else {
165165
// MoE branch
166-
ggml_tensor * moe_out = build_mergez_moe_ffn(cur,
167-
pre_gate_hidden,
168-
model.layers[il].ffn_gate_inp, model.layers[il].ffn_exp_probs_b,
166+
// Note: Megrez-MoE uses pre_gate_hidden (from previous layer's FFN norm) for gating
167+
// This is different from standard MoE which uses current layer's input
168+
// Compute gate logits from pre_gate_hidden instead of cur
169+
ggml_tensor * gate_logits = build_lora_mm(model.layers[il].ffn_gate_inp, pre_gate_hidden);
170+
cb(gate_logits, "ffn_moe_logits", il);
171+
172+
// Use standard build_moe_ffn but with pre-computed gate logits
173+
ggml_tensor * moe_out = build_moe_ffn(cur,
174+
model.layers[il].ffn_gate_inp,
169175
model.layers[((il - 1) / (3) * (3)) + 1].ffn_up_exps,
170176
model.layers[((il - 1) / (3) * (3)) + 1].ffn_gate_exps,
171177
model.layers[((il - 1) / (3) * (3)) + 1].ffn_down_exps,
178+
model.layers[il].ffn_exp_probs_b,
172179
n_expert, n_expert_used,
173-
il);
180+
LLM_FFN_SILU,
181+
true, // norm_w
182+
false, // scale_w
183+
1.0f, // w_scale
184+
LLAMA_EXPERT_GATING_FUNC_TYPE_SIGMOID,
185+
il,
186+
gate_logits); // Use pre-computed logits from pre_gate_hidden
174187
cb(moe_out, "ffn_moe_out", il);
175188

176189
pre_gate_hidden = cur;

0 commit comments

Comments
 (0)