Skip to content

MoE Expert manipulation args #15165

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 57 additions & 0 deletions common/arg.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2405,6 +2405,63 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
}
}
).set_env("LLAMA_ARG_N_CPU_MOE"));
add_opt(common_arg(
{"--num-experts"}, "N",
"Override the number of experts to use for MoE models (default: 0 = use model's default)",
[](common_params & params, int value) {
params.num_experts = value;
}
));
add_opt(common_arg(
{"--omit-experts"}, "IDs",
"comma-separated list of expert indices to omit from MoE selection (e.g. 1,3,5 or 1-5,7)",
[](common_params & params, const std::string & value) {
params.omit_experts.clear();
auto parts = string_split<std::string>(value, ',');
for (const auto& part : parts) {
if (part.find('-') != std::string::npos) {
// Parse range (e.g., "1-5")
auto range = string_split<int32_t>(part, '-');
if (range.size() == 2 && range[0] <= range[1]) {
for (int32_t i = range[0]; i <= range[1]; ++i) {
params.omit_experts.push_back(i);
}
}
} else {
params.omit_experts.push_back(std::stoi(part));
}
}

// Sort and remove duplicates for efficient processing later
std::sort(params.omit_experts.begin(), params.omit_experts.end());
params.omit_experts.erase(std::unique(params.omit_experts.begin(), params.omit_experts.end()), params.omit_experts.end());
}
));
add_opt(common_arg(
{"--force-experts"}, "IDs",
"comma-separated list of expert indices to always use in MoE selection (e.g. 1,3,5 or 1-5,7)",
[](common_params & params, const std::string & value) {
params.force_experts.clear();
auto parts = string_split<std::string>(value, ',');
for (const auto& part : parts) {
if (part.find('-') != std::string::npos) {
// Parse range (e.g., "1-5")
auto range = string_split<int32_t>(part, '-');
if (range.size() == 2 && range[0] <= range[1]) {
for (int32_t i = range[0]; i <= range[1]; ++i) {
params.force_experts.push_back(i);
}
}
} else {
params.force_experts.push_back(std::stoi(part));
}
}

// Sort and remove duplicates for efficient processing later
std::sort(params.force_experts.begin(), params.force_experts.end());
params.force_experts.erase(std::unique(params.force_experts.begin(), params.force_experts.end()), params.force_experts.end());
}
));
add_opt(common_arg(
{"-ngl", "--gpu-layers", "--n-gpu-layers"}, "N",
"number of layers to store in VRAM",
Expand Down
5 changes: 5 additions & 0 deletions common/common.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1130,6 +1130,7 @@ struct llama_model_params common_model_params_to_llama(common_params & params) {
GGML_ASSERT(params.kv_overrides.back().key[0] == 0 && "KV overrides not terminated with empty key");
mparams.kv_overrides = params.kv_overrides.data();
}
mparams.n_expert_used_override = params.num_experts;

if (params.tensor_buft_overrides.empty()) {
mparams.tensor_buft_overrides = NULL;
Expand Down Expand Up @@ -1178,6 +1179,10 @@ struct llama_context_params common_context_params_to_llama(const common_params &
cparams.type_k = params.cache_type_k;
cparams.type_v = params.cache_type_v;

cparams.num_experts = params.num_experts;
cparams.omit_experts = params.omit_experts;
cparams.force_experts = params.force_experts;

return cparams;
}

Expand Down
5 changes: 5 additions & 0 deletions common/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -467,6 +467,11 @@ struct common_params {
// return false from callback to abort model loading or true to continue
llama_progress_callback load_progress_callback = NULL;
void * load_progress_callback_user_data = NULL;

// MoE expert selection
int32_t num_experts = 0; // number of experts to use, 0 = model defined
std::vector<int32_t> omit_experts; // comma-separated list of expert indices to omit
std::vector<int32_t> force_experts; // comma-separated list of expert indices to force
};

// call once at the start of a program if it uses libcommon
Expand Down
13 changes: 10 additions & 3 deletions include/llama.h
Original file line number Diff line number Diff line change
@@ -1,15 +1,17 @@
#ifndef LLAMA_H
#define LLAMA_H

#include "ggml.h"
#include "ggml-cpu.h"
#include "ggml-backend.h"
#include "ggml-cpu.h"
#include "ggml-opt.h"
#include "ggml.h"

#include <stdbool.h>
#include <stddef.h>
#include <stdint.h>
#include <stdio.h>
#include <stdbool.h>

#include <vector>

#ifdef LLAMA_SHARED
# if defined(_WIN32) && !defined(__MINGW32__)
Expand Down Expand Up @@ -283,6 +285,7 @@ extern "C" {

// override key-value pairs of the model meta data
const struct llama_model_kv_override * kv_overrides;
int32_t n_expert_used_override; // number of expert overrides, 0 = no overrides

// Keep the booleans together to avoid misalignment during copy-by-value.
bool vocab_only; // only load the vocabulary, no weights
Expand Down Expand Up @@ -340,6 +343,10 @@ extern "C" {
bool kv_unified; // use a unified buffer across the input sequences when computing the attention
// try to disable when n_seq_max > 1 for improved performance when the sequences do not share a large prefix
// ref: https://github.com/ggml-org/llama.cpp/pull/14363
// MoE expert selection
int32_t num_experts; // number of experts to use, 0 = model defined
std::vector<int32_t> omit_experts; // comma-separated list of expert indices to omit
std::vector<int32_t> force_experts; // comma-separated list of expert indices to force
};

// model quantization parameters
Expand Down
3 changes: 3 additions & 0 deletions src/llama-context.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,8 @@ llama_context::llama_context(

cparams.op_offload = params.op_offload;
cparams.kv_unified = params.kv_unified;
cparams.omit_experts = params.omit_experts;
cparams.force_experts = params.force_experts;

{
const char * LLAMA_SET_ROWS = getenv("LLAMA_SET_ROWS");
Expand Down Expand Up @@ -2269,6 +2271,7 @@ llama_context_params llama_context_default_params() {
/*.op_offload =*/ true,
/*.swa_full =*/ true,
/*.kv_unified =*/ false,
/*.omit_experts =*/ {},
};

return result;
Expand Down
4 changes: 4 additions & 0 deletions src/llama-cparams.h
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
#include "llama.h"

#include <cstdint>
#include <vector>

#define LLAMA_MAX_SEQ 64

Expand All @@ -26,6 +27,9 @@ struct llama_cparams {
float yarn_beta_slow;
float defrag_thold;

std::vector<int32_t> omit_experts;
std::vector<int32_t> force_experts;

bool embeddings;
bool causal_attn;
bool offload_kqv;
Expand Down
50 changes: 50 additions & 0 deletions src/llama-graph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -258,6 +258,39 @@ void llm_graph_input_cross_embd::set_input(const llama_ubatch * ubatch) {
}
}

void llm_graph_input_expert_mask::set_input(const llama_ubatch * ubatch) {
if (mask == nullptr || (cparams.omit_experts.empty() && cparams.force_experts.empty())) {
return;
}
GGML_UNUSED(ubatch);

const int64_t n_expert = mask->ne[0];

GGML_ASSERT(ggml_backend_buffer_is_host(mask->buffer));
float * data = (float *) mask->data;

std::fill(data, data + n_expert, 0.0f);

for (int32_t expert_idx : cparams.omit_experts) {
if (expert_idx >= 0 && expert_idx < n_expert) {
data[expert_idx] = -INFINITY;
}
}
for (int32_t expert_idx : cparams.force_experts) {
if (expert_idx >= 0 && expert_idx < n_expert) {
data[expert_idx] = INFINITY;
}
}
}

bool llm_graph_input_expert_mask::can_reuse(const llm_graph_params & params) {
bool res = true;
res &= mask->ne[0] == params.hparams.n_expert;
res &= cparams.omit_experts == params.cparams.omit_experts;
res &= cparams.force_experts == params.cparams.force_experts;
return res;
}

void llm_graph_input_attn_no_cache::set_input(const llama_ubatch * ubatch) {
const int64_t n_kv = ubatch->n_tokens;
const int64_t n_tokens = ubatch->n_tokens;
Expand Down Expand Up @@ -787,6 +820,7 @@ ggml_tensor * llm_graph_context::build_moe_ffn(
bool scale_w,
float w_scale,
llama_expert_gating_func_type gating_op,
ggml_tensor * expert_mask,
int il,
ggml_tensor * probs_in) const {
return build_moe_ffn(
Expand All @@ -803,6 +837,7 @@ ggml_tensor * llm_graph_context::build_moe_ffn(
scale_w,
w_scale,
gating_op,
expert_mask,
il,
probs_in
);
Expand All @@ -826,6 +861,7 @@ ggml_tensor * llm_graph_context::build_moe_ffn(
bool scale_w,
float w_scale,
llama_expert_gating_func_type gating_op,
ggml_tensor * expert_mask,
int il,
ggml_tensor * probs_in) const {
const int64_t n_embd = cur->ne[0];
Expand Down Expand Up @@ -879,6 +915,12 @@ ggml_tensor * llm_graph_context::build_moe_ffn(
selection_probs = logits;
}

// Omit or force specified experts by adding a mask of -INF/INF respectively
if (expert_mask != nullptr) {
selection_probs = ggml_add(ctx0, selection_probs, expert_mask);
cb(selection_probs, "ffn_moe_probs_masked", il);
}

// select experts
ggml_tensor * selected_experts = ggml_top_k(ctx0, selection_probs, n_expert_used); // [n_expert_used, n_tokens]
cb(selected_experts->src[0], "ffn_moe_argsort", il);
Expand Down Expand Up @@ -1352,6 +1394,14 @@ llm_graph_input_attn_no_cache * llm_graph_context::build_attn_inp_no_cache() con
return (llm_graph_input_attn_no_cache *) res->add_input(std::move(inp));
}

llm_graph_input_expert_mask * llm_graph_context::build_inp_expert_mask() const {
auto inp = std::make_unique<llm_graph_input_expert_mask>(cparams);
auto & cur = inp->mask;
cur = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, hparams.n_expert);
ggml_set_input(cur);
return (llm_graph_input_expert_mask *) res->add_input(std::move(inp));
}

ggml_tensor * llm_graph_context::build_attn(
llm_graph_input_attn_no_cache * inp,
ggml_tensor * wo,
Expand Down
18 changes: 18 additions & 0 deletions src/llama-graph.h
Original file line number Diff line number Diff line change
Expand Up @@ -238,6 +238,20 @@ class llm_graph_input_cross_embd : public llm_graph_input_i {
const llama_cross * cross;
};

class llm_graph_input_expert_mask : public llm_graph_input_i {
public:
llm_graph_input_expert_mask(const llama_cparams & cparams) : cparams(cparams) {}

virtual ~llm_graph_input_expert_mask() = default;

void set_input(const llama_ubatch * ubatch) override;
bool can_reuse(const llm_graph_params & params) override;

ggml_tensor * mask = nullptr; // F32 [n_expert]

const llama_cparams & cparams;
};

class llm_graph_input_attn_no_cache : public llm_graph_input_i {
public:
llm_graph_input_attn_no_cache(const llama_hparams & hparams, const llama_cparams & cparams) :
Expand Down Expand Up @@ -635,6 +649,7 @@ struct llm_graph_context {
bool scale_w,
float w_scale,
llama_expert_gating_func_type gating_op,
ggml_tensor * expert_mask,
int il,
ggml_tensor * probs_in = nullptr) const;

Expand All @@ -656,6 +671,7 @@ struct llm_graph_context {
bool scale_w,
float w_scale,
llama_expert_gating_func_type gating_op,
ggml_tensor * expert_mask,
int il,
ggml_tensor * probs_in = nullptr) const;

Expand Down Expand Up @@ -814,6 +830,8 @@ struct llm_graph_context {
ggml_tensor * cls_b,
ggml_tensor * cls_out,
ggml_tensor * cls_out_b) const;

llm_graph_input_expert_mask * build_inp_expert_mask() const;
};

// TODO: better name
Expand Down
Loading
Loading