Skip to content

Commit d8914fc

Browse files
CopilotCISCggerganovslaren
authored
common : add --override-tensor-draft, --cpu-moe-draft and --n-cpu-moe-draft parameters (#15191)
* Checkpoint from VS Code for coding agent session * Initial plan * Fix typo in --override-tensor-draft flag implementation * Add null termination for speculative tensor buffer overrides * Apply suggestions from code review * Apply suggestions from code review * Extract tensor override parsing logic to common function (addresses @slaren's feedback) * Apply suggestions from code review * Apply suggestions --------- Co-authored-by: Sigbjørn Skjæret <[email protected]> Co-authored-by: Georgi Gerganov <[email protected]> Co-authored-by: Diego Devesa <[email protected]>
1 parent e885445 commit d8914fc

File tree

5 files changed

+72
-32
lines changed

5 files changed

+72
-32
lines changed

common/arg.cpp

Lines changed: 65 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -749,6 +749,39 @@ std::pair<long, std::vector<char>> common_remote_get_content(const std::string &
749749
// utils
750750
//
751751

752+
// Helper function to parse tensor buffer override strings
753+
static void parse_tensor_buffer_overrides(const std::string & value, std::vector<llama_model_tensor_buft_override> & overrides) {
754+
std::map<std::string, ggml_backend_buffer_type_t> buft_list;
755+
for (size_t i = 0; i < ggml_backend_dev_count(); ++i) {
756+
auto * dev = ggml_backend_dev_get(i);
757+
auto * buft = ggml_backend_dev_buffer_type(dev);
758+
if (buft) {
759+
buft_list[ggml_backend_buft_name(buft)] = buft;
760+
}
761+
}
762+
763+
for (const auto & override : string_split<std::string>(value, ',')) {
764+
std::string::size_type pos = override.find('=');
765+
if (pos == std::string::npos) {
766+
throw std::invalid_argument("invalid value");
767+
}
768+
std::string tensor_name = override.substr(0, pos);
769+
std::string buffer_type = override.substr(pos + 1);
770+
771+
if (buft_list.find(buffer_type) == buft_list.end()) {
772+
printf("Available buffer types:\n");
773+
for (const auto & it : buft_list) {
774+
printf(" %s\n", ggml_backend_buft_name(it.second));
775+
}
776+
throw std::invalid_argument("unknown buffer type");
777+
}
778+
// keep strings alive and avoid leaking memory by storing them in a static vector
779+
static std::list<std::string> buft_overrides;
780+
buft_overrides.push_back(tensor_name);
781+
overrides.push_back({buft_overrides.back().c_str(), buft_list.at(buffer_type)});
782+
}
783+
}
784+
752785
struct handle_model_result {
753786
bool found_mmproj = false;
754787
common_params_model mmproj;
@@ -993,6 +1026,10 @@ static bool common_params_parse_ex(int argc, char ** argv, common_params_context
9931026
params.tensor_buft_overrides.push_back({nullptr, nullptr});
9941027
}
9951028

1029+
if (!params.speculative.tensor_buft_overrides.empty()) {
1030+
params.speculative.tensor_buft_overrides.push_back({nullptr, nullptr});
1031+
}
1032+
9961033
if (!params.chat_template.empty() && !common_chat_verify_template(params.chat_template, params.use_jinja)) {
9971034
throw std::runtime_error(string_format(
9981035
"error: the supplied chat template is not supported: %s%s\n",
@@ -2349,40 +2386,15 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
23492386
add_opt(common_arg(
23502387
{"--override-tensor", "-ot"}, "<tensor name pattern>=<buffer type>,...",
23512388
"override tensor buffer type", [](common_params & params, const std::string & value) {
2352-
/* static */ std::map<std::string, ggml_backend_buffer_type_t> buft_list;
2353-
if (buft_list.empty()) {
2354-
// enumerate all the devices and add their buffer types to the list
2355-
for (size_t i = 0; i < ggml_backend_dev_count(); ++i) {
2356-
auto * dev = ggml_backend_dev_get(i);
2357-
auto * buft = ggml_backend_dev_buffer_type(dev);
2358-
if (buft) {
2359-
buft_list[ggml_backend_buft_name(buft)] = buft;
2360-
}
2361-
}
2362-
}
2363-
2364-
for (const auto & override : string_split<std::string>(value, ',')) {
2365-
std::string::size_type pos = override.find('=');
2366-
if (pos == std::string::npos) {
2367-
throw std::invalid_argument("invalid value");
2368-
}
2369-
std::string tensor_name = override.substr(0, pos);
2370-
std::string buffer_type = override.substr(pos + 1);
2371-
2372-
if (buft_list.find(buffer_type) == buft_list.end()) {
2373-
printf("Available buffer types:\n");
2374-
for (const auto & it : buft_list) {
2375-
printf(" %s\n", ggml_backend_buft_name(it.second));
2376-
}
2377-
throw std::invalid_argument("unknown buffer type");
2378-
}
2379-
// keep strings alive and avoid leaking memory by storing them in a static vector
2380-
static std::list<std::string> buft_overrides;
2381-
buft_overrides.push_back(tensor_name);
2382-
params.tensor_buft_overrides.push_back({buft_overrides.back().c_str(), buft_list.at(buffer_type)});
2383-
}
2389+
parse_tensor_buffer_overrides(value, params.tensor_buft_overrides);
23842390
}
23852391
));
2392+
add_opt(common_arg(
2393+
{"--override-tensor-draft", "-otd"}, "<tensor name pattern>=<buffer type>,...",
2394+
"override tensor buffer type for draft model", [](common_params & params, const std::string & value) {
2395+
parse_tensor_buffer_overrides(value, params.speculative.tensor_buft_overrides);
2396+
}
2397+
).set_examples({LLAMA_EXAMPLE_SPECULATIVE, LLAMA_EXAMPLE_SERVER}));
23862398
add_opt(common_arg(
23872399
{"--cpu-moe", "-cmoe"},
23882400
"keep all Mixture of Experts (MoE) weights in the CPU",
@@ -2405,6 +2417,27 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
24052417
}
24062418
}
24072419
).set_env("LLAMA_ARG_N_CPU_MOE"));
2420+
add_opt(common_arg(
2421+
{"--cpu-moe-draft", "-cmoed"},
2422+
"keep all Mixture of Experts (MoE) weights in the CPU for the draft model",
2423+
[](common_params & params) {
2424+
params.speculative.tensor_buft_overrides.push_back({"\\.ffn_(up|down|gate)_exps", ggml_backend_cpu_buffer_type()});
2425+
}
2426+
).set_examples({LLAMA_EXAMPLE_SPECULATIVE, LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_CPU_MOE_DRAFT"));
2427+
add_opt(common_arg(
2428+
{"--n-cpu-moe-draft", "-ncmoed"}, "N",
2429+
"keep the Mixture of Experts (MoE) weights of the first N layers in the CPU for the draft model",
2430+
[](common_params & params, int value) {
2431+
if (value < 0) {
2432+
throw std::invalid_argument("invalid value");
2433+
}
2434+
for (int i = 0; i < value; ++i) {
2435+
static std::list<std::string> buft_overrides_draft;
2436+
buft_overrides_draft.push_back(string_format("blk\\.%d\\.ffn_(up|down|gate)_exps", i));
2437+
params.speculative.tensor_buft_overrides.push_back({buft_overrides_draft.back().c_str(), ggml_backend_cpu_buffer_type()});
2438+
}
2439+
}
2440+
).set_examples({LLAMA_EXAMPLE_SPECULATIVE, LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_N_CPU_MOE_DRAFT"));
24082441
add_opt(common_arg(
24092442
{"-ngl", "--gpu-layers", "--n-gpu-layers"}, "N",
24102443
"number of layers to store in VRAM",

common/common.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -202,6 +202,7 @@ struct common_params_speculative {
202202
float p_split = 0.1f; // speculative decoding split probability
203203
float p_min = 0.75f; // minimum speculative decoding probability (greedy)
204204
std::vector<std::pair<std::string, std::string>> replacements; // main to speculative model replacements
205+
std::vector<llama_model_tensor_buft_override> tensor_buft_overrides;
205206

206207
ggml_type cache_type_k = GGML_TYPE_F16; // KV cache data type for the K
207208
ggml_type cache_type_v = GGML_TYPE_F16; // KV cache data type for the V

examples/speculative-simple/speculative-simple.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,8 @@ int main(int argc, char ** argv) {
5959
}
6060

6161
params.cpuparams_batch.n_threads = params.speculative.cpuparams_batch.n_threads;
62+
params.tensor_buft_overrides = params.speculative.tensor_buft_overrides;
63+
6264
common_init_result llama_init_dft = common_init_from_params(params);
6365

6466
//model_dft = llama_init_dft.model.get();

examples/speculative/speculative.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,8 @@ int main(int argc, char ** argv) {
8585
}
8686

8787
params.cpuparams_batch.n_threads = params.speculative.cpuparams_batch.n_threads;
88+
params.tensor_buft_overrides = params.speculative.tensor_buft_overrides;
89+
8890
common_init_result llama_init_dft = common_init_from_params(params);
8991

9092
model_dft = llama_init_dft.model.get();

tools/server/server.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2015,6 +2015,8 @@ struct server_context {
20152015
params_dft.cache_type_k = params_base.speculative.cache_type_k;
20162016
params_dft.cache_type_v = params_base.speculative.cache_type_v;
20172017

2018+
params_dft.tensor_buft_overrides = params_base.speculative.tensor_buft_overrides;
2019+
20182020
llama_init_dft = common_init_from_params(params_dft);
20192021

20202022
model_dft = llama_init_dft.model.get();

0 commit comments

Comments
 (0)