diff --git a/common/arg.cpp b/common/arg.cpp index 0f01bb31454a4..066abff1d713c 100644 --- a/common/arg.cpp +++ b/common/arg.cpp @@ -749,6 +749,39 @@ std::pair> common_remote_get_content(const std::string & // utils // +// Helper function to parse tensor buffer override strings +static void parse_tensor_buffer_overrides(const std::string & value, std::vector & overrides) { + std::map buft_list; + for (size_t i = 0; i < ggml_backend_dev_count(); ++i) { + auto * dev = ggml_backend_dev_get(i); + auto * buft = ggml_backend_dev_buffer_type(dev); + if (buft) { + buft_list[ggml_backend_buft_name(buft)] = buft; + } + } + + for (const auto & override : string_split(value, ',')) { + std::string::size_type pos = override.find('='); + if (pos == std::string::npos) { + throw std::invalid_argument("invalid value"); + } + std::string tensor_name = override.substr(0, pos); + std::string buffer_type = override.substr(pos + 1); + + if (buft_list.find(buffer_type) == buft_list.end()) { + printf("Available buffer types:\n"); + for (const auto & it : buft_list) { + printf(" %s\n", ggml_backend_buft_name(it.second)); + } + throw std::invalid_argument("unknown buffer type"); + } + // keep strings alive and avoid leaking memory by storing them in a static vector + static std::list buft_overrides; + buft_overrides.push_back(tensor_name); + overrides.push_back({buft_overrides.back().c_str(), buft_list.at(buffer_type)}); + } +} + struct handle_model_result { bool found_mmproj = false; common_params_model mmproj; @@ -993,6 +1026,10 @@ static bool common_params_parse_ex(int argc, char ** argv, common_params_context params.tensor_buft_overrides.push_back({nullptr, nullptr}); } + if (!params.speculative.tensor_buft_overrides.empty()) { + params.speculative.tensor_buft_overrides.push_back({nullptr, nullptr}); + } + if (!params.chat_template.empty() && !common_chat_verify_template(params.chat_template, params.use_jinja)) { throw std::runtime_error(string_format( "error: the supplied chat template is not supported: %s%s\n", @@ -2349,40 +2386,15 @@ common_params_context common_params_parser_init(common_params & params, llama_ex add_opt(common_arg( {"--override-tensor", "-ot"}, "=,...", "override tensor buffer type", [](common_params & params, const std::string & value) { - /* static */ std::map buft_list; - if (buft_list.empty()) { - // enumerate all the devices and add their buffer types to the list - for (size_t i = 0; i < ggml_backend_dev_count(); ++i) { - auto * dev = ggml_backend_dev_get(i); - auto * buft = ggml_backend_dev_buffer_type(dev); - if (buft) { - buft_list[ggml_backend_buft_name(buft)] = buft; - } - } - } - - for (const auto & override : string_split(value, ',')) { - std::string::size_type pos = override.find('='); - if (pos == std::string::npos) { - throw std::invalid_argument("invalid value"); - } - std::string tensor_name = override.substr(0, pos); - std::string buffer_type = override.substr(pos + 1); - - if (buft_list.find(buffer_type) == buft_list.end()) { - printf("Available buffer types:\n"); - for (const auto & it : buft_list) { - printf(" %s\n", ggml_backend_buft_name(it.second)); - } - throw std::invalid_argument("unknown buffer type"); - } - // keep strings alive and avoid leaking memory by storing them in a static vector - static std::list buft_overrides; - buft_overrides.push_back(tensor_name); - params.tensor_buft_overrides.push_back({buft_overrides.back().c_str(), buft_list.at(buffer_type)}); - } + parse_tensor_buffer_overrides(value, params.tensor_buft_overrides); } )); + add_opt(common_arg( + {"--override-tensor-draft", "-otd"}, "=,...", + "override tensor buffer type for draft model", [](common_params & params, const std::string & value) { + parse_tensor_buffer_overrides(value, params.speculative.tensor_buft_overrides); + } + ).set_examples({LLAMA_EXAMPLE_SPECULATIVE, LLAMA_EXAMPLE_SERVER})); add_opt(common_arg( {"--cpu-moe", "-cmoe"}, "keep all Mixture of Experts (MoE) weights in the CPU", @@ -2405,6 +2417,27 @@ common_params_context common_params_parser_init(common_params & params, llama_ex } } ).set_env("LLAMA_ARG_N_CPU_MOE")); + add_opt(common_arg( + {"--cpu-moe-draft", "-cmoed"}, + "keep all Mixture of Experts (MoE) weights in the CPU for the draft model", + [](common_params & params) { + params.speculative.tensor_buft_overrides.push_back({"\\.ffn_(up|down|gate)_exps", ggml_backend_cpu_buffer_type()}); + } + ).set_examples({LLAMA_EXAMPLE_SPECULATIVE, LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_CPU_MOE_DRAFT")); + add_opt(common_arg( + {"--n-cpu-moe-draft", "-ncmoed"}, "N", + "keep the Mixture of Experts (MoE) weights of the first N layers in the CPU for the draft model", + [](common_params & params, int value) { + if (value < 0) { + throw std::invalid_argument("invalid value"); + } + for (int i = 0; i < value; ++i) { + static std::list buft_overrides_draft; + buft_overrides_draft.push_back(string_format("blk\\.%d\\.ffn_(up|down|gate)_exps", i)); + params.speculative.tensor_buft_overrides.push_back({buft_overrides_draft.back().c_str(), ggml_backend_cpu_buffer_type()}); + } + } + ).set_examples({LLAMA_EXAMPLE_SPECULATIVE, LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_N_CPU_MOE_DRAFT")); add_opt(common_arg( {"-ngl", "--gpu-layers", "--n-gpu-layers"}, "N", "number of layers to store in VRAM", diff --git a/common/common.h b/common/common.h index 5eab199af559e..c09509b669e54 100644 --- a/common/common.h +++ b/common/common.h @@ -202,6 +202,7 @@ struct common_params_speculative { float p_split = 0.1f; // speculative decoding split probability float p_min = 0.75f; // minimum speculative decoding probability (greedy) std::vector> replacements; // main to speculative model replacements + std::vector tensor_buft_overrides; ggml_type cache_type_k = GGML_TYPE_F16; // KV cache data type for the K ggml_type cache_type_v = GGML_TYPE_F16; // KV cache data type for the V diff --git a/examples/speculative-simple/speculative-simple.cpp b/examples/speculative-simple/speculative-simple.cpp index 722cd7f40f088..a8e53f28eb597 100644 --- a/examples/speculative-simple/speculative-simple.cpp +++ b/examples/speculative-simple/speculative-simple.cpp @@ -59,6 +59,8 @@ int main(int argc, char ** argv) { } params.cpuparams_batch.n_threads = params.speculative.cpuparams_batch.n_threads; + params.tensor_buft_overrides = params.speculative.tensor_buft_overrides; + common_init_result llama_init_dft = common_init_from_params(params); //model_dft = llama_init_dft.model.get(); diff --git a/examples/speculative/speculative.cpp b/examples/speculative/speculative.cpp index 0adffdb006bcf..8449406a6d27a 100644 --- a/examples/speculative/speculative.cpp +++ b/examples/speculative/speculative.cpp @@ -85,6 +85,8 @@ int main(int argc, char ** argv) { } params.cpuparams_batch.n_threads = params.speculative.cpuparams_batch.n_threads; + params.tensor_buft_overrides = params.speculative.tensor_buft_overrides; + common_init_result llama_init_dft = common_init_from_params(params); model_dft = llama_init_dft.model.get(); diff --git a/tools/server/server.cpp b/tools/server/server.cpp index a255d481a4d1c..34877c1d8304a 100644 --- a/tools/server/server.cpp +++ b/tools/server/server.cpp @@ -2011,6 +2011,8 @@ struct server_context { params_dft.cache_type_k = params_base.speculative.cache_type_k; params_dft.cache_type_v = params_base.speculative.cache_type_v; + params_dft.tensor_buft_overrides = params_base.speculative.tensor_buft_overrides; + llama_init_dft = common_init_from_params(params_dft); model_dft = llama_init_dft.model.get();