@@ -167,6 +167,36 @@ static std::vector<ggml_backend_dev_t> parse_devices_arg(const std::string & val
167167 return devices;
168168}
169169
170+ static std::vector<ggml_backend_dev_t > register_rpc_device_list (const std::string & servers) {
171+ auto rpc_servers = string_split<std::string>(servers, ' ,' );
172+ if (rpc_servers.empty ()) {
173+ throw std::invalid_argument (" no RPC servers specified" );
174+ }
175+
176+ auto * rpc_reg = ggml_backend_reg_by_name (" RPC" );
177+ if (!rpc_reg) {
178+ throw std::invalid_argument (" failed to find RPC backend" );
179+ }
180+
181+ using add_rpc_device_fn = ggml_backend_dev_t (*)(const char * endpoint);
182+ auto * ggml_backend_rpc_add_device_fn = (add_rpc_device_fn) ggml_backend_reg_get_proc_address (rpc_reg, " ggml_backend_rpc_add_device" );
183+ if (!ggml_backend_rpc_add_device_fn) {
184+ throw std::invalid_argument (" failed to find RPC device add function" );
185+ }
186+
187+ std::vector<ggml_backend_dev_t > devices;
188+ for (const auto & server : rpc_servers) {
189+ ggml_backend_dev_t dev = ggml_backend_rpc_add_device_fn (server.c_str ());
190+ if (!dev) {
191+ throw std::invalid_argument (string_format (" failed to add RPC device for server '%s'" , server.c_str ()));
192+ }
193+ ggml_backend_device_register (dev);
194+ devices.push_back (dev);
195+ }
196+
197+ return devices;
198+ }
199+
170200[[noreturn]] static void print_available_devices_and_exit () {
171201 std::vector<ggml_backend_dev_t > devices;
172202 for (size_t i = 0 ; i < ggml_backend_dev_count (); ++i) {
@@ -332,6 +362,7 @@ struct cmd_params {
332362 std::vector<int > n_gpu_layers;
333363 std::vector<int > n_cpu_moe;
334364 std::vector<std::string> rpc_servers;
365+ std::vector<std::vector<ggml_backend_dev_t >> rpc_device_sets;
335366 std::vector<llama_split_mode> split_mode;
336367 std::vector<int > main_gpu;
337368 std::vector<bool > no_kv_offload;
@@ -370,6 +401,7 @@ static const cmd_params cmd_params_defaults = {
370401 /* n_gpu_layers */ { 99 },
371402 /* n_cpu_moe */ { 0 },
372403 /* rpc_servers */ { " " },
404+ /* rpc_device_sets */ { std::vector<ggml_backend_dev_t >() },
373405 /* split_mode */ { LLAMA_SPLIT_MODE_LAYER },
374406 /* main_gpu */ { 0 },
375407 /* no_kv_offload */ { false },
@@ -684,7 +716,16 @@ static cmd_params parse_cmd_params(int argc, char ** argv) {
684716 invalid_param = true ;
685717 break ;
686718 }
687- params.rpc_servers .push_back (argv[i]);
719+ ggml_backend_load_all ();
720+ try {
721+ auto devices = register_rpc_device_list (argv[i]);
722+ params.rpc_servers .push_back (argv[i]);
723+ params.rpc_device_sets .push_back (devices);
724+ } catch (const std::exception & e) {
725+ fprintf (stderr, " error: %s\n " , e.what ());
726+ invalid_param = true ;
727+ break ;
728+ }
688729 } else if (arg == " -sm" || arg == " --split-mode" ) {
689730 if (++i >= argc) {
690731 invalid_param = true ;
@@ -962,6 +1003,12 @@ static cmd_params parse_cmd_params(int argc, char ** argv) {
9621003 if (params.rpc_servers .empty ()) {
9631004 params.rpc_servers = cmd_params_defaults.rpc_servers ;
9641005 }
1006+ if (params.rpc_device_sets .empty ()) {
1007+ params.rpc_device_sets = cmd_params_defaults.rpc_device_sets ;
1008+ }
1009+ if (params.rpc_device_sets .size () < params.rpc_servers .size ()) {
1010+ params.rpc_device_sets .resize (params.rpc_servers .size ());
1011+ }
9651012 if (params.split_mode .empty ()) {
9661013 params.split_mode = cmd_params_defaults.split_mode ;
9671014 }
@@ -1024,6 +1071,7 @@ struct cmd_params_instance {
10241071 int n_gpu_layers;
10251072 int n_cpu_moe;
10261073 std::string rpc_servers_str;
1074+ std::vector<ggml_backend_dev_t > rpc_devices;
10271075 llama_split_mode split_mode;
10281076 int main_gpu;
10291077 bool no_kv_offload;
@@ -1041,57 +1089,24 @@ struct cmd_params_instance {
10411089 mparams.n_gpu_layers = n_gpu_layers;
10421090 if (!devices.empty ()) {
10431091 mparams.devices = const_cast <ggml_backend_dev_t *>(devices.data ());
1044- } else if (!rpc_servers_str.empty ()) {
1045- auto rpc_servers = string_split<std::string>(rpc_servers_str, ' ,' );
1046-
1047- // add RPC devices
1048- if (!rpc_servers.empty ()) {
1049- ggml_backend_reg_t rpc_reg = ggml_backend_reg_by_name (" RPC" );
1050- if (!rpc_reg) {
1051- fprintf (stderr, " %s: failed to find RPC backend\n " , __func__);
1052- exit (1 );
1053- }
1092+ } else if (!rpc_devices.empty ()) {
1093+ static std::vector<ggml_backend_dev_t > merged_devices;
1094+ merged_devices.clear ();
1095+ merged_devices.insert (merged_devices.end (), rpc_devices.begin (), rpc_devices.end ());
10541096
1055- typedef ggml_backend_dev_t (*ggml_backend_rpc_add_device_t )(const char * endpoint);
1056- ggml_backend_rpc_add_device_t ggml_backend_rpc_add_device_fn = (ggml_backend_rpc_add_device_t ) ggml_backend_reg_get_proc_address (rpc_reg, " ggml_backend_rpc_add_device" );
1057- if (!ggml_backend_rpc_add_device_fn) {
1058- fprintf (stderr, " %s: failed to find RPC device add function\n " , __func__);
1059- exit (1 );
1060- }
1061- static std::vector<ggml_backend_dev_t > rpc_devices;
1062- rpc_devices.clear ();
1063- // RPC devices should always come first for performance reasons
1064- for (const std::string & server : rpc_servers) {
1065- ggml_backend_dev_t dev = ggml_backend_rpc_add_device_fn (server.c_str ());
1066- if (dev) {
1067- rpc_devices.push_back (dev);
1068- } else {
1069- fprintf (stderr, " %s: failed to add RPC device for server '%s'\n " , __func__, server.c_str ());
1070- exit (1 );
1071- }
1097+ for (size_t i = 0 ; i < ggml_backend_dev_count (); ++i) {
1098+ ggml_backend_dev_t dev = ggml_backend_dev_get (i);
1099+ auto dev_type = ggml_backend_dev_type (dev);
1100+ if (dev_type == GGML_BACKEND_DEVICE_TYPE_CPU || dev_type == GGML_BACKEND_DEVICE_TYPE_ACCEL) {
1101+ continue ;
10721102 }
1073- // FIXME: use llama.cpp device selection logic
1074- // add local GPU devices if any
1075- for (size_t i = 0 ; i < ggml_backend_dev_count (); ++i) {
1076- ggml_backend_dev_t dev = ggml_backend_dev_get (i);
1077- switch (ggml_backend_dev_type (dev)) {
1078- case GGML_BACKEND_DEVICE_TYPE_CPU:
1079- case GGML_BACKEND_DEVICE_TYPE_ACCEL:
1080- // skip CPU backends since they are handled separately
1081- break ;
1082-
1083- case GGML_BACKEND_DEVICE_TYPE_GPU:
1084- rpc_devices.push_back (dev);
1085- break ;
1086-
1087- case GGML_BACKEND_DEVICE_TYPE_IGPU:
1088- // iGPUs are not used when there are RPC servers
1089- break ;
1090- }
1103+ if (std::find (merged_devices.begin (), merged_devices.end (), dev) == merged_devices.end ()) {
1104+ merged_devices.push_back (dev);
10911105 }
1092- rpc_devices.push_back (nullptr );
1093- mparams.devices = rpc_devices.data ();
10941106 }
1107+
1108+ merged_devices.push_back (nullptr );
1109+ mparams.devices = merged_devices.data ();
10951110 }
10961111 mparams.split_mode = split_mode;
10971112 mparams.main_gpu = main_gpu;
@@ -1139,7 +1154,7 @@ struct cmd_params_instance {
11391154
11401155 bool equal_mparams (const cmd_params_instance & other) const {
11411156 return model == other.model && n_gpu_layers == other.n_gpu_layers && n_cpu_moe == other.n_cpu_moe &&
1142- rpc_servers_str == other.rpc_servers_str && split_mode == other.split_mode &&
1157+ rpc_servers_str == other.rpc_servers_str && rpc_devices == other. rpc_devices && split_mode == other.split_mode &&
11431158 main_gpu == other.main_gpu && use_mmap == other.use_mmap && tensor_split == other.tensor_split &&
11441159 devices == other.devices &&
11451160 vec_tensor_buft_override_equal (tensor_buft_overrides, other.tensor_buft_overrides );
@@ -1171,7 +1186,7 @@ static std::vector<cmd_params_instance> get_cmd_params_instances(const cmd_param
11711186 for (const auto & m : params.model )
11721187 for (const auto & nl : params.n_gpu_layers )
11731188 for (const auto & ncmoe : params.n_cpu_moe )
1174- for (const auto & rpc : params.rpc_servers )
1189+ for (size_t rpc_idx = 0 ; rpc_idx < params.rpc_servers . size (); ++rpc_idx )
11751190 for (const auto & sm : params.split_mode )
11761191 for (const auto & mg : params.main_gpu )
11771192 for (const auto & devs : params.devices )
@@ -1191,6 +1206,9 @@ static std::vector<cmd_params_instance> get_cmd_params_instances(const cmd_param
11911206 for (const auto & cs : params.cpu_strict )
11921207 for (const auto & nd : params.n_depth )
11931208 for (const auto & pl : params.poll ) {
1209+ const auto & rpc = params.rpc_servers [rpc_idx];
1210+ const auto & rpc_set = params.rpc_device_sets [rpc_idx];
1211+
11941212 for (const auto & n_prompt : params.n_prompt ) {
11951213 if (n_prompt == 0 ) {
11961214 continue ;
@@ -1211,6 +1229,7 @@ static std::vector<cmd_params_instance> get_cmd_params_instances(const cmd_param
12111229 /* .n_gpu_layers = */ nl,
12121230 /* .n_cpu_moe = */ ncmoe,
12131231 /* .rpc_servers = */ rpc,
1232+ /* .rpc_devices = */ rpc_set,
12141233 /* .split_mode = */ sm,
12151234 /* .main_gpu = */ mg,
12161235 /* .no_kv_offload= */ nkvo,
@@ -1245,6 +1264,7 @@ static std::vector<cmd_params_instance> get_cmd_params_instances(const cmd_param
12451264 /* .n_gpu_layers = */ nl,
12461265 /* .n_cpu_moe = */ ncmoe,
12471266 /* .rpc_servers = */ rpc,
1267+ /* .rpc_devices = */ rpc_set,
12481268 /* .split_mode = */ sm,
12491269 /* .main_gpu = */ mg,
12501270 /* .no_kv_offload= */ nkvo,
@@ -1279,6 +1299,7 @@ static std::vector<cmd_params_instance> get_cmd_params_instances(const cmd_param
12791299 /* .n_gpu_layers = */ nl,
12801300 /* .n_cpu_moe = */ ncmoe,
12811301 /* .rpc_servers = */ rpc,
1302+ /* .rpc_devices = */ rpc_set,
12821303 /* .split_mode = */ sm,
12831304 /* .main_gpu = */ mg,
12841305 /* .no_kv_offload= */ nkvo,
0 commit comments