Skip to content

common : add --override-tensor-draft, --cpu-moe-draft and --n-cpu-moe-draft parameters #15191

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 9 commits into from
Aug 13, 2025
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
99 changes: 67 additions & 32 deletions common/arg.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -749,6 +749,41 @@ std::pair<long, std::vector<char>> common_remote_get_content(const std::string &
// utils
//

// Helper function to parse tensor buffer override strings
static void parse_tensor_buffer_overrides(const std::string & value, std::vector<llama_model_tensor_buft_override> & overrides) {
static std::map<std::string, ggml_backend_buffer_type_t> buft_list;
if (buft_list.empty()) {
for (size_t i = 0; i < ggml_backend_dev_count(); ++i) {
auto * dev = ggml_backend_dev_get(i);
auto * buft = ggml_backend_dev_buffer_type(dev);
if (buft) {
buft_list[ggml_backend_buft_name(buft)] = buft;
}
}
}

for (const auto & override : string_split<std::string>(value, ',')) {
std::string::size_type pos = override.find('=');
if (pos == std::string::npos) {
throw std::invalid_argument("invalid value");
}
std::string tensor_name = override.substr(0, pos);
std::string buffer_type = override.substr(pos + 1);

if (buft_list.find(buffer_type) == buft_list.end()) {
printf("Available buffer types:\n");
for (const auto & it : buft_list) {
printf(" %s\n", ggml_backend_buft_name(it.second));
}
throw std::invalid_argument("unknown buffer type");
}
// keep strings alive and avoid leaking memory by storing them in a static vector
static std::list<std::string> buft_overrides;
buft_overrides.push_back(tensor_name);
overrides.push_back({buft_overrides.back().c_str(), buft_list.at(buffer_type)});
}
}

struct handle_model_result {
bool found_mmproj = false;
common_params_model mmproj;
Expand Down Expand Up @@ -993,6 +1028,10 @@ static bool common_params_parse_ex(int argc, char ** argv, common_params_context
params.tensor_buft_overrides.push_back({nullptr, nullptr});
}

if (!params.speculative.tensor_buft_overrides.empty()) {
params.speculative.tensor_buft_overrides.push_back({nullptr, nullptr});
}

if (!params.chat_template.empty() && !common_chat_verify_template(params.chat_template, params.use_jinja)) {
throw std::runtime_error(string_format(
"error: the supplied chat template is not supported: %s%s\n",
Expand Down Expand Up @@ -2349,40 +2388,15 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
add_opt(common_arg(
{"--override-tensor", "-ot"}, "<tensor name pattern>=<buffer type>,...",
"override tensor buffer type", [](common_params & params, const std::string & value) {
/* static */ std::map<std::string, ggml_backend_buffer_type_t> buft_list;
if (buft_list.empty()) {
// enumerate all the devices and add their buffer types to the list
for (size_t i = 0; i < ggml_backend_dev_count(); ++i) {
auto * dev = ggml_backend_dev_get(i);
auto * buft = ggml_backend_dev_buffer_type(dev);
if (buft) {
buft_list[ggml_backend_buft_name(buft)] = buft;
}
}
}

for (const auto & override : string_split<std::string>(value, ',')) {
std::string::size_type pos = override.find('=');
if (pos == std::string::npos) {
throw std::invalid_argument("invalid value");
}
std::string tensor_name = override.substr(0, pos);
std::string buffer_type = override.substr(pos + 1);

if (buft_list.find(buffer_type) == buft_list.end()) {
printf("Available buffer types:\n");
for (const auto & it : buft_list) {
printf(" %s\n", ggml_backend_buft_name(it.second));
}
throw std::invalid_argument("unknown buffer type");
}
// keep strings alive and avoid leaking memory by storing them in a static vector
static std::list<std::string> buft_overrides;
buft_overrides.push_back(tensor_name);
params.tensor_buft_overrides.push_back({buft_overrides.back().c_str(), buft_list.at(buffer_type)});
}
parse_tensor_buffer_overrides(value, params.tensor_buft_overrides);
}
));
add_opt(common_arg(
{"--override-tensor-draft"}, "<tensor name pattern>=<buffer type>,...",
"override tensor buffer type for draft model", [](common_params & params, const std::string & value) {
parse_tensor_buffer_overrides(value, params.speculative.tensor_buft_overrides);
}
).set_examples({LLAMA_EXAMPLE_SPECULATIVE, LLAMA_EXAMPLE_SERVER}));
add_opt(common_arg(
{"--cpu-moe", "-cmoe"},
"keep all Mixture of Experts (MoE) weights in the CPU",
Expand All @@ -2405,6 +2419,27 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
}
}
).set_env("LLAMA_ARG_N_CPU_MOE"));
add_opt(common_arg(
{"--cpu-moe-draft", "-cmoed"},
"keep all Mixture of Experts (MoE) weights in the CPU for the draft model",
[](common_params & params) {
params.speculative.tensor_buft_overrides.push_back({"\\.ffn_(up|down|gate)_exps", ggml_backend_cpu_buffer_type()});
}
).set_examples({LLAMA_EXAMPLE_SPECULATIVE, LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_CPU_MOE_DRAFT"));
add_opt(common_arg(
{"--n-cpu-moe-draft", "-ncmoed"}, "N",
"keep the Mixture of Experts (MoE) weights of the first N layers in the CPU for the draft model",
[](common_params & params, int value) {
if (value < 0) {
throw std::invalid_argument("invalid value");
}
for (int i = 0; i < value; ++i) {
static std::list<std::string> buft_overrides_draft;
buft_overrides_draft.push_back(string_format("blk\\.%d\\.ffn_(up|down|gate)_exps", i));
params.speculative.tensor_buft_overrides.push_back({buft_overrides_draft.back().c_str(), ggml_backend_cpu_buffer_type()});
}
}
).set_examples({LLAMA_EXAMPLE_SPECULATIVE, LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_N_CPU_MOE_DRAFT"));
add_opt(common_arg(
{"-ngl", "--gpu-layers", "--n-gpu-layers"}, "N",
"number of layers to store in VRAM",
Expand Down
1 change: 1 addition & 0 deletions common/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -202,6 +202,7 @@ struct common_params_speculative {
float p_split = 0.1f; // speculative decoding split probability
float p_min = 0.75f; // minimum speculative decoding probability (greedy)
std::vector<std::pair<std::string, std::string>> replacements; // main to speculative model replacements
std::vector<llama_model_tensor_buft_override> tensor_buft_overrides;

ggml_type cache_type_k = GGML_TYPE_F16; // KV cache data type for the K
ggml_type cache_type_v = GGML_TYPE_F16; // KV cache data type for the V
Expand Down
2 changes: 2 additions & 0 deletions examples/speculative-simple/speculative-simple.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,8 @@ int main(int argc, char ** argv) {
}

params.cpuparams_batch.n_threads = params.speculative.cpuparams_batch.n_threads;
params.tensor_buft_overrides = params.speculative.tensor_buft_overrides;

common_init_result llama_init_dft = common_init_from_params(params);

//model_dft = llama_init_dft.model.get();
Expand Down
2 changes: 2 additions & 0 deletions examples/speculative/speculative.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,8 @@ int main(int argc, char ** argv) {
}

params.cpuparams_batch.n_threads = params.speculative.cpuparams_batch.n_threads;
params.tensor_buft_overrides = params.speculative.tensor_buft_overrides;

common_init_result llama_init_dft = common_init_from_params(params);

model_dft = llama_init_dft.model.get();
Expand Down
2 changes: 2 additions & 0 deletions tools/server/server.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2011,6 +2011,8 @@ struct server_context {
params_dft.cache_type_k = params_base.speculative.cache_type_k;
params_dft.cache_type_v = params_base.speculative.cache_type_v;

params_dft.tensor_buft_overrides = params_base.speculative.tensor_buft_overrides;

llama_init_dft = common_init_from_params(params_dft);

model_dft = llama_init_dft.model.get();
Expand Down