@@ -749,6 +749,39 @@ std::pair<long, std::vector<char>> common_remote_get_content(const std::string &
749749// utils
750750//
751751
752+ // Helper function to parse tensor buffer override strings
753+ static void parse_tensor_buffer_overrides (const std::string & value, std::vector<llama_model_tensor_buft_override> & overrides) {
754+ std::map<std::string, ggml_backend_buffer_type_t > buft_list;
755+ for (size_t i = 0 ; i < ggml_backend_dev_count (); ++i) {
756+ auto * dev = ggml_backend_dev_get (i);
757+ auto * buft = ggml_backend_dev_buffer_type (dev);
758+ if (buft) {
759+ buft_list[ggml_backend_buft_name (buft)] = buft;
760+ }
761+ }
762+
763+ for (const auto & override : string_split<std::string>(value, ' ,' )) {
764+ std::string::size_type pos = override .find (' =' );
765+ if (pos == std::string::npos) {
766+ throw std::invalid_argument (" invalid value" );
767+ }
768+ std::string tensor_name = override .substr (0 , pos);
769+ std::string buffer_type = override .substr (pos + 1 );
770+
771+ if (buft_list.find (buffer_type) == buft_list.end ()) {
772+ printf (" Available buffer types:\n " );
773+ for (const auto & it : buft_list) {
774+ printf (" %s\n " , ggml_backend_buft_name (it.second ));
775+ }
776+ throw std::invalid_argument (" unknown buffer type" );
777+ }
778+ // keep strings alive and avoid leaking memory by storing them in a static vector
779+ static std::list<std::string> buft_overrides;
780+ buft_overrides.push_back (tensor_name);
781+ overrides.push_back ({buft_overrides.back ().c_str (), buft_list.at (buffer_type)});
782+ }
783+ }
784+
752785struct handle_model_result {
753786 bool found_mmproj = false ;
754787 common_params_model mmproj;
@@ -993,6 +1026,10 @@ static bool common_params_parse_ex(int argc, char ** argv, common_params_context
9931026 params.tensor_buft_overrides .push_back ({nullptr , nullptr });
9941027 }
9951028
1029+ if (!params.speculative .tensor_buft_overrides .empty ()) {
1030+ params.speculative .tensor_buft_overrides .push_back ({nullptr , nullptr });
1031+ }
1032+
9961033 if (!params.chat_template .empty () && !common_chat_verify_template (params.chat_template , params.use_jinja )) {
9971034 throw std::runtime_error (string_format (
9981035 " error: the supplied chat template is not supported: %s%s\n " ,
@@ -2349,40 +2386,15 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
23492386 add_opt (common_arg (
23502387 {" --override-tensor" , " -ot" }, " <tensor name pattern>=<buffer type>,..." ,
23512388 " override tensor buffer type" , [](common_params & params, const std::string & value) {
2352- /* static */ std::map<std::string, ggml_backend_buffer_type_t > buft_list;
2353- if (buft_list.empty ()) {
2354- // enumerate all the devices and add their buffer types to the list
2355- for (size_t i = 0 ; i < ggml_backend_dev_count (); ++i) {
2356- auto * dev = ggml_backend_dev_get (i);
2357- auto * buft = ggml_backend_dev_buffer_type (dev);
2358- if (buft) {
2359- buft_list[ggml_backend_buft_name (buft)] = buft;
2360- }
2361- }
2362- }
2363-
2364- for (const auto & override : string_split<std::string>(value, ' ,' )) {
2365- std::string::size_type pos = override .find (' =' );
2366- if (pos == std::string::npos) {
2367- throw std::invalid_argument (" invalid value" );
2368- }
2369- std::string tensor_name = override .substr (0 , pos);
2370- std::string buffer_type = override .substr (pos + 1 );
2371-
2372- if (buft_list.find (buffer_type) == buft_list.end ()) {
2373- printf (" Available buffer types:\n " );
2374- for (const auto & it : buft_list) {
2375- printf (" %s\n " , ggml_backend_buft_name (it.second ));
2376- }
2377- throw std::invalid_argument (" unknown buffer type" );
2378- }
2379- // keep strings alive and avoid leaking memory by storing them in a static vector
2380- static std::list<std::string> buft_overrides;
2381- buft_overrides.push_back (tensor_name);
2382- params.tensor_buft_overrides .push_back ({buft_overrides.back ().c_str (), buft_list.at (buffer_type)});
2383- }
2389+ parse_tensor_buffer_overrides (value, params.tensor_buft_overrides );
23842390 }
23852391 ));
2392+ add_opt (common_arg (
2393+ {" --override-tensor-draft" , " -otd" }, " <tensor name pattern>=<buffer type>,..." ,
2394+ " override tensor buffer type for draft model" , [](common_params & params, const std::string & value) {
2395+ parse_tensor_buffer_overrides (value, params.speculative .tensor_buft_overrides );
2396+ }
2397+ ).set_examples ({LLAMA_EXAMPLE_SPECULATIVE, LLAMA_EXAMPLE_SERVER}));
23862398 add_opt (common_arg (
23872399 {" --cpu-moe" , " -cmoe" },
23882400 " keep all Mixture of Experts (MoE) weights in the CPU" ,
@@ -2405,6 +2417,27 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
24052417 }
24062418 }
24072419 ).set_env (" LLAMA_ARG_N_CPU_MOE" ));
2420+ add_opt (common_arg (
2421+ {" --cpu-moe-draft" , " -cmoed" },
2422+ " keep all Mixture of Experts (MoE) weights in the CPU for the draft model" ,
2423+ [](common_params & params) {
2424+ params.speculative .tensor_buft_overrides .push_back ({" \\ .ffn_(up|down|gate)_exps" , ggml_backend_cpu_buffer_type ()});
2425+ }
2426+ ).set_examples ({LLAMA_EXAMPLE_SPECULATIVE, LLAMA_EXAMPLE_SERVER}).set_env (" LLAMA_ARG_CPU_MOE_DRAFT" ));
2427+ add_opt (common_arg (
2428+ {" --n-cpu-moe-draft" , " -ncmoed" }, " N" ,
2429+ " keep the Mixture of Experts (MoE) weights of the first N layers in the CPU for the draft model" ,
2430+ [](common_params & params, int value) {
2431+ if (value < 0 ) {
2432+ throw std::invalid_argument (" invalid value" );
2433+ }
2434+ for (int i = 0 ; i < value; ++i) {
2435+ static std::list<std::string> buft_overrides_draft;
2436+ buft_overrides_draft.push_back (string_format (" blk\\ .%d\\ .ffn_(up|down|gate)_exps" , i));
2437+ params.speculative .tensor_buft_overrides .push_back ({buft_overrides_draft.back ().c_str (), ggml_backend_cpu_buffer_type ()});
2438+ }
2439+ }
2440+ ).set_examples ({LLAMA_EXAMPLE_SPECULATIVE, LLAMA_EXAMPLE_SERVER}).set_env (" LLAMA_ARG_N_CPU_MOE_DRAFT" ));
24082441 add_opt (common_arg (
24092442 {" -ngl" , " --gpu-layers" , " --n-gpu-layers" }, " N" ,
24102443 " number of layers to store in VRAM" ,
@@ -3130,7 +3163,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
31303163 params.speculative .cpuparams .n_threads = std::thread::hardware_concurrency ();
31313164 }
31323165 }
3133- ).set_examples ({LLAMA_EXAMPLE_SPECULATIVE}));
3166+ ).set_examples ({LLAMA_EXAMPLE_SPECULATIVE, LLAMA_EXAMPLE_SERVER }));
31343167 add_opt (common_arg (
31353168 {" -tbd" , " --threads-batch-draft" }, " N" ,
31363169 " number of threads to use during batch and prompt processing (default: same as --threads-draft)" ,
@@ -3140,7 +3173,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
31403173 params.speculative .cpuparams_batch .n_threads = std::thread::hardware_concurrency ();
31413174 }
31423175 }
3143- ).set_examples ({LLAMA_EXAMPLE_SPECULATIVE}));
3176+ ).set_examples ({LLAMA_EXAMPLE_SPECULATIVE, LLAMA_EXAMPLE_SERVER }));
31443177 add_opt (common_arg (
31453178 {" -Cd" , " --cpu-mask-draft" }, " M" ,
31463179 " Draft model CPU affinity mask. Complements cpu-range-draft (default: same as --cpu-mask)" ,
0 commit comments