|
1 | 1 | #include "gguf.h" // for reading GGUF splits |
2 | 2 | #include "arg.h" |
3 | 3 |
|
| 4 | +#include "common.h" |
4 | 5 | #include "log.h" |
5 | 6 | #include "sampling.h" |
6 | 7 | #include "chat.h" |
@@ -848,6 +849,10 @@ static bool common_params_parse_ex(int argc, char ** argv, common_params_context |
848 | 849 | params.kv_overrides.back().key[0] = 0; |
849 | 850 | } |
850 | 851 |
|
| 852 | + if (!params.tensor_buft_overrides.empty()) { |
| 853 | + params.tensor_buft_overrides.push_back({nullptr, nullptr}); |
| 854 | + } |
| 855 | + |
851 | 856 | if (params.reranking && params.embedding) { |
852 | 857 | throw std::invalid_argument("error: either --embedding or --reranking can be specified, but not both"); |
853 | 858 | } |
@@ -2180,6 +2185,41 @@ common_params_context common_params_parser_init(common_params & params, llama_ex |
2180 | 2185 | exit(0); |
2181 | 2186 | } |
2182 | 2187 | )); |
| 2188 | + add_opt(common_arg( |
| 2189 | + {"--override-tensor", "-ot"}, "<tensor name pattern>=<buffer type>,...", |
| 2190 | + "override tensor buffer type", [](common_params & params, const std::string & value) { |
| 2191 | + /* static */ std::map<std::string, ggml_backend_buffer_type_t> buft_list; |
| 2192 | + if (buft_list.empty()) { |
| 2193 | + // enumerate all the devices and add their buffer types to the list |
| 2194 | + for (size_t i = 0; i < ggml_backend_dev_count(); ++i) { |
| 2195 | + auto * dev = ggml_backend_dev_get(i); |
| 2196 | + auto * buft = ggml_backend_dev_buffer_type(dev); |
| 2197 | + if (buft) { |
| 2198 | + buft_list[ggml_backend_buft_name(buft)] = buft; |
| 2199 | + } |
| 2200 | + } |
| 2201 | + } |
| 2202 | + |
| 2203 | + for (const auto & override : string_split<std::string>(value, ',')) { |
| 2204 | + std::string::size_type pos = override.find('='); |
| 2205 | + if (pos == std::string::npos) { |
| 2206 | + throw std::invalid_argument("invalid value"); |
| 2207 | + } |
| 2208 | + std::string tensor_name = override.substr(0, pos); |
| 2209 | + std::string buffer_type = override.substr(pos + 1); |
| 2210 | + |
| 2211 | + if (buft_list.find(buffer_type) == buft_list.end()) { |
| 2212 | + printf("Available buffer types:\n"); |
| 2213 | + for (const auto & it : buft_list) { |
| 2214 | + printf(" %s\n", ggml_backend_buft_name(it.second)); |
| 2215 | + } |
| 2216 | + throw std::invalid_argument("unknown buffer type"); |
| 2217 | + } |
| 2218 | + // FIXME: this leaks memory |
| 2219 | + params.tensor_buft_overrides.push_back({strdup(tensor_name.c_str()), buft_list.at(buffer_type)}); |
| 2220 | + } |
| 2221 | + } |
| 2222 | + )); |
2183 | 2223 | add_opt(common_arg( |
2184 | 2224 | {"-ngl", "--gpu-layers", "--n-gpu-layers"}, "N", |
2185 | 2225 | "number of layers to store in VRAM", |
|
0 commit comments