@@ -947,14 +947,10 @@ ggml_tensor * llm_graph_context::build_moe_ffn_from_probs(
947947 ggml_tensor * exp_probs_b,
948948 int64_t n_expert,
949949 int64_t n_expert_used,
950- llm_ffn_op_type type_op,
951- bool norm_w,
952- bool scale_w,
953- float w_scale,
950+ llama_expert_gating_func_type gating_op,
954951 int il) const {
955952 const int64_t n_embd = cur->ne [0 ];
956953 const int64_t n_tokens = cur->ne [1 ];
957- const bool weight_before_ffn = arch == LLM_ARCH_LLAMA4; // for llama4, we apply the sigmoid-ed weights before the FFN
958954
959955 // add experts selection bias - introduced in DeepSeek V3
960956 // leave probs unbiased as it's later used to get expert weights
@@ -973,90 +969,57 @@ ggml_tensor * llm_graph_context::build_moe_ffn_from_probs(
973969 ggml_reshape_3d (ctx0, probs, 1 , n_expert, n_tokens), selected_experts); // [1, n_expert_used, n_tokens]
974970 cb (weights, " ffn_moe_weights" , il);
975971
976- if (norm_w) {
977- weights = ggml_reshape_2d (ctx0, weights, n_expert_used, n_tokens);
978-
972+ weights = ggml_reshape_2d (ctx0, weights, n_expert_used, n_tokens);
973+ if (gating_op == LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX) {
974+ weights = ggml_soft_max (ctx0, weights);
975+ } else {
976+ weights = ggml_sigmoid (ctx0, weights);
979977 ggml_tensor * weights_sum = ggml_sum_rows (ctx0, weights); // [1, n_tokens]
980978 cb (weights_sum, " ffn_moe_weights_sum" , il);
981979
982980 weights = ggml_div (ctx0, weights, weights_sum); // [n_expert_used, n_tokens]
983981 cb (weights, " ffn_moe_weights_norm" , il);
984-
985- weights = ggml_reshape_3d (ctx0, weights, 1 , n_expert_used, n_tokens);
986- }
987- if (scale_w) {
988- weights = ggml_scale (ctx0, weights, w_scale);
989- cb (weights, " ffn_moe_weights_scaled" , il);
990982 }
991983
992- cur = ggml_reshape_3d (ctx0, cur, n_embd, 1 , n_tokens);
984+ weights = ggml_reshape_3d (ctx0, weights, 1 , n_expert_used , n_tokens);
993985
994- if (weight_before_ffn) {
995- // repeat cur to [n_embd, n_expert_used, n_tokens]
996- ggml_tensor * repeated = ggml_repeat_4d (ctx0, cur, n_embd, n_expert_used, n_tokens, 1 );
997- cur = ggml_mul (ctx0, repeated, weights);
998- cb (cur, " ffn_moe_weighted" , il);
999- }
986+ cur = ggml_reshape_3d (ctx0, cur, n_embd, 1 , n_tokens);
1000987
1001988 ggml_tensor * up = build_lora_mm_id (up_exps, cur, selected_experts); // [n_ff, n_expert_used, n_tokens]
1002989 cb (up, " ffn_moe_up" , il);
1003990
1004991 ggml_tensor * experts = nullptr ;
1005- if (gate_exps) {
1006- cur = build_lora_mm_id (gate_exps, cur, selected_experts); // [n_ff, n_expert_used, n_tokens]
1007- cb (cur, " ffn_moe_gate" , il);
1008- } else {
1009- cur = up;
1010- }
992+ cur = build_lora_mm_id (gate_exps, cur, selected_experts); // [n_ff, n_expert_used, n_tokens]
993+ cb (cur, " ffn_moe_gate" , il);
1011994
1012- switch (type_op) {
1013- case LLM_FFN_SILU:
1014- if (gate_exps) {
1015- cur = ggml_swiglu_split (ctx0, cur, up);
1016- cb (cur, " ffn_moe_swiglu" , il);
1017- } else {
1018- cur = ggml_silu (ctx0, cur);
1019- cb (cur, " ffn_moe_silu" , il);
1020- } break ;
1021- case LLM_FFN_GELU:
1022- if (gate_exps) {
1023- cur = ggml_geglu_split (ctx0, cur, up);
1024- cb (cur, " ffn_moe_geglu" , il);
1025- } else {
1026- cur = ggml_gelu (ctx0, cur);
1027- cb (cur, " ffn_moe_gelu" , il);
1028- } break ;
1029- case LLM_FFN_RELU:
1030- if (gate_exps) {
1031- cur = ggml_reglu_split (ctx0, cur, up);
1032- cb (cur, " ffn_moe_reglu" , il);
1033- } else {
1034- cur = ggml_relu (ctx0, cur);
1035- cb (cur, " ffn_moe_relu" , il);
1036- } break ;
1037- default :
1038- GGML_ABORT (" fatal error" );
1039- }
995+ cur = ggml_reglu_split (ctx0, cur, up);
996+ cb (cur, " ffn_moe_reglu" , il);
1040997
1041998 experts = build_lora_mm_id (down_exps, cur, selected_experts); // [n_embd, n_expert_used, n_tokens]
1042999 cb (experts, " ffn_moe_down" , il);
10431000
1044- if (!weight_before_ffn) {
1045- experts = ggml_mul (ctx0, experts, weights);
1046- cb (cur, " ffn_moe_weighted" , il);
1001+ experts = ggml_mul (ctx0, experts, weights);
1002+ cb (cur, " ffn_moe_weighted" , il);
1003+
1004+ ggml_tensor * cur_experts[LLAMA_MAX_EXPERTS] = { nullptr };
1005+
1006+ assert (n_expert_used > 0 );
1007+
1008+ // order the views before the adds
1009+ for (uint32_t i = 0 ; i < hparams.n_expert_used ; ++i) {
1010+ cur_experts[i] = ggml_view_2d (ctx0, experts, n_embd, n_tokens, experts->nb [2 ], i*experts->nb [1 ]);
1011+
1012+ ggml_build_forward_expand (gf, cur_experts[i]);
10471013 }
10481014
10491015 // aggregate experts
1050- ggml_tensor * moe_out = nullptr ;
1051- for ( int i = 0 ; i < n_expert_used; ++i) {
1052- ggml_tensor * cur_expert = ggml_view_2d (ctx0, experts, n_embd, n_tokens,
1053- experts-> nb [ 2 ], i*experts-> nb [ 1 ]) ;
1016+ // note: here we explicitly use hparams.n_expert_used instead of n_expert_used
1017+ // to avoid potentially a large number of add nodes during warmup
1018+ // ref: https://github.com/ggml-org/llama.cpp/pull/14753
1019+ ggml_tensor * moe_out = cur_experts[ 0 ] ;
10541020
1055- if (i == 0 ) {
1056- moe_out = cur_expert;
1057- } else {
1058- moe_out = ggml_add (ctx0, moe_out, cur_expert);
1059- }
1021+ for (uint32_t i = 1 ; i < hparams.n_expert_used ; ++i) {
1022+ moe_out = ggml_add (ctx0, moe_out, cur_experts[i]);
10601023 }
10611024
10621025 if (n_expert_used == 1 ) {
0 commit comments