|
1 | 1 | #include "arg.h" |
2 | 2 |
|
| 3 | +#include "common.h" |
3 | 4 | #include "log.h" |
4 | 5 | #include "sampling.h" |
5 | 6 | #include "chat.h" |
@@ -322,6 +323,10 @@ static bool common_params_parse_ex(int argc, char ** argv, common_params_context |
322 | 323 | params.kv_overrides.back().key[0] = 0; |
323 | 324 | } |
324 | 325 |
|
| 326 | + if (!params.tensor_buft_overrides.empty()) { |
| 327 | + params.tensor_buft_overrides.push_back({nullptr, nullptr}); |
| 328 | + } |
| 329 | + |
325 | 330 | if (params.reranking && params.embedding) { |
326 | 331 | throw std::invalid_argument("error: either --embedding or --reranking can be specified, but not both"); |
327 | 332 | } |
@@ -1629,6 +1634,41 @@ common_params_context common_params_parser_init(common_params & params, llama_ex |
1629 | 1634 | exit(0); |
1630 | 1635 | } |
1631 | 1636 | )); |
| 1637 | + add_opt(common_arg( |
| 1638 | + {"--override-tensor", "-ot"}, "<tensor name pattern>=<buffer type>,...", |
| 1639 | + "override tensor buffer type", [](common_params & params, const std::string & value) { |
| 1640 | + /* static */ std::map<std::string, ggml_backend_buffer_type_t> buft_list; |
| 1641 | + if (buft_list.empty()) { |
| 1642 | + // enumerate all the devices and add their buffer types to the list |
| 1643 | + for (size_t i = 0; i < ggml_backend_dev_count(); ++i) { |
| 1644 | + auto * dev = ggml_backend_dev_get(i); |
| 1645 | + auto * buft = ggml_backend_dev_buffer_type(dev); |
| 1646 | + if (buft) { |
| 1647 | + buft_list[ggml_backend_buft_name(buft)] = buft; |
| 1648 | + } |
| 1649 | + } |
| 1650 | + } |
| 1651 | + |
| 1652 | + for (const auto & override : string_split<std::string>(value, ',')) { |
| 1653 | + std::string::size_type pos = override.find('='); |
| 1654 | + if (pos == std::string::npos) { |
| 1655 | + throw std::invalid_argument("invalid value"); |
| 1656 | + } |
| 1657 | + std::string tensor_name = override.substr(0, pos); |
| 1658 | + std::string buffer_type = override.substr(pos + 1); |
| 1659 | + |
| 1660 | + if (buft_list.find(buffer_type) == buft_list.end()) { |
| 1661 | + printf("Available buffer types:\n"); |
| 1662 | + for (const auto & it : buft_list) { |
| 1663 | + printf(" %s\n", ggml_backend_buft_name(it.second)); |
| 1664 | + } |
| 1665 | + throw std::invalid_argument("unknown buffer type"); |
| 1666 | + } |
| 1667 | + // FIXME: this leaks memory |
| 1668 | + params.tensor_buft_overrides.push_back({strdup(tensor_name.c_str()), buft_list.at(buffer_type)}); |
| 1669 | + } |
| 1670 | + } |
| 1671 | + )); |
1632 | 1672 | add_opt(common_arg( |
1633 | 1673 | {"-ngl", "--gpu-layers", "--n-gpu-layers"}, "N", |
1634 | 1674 | "number of layers to store in VRAM", |
|
0 commit comments