Skip to content

Commit 63b7b5e

Browse files
wdl339Nexesenex
authored andcommitted
llama : merge build_moe_ffn_from_probs function into build_moe_ffn (ggml-org#14968)
1 parent 465efe1 commit 63b7b5e

File tree

3 files changed

+52
-114
lines changed

3 files changed

+52
-114
lines changed

src/llama-graph.cpp

Lines changed: 38 additions & 97 deletions
Original file line numberDiff line numberDiff line change
@@ -804,13 +804,20 @@ ggml_tensor * llm_graph_context::build_moe_ffn(
804804
bool scale_w,
805805
float w_scale,
806806
llama_expert_gating_func_type gating_op,
807-
int il) const {
807+
int il,
808+
ggml_tensor * probs_in) const {
808809
const int64_t n_embd = cur->ne[0];
809810
const int64_t n_tokens = cur->ne[1];
810811
const bool weight_before_ffn = arch == LLM_ARCH_LLAMA4; // for llama4, we apply the sigmoid-ed weights before the FFN
811812

812-
ggml_tensor * logits = build_lora_mm(gate_inp, cur); // [n_expert, n_tokens]
813-
cb(logits, "ffn_moe_logits", il);
813+
ggml_tensor * logits = nullptr;
814+
815+
if (probs_in == nullptr) {
816+
logits = build_lora_mm(gate_inp, cur); // [n_expert, n_tokens]
817+
cb(logits, "ffn_moe_logits", il);
818+
} else {
819+
logits = probs_in;
820+
}
814821

815822
ggml_tensor * probs = nullptr;
816823
switch (gating_op) {
@@ -930,6 +937,34 @@ ggml_tensor * llm_graph_context::build_moe_ffn(
930937
// GGML_ABORT("fatal error");
931938
// }
932939

940+
// switch (type_op) { // NEW 2
941+
// case LLM_FFN_SILU:
942+
// if (gate_exps) {
943+
// cur = ggml_swiglu_split(ctx0, cur, up);
944+
// cb(cur, "ffn_moe_swiglu", il);
945+
// } else {
946+
// cur = ggml_silu(ctx0, cur);
947+
// cb(cur, "ffn_moe_silu", il);
948+
// } break;
949+
// case LLM_FFN_GELU:
950+
// if (gate_exps) {
951+
// cur = ggml_geglu_split(ctx0, cur, up);
952+
// cb(cur, "ffn_moe_geglu", il);
953+
// } else {
954+
// cur = ggml_gelu(ctx0, cur);
955+
// cb(cur, "ffn_moe_gelu", il);
956+
// } break;
957+
// case LLM_FFN_RELU:
958+
// if (gate_exps) {
959+
// cur = ggml_reglu_split(ctx0, cur, up);
960+
// cb(cur, "ffn_moe_reglu", il);
961+
// } else {
962+
// cur = ggml_relu(ctx0, cur);
963+
// cb(cur, "ffn_moe_relu", il);
964+
// } break;
965+
// default:
966+
// GGML_ABORT("fatal error");
967+
933968
// ggml_tensor * parent = ggml_fused_mul_unary(ctx0, cur, up, type_op == LLM_FFN_SILU ? GGML_GLU_OP_SWIGLU : GGML_GLU_OP_GEGLU); // NEW
934969

935970

@@ -1029,100 +1064,6 @@ ggml_tensor * llm_graph_context::build_moe_ffn(
10291064
//return moe_out;
10301065
}
10311066

1032-
ggml_tensor * llm_graph_context::build_moe_ffn_from_probs(
1033-
ggml_tensor * cur,
1034-
ggml_tensor * probs,
1035-
ggml_tensor * up_exps,
1036-
ggml_tensor * gate_exps,
1037-
ggml_tensor * down_exps,
1038-
ggml_tensor * exp_probs_b,
1039-
int64_t n_expert,
1040-
int64_t n_expert_used,
1041-
llama_expert_gating_func_type gating_op,
1042-
int il) const {
1043-
const int64_t n_embd = cur->ne[0];
1044-
const int64_t n_tokens = cur->ne[1];
1045-
1046-
// add experts selection bias - introduced in DeepSeek V3
1047-
// leave probs unbiased as it's later used to get expert weights
1048-
ggml_tensor * selection_probs = probs;
1049-
if (exp_probs_b != nullptr) {
1050-
selection_probs = ggml_add(ctx0, probs, exp_probs_b);
1051-
cb(selection_probs, "ffn_moe_probs_biased", il);
1052-
}
1053-
1054-
// select experts
1055-
ggml_tensor * selected_experts = ggml_top_k(ctx0, selection_probs, n_expert_used); // [n_expert_used, n_tokens]
1056-
cb(selected_experts->src[0], "ffn_moe_argsort", il);
1057-
cb(selected_experts, "ffn_moe_topk", il);
1058-
1059-
ggml_tensor * weights = ggml_get_rows(ctx0,
1060-
ggml_reshape_3d(ctx0, probs, 1, n_expert, n_tokens), selected_experts); // [1, n_expert_used, n_tokens]
1061-
cb(weights, "ffn_moe_weights", il);
1062-
1063-
weights = ggml_reshape_2d(ctx0, weights, n_expert_used, n_tokens);
1064-
if (gating_op == LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX) {
1065-
weights = ggml_soft_max(ctx0, weights);
1066-
} else {
1067-
weights = ggml_sigmoid(ctx0, weights);
1068-
ggml_tensor * weights_sum = ggml_sum_rows(ctx0, weights); // [1, n_tokens]
1069-
cb(weights_sum, "ffn_moe_weights_sum", il);
1070-
1071-
weights = ggml_div(ctx0, weights, weights_sum); // [n_expert_used, n_tokens]
1072-
cb(weights, "ffn_moe_weights_norm", il);
1073-
}
1074-
1075-
weights = ggml_reshape_3d(ctx0, weights, 1, n_expert_used, n_tokens);
1076-
1077-
cur = ggml_reshape_3d(ctx0, cur, n_embd, 1, n_tokens);
1078-
1079-
ggml_tensor * up = build_lora_mm_id(up_exps, cur, selected_experts); // [n_ff, n_expert_used, n_tokens]
1080-
cb(up, "ffn_moe_up", il);
1081-
1082-
ggml_tensor * experts = nullptr;
1083-
cur = build_lora_mm_id(gate_exps, cur, selected_experts); // [n_ff, n_expert_used, n_tokens]
1084-
cb(cur, "ffn_moe_gate", il);
1085-
1086-
cur = ggml_reglu_split(ctx0, cur, up);
1087-
cb(cur, "ffn_moe_reglu", il);
1088-
1089-
experts = build_lora_mm_id(down_exps, cur, selected_experts); // [n_embd, n_expert_used, n_tokens]
1090-
cb(experts, "ffn_moe_down", il);
1091-
1092-
experts = ggml_mul(ctx0, experts, weights);
1093-
cb(cur, "ffn_moe_weighted", il);
1094-
1095-
ggml_tensor * cur_experts[LLAMA_MAX_EXPERTS] = { nullptr };
1096-
1097-
assert(n_expert_used > 0);
1098-
1099-
// order the views before the adds
1100-
for (uint32_t i = 0; i < hparams.n_expert_used; ++i) {
1101-
cur_experts[i] = ggml_view_2d(ctx0, experts, n_embd, n_tokens, experts->nb[2], i*experts->nb[1]);
1102-
1103-
ggml_build_forward_expand(gf, cur_experts[i]);
1104-
}
1105-
1106-
// aggregate experts
1107-
// note: here we explicitly use hparams.n_expert_used instead of n_expert_used
1108-
// to avoid potentially a large number of add nodes during warmup
1109-
// ref: https://github.com/ggml-org/llama.cpp/pull/14753
1110-
ggml_tensor * moe_out = cur_experts[0];
1111-
1112-
for (uint32_t i = 1; i < hparams.n_expert_used; ++i) {
1113-
moe_out = ggml_add(ctx0, moe_out, cur_experts[i]);
1114-
}
1115-
1116-
if (n_expert_used == 1) {
1117-
// avoid returning a non-contiguous tensor
1118-
moe_out = ggml_cont(ctx0, moe_out);
1119-
}
1120-
1121-
cb(moe_out, "ffn_moe_out", il);
1122-
1123-
return moe_out;
1124-
}
1125-
11261067
// input embeddings with optional lora
11271068
ggml_tensor * llm_graph_context::build_inp_embd(ggml_tensor * tok_embd) const {
11281069
const int64_t n_embd = hparams.n_embd;

src/llama-graph.h

Lines changed: 2 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -631,19 +631,8 @@ struct llm_graph_context {
631631
bool scale_w,
632632
float w_scale,
633633
llama_expert_gating_func_type gating_op,
634-
int il) const;
635-
636-
ggml_tensor * build_moe_ffn_from_probs(
637-
ggml_tensor * cur,
638-
ggml_tensor * probs,
639-
ggml_tensor * up_exps,
640-
ggml_tensor * gate_exps,
641-
ggml_tensor * down_exps,
642-
ggml_tensor * exp_probs_b,
643-
int64_t n_expert,
644-
int64_t n_expert_used,
645-
llama_expert_gating_func_type gating_op,
646-
int il) const;
634+
int il,
635+
ggml_tensor * probs_in = nullptr) const;
647636

648637
//
649638
// inputs

src/llama-model.cpp

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17729,10 +17729,18 @@ struct llm_build_smallthinker : public llm_graph_context{
1772917729
cur = build_norm(ffn_inp, model.layers[il].ffn_norm, NULL, LLM_NORM_RMS, il);
1773017730
cb(cur, "ffn_norm", il);
1773117731

17732-
ggml_tensor * ffn_out = build_moe_ffn_from_probs(cur, probs, model.layers[il].ffn_up_exps,
17733-
model.layers[il].ffn_gate_exps, model.layers[il].ffn_down_exps,
17734-
nullptr, n_expert, n_expert_used,
17735-
static_cast<llama_expert_gating_func_type>(hparams.expert_gating_func), il);
17732+
ggml_tensor * ffn_out =
17733+
build_moe_ffn(cur,
17734+
nullptr,
17735+
model.layers[il].ffn_up_exps,
17736+
model.layers[il].ffn_gate_exps,
17737+
model.layers[il].ffn_down_exps,
17738+
nullptr,
17739+
n_expert, n_expert_used,
17740+
LLM_FFN_RELU, true,
17741+
false, 0.0,
17742+
static_cast<llama_expert_gating_func_type>(hparams.expert_gating_func),
17743+
il, probs);
1773617744

1773717745
cb(ffn_out, "ffn_out", il);
1773817746
cur = ffn_out;

0 commit comments

Comments
 (0)