Skip to content

Commit 19951fb

Browse files
committed
expert selection
1 parent e725a1a commit 19951fb

File tree

10 files changed

+231
-4
lines changed

10 files changed

+231
-4
lines changed

common/arg.cpp

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2405,6 +2405,63 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
24052405
}
24062406
}
24072407
).set_env("LLAMA_ARG_N_CPU_MOE"));
2408+
add_opt(common_arg(
2409+
{"--num-experts"}, "N",
2410+
"Override the number of experts to use for MoE models (default: 0 = use model's default)",
2411+
[](common_params & params, int value) {
2412+
params.num_experts = value;
2413+
}
2414+
));
2415+
add_opt(common_arg(
2416+
{"--omit-experts"}, "IDs",
2417+
"comma-separated list of expert indices to omit from MoE selection (e.g. 1,3,5 or 1-5,7)",
2418+
[](common_params & params, const std::string & value) {
2419+
params.omit_experts.clear();
2420+
auto parts = string_split<std::string>(value, ',');
2421+
for (const auto& part : parts) {
2422+
if (part.find('-') != std::string::npos) {
2423+
// Parse range (e.g., "1-5")
2424+
auto range = string_split<int32_t>(part, '-');
2425+
if (range.size() == 2 && range[0] <= range[1]) {
2426+
for (int32_t i = range[0]; i <= range[1]; ++i) {
2427+
params.omit_experts.push_back(i);
2428+
}
2429+
}
2430+
} else {
2431+
params.omit_experts.push_back(std::stoi(part));
2432+
}
2433+
}
2434+
2435+
// Sort and remove duplicates for efficient processing later
2436+
std::sort(params.omit_experts.begin(), params.omit_experts.end());
2437+
params.omit_experts.erase(std::unique(params.omit_experts.begin(), params.omit_experts.end()), params.omit_experts.end());
2438+
}
2439+
));
2440+
add_opt(common_arg(
2441+
{"--force-experts"}, "IDs",
2442+
"comma-separated list of expert indices to always use in MoE selection (e.g. 1,3,5 or 1-5,7)",
2443+
[](common_params & params, const std::string & value) {
2444+
params.force_experts.clear();
2445+
auto parts = string_split<std::string>(value, ',');
2446+
for (const auto& part : parts) {
2447+
if (part.find('-') != std::string::npos) {
2448+
// Parse range (e.g., "1-5")
2449+
auto range = string_split<int32_t>(part, '-');
2450+
if (range.size() == 2 && range[0] <= range[1]) {
2451+
for (int32_t i = range[0]; i <= range[1]; ++i) {
2452+
params.force_experts.push_back(i);
2453+
}
2454+
}
2455+
} else {
2456+
params.force_experts.push_back(std::stoi(part));
2457+
}
2458+
}
2459+
2460+
// Sort and remove duplicates for efficient processing later
2461+
std::sort(params.force_experts.begin(), params.force_experts.end());
2462+
params.force_experts.erase(std::unique(params.force_experts.begin(), params.force_experts.end()), params.force_experts.end());
2463+
}
2464+
));
24082465
add_opt(common_arg(
24092466
{"-ngl", "--gpu-layers", "--n-gpu-layers"}, "N",
24102467
"number of layers to store in VRAM",

common/common.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1130,6 +1130,7 @@ struct llama_model_params common_model_params_to_llama(common_params & params) {
11301130
GGML_ASSERT(params.kv_overrides.back().key[0] == 0 && "KV overrides not terminated with empty key");
11311131
mparams.kv_overrides = params.kv_overrides.data();
11321132
}
1133+
mparams.n_expert_used_override = params.num_experts;
11331134

11341135
if (params.tensor_buft_overrides.empty()) {
11351136
mparams.tensor_buft_overrides = NULL;
@@ -1178,6 +1179,10 @@ struct llama_context_params common_context_params_to_llama(const common_params &
11781179
cparams.type_k = params.cache_type_k;
11791180
cparams.type_v = params.cache_type_v;
11801181

1182+
cparams.num_experts = params.num_experts;
1183+
cparams.omit_experts = params.omit_experts;
1184+
cparams.force_experts = params.force_experts;
1185+
11811186
return cparams;
11821187
}
11831188

common/common.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -467,6 +467,11 @@ struct common_params {
467467
// return false from callback to abort model loading or true to continue
468468
llama_progress_callback load_progress_callback = NULL;
469469
void * load_progress_callback_user_data = NULL;
470+
471+
// MoE expert selection
472+
int32_t num_experts = 0; // number of experts to use, 0 = model defined
473+
std::vector<int32_t> omit_experts; // comma-separated list of expert indices to omit
474+
std::vector<int32_t> force_experts; // comma-separated list of expert indices to force
470475
};
471476

472477
// call once at the start of a program if it uses libcommon

include/llama.h

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,17 @@
11
#ifndef LLAMA_H
22
#define LLAMA_H
33

4-
#include "ggml.h"
5-
#include "ggml-cpu.h"
64
#include "ggml-backend.h"
5+
#include "ggml-cpu.h"
76
#include "ggml-opt.h"
7+
#include "ggml.h"
88

9+
#include <stdbool.h>
910
#include <stddef.h>
1011
#include <stdint.h>
1112
#include <stdio.h>
12-
#include <stdbool.h>
13+
14+
#include <vector>
1315

1416
#ifdef LLAMA_SHARED
1517
# if defined(_WIN32) && !defined(__MINGW32__)
@@ -283,6 +285,7 @@ extern "C" {
283285

284286
// override key-value pairs of the model meta data
285287
const struct llama_model_kv_override * kv_overrides;
288+
int32_t n_expert_used_override; // number of expert overrides, 0 = no overrides
286289

287290
// Keep the booleans together to avoid misalignment during copy-by-value.
288291
bool vocab_only; // only load the vocabulary, no weights
@@ -340,6 +343,10 @@ extern "C" {
340343
bool kv_unified; // use a unified buffer across the input sequences when computing the attention
341344
// try to disable when n_seq_max > 1 for improved performance when the sequences do not share a large prefix
342345
// ref: https://github.com/ggml-org/llama.cpp/pull/14363
346+
// MoE expert selection
347+
int32_t num_experts; // number of experts to use, 0 = model defined
348+
std::vector<int32_t> omit_experts; // comma-separated list of expert indices to omit
349+
std::vector<int32_t> force_experts; // comma-separated list of expert indices to force
343350
};
344351

345352
// model quantization parameters

src/llama-context.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,8 @@ llama_context::llama_context(
102102

103103
cparams.op_offload = params.op_offload;
104104
cparams.kv_unified = params.kv_unified;
105+
cparams.omit_experts = params.omit_experts;
106+
cparams.force_experts = params.force_experts;
105107

106108
{
107109
const char * LLAMA_SET_ROWS = getenv("LLAMA_SET_ROWS");
@@ -2269,6 +2271,7 @@ llama_context_params llama_context_default_params() {
22692271
/*.op_offload =*/ true,
22702272
/*.swa_full =*/ true,
22712273
/*.kv_unified =*/ false,
2274+
/*.omit_experts =*/ {},
22722275
};
22732276

22742277
return result;

src/llama-cparams.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
#include "llama.h"
44

55
#include <cstdint>
6+
#include <vector>
67

78
#define LLAMA_MAX_SEQ 64
89

@@ -26,6 +27,9 @@ struct llama_cparams {
2627
float yarn_beta_slow;
2728
float defrag_thold;
2829

30+
std::vector<int32_t> omit_experts;
31+
std::vector<int32_t> force_experts;
32+
2933
bool embeddings;
3034
bool causal_attn;
3135
bool offload_kqv;

src/llama-graph.cpp

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -258,6 +258,39 @@ void llm_graph_input_cross_embd::set_input(const llama_ubatch * ubatch) {
258258
}
259259
}
260260

261+
void llm_graph_input_expert_mask::set_input(const llama_ubatch * ubatch) {
262+
if (mask == nullptr || (cparams.omit_experts.empty() && cparams.force_experts.empty())) {
263+
return;
264+
}
265+
GGML_UNUSED(ubatch);
266+
267+
const int64_t n_expert = mask->ne[0];
268+
269+
GGML_ASSERT(ggml_backend_buffer_is_host(mask->buffer));
270+
float * data = (float *) mask->data;
271+
272+
std::fill(data, data + n_expert, 0.0f);
273+
274+
for (int32_t expert_idx : cparams.omit_experts) {
275+
if (expert_idx >= 0 && expert_idx < n_expert) {
276+
data[expert_idx] = -INFINITY;
277+
}
278+
}
279+
for (int32_t expert_idx : cparams.force_experts) {
280+
if (expert_idx >= 0 && expert_idx < n_expert) {
281+
data[expert_idx] = INFINITY;
282+
}
283+
}
284+
}
285+
286+
bool llm_graph_input_expert_mask::can_reuse(const llm_graph_params & params) {
287+
bool res = true;
288+
res &= mask->ne[0] == params.hparams.n_expert;
289+
res &= cparams.omit_experts == params.cparams.omit_experts;
290+
res &= cparams.force_experts == params.cparams.force_experts;
291+
return res;
292+
}
293+
261294
void llm_graph_input_attn_no_cache::set_input(const llama_ubatch * ubatch) {
262295
const int64_t n_kv = ubatch->n_tokens;
263296
const int64_t n_tokens = ubatch->n_tokens;
@@ -787,6 +820,7 @@ ggml_tensor * llm_graph_context::build_moe_ffn(
787820
bool scale_w,
788821
float w_scale,
789822
llama_expert_gating_func_type gating_op,
823+
ggml_tensor * expert_mask,
790824
int il,
791825
ggml_tensor * probs_in) const {
792826
return build_moe_ffn(
@@ -803,6 +837,7 @@ ggml_tensor * llm_graph_context::build_moe_ffn(
803837
scale_w,
804838
w_scale,
805839
gating_op,
840+
expert_mask,
806841
il,
807842
probs_in
808843
);
@@ -826,6 +861,7 @@ ggml_tensor * llm_graph_context::build_moe_ffn(
826861
bool scale_w,
827862
float w_scale,
828863
llama_expert_gating_func_type gating_op,
864+
ggml_tensor * expert_mask,
829865
int il,
830866
ggml_tensor * probs_in) const {
831867
const int64_t n_embd = cur->ne[0];
@@ -879,6 +915,12 @@ ggml_tensor * llm_graph_context::build_moe_ffn(
879915
selection_probs = logits;
880916
}
881917

918+
// Omit or force specified experts by adding a mask of -INF/INF respectively
919+
if (expert_mask != nullptr) {
920+
selection_probs = ggml_add(ctx0, selection_probs, expert_mask);
921+
cb(selection_probs, "ffn_moe_probs_masked", il);
922+
}
923+
882924
// select experts
883925
ggml_tensor * selected_experts = ggml_top_k(ctx0, selection_probs, n_expert_used); // [n_expert_used, n_tokens]
884926
cb(selected_experts->src[0], "ffn_moe_argsort", il);
@@ -1352,6 +1394,14 @@ llm_graph_input_attn_no_cache * llm_graph_context::build_attn_inp_no_cache() con
13521394
return (llm_graph_input_attn_no_cache *) res->add_input(std::move(inp));
13531395
}
13541396

1397+
llm_graph_input_expert_mask * llm_graph_context::build_inp_expert_mask() const {
1398+
auto inp = std::make_unique<llm_graph_input_expert_mask>(cparams);
1399+
auto & cur = inp->mask;
1400+
cur = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, hparams.n_expert);
1401+
ggml_set_input(cur);
1402+
return (llm_graph_input_expert_mask *) res->add_input(std::move(inp));
1403+
}
1404+
13551405
ggml_tensor * llm_graph_context::build_attn(
13561406
llm_graph_input_attn_no_cache * inp,
13571407
ggml_tensor * wo,

src/llama-graph.h

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -238,6 +238,20 @@ class llm_graph_input_cross_embd : public llm_graph_input_i {
238238
const llama_cross * cross;
239239
};
240240

241+
class llm_graph_input_expert_mask : public llm_graph_input_i {
242+
public:
243+
llm_graph_input_expert_mask(const llama_cparams & cparams) : cparams(cparams) {}
244+
245+
virtual ~llm_graph_input_expert_mask() = default;
246+
247+
void set_input(const llama_ubatch * ubatch) override;
248+
bool can_reuse(const llm_graph_params & params) override;
249+
250+
ggml_tensor * mask = nullptr; // F32 [n_expert]
251+
252+
const llama_cparams & cparams;
253+
};
254+
241255
class llm_graph_input_attn_no_cache : public llm_graph_input_i {
242256
public:
243257
llm_graph_input_attn_no_cache(const llama_hparams & hparams, const llama_cparams & cparams) :
@@ -635,6 +649,7 @@ struct llm_graph_context {
635649
bool scale_w,
636650
float w_scale,
637651
llama_expert_gating_func_type gating_op,
652+
ggml_tensor * expert_mask,
638653
int il,
639654
ggml_tensor * probs_in = nullptr) const;
640655

@@ -656,6 +671,7 @@ struct llm_graph_context {
656671
bool scale_w,
657672
float w_scale,
658673
llama_expert_gating_func_type gating_op,
674+
ggml_tensor * expert_mask,
659675
int il,
660676
ggml_tensor * probs_in = nullptr) const;
661677

@@ -814,6 +830,8 @@ struct llm_graph_context {
814830
ggml_tensor * cls_b,
815831
ggml_tensor * cls_out,
816832
ggml_tensor * cls_out_b) const;
833+
834+
llm_graph_input_expert_mask * build_inp_expert_mask() const;
817835
};
818836

819837
// TODO: better name

0 commit comments

Comments
 (0)