@@ -749,6 +749,41 @@ std::pair<long, std::vector<char>> common_remote_get_content(const std::string &
749749// utils
750750//
751751
752+ // Helper function to parse tensor buffer override strings
753+ static void parse_tensor_buffer_overrides (const std::string & value, std::vector<llama_model_tensor_buft_override> & overrides) {
754+ static std::map<std::string, ggml_backend_buffer_type_t > buft_list;
755+ if (buft_list.empty ()) {
756+ for (size_t i = 0 ; i < ggml_backend_dev_count (); ++i) {
757+ auto * dev = ggml_backend_dev_get (i);
758+ auto * buft = ggml_backend_dev_buffer_type (dev);
759+ if (buft) {
760+ buft_list[ggml_backend_buft_name (buft)] = buft;
761+ }
762+ }
763+ }
764+
765+ for (const auto & override : string_split<std::string>(value, ' ,' )) {
766+ std::string::size_type pos = override .find (' =' );
767+ if (pos == std::string::npos) {
768+ throw std::invalid_argument (" invalid value" );
769+ }
770+ std::string tensor_name = override .substr (0 , pos);
771+ std::string buffer_type = override .substr (pos + 1 );
772+
773+ if (buft_list.find (buffer_type) == buft_list.end ()) {
774+ printf (" Available buffer types:\n " );
775+ for (const auto & it : buft_list) {
776+ printf (" %s\n " , ggml_backend_buft_name (it.second ));
777+ }
778+ throw std::invalid_argument (" unknown buffer type" );
779+ }
780+ // keep strings alive and avoid leaking memory by storing them in a static vector
781+ static std::list<std::string> buft_overrides;
782+ buft_overrides.push_back (tensor_name);
783+ overrides.push_back ({buft_overrides.back ().c_str (), buft_list.at (buffer_type)});
784+ }
785+ }
786+
752787struct handle_model_result {
753788 bool found_mmproj = false ;
754789 common_params_model mmproj;
@@ -2353,74 +2388,13 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
23532388 add_opt (common_arg (
23542389 {" --override-tensor" , " -ot" }, " <tensor name pattern>=<buffer type>,..." ,
23552390 " override tensor buffer type" , [](common_params & params, const std::string & value) {
2356- /* static */ std::map<std::string, ggml_backend_buffer_type_t > buft_list;
2357- if (buft_list.empty ()) {
2358- // enumerate all the devices and add their buffer types to the list
2359- for (size_t i = 0 ; i < ggml_backend_dev_count (); ++i) {
2360- auto * dev = ggml_backend_dev_get (i);
2361- auto * buft = ggml_backend_dev_buffer_type (dev);
2362- if (buft) {
2363- buft_list[ggml_backend_buft_name (buft)] = buft;
2364- }
2365- }
2366- }
2367-
2368- for (const auto & override : string_split<std::string>(value, ' ,' )) {
2369- std::string::size_type pos = override .find (' =' );
2370- if (pos == std::string::npos) {
2371- throw std::invalid_argument (" invalid value" );
2372- }
2373- std::string tensor_name = override .substr (0 , pos);
2374- std::string buffer_type = override .substr (pos + 1 );
2375-
2376- if (buft_list.find (buffer_type) == buft_list.end ()) {
2377- printf (" Available buffer types:\n " );
2378- for (const auto & it : buft_list) {
2379- printf (" %s\n " , ggml_backend_buft_name (it.second ));
2380- }
2381- throw std::invalid_argument (" unknown buffer type" );
2382- }
2383- // keep strings alive and avoid leaking memory by storing them in a static vector
2384- static std::list<std::string> buft_overrides;
2385- buft_overrides.push_back (tensor_name);
2386- params.tensor_buft_overrides .push_back ({buft_overrides.back ().c_str (), buft_list.at (buffer_type)});
2387- }
2391+ parse_tensor_buffer_overrides (value, params.tensor_buft_overrides );
23882392 }
23892393 ));
23902394 add_opt (common_arg (
23912395 {" --override-tensor-draft" }, " <tensor name pattern>=<buffer type>,..." ,
23922396 " override tensor buffer type for draft model" , [](common_params & params, const std::string & value) {
2393- /* static */ std::map<std::string, ggml_backend_buffer_type_t > buft_list;
2394- if (buft_list.empty ()) {
2395- for (size_t i = 0 ; i < ggml_backend_dev_count (); ++i) {
2396- auto * dev = ggml_backend_dev_get (i);
2397- auto * buft = ggml_backend_dev_buffer_type (dev);
2398- if (buft) {
2399- buft_list[ggml_backend_buft_name (buft)] = buft;
2400- }
2401- }
2402- }
2403-
2404- for (const auto & override : string_split<std::string>(value, ' ,' )) {
2405- std::string::size_type pos = override .find (' =' );
2406- if (pos == std::string::npos) {
2407- throw std::invalid_argument (" invalid value" );
2408- }
2409- std::string tensor_name = override .substr (0 , pos);
2410- std::string buffer_type = override .substr (pos + 1 );
2411-
2412- if (buft_list.find (buffer_type) == buft_list.end ()) {
2413- printf (" Available buffer types:\n " );
2414- for (const auto & it : buft_list) {
2415- printf (" %s\n " , ggml_backend_buft_name (it.second ));
2416- }
2417- throw std::invalid_argument (" unknown buffer type" );
2418- }
2419- // keep strings alive and avoid leaking memory by storing them in a static vector
2420- static std::list<std::string> buft_overrides;
2421- buft_overrides.push_back (tensor_name);
2422- params.speculative .tensor_buft_overrides .push_back ({buft_overrides.back ().c_str (), buft_list.at (buffer_type)});
2423- }
2397+ parse_tensor_buffer_overrides (value, params.speculative .tensor_buft_overrides );
24242398 }
24252399 ).set_examples ({LLAMA_EXAMPLE_SPECULATIVE, LLAMA_EXAMPLE_SERVER}));
24262400 add_opt (common_arg (
0 commit comments