Skip to content

common : add --override-tensor-draft, --cpu-moe-draft and --n-cpu-moe-draft parameters #15191

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 9 commits into from
Aug 13, 2025
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 58 additions & 0 deletions common/arg.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -993,6 +993,10 @@ static bool common_params_parse_ex(int argc, char ** argv, common_params_context
params.tensor_buft_overrides.push_back({nullptr, nullptr});
}

if (!params.speculative.tensor_buft_overrides.empty()) {
params.speculative.tensor_buft_overrides.push_back({nullptr, nullptr});
}

if (!params.chat_template.empty() && !common_chat_verify_template(params.chat_template, params.use_jinja)) {
throw std::runtime_error(string_format(
"error: the supplied chat template is not supported: %s%s\n",
Expand Down Expand Up @@ -2383,6 +2387,39 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
}
}
));
add_opt(common_arg(
{"--override-tensor-draft"}, "<tensor name pattern>=<buffer type>,...",
"override tensor buffer type for draft model", [](common_params & params, const std::string & value) {
/* static */ std::map<std::string, ggml_backend_buffer_type_t> buft_list;
if (buft_list.empty()) {
for (size_t i = 0; i < ggml_backend_dev_count(); ++i) {
auto * dev = ggml_backend_dev_get(i);
auto * buft = ggml_backend_dev_buffer_type(dev);
if (buft) {
buft_list[ggml_backend_buft_name(buft)] = buft;
}
}
}
for (const auto & override : string_split<std::string>(value, ',')) {
std::string::size_type pos = override.find('=');
if (pos == std::string::npos) {
throw std::invalid_argument("invalid value");
}
std::string tensor_name = override.substr(0, pos);
std::string buffer_type = override.substr(pos + 1);
if (buft_list.find(buffer_type) == buft_list.end()) {
printf("Available buffer types:\n");
for (const auto & it : buft_list) {
printf(" %s\n", ggml_backend_buft_name(it.second));
}
throw std::invalid_argument("unknown buffer type");
}
static std::list<std::string> buft_overrides;
buft_overrides.push_back(tensor_name);
params.speculative.tensor_buft_overrides.push_back({buft_overrides.back().c_str(), buft_list.at(buffer_type)});
}
}
).set_examples({LLAMA_EXAMPLE_SPECULATIVE, LLAMA_EXAMPLE_SERVER}));
add_opt(common_arg(
{"--cpu-moe", "-cmoe"},
"keep all Mixture of Experts (MoE) weights in the CPU",
Expand All @@ -2405,6 +2442,27 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
}
}
).set_env("LLAMA_ARG_N_CPU_MOE"));
add_opt(common_arg(
{"--cpu-moe-draft", "-cmoed"},
"keep all Mixture of Experts (MoE) weights in the CPU for the draft model",
[](common_params & params) {
params.speculative.tensor_buft_overrides.push_back({"\\.ffn_(up|down|gate)_exps", ggml_backend_cpu_buffer_type()});
}
).set_examples({LLAMA_EXAMPLE_SPECULATIVE, LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_CPU_MOE_DRAFT"));
add_opt(common_arg(
{"--n-cpu-moe-draft", "-ncmoed"}, "N",
"keep the Mixture of Experts (MoE) weights of the first N layers in the CPU for the draft model",
[](common_params & params, int value) {
if (value < 0) {
throw std::invalid_argument("invalid value");
}
for (int i = 0; i < value; ++i) {
static std::list<std::string> buft_overrides_draft;
buft_overrides_draft.push_back(string_format("blk\\.%d\\.ffn_(up|down|gate)_exps", i));
params.speculative.tensor_buft_overrides.push_back({buft_overrides_draft.back().c_str(), ggml_backend_cpu_buffer_type()});
}
}
).set_examples({LLAMA_EXAMPLE_SPECULATIVE, LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_N_CPU_MOE_DRAFT"));
add_opt(common_arg(
{"-ngl", "--gpu-layers", "--n-gpu-layers"}, "N",
"number of layers to store in VRAM",
Expand Down
1 change: 1 addition & 0 deletions common/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -202,6 +202,7 @@ struct common_params_speculative {
float p_split = 0.1f; // speculative decoding split probability
float p_min = 0.75f; // minimum speculative decoding probability (greedy)
std::vector<std::pair<std::string, std::string>> replacements; // main to speculative model replacements
std::vector<llama_model_tensor_buft_override> tensor_buft_overrides;

ggml_type cache_type_k = GGML_TYPE_F16; // KV cache data type for the K
ggml_type cache_type_v = GGML_TYPE_F16; // KV cache data type for the V
Expand Down
8 changes: 8 additions & 0 deletions examples/speculative-simple/speculative-simple.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,14 @@ int main(int argc, char ** argv) {
}

params.cpuparams_batch.n_threads = params.speculative.cpuparams_batch.n_threads;

// Apply tensor overrides for draft model
if (!params.speculative.tensor_buft_overrides.empty()) {
params.tensor_buft_overrides = params.speculative.tensor_buft_overrides;
} else {
params.tensor_buft_overrides.clear();
}

common_init_result llama_init_dft = common_init_from_params(params);

//model_dft = llama_init_dft.model.get();
Expand Down
8 changes: 8 additions & 0 deletions examples/speculative/speculative.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,14 @@ int main(int argc, char ** argv) {
}

params.cpuparams_batch.n_threads = params.speculative.cpuparams_batch.n_threads;

// Apply tensor overrides for draft model
if (!params.speculative.tensor_buft_overrides.empty()) {
params.tensor_buft_overrides = params.speculative.tensor_buft_overrides;
} else {
params.tensor_buft_overrides.clear();
}

common_init_result llama_init_dft = common_init_from_params(params);

model_dft = llama_init_dft.model.get();
Expand Down
7 changes: 7 additions & 0 deletions tools/server/server.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2011,6 +2011,13 @@ struct server_context {
params_dft.cache_type_k = params_base.speculative.cache_type_k;
params_dft.cache_type_v = params_base.speculative.cache_type_v;

// Apply tensor overrides for draft model
if (!params_base.speculative.tensor_buft_overrides.empty()) {
params_dft.tensor_buft_overrides = params_base.speculative.tensor_buft_overrides;
} else {
params_dft.tensor_buft_overrides.clear(); // ensure no main overrides leak in
}

llama_init_dft = common_init_from_params(params_dft);

model_dft = llama_init_dft.model.get();
Expand Down
Loading