Skip to content

Commit c1dacaa

Browse files
authored
llama : merge build_moe_ffn_from_probs function into build_moe_ffn (#14968)
1 parent a9f77a8 commit c1dacaa

File tree

3 files changed

+32
-114
lines changed

3 files changed

+32
-114
lines changed

src/llama-graph.cpp

Lines changed: 18 additions & 97 deletions
Original file line numberDiff line numberDiff line change
@@ -785,13 +785,20 @@ ggml_tensor * llm_graph_context::build_moe_ffn(
785785
bool scale_w,
786786
float w_scale,
787787
llama_expert_gating_func_type gating_op,
788-
int il) const {
788+
int il,
789+
ggml_tensor * probs_in) const {
789790
const int64_t n_embd = cur->ne[0];
790791
const int64_t n_tokens = cur->ne[1];
791792
const bool weight_before_ffn = arch == LLM_ARCH_LLAMA4; // for llama4, we apply the sigmoid-ed weights before the FFN
792793

793-
ggml_tensor * logits = build_lora_mm(gate_inp, cur); // [n_expert, n_tokens]
794-
cb(logits, "ffn_moe_logits", il);
794+
ggml_tensor * logits = nullptr;
795+
796+
if (probs_in == nullptr) {
797+
logits = build_lora_mm(gate_inp, cur); // [n_expert, n_tokens]
798+
cb(logits, "ffn_moe_logits", il);
799+
} else {
800+
logits = probs_in;
801+
}
795802

796803
ggml_tensor * probs = nullptr;
797804
switch (gating_op) {
@@ -884,6 +891,14 @@ ggml_tensor * llm_graph_context::build_moe_ffn(
884891
cur = ggml_gelu(ctx0, cur);
885892
cb(cur, "ffn_moe_gelu", il);
886893
} break;
894+
case LLM_FFN_RELU:
895+
if (gate_exps) {
896+
cur = ggml_reglu_split(ctx0, cur, up);
897+
cb(cur, "ffn_moe_reglu", il);
898+
} else {
899+
cur = ggml_relu(ctx0, cur);
900+
cb(cur, "ffn_moe_relu", il);
901+
} break;
887902
default:
888903
GGML_ABORT("fatal error");
889904
}
@@ -927,100 +942,6 @@ ggml_tensor * llm_graph_context::build_moe_ffn(
927942
return moe_out;
928943
}
929944

930-
ggml_tensor * llm_graph_context::build_moe_ffn_from_probs(
931-
ggml_tensor * cur,
932-
ggml_tensor * probs,
933-
ggml_tensor * up_exps,
934-
ggml_tensor * gate_exps,
935-
ggml_tensor * down_exps,
936-
ggml_tensor * exp_probs_b,
937-
int64_t n_expert,
938-
int64_t n_expert_used,
939-
llama_expert_gating_func_type gating_op,
940-
int il) const {
941-
const int64_t n_embd = cur->ne[0];
942-
const int64_t n_tokens = cur->ne[1];
943-
944-
// add experts selection bias - introduced in DeepSeek V3
945-
// leave probs unbiased as it's later used to get expert weights
946-
ggml_tensor * selection_probs = probs;
947-
if (exp_probs_b != nullptr) {
948-
selection_probs = ggml_add(ctx0, probs, exp_probs_b);
949-
cb(selection_probs, "ffn_moe_probs_biased", il);
950-
}
951-
952-
// select experts
953-
ggml_tensor * selected_experts = ggml_top_k(ctx0, selection_probs, n_expert_used); // [n_expert_used, n_tokens]
954-
cb(selected_experts->src[0], "ffn_moe_argsort", il);
955-
cb(selected_experts, "ffn_moe_topk", il);
956-
957-
ggml_tensor * weights = ggml_get_rows(ctx0,
958-
ggml_reshape_3d(ctx0, probs, 1, n_expert, n_tokens), selected_experts); // [1, n_expert_used, n_tokens]
959-
cb(weights, "ffn_moe_weights", il);
960-
961-
weights = ggml_reshape_2d(ctx0, weights, n_expert_used, n_tokens);
962-
if (gating_op == LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX) {
963-
weights = ggml_soft_max(ctx0, weights);
964-
} else {
965-
weights = ggml_sigmoid(ctx0, weights);
966-
ggml_tensor * weights_sum = ggml_sum_rows(ctx0, weights); // [1, n_tokens]
967-
cb(weights_sum, "ffn_moe_weights_sum", il);
968-
969-
weights = ggml_div(ctx0, weights, weights_sum); // [n_expert_used, n_tokens]
970-
cb(weights, "ffn_moe_weights_norm", il);
971-
}
972-
973-
weights = ggml_reshape_3d(ctx0, weights, 1, n_expert_used, n_tokens);
974-
975-
cur = ggml_reshape_3d(ctx0, cur, n_embd, 1, n_tokens);
976-
977-
ggml_tensor * up = build_lora_mm_id(up_exps, cur, selected_experts); // [n_ff, n_expert_used, n_tokens]
978-
cb(up, "ffn_moe_up", il);
979-
980-
ggml_tensor * experts = nullptr;
981-
cur = build_lora_mm_id(gate_exps, cur, selected_experts); // [n_ff, n_expert_used, n_tokens]
982-
cb(cur, "ffn_moe_gate", il);
983-
984-
cur = ggml_reglu_split(ctx0, cur, up);
985-
cb(cur, "ffn_moe_reglu", il);
986-
987-
experts = build_lora_mm_id(down_exps, cur, selected_experts); // [n_embd, n_expert_used, n_tokens]
988-
cb(experts, "ffn_moe_down", il);
989-
990-
experts = ggml_mul(ctx0, experts, weights);
991-
cb(cur, "ffn_moe_weighted", il);
992-
993-
ggml_tensor * cur_experts[LLAMA_MAX_EXPERTS] = { nullptr };
994-
995-
assert(n_expert_used > 0);
996-
997-
// order the views before the adds
998-
for (uint32_t i = 0; i < hparams.n_expert_used; ++i) {
999-
cur_experts[i] = ggml_view_2d(ctx0, experts, n_embd, n_tokens, experts->nb[2], i*experts->nb[1]);
1000-
1001-
ggml_build_forward_expand(gf, cur_experts[i]);
1002-
}
1003-
1004-
// aggregate experts
1005-
// note: here we explicitly use hparams.n_expert_used instead of n_expert_used
1006-
// to avoid potentially a large number of add nodes during warmup
1007-
// ref: https://github.com/ggml-org/llama.cpp/pull/14753
1008-
ggml_tensor * moe_out = cur_experts[0];
1009-
1010-
for (uint32_t i = 1; i < hparams.n_expert_used; ++i) {
1011-
moe_out = ggml_add(ctx0, moe_out, cur_experts[i]);
1012-
}
1013-
1014-
if (n_expert_used == 1) {
1015-
// avoid returning a non-contiguous tensor
1016-
moe_out = ggml_cont(ctx0, moe_out);
1017-
}
1018-
1019-
cb(moe_out, "ffn_moe_out", il);
1020-
1021-
return moe_out;
1022-
}
1023-
1024945
// input embeddings with optional lora
1025946
ggml_tensor * llm_graph_context::build_inp_embd(ggml_tensor * tok_embd) const {
1026947
const int64_t n_embd = hparams.n_embd;

src/llama-graph.h

Lines changed: 2 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -631,19 +631,8 @@ struct llm_graph_context {
631631
bool scale_w,
632632
float w_scale,
633633
llama_expert_gating_func_type gating_op,
634-
int il) const;
635-
636-
ggml_tensor * build_moe_ffn_from_probs(
637-
ggml_tensor * cur,
638-
ggml_tensor * probs,
639-
ggml_tensor * up_exps,
640-
ggml_tensor * gate_exps,
641-
ggml_tensor * down_exps,
642-
ggml_tensor * exp_probs_b,
643-
int64_t n_expert,
644-
int64_t n_expert_used,
645-
llama_expert_gating_func_type gating_op,
646-
int il) const;
634+
int il,
635+
ggml_tensor * probs_in = nullptr) const;
647636

648637
//
649638
// inputs

src/llama-model.cpp

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17320,10 +17320,18 @@ struct llm_build_smallthinker : public llm_graph_context{
1732017320
cur = build_norm(ffn_inp, model.layers[il].ffn_norm, NULL, LLM_NORM_RMS, il);
1732117321
cb(cur, "ffn_norm", il);
1732217322

17323-
ggml_tensor * ffn_out = build_moe_ffn_from_probs(cur, probs, model.layers[il].ffn_up_exps,
17324-
model.layers[il].ffn_gate_exps, model.layers[il].ffn_down_exps,
17325-
nullptr, n_expert, n_expert_used,
17326-
static_cast<llama_expert_gating_func_type>(hparams.expert_gating_func), il);
17323+
ggml_tensor * ffn_out =
17324+
build_moe_ffn(cur,
17325+
nullptr,
17326+
model.layers[il].ffn_up_exps,
17327+
model.layers[il].ffn_gate_exps,
17328+
model.layers[il].ffn_down_exps,
17329+
nullptr,
17330+
n_expert, n_expert_used,
17331+
LLM_FFN_RELU, true,
17332+
false, 0.0,
17333+
static_cast<llama_expert_gating_func_type>(hparams.expert_gating_func),
17334+
il, probs);
1732717335

1732817336
cb(ffn_out, "ffn_out", il);
1732917337
cur = ffn_out;

0 commit comments

Comments
 (0)