Skip to content

Commit 072ab9c

Browse files
Copilotggerganov
andcommitted
Extract tensor override parsing logic to common function (addresses @slaren's feedback)
Co-authored-by: ggerganov <[email protected]>
1 parent c87f4b0 commit 072ab9c

File tree

1 file changed

+37
-63
lines changed

1 file changed

+37
-63
lines changed

common/arg.cpp

Lines changed: 37 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -749,6 +749,41 @@ std::pair<long, std::vector<char>> common_remote_get_content(const std::string &
749749
// utils
750750
//
751751

752+
// Helper function to parse tensor buffer override strings
753+
static void parse_tensor_buffer_overrides(const std::string & value, std::vector<llama_model_tensor_buft_override> & overrides) {
754+
static std::map<std::string, ggml_backend_buffer_type_t> buft_list;
755+
if (buft_list.empty()) {
756+
for (size_t i = 0; i < ggml_backend_dev_count(); ++i) {
757+
auto * dev = ggml_backend_dev_get(i);
758+
auto * buft = ggml_backend_dev_buffer_type(dev);
759+
if (buft) {
760+
buft_list[ggml_backend_buft_name(buft)] = buft;
761+
}
762+
}
763+
}
764+
765+
for (const auto & override : string_split<std::string>(value, ',')) {
766+
std::string::size_type pos = override.find('=');
767+
if (pos == std::string::npos) {
768+
throw std::invalid_argument("invalid value");
769+
}
770+
std::string tensor_name = override.substr(0, pos);
771+
std::string buffer_type = override.substr(pos + 1);
772+
773+
if (buft_list.find(buffer_type) == buft_list.end()) {
774+
printf("Available buffer types:\n");
775+
for (const auto & it : buft_list) {
776+
printf(" %s\n", ggml_backend_buft_name(it.second));
777+
}
778+
throw std::invalid_argument("unknown buffer type");
779+
}
780+
// keep strings alive and avoid leaking memory by storing them in a static vector
781+
static std::list<std::string> buft_overrides;
782+
buft_overrides.push_back(tensor_name);
783+
overrides.push_back({buft_overrides.back().c_str(), buft_list.at(buffer_type)});
784+
}
785+
}
786+
752787
struct handle_model_result {
753788
bool found_mmproj = false;
754789
common_params_model mmproj;
@@ -2353,74 +2388,13 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
23532388
add_opt(common_arg(
23542389
{"--override-tensor", "-ot"}, "<tensor name pattern>=<buffer type>,...",
23552390
"override tensor buffer type", [](common_params & params, const std::string & value) {
2356-
/* static */ std::map<std::string, ggml_backend_buffer_type_t> buft_list;
2357-
if (buft_list.empty()) {
2358-
// enumerate all the devices and add their buffer types to the list
2359-
for (size_t i = 0; i < ggml_backend_dev_count(); ++i) {
2360-
auto * dev = ggml_backend_dev_get(i);
2361-
auto * buft = ggml_backend_dev_buffer_type(dev);
2362-
if (buft) {
2363-
buft_list[ggml_backend_buft_name(buft)] = buft;
2364-
}
2365-
}
2366-
}
2367-
2368-
for (const auto & override : string_split<std::string>(value, ',')) {
2369-
std::string::size_type pos = override.find('=');
2370-
if (pos == std::string::npos) {
2371-
throw std::invalid_argument("invalid value");
2372-
}
2373-
std::string tensor_name = override.substr(0, pos);
2374-
std::string buffer_type = override.substr(pos + 1);
2375-
2376-
if (buft_list.find(buffer_type) == buft_list.end()) {
2377-
printf("Available buffer types:\n");
2378-
for (const auto & it : buft_list) {
2379-
printf(" %s\n", ggml_backend_buft_name(it.second));
2380-
}
2381-
throw std::invalid_argument("unknown buffer type");
2382-
}
2383-
// keep strings alive and avoid leaking memory by storing them in a static vector
2384-
static std::list<std::string> buft_overrides;
2385-
buft_overrides.push_back(tensor_name);
2386-
params.tensor_buft_overrides.push_back({buft_overrides.back().c_str(), buft_list.at(buffer_type)});
2387-
}
2391+
parse_tensor_buffer_overrides(value, params.tensor_buft_overrides);
23882392
}
23892393
));
23902394
add_opt(common_arg(
23912395
{"--override-tensor-draft"}, "<tensor name pattern>=<buffer type>,...",
23922396
"override tensor buffer type for draft model", [](common_params & params, const std::string & value) {
2393-
/* static */ std::map<std::string, ggml_backend_buffer_type_t> buft_list;
2394-
if (buft_list.empty()) {
2395-
for (size_t i = 0; i < ggml_backend_dev_count(); ++i) {
2396-
auto * dev = ggml_backend_dev_get(i);
2397-
auto * buft = ggml_backend_dev_buffer_type(dev);
2398-
if (buft) {
2399-
buft_list[ggml_backend_buft_name(buft)] = buft;
2400-
}
2401-
}
2402-
}
2403-
2404-
for (const auto & override : string_split<std::string>(value, ',')) {
2405-
std::string::size_type pos = override.find('=');
2406-
if (pos == std::string::npos) {
2407-
throw std::invalid_argument("invalid value");
2408-
}
2409-
std::string tensor_name = override.substr(0, pos);
2410-
std::string buffer_type = override.substr(pos + 1);
2411-
2412-
if (buft_list.find(buffer_type) == buft_list.end()) {
2413-
printf("Available buffer types:\n");
2414-
for (const auto & it : buft_list) {
2415-
printf(" %s\n", ggml_backend_buft_name(it.second));
2416-
}
2417-
throw std::invalid_argument("unknown buffer type");
2418-
}
2419-
// keep strings alive and avoid leaking memory by storing them in a static vector
2420-
static std::list<std::string> buft_overrides;
2421-
buft_overrides.push_back(tensor_name);
2422-
params.speculative.tensor_buft_overrides.push_back({buft_overrides.back().c_str(), buft_list.at(buffer_type)});
2423-
}
2397+
parse_tensor_buffer_overrides(value, params.speculative.tensor_buft_overrides);
24242398
}
24252399
).set_examples({LLAMA_EXAMPLE_SPECULATIVE, LLAMA_EXAMPLE_SERVER}));
24262400
add_opt(common_arg(

0 commit comments

Comments
 (0)