diff --git a/common/arg.cpp b/common/arg.cpp index 0f01bb31454a4..16cf915e3de20 100644 --- a/common/arg.cpp +++ b/common/arg.cpp @@ -2405,6 +2405,63 @@ common_params_context common_params_parser_init(common_params & params, llama_ex } } ).set_env("LLAMA_ARG_N_CPU_MOE")); + add_opt(common_arg( + {"--num-experts"}, "N", + "Override the number of experts to use for MoE models (default: 0 = use model's default)", + [](common_params & params, int value) { + params.num_experts = value; + } + )); + add_opt(common_arg( + {"--omit-experts"}, "IDs", + "comma-separated list of expert indices to omit from MoE selection (e.g. 1,3,5 or 1-5,7)", + [](common_params & params, const std::string & value) { + params.omit_experts.clear(); + auto parts = string_split(value, ','); + for (const auto& part : parts) { + if (part.find('-') != std::string::npos) { + // Parse range (e.g., "1-5") + auto range = string_split(part, '-'); + if (range.size() == 2 && range[0] <= range[1]) { + for (int32_t i = range[0]; i <= range[1]; ++i) { + params.omit_experts.push_back(i); + } + } + } else { + params.omit_experts.push_back(std::stoi(part)); + } + } + + // Sort and remove duplicates for efficient processing later + std::sort(params.omit_experts.begin(), params.omit_experts.end()); + params.omit_experts.erase(std::unique(params.omit_experts.begin(), params.omit_experts.end()), params.omit_experts.end()); + } + )); + add_opt(common_arg( + {"--force-experts"}, "IDs", + "comma-separated list of expert indices to always use in MoE selection (e.g. 1,3,5 or 1-5,7)", + [](common_params & params, const std::string & value) { + params.force_experts.clear(); + auto parts = string_split(value, ','); + for (const auto& part : parts) { + if (part.find('-') != std::string::npos) { + // Parse range (e.g., "1-5") + auto range = string_split(part, '-'); + if (range.size() == 2 && range[0] <= range[1]) { + for (int32_t i = range[0]; i <= range[1]; ++i) { + params.force_experts.push_back(i); + } + } + } else { + params.force_experts.push_back(std::stoi(part)); + } + } + + // Sort and remove duplicates for efficient processing later + std::sort(params.force_experts.begin(), params.force_experts.end()); + params.force_experts.erase(std::unique(params.force_experts.begin(), params.force_experts.end()), params.force_experts.end()); + } + )); add_opt(common_arg( {"-ngl", "--gpu-layers", "--n-gpu-layers"}, "N", "number of layers to store in VRAM", diff --git a/common/common.cpp b/common/common.cpp index c6962d1d19b33..6b61a1168e90e 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -1130,6 +1130,7 @@ struct llama_model_params common_model_params_to_llama(common_params & params) { GGML_ASSERT(params.kv_overrides.back().key[0] == 0 && "KV overrides not terminated with empty key"); mparams.kv_overrides = params.kv_overrides.data(); } + mparams.n_expert_used_override = params.num_experts; if (params.tensor_buft_overrides.empty()) { mparams.tensor_buft_overrides = NULL; @@ -1178,6 +1179,10 @@ struct llama_context_params common_context_params_to_llama(const common_params & cparams.type_k = params.cache_type_k; cparams.type_v = params.cache_type_v; + cparams.num_experts = params.num_experts; + cparams.omit_experts = params.omit_experts; + cparams.force_experts = params.force_experts; + return cparams; } diff --git a/common/common.h b/common/common.h index 5eab199af559e..a871326ebab9a 100644 --- a/common/common.h +++ b/common/common.h @@ -467,6 +467,11 @@ struct common_params { // return false from callback to abort model loading or true to continue llama_progress_callback load_progress_callback = NULL; void * load_progress_callback_user_data = NULL; + + // MoE expert selection + int32_t num_experts = 0; // number of experts to use, 0 = model defined + std::vector omit_experts; // comma-separated list of expert indices to omit + std::vector force_experts; // comma-separated list of expert indices to force }; // call once at the start of a program if it uses libcommon diff --git a/include/llama.h b/include/llama.h index 545e957e5f52b..f224e4253f3f9 100644 --- a/include/llama.h +++ b/include/llama.h @@ -1,15 +1,17 @@ #ifndef LLAMA_H #define LLAMA_H -#include "ggml.h" -#include "ggml-cpu.h" #include "ggml-backend.h" +#include "ggml-cpu.h" #include "ggml-opt.h" +#include "ggml.h" +#include #include #include #include -#include + +#include #ifdef LLAMA_SHARED # if defined(_WIN32) && !defined(__MINGW32__) @@ -283,6 +285,7 @@ extern "C" { // override key-value pairs of the model meta data const struct llama_model_kv_override * kv_overrides; + int32_t n_expert_used_override; // number of expert overrides, 0 = no overrides // Keep the booleans together to avoid misalignment during copy-by-value. bool vocab_only; // only load the vocabulary, no weights @@ -340,6 +343,10 @@ extern "C" { bool kv_unified; // use a unified buffer across the input sequences when computing the attention // try to disable when n_seq_max > 1 for improved performance when the sequences do not share a large prefix // ref: https://github.com/ggml-org/llama.cpp/pull/14363 + // MoE expert selection + int32_t num_experts; // number of experts to use, 0 = model defined + std::vector omit_experts; // comma-separated list of expert indices to omit + std::vector force_experts; // comma-separated list of expert indices to force }; // model quantization parameters diff --git a/src/llama-context.cpp b/src/llama-context.cpp index 26a5cf9c3f8db..ffd486c634e4c 100644 --- a/src/llama-context.cpp +++ b/src/llama-context.cpp @@ -102,6 +102,8 @@ llama_context::llama_context( cparams.op_offload = params.op_offload; cparams.kv_unified = params.kv_unified; + cparams.omit_experts = params.omit_experts; + cparams.force_experts = params.force_experts; { const char * LLAMA_SET_ROWS = getenv("LLAMA_SET_ROWS"); @@ -2269,6 +2271,7 @@ llama_context_params llama_context_default_params() { /*.op_offload =*/ true, /*.swa_full =*/ true, /*.kv_unified =*/ false, + /*.omit_experts =*/ {}, }; return result; diff --git a/src/llama-cparams.h b/src/llama-cparams.h index 38750affc500b..991e6b64abe7d 100644 --- a/src/llama-cparams.h +++ b/src/llama-cparams.h @@ -3,6 +3,7 @@ #include "llama.h" #include +#include #define LLAMA_MAX_SEQ 64 @@ -26,6 +27,9 @@ struct llama_cparams { float yarn_beta_slow; float defrag_thold; + std::vector omit_experts; + std::vector force_experts; + bool embeddings; bool causal_attn; bool offload_kqv; diff --git a/src/llama-graph.cpp b/src/llama-graph.cpp index 053c72d6dc8d1..7e9aca692b55a 100644 --- a/src/llama-graph.cpp +++ b/src/llama-graph.cpp @@ -258,6 +258,39 @@ void llm_graph_input_cross_embd::set_input(const llama_ubatch * ubatch) { } } +void llm_graph_input_expert_mask::set_input(const llama_ubatch * ubatch) { + if (mask == nullptr || (cparams.omit_experts.empty() && cparams.force_experts.empty())) { + return; + } + GGML_UNUSED(ubatch); + + const int64_t n_expert = mask->ne[0]; + + GGML_ASSERT(ggml_backend_buffer_is_host(mask->buffer)); + float * data = (float *) mask->data; + + std::fill(data, data + n_expert, 0.0f); + + for (int32_t expert_idx : cparams.omit_experts) { + if (expert_idx >= 0 && expert_idx < n_expert) { + data[expert_idx] = -INFINITY; + } + } + for (int32_t expert_idx : cparams.force_experts) { + if (expert_idx >= 0 && expert_idx < n_expert) { + data[expert_idx] = INFINITY; + } + } +} + +bool llm_graph_input_expert_mask::can_reuse(const llm_graph_params & params) { + bool res = true; + res &= mask->ne[0] == params.hparams.n_expert; + res &= cparams.omit_experts == params.cparams.omit_experts; + res &= cparams.force_experts == params.cparams.force_experts; + return res; +} + void llm_graph_input_attn_no_cache::set_input(const llama_ubatch * ubatch) { const int64_t n_kv = ubatch->n_tokens; const int64_t n_tokens = ubatch->n_tokens; @@ -787,6 +820,7 @@ ggml_tensor * llm_graph_context::build_moe_ffn( bool scale_w, float w_scale, llama_expert_gating_func_type gating_op, + ggml_tensor * expert_mask, int il, ggml_tensor * probs_in) const { return build_moe_ffn( @@ -803,6 +837,7 @@ ggml_tensor * llm_graph_context::build_moe_ffn( scale_w, w_scale, gating_op, + expert_mask, il, probs_in ); @@ -826,6 +861,7 @@ ggml_tensor * llm_graph_context::build_moe_ffn( bool scale_w, float w_scale, llama_expert_gating_func_type gating_op, + ggml_tensor * expert_mask, int il, ggml_tensor * probs_in) const { const int64_t n_embd = cur->ne[0]; @@ -879,6 +915,12 @@ ggml_tensor * llm_graph_context::build_moe_ffn( selection_probs = logits; } + // Omit or force specified experts by adding a mask of -INF/INF respectively + if (expert_mask != nullptr) { + selection_probs = ggml_add(ctx0, selection_probs, expert_mask); + cb(selection_probs, "ffn_moe_probs_masked", il); + } + // select experts ggml_tensor * selected_experts = ggml_top_k(ctx0, selection_probs, n_expert_used); // [n_expert_used, n_tokens] cb(selected_experts->src[0], "ffn_moe_argsort", il); @@ -1352,6 +1394,14 @@ llm_graph_input_attn_no_cache * llm_graph_context::build_attn_inp_no_cache() con return (llm_graph_input_attn_no_cache *) res->add_input(std::move(inp)); } +llm_graph_input_expert_mask * llm_graph_context::build_inp_expert_mask() const { + auto inp = std::make_unique(cparams); + auto & cur = inp->mask; + cur = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, hparams.n_expert); + ggml_set_input(cur); + return (llm_graph_input_expert_mask *) res->add_input(std::move(inp)); +} + ggml_tensor * llm_graph_context::build_attn( llm_graph_input_attn_no_cache * inp, ggml_tensor * wo, diff --git a/src/llama-graph.h b/src/llama-graph.h index 6ff49de3a1ce8..a054d6bf4dc29 100644 --- a/src/llama-graph.h +++ b/src/llama-graph.h @@ -238,6 +238,20 @@ class llm_graph_input_cross_embd : public llm_graph_input_i { const llama_cross * cross; }; +class llm_graph_input_expert_mask : public llm_graph_input_i { + public: + llm_graph_input_expert_mask(const llama_cparams & cparams) : cparams(cparams) {} + + virtual ~llm_graph_input_expert_mask() = default; + + void set_input(const llama_ubatch * ubatch) override; + bool can_reuse(const llm_graph_params & params) override; + + ggml_tensor * mask = nullptr; // F32 [n_expert] + + const llama_cparams & cparams; +}; + class llm_graph_input_attn_no_cache : public llm_graph_input_i { public: llm_graph_input_attn_no_cache(const llama_hparams & hparams, const llama_cparams & cparams) : @@ -635,6 +649,7 @@ struct llm_graph_context { bool scale_w, float w_scale, llama_expert_gating_func_type gating_op, + ggml_tensor * expert_mask, int il, ggml_tensor * probs_in = nullptr) const; @@ -656,6 +671,7 @@ struct llm_graph_context { bool scale_w, float w_scale, llama_expert_gating_func_type gating_op, + ggml_tensor * expert_mask, int il, ggml_tensor * probs_in = nullptr) const; @@ -814,6 +830,8 @@ struct llm_graph_context { ggml_tensor * cls_b, ggml_tensor * cls_out, ggml_tensor * cls_out_b) const; + + llm_graph_input_expert_mask * build_inp_expert_mask() const; }; // TODO: better name diff --git a/src/llama-model.cpp b/src/llama-model.cpp index 58ca7df707ef3..2b0b4a355326c 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -5986,6 +5986,8 @@ struct llm_build_llama : public llm_graph_context { ggml_tensor * inp_out_ids = build_inp_out_ids(); + llm_graph_input_expert_mask * inp_expert_mask = build_inp_expert_mask(); + for (int il = 0; il < n_layer; ++il) { ggml_tensor * inpSA = inpL; @@ -6088,6 +6090,7 @@ struct llm_build_llama : public llm_graph_context { LLM_FFN_SILU, true, false, 0.0, LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX, + inp_expert_mask->mask, il); cb(cur, "ffn_moe_out", il); } @@ -6146,6 +6149,8 @@ struct llm_build_llama_iswa : public llm_graph_context { ggml_tensor * inp_out_ids = build_inp_out_ids(); + llm_graph_input_expert_mask * inp_expert_mask = build_inp_expert_mask(); + for (int il = 0; il < n_layer; ++il) { ggml_tensor * inpSA = inpL; @@ -6260,6 +6265,7 @@ struct llm_build_llama_iswa : public llm_graph_context { LLM_FFN_SILU, false, false, 0.0, LLAMA_EXPERT_GATING_FUNC_TYPE_SIGMOID, + inp_expert_mask->mask, il); // Shared experts @@ -6839,6 +6845,8 @@ struct llm_build_grok : public llm_graph_context { ggml_tensor * inp_out_ids = build_inp_out_ids(); + llm_graph_input_expert_mask * inp_expert_mask = build_inp_expert_mask(); + for (int il = 0; il < n_layer; ++il) { ggml_tensor * inpSA = inpL; @@ -6932,6 +6940,7 @@ struct llm_build_grok : public llm_graph_context { LLM_FFN_GELU, true, false, 0.0, LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX, + inp_expert_mask->mask, il); cb(cur, "ffn_moe_out", il); @@ -6999,6 +7008,8 @@ struct llm_build_dbrx : public llm_graph_context { ggml_tensor * inp_out_ids = build_inp_out_ids(); + llm_graph_input_expert_mask * inp_expert_mask = build_inp_expert_mask(); + for (int il = 0; il < n_layer; ++il) { ggml_tensor * inpSA = inpL; @@ -7072,6 +7083,7 @@ struct llm_build_dbrx : public llm_graph_context { LLM_FFN_SILU, true, false, 0.0, LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX, + inp_expert_mask->mask, il); cb(cur, "ffn_moe_out", il); @@ -7348,6 +7360,8 @@ struct llm_build_bert : public llm_graph_context { ggml_tensor * inp_out_ids = build_inp_out_ids(); + llm_graph_input_expert_mask * inp_expert_mask = build_inp_expert_mask(); + for (int il = 0; il < n_layer; ++il) { ggml_tensor * cur = inpL; @@ -7451,7 +7465,9 @@ struct llm_build_bert : public llm_graph_context { LLM_FFN_GELU, false, false, 0.0f, - LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX, il); + LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX, + inp_expert_mask->mask, + il); cb(cur, "ffn_moe_out", il); } else if (model.arch == LLM_ARCH_BERT || model.arch == LLM_ARCH_NOMIC_BERT_MOE) { cur = build_ffn(cur, @@ -8593,6 +8609,8 @@ struct llm_build_qwen2moe : public llm_graph_context { ggml_tensor * inp_out_ids = build_inp_out_ids(); + llm_graph_input_expert_mask * inp_expert_mask = build_inp_expert_mask(); + for (int il = 0; il < n_layer; ++il) { ggml_tensor * inpSA = inpL; @@ -8676,6 +8694,7 @@ struct llm_build_qwen2moe : public llm_graph_context { LLM_FFN_SILU, false, false, 0.0, LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX, + inp_expert_mask->mask, il); cb(moe_out, "ffn_moe_out", il); @@ -8873,6 +8892,8 @@ struct llm_build_qwen3moe : public llm_graph_context { ggml_tensor * inp_out_ids = build_inp_out_ids(); + llm_graph_input_expert_mask * inp_expert_mask = build_inp_expert_mask(); + for (int il = 0; il < n_layer; ++il) { ggml_tensor * inpSA = inpL; @@ -8950,6 +8971,7 @@ struct llm_build_qwen3moe : public llm_graph_context { LLM_FFN_SILU, true, false, 0.0, LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX, + inp_expert_mask->mask, il); cb(moe_out, "ffn_moe_out", il); cur = moe_out; @@ -9139,6 +9161,8 @@ struct llm_build_phi3 : public llm_graph_context { ggml_tensor * inp_out_ids = build_inp_out_ids(); + llm_graph_input_expert_mask * inp_expert_mask = build_inp_expert_mask(); + for (int il = 0; il < n_layer; ++il) { auto * residual = inpL; @@ -9236,6 +9260,7 @@ struct llm_build_phi3 : public llm_graph_context { LLM_FFN_SILU, true, false, 0.0, LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX, + inp_expert_mask->mask, il); cb(cur, "ffn_moe_out", il); } @@ -11349,6 +11374,8 @@ struct llm_build_jamba : public llm_graph_context_mamba { ggml_tensor * inp_out_ids = build_inp_out_ids(); + llm_graph_input_expert_mask * inp_expert_mask = build_inp_expert_mask(); + for (int il = 0; il < n_layer; ++il) { const int64_t n_head_kv = hparams.n_head_kv(il); @@ -11414,6 +11441,7 @@ struct llm_build_jamba : public llm_graph_context_mamba { LLM_FFN_SILU, false, false, 0.0, LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX, + inp_expert_mask->mask, il); cb(cur, "ffn_moe_out", il); } @@ -12003,6 +12031,8 @@ struct llm_build_olmoe : public llm_graph_context { ggml_tensor * inp_out_ids = build_inp_out_ids(); + llm_graph_input_expert_mask * inp_expert_mask = build_inp_expert_mask(); + for (int il = 0; il < n_layer; ++il) { ggml_tensor * inpSA = inpL; @@ -12081,6 +12111,7 @@ struct llm_build_olmoe : public llm_graph_context { LLM_FFN_SILU, false, false, 0.0, LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX, + inp_expert_mask->mask, il); cb(cur, "ffn_moe_out", il); @@ -12406,6 +12437,8 @@ struct llm_build_arctic : public llm_graph_context { ggml_tensor * inp_out_ids = build_inp_out_ids(); + llm_graph_input_expert_mask * inp_expert_mask = build_inp_expert_mask(); + for (int il = 0; il < n_layer; ++il) { ggml_tensor * inpSA = inpL; @@ -12493,6 +12526,7 @@ struct llm_build_arctic : public llm_graph_context { LLM_FFN_SILU, true, false, 0.0, LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX, + inp_expert_mask->mask, il); cb(cur, "ffn_moe_out", il); @@ -12546,6 +12580,8 @@ struct llm_build_deepseek : public llm_graph_context { ggml_tensor * inp_out_ids = build_inp_out_ids(); + llm_graph_input_expert_mask * inp_expert_mask = build_inp_expert_mask(); + for (int il = 0; il < n_layer; ++il) { ggml_tensor * inpSA = inpL; @@ -12641,6 +12677,7 @@ struct llm_build_deepseek : public llm_graph_context { LLM_FFN_SILU, false, false, hparams.expert_weights_scale, LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX, + inp_expert_mask->mask, il); cb(moe_out, "ffn_moe_out", il); @@ -12721,6 +12758,8 @@ struct llm_build_deepseek2 : public llm_graph_context { ggml_tensor * inp_out_ids = build_inp_out_ids(); + llm_graph_input_expert_mask * inp_expert_mask = build_inp_expert_mask(); + for (int il = 0; il < n_layer; ++il) { ggml_tensor * inpSA = inpL; @@ -12904,6 +12943,7 @@ struct llm_build_deepseek2 : public llm_graph_context { LLM_FFN_SILU, hparams.expert_weights_norm, true, hparams.expert_weights_scale, (llama_expert_gating_func_type) hparams.expert_gating_func, + inp_expert_mask->mask, il); cb(moe_out, "ffn_moe_out", il); @@ -13778,6 +13818,8 @@ struct llm_build_glm4_moe : public llm_graph_context { ggml_tensor * inp_out_ids = build_inp_out_ids(); + llm_graph_input_expert_mask * inp_expert_mask = build_inp_expert_mask(); + // Only process up to last layer (skip final NextN layer) // Final layer tensors are loaded but not processed in forward pass const int n_transformer_layers = n_layer - hparams.nextn_predict_layers; @@ -13877,6 +13919,7 @@ struct llm_build_glm4_moe : public llm_graph_context { LLM_FFN_SILU, hparams.expert_weights_norm, true, hparams.expert_weights_scale, (llama_expert_gating_func_type) hparams.expert_gating_func, + inp_expert_mask->mask, il); cb(routed_out, "ffn_moe_out", il); @@ -15242,6 +15285,7 @@ struct llm_build_granite : public llm_graph_context { LLM_FFN_SILU, true, false, 0.0, LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX, + nullptr, il); cb(moe_out, "ffn_moe_out", il); @@ -15461,6 +15505,7 @@ struct llm_build_granite_hybrid : public llm_graph_context_mamba { LLM_FFN_SILU, true, false, 0.0, LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX, + nullptr, il); cb(moe_out, "ffn_moe_out", il); @@ -16016,6 +16061,8 @@ struct llm_build_bailingmoe : public llm_graph_context { ggml_tensor * inp_out_ids = build_inp_out_ids(); + llm_graph_input_expert_mask * inp_expert_mask = build_inp_expert_mask(); + for (int il = 0; il < n_layer; ++il) { ggml_tensor * inpSA = inpL; @@ -16101,6 +16148,7 @@ struct llm_build_bailingmoe : public llm_graph_context { LLM_FFN_SILU, hparams.expert_weights_norm, false, hparams.expert_weights_scale, LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX, + inp_expert_mask->mask, il); cb(moe_out, "ffn_moe_out", il); @@ -16165,6 +16213,8 @@ struct llm_build_dots1 : public llm_graph_context { ggml_tensor * inp_out_ids = build_inp_out_ids(); + llm_graph_input_expert_mask * inp_expert_mask = build_inp_expert_mask(); + for (int il = 0; il < n_layer; ++il) { ggml_tensor * inpSA = inpL; @@ -16251,6 +16301,7 @@ struct llm_build_dots1 : public llm_graph_context { LLM_FFN_SILU, hparams.expert_weights_norm, true, hparams.expert_weights_scale, (llama_expert_gating_func_type) hparams.expert_gating_func, + inp_expert_mask->mask, il); cb(moe_out, "ffn_moe_out", il); @@ -16445,6 +16496,8 @@ struct llm_build_ernie4_5_moe : public llm_graph_context { ggml_tensor * inp_out_ids = build_inp_out_ids(); + llm_graph_input_expert_mask * inp_expert_mask = build_inp_expert_mask(); + GGML_ASSERT(hparams.n_moe_layer_step > 0 && "Ernie 4.5 MoE requires n_moe_layer_step > 0"); for (int il = 0; il < n_layer; ++il) { ggml_tensor * inpSA = inpL; @@ -16547,6 +16600,7 @@ struct llm_build_ernie4_5_moe : public llm_graph_context { LLM_FFN_SILU, true, false, 0.0, LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX, + inp_expert_mask->mask, il); cb(moe_out, "ffn_moe_out", il); @@ -17191,6 +17245,8 @@ struct llm_build_hunyuan_moe : public llm_graph_context { ggml_tensor * inp_out_ids = build_inp_out_ids(); + llm_graph_input_expert_mask * inp_expert_mask = build_inp_expert_mask(); + for (int il = 0; il < n_layer; ++il) { ggml_tensor * inpSA = inpL; @@ -17298,6 +17354,7 @@ struct llm_build_hunyuan_moe : public llm_graph_context { false, 0.0, LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX, + inp_expert_mask->mask, il); cb(cur_moe, "ffn_moe_out", il); @@ -17618,6 +17675,8 @@ struct llm_build_openai_moe_iswa : public llm_graph_context { auto * inp_attn = build_attn_inp_kv_unified_iswa(); + llm_graph_input_expert_mask * inp_expert_mask = build_inp_expert_mask(); + for (int il = 0; il < n_layer; ++il) { ggml_tensor * inpSA = inpL; @@ -17705,6 +17764,7 @@ struct llm_build_openai_moe_iswa : public llm_graph_context { LLM_FFN_SWIGLU_OAI_MOE, false, false, 0.0, LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX_WEIGHT, + inp_expert_mask->mask, il); cb(cur, "ffn_moe_out", il); @@ -17940,6 +18000,8 @@ struct llm_build_smallthinker : public llm_graph_context{ ggml_tensor * inp_out_ids = build_inp_out_ids(); + llm_graph_input_expert_mask * inp_expert_mask = build_inp_expert_mask(); + for (int il = 0; il < n_layer; ++il) { ggml_tensor * inpSA = inpL; ggml_tensor * probs = nullptr; @@ -18007,6 +18069,7 @@ struct llm_build_smallthinker : public llm_graph_context{ LLM_FFN_RELU, true, false, 0.0, static_cast(hparams.expert_gating_func), + inp_expert_mask->mask, il, probs); cb(ffn_out, "ffn_out", il); @@ -18524,6 +18587,7 @@ llama_model_params llama_model_default_params() { /*.progress_callback =*/ nullptr, /*.progress_callback_user_data =*/ nullptr, /*.kv_overrides =*/ nullptr, + /*.n_experts_used_override =*/ 0, /*.vocab_only =*/ false, /*.use_mmap =*/ true, /*.use_mlock =*/ false, diff --git a/src/llama.cpp b/src/llama.cpp index 34906cdb62844..9c3c7bf415d92 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -109,6 +109,20 @@ static int llama_model_load(const std::string & fname, std::vector } catch(const std::exception & e) { throw std::runtime_error("error loading model hyperparameters: " + std::string(e.what())); } + if (params.n_expert_used_override > 0) { + if (model.hparams.n_expert == 0) { + LLAMA_LOG_WARN("%s: --num-experts is set to %d, but the model is not a Mixture-of-Experts model. Ignoring.\n", + __func__, params.n_expert_used_override); + } else if (params.n_expert_used_override > (int32_t)model.hparams.n_expert) { + LLAMA_LOG_WARN("%s: --num-experts is set to %d, which is greater than the total number of experts available in the model (%u). Clamping to %u.\n", + __func__, params.n_expert_used_override, model.hparams.n_expert, model.hparams.n_expert); + model.hparams.n_expert_used = model.hparams.n_expert; + } else { + LLAMA_LOG_INFO("%s: Overriding n_expert_used from %u to %d.\n", + __func__, model.hparams.n_expert_used, params.n_expert_used_override); + model.hparams.n_expert_used = params.n_expert_used_override; + } + } try { model.load_vocab(ml); } catch(const std::exception & e) {