|
17 | 17 | #include <string> |
18 | 18 | #include <thread> |
19 | 19 | #include <vector> |
| 20 | +#include <unordered_set> |
20 | 21 |
|
21 | 22 | #include "common.h" |
22 | 23 | #include "ggml.h" |
@@ -184,13 +185,26 @@ static std::vector<ggml_backend_dev_t> register_rpc_device_list(const std::strin |
184 | 185 | throw std::invalid_argument("failed to find RPC device add function"); |
185 | 186 | } |
186 | 187 |
|
| 188 | + static std::unordered_set<std::string> registered; |
187 | 189 | std::vector<ggml_backend_dev_t> devices; |
188 | 190 | for (const auto & server : rpc_servers) { |
189 | | - ggml_backend_dev_t dev = ggml_backend_rpc_add_device_fn(server.c_str()); |
| 191 | + ggml_backend_dev_t dev = nullptr; |
| 192 | + |
| 193 | + std::string name = string_format("RPC[%s]", server.c_str()); |
| 194 | + |
| 195 | + if (registered.find(server) != registered.end()) { |
| 196 | + dev = ggml_backend_dev_by_name(name.c_str()); |
| 197 | + } |
| 198 | + |
190 | 199 | if (!dev) { |
191 | | - throw std::invalid_argument(string_format("failed to add RPC device for server '%s'", server.c_str())); |
| 200 | + dev = ggml_backend_rpc_add_device_fn(server.c_str()); |
| 201 | + if (!dev) { |
| 202 | + throw std::invalid_argument(string_format("failed to add RPC device for server '%s'", server.c_str())); |
| 203 | + } |
| 204 | + ggml_backend_device_register(dev); |
| 205 | + registered.insert(server); |
192 | 206 | } |
193 | | - ggml_backend_device_register(dev); |
| 207 | + |
194 | 208 | devices.push_back(dev); |
195 | 209 | } |
196 | 210 |
|
@@ -382,6 +396,7 @@ struct cmd_params { |
382 | 396 | bool no_warmup; |
383 | 397 | output_formats output_format; |
384 | 398 | output_formats output_format_stderr; |
| 399 | + bool list_devices; |
385 | 400 | }; |
386 | 401 |
|
387 | 402 | static const cmd_params cmd_params_defaults = { |
@@ -421,6 +436,7 @@ static const cmd_params cmd_params_defaults = { |
421 | 436 | /* no_warmup */ false, |
422 | 437 | /* output_format */ MARKDOWN, |
423 | 438 | /* output_format_stderr */ NONE, |
| 439 | + /* list_devices */ false, |
424 | 440 | }; |
425 | 441 |
|
426 | 442 | static void print_usage(int /* argc */, char ** argv) { |
@@ -545,6 +561,7 @@ static cmd_params parse_cmd_params(int argc, char ** argv) { |
545 | 561 | params.delay = cmd_params_defaults.delay; |
546 | 562 | params.progress = cmd_params_defaults.progress; |
547 | 563 | params.no_warmup = cmd_params_defaults.no_warmup; |
| 564 | + params.list_devices = cmd_params_defaults.list_devices; |
548 | 565 |
|
549 | 566 | for (int i = 1; i < argc; i++) { |
550 | 567 | arg = argv[i]; |
@@ -668,7 +685,7 @@ static cmd_params parse_cmd_params(int argc, char ** argv) { |
668 | 685 | break; |
669 | 686 | } |
670 | 687 | } else if (arg == "--list-devices") { |
671 | | - print_available_devices_and_exit(); |
| 688 | + params.list_devices = true; |
672 | 689 | } else if (arg == "-t" || arg == "--threads") { |
673 | 690 | if (++i >= argc) { |
674 | 691 | invalid_param = true; |
@@ -1006,9 +1023,6 @@ static cmd_params parse_cmd_params(int argc, char ** argv) { |
1006 | 1023 | if (params.rpc_device_sets.empty()) { |
1007 | 1024 | params.rpc_device_sets = cmd_params_defaults.rpc_device_sets; |
1008 | 1025 | } |
1009 | | - if (params.rpc_device_sets.size() < params.rpc_servers.size()) { |
1010 | | - params.rpc_device_sets.resize(params.rpc_servers.size()); |
1011 | | - } |
1012 | 1026 | if (params.split_mode.empty()) { |
1013 | 1027 | params.split_mode = cmd_params_defaults.split_mode; |
1014 | 1028 | } |
@@ -2037,6 +2051,20 @@ int main(int argc, char ** argv) { |
2037 | 2051 |
|
2038 | 2052 | cmd_params params = parse_cmd_params(argc, argv); |
2039 | 2053 |
|
| 2054 | + if (params.list_devices) { |
| 2055 | + ggml_backend_load_all(); |
| 2056 | + for (const auto & rpc : params.rpc_servers) { |
| 2057 | + if (!rpc.empty()) { |
| 2058 | + try { |
| 2059 | + register_rpc_device_list(rpc); |
| 2060 | + } catch (const std::exception & e) { |
| 2061 | + fprintf(stderr, "warning: %s\n", e.what()); |
| 2062 | + } |
| 2063 | + } |
| 2064 | + } |
| 2065 | + print_available_devices_and_exit(); |
| 2066 | + } |
| 2067 | + |
2040 | 2068 | auto * cpu_dev = ggml_backend_dev_by_type(GGML_BACKEND_DEVICE_TYPE_CPU); |
2041 | 2069 | if (!cpu_dev) { |
2042 | 2070 | fprintf(stderr, "%s: error: CPU backend is not loaded\n", __func__); |
|
0 commit comments