22#define _SILENCE_CXX17_CODECVT_HEADER_DEPRECATION_WARNING
33#endif
44
5- #include " ggml-cpu.h"
6-
7- #ifdef GGML_USE_CUDA
8- #include " ggml-cuda.h"
9- #endif
10-
11- #ifdef GGML_USE_METAL
12- #include " ggml-metal.h"
13- #endif
14-
15- #ifdef GGML_USE_VULKAN
16- #include " ggml-vulkan.h"
17- #endif
18-
19- #ifdef GGML_USE_SYCL
20- #include " ggml-sycl.h"
21- #endif
22-
235#include " ggml-rpc.h"
246#ifdef _WIN32
257# define NOMINMAX
@@ -154,13 +136,15 @@ struct rpc_server_params {
154136 size_t backend_mem = 0 ;
155137 bool use_cache = false ;
156138 int n_threads = std::max(1U , std::thread::hardware_concurrency()/2 );
139+ std::string device;
157140};
158141
159142static void print_usage (int /* argc*/ , char ** argv, rpc_server_params params) {
160143 fprintf (stderr, " Usage: %s [options]\n\n " , argv[0 ]);
161144 fprintf (stderr, " options:\n " );
162145 fprintf (stderr, " -h, --help show this help message and exit\n " );
163146 fprintf (stderr, " -t, --threads number of threads for the CPU backend (default: %d)\n " , params.n_threads );
147+ fprintf (stderr, " -d DEV, --device device to use\n " );
164148 fprintf (stderr, " -H HOST, --host HOST host to bind to (default: %s)\n " , params.host .c_str ());
165149 fprintf (stderr, " -p PORT, --port PORT port to bind to (default: %d)\n " , params.port );
166150 fprintf (stderr, " -m MEM, --mem MEM backend memory size (in MB)\n " );
@@ -186,6 +170,22 @@ static bool rpc_server_params_parse(int argc, char ** argv, rpc_server_params &
186170 fprintf (stderr, " error: invalid number of threads: %d\n " , params.n_threads );
187171 return false ;
188172 }
173+ } else if (arg == " -d" || arg == " --device" ) {
174+ if (++i >= argc) {
175+ return false ;
176+ }
177+ params.device = argv[i];
178+ if (ggml_backend_dev_by_name (params.device .c_str ()) == nullptr ) {
179+ fprintf (stderr, " error: unknown device: %s\n " , params.device .c_str ());
180+ fprintf (stderr, " available devices:\n " );
181+ for (size_t i = 0 ; i < ggml_backend_dev_count (); i++) {
182+ auto * dev = ggml_backend_dev_get (i);
183+ size_t free, total;
184+ ggml_backend_dev_memory (dev, &free, &total);
185+ printf (" %s: %s (%zu MiB, %zu MiB free)\n " , ggml_backend_dev_name (dev), ggml_backend_dev_description (dev), total / 1024 / 1024 , free / 1024 / 1024 );
186+ }
187+ return false ;
188+ }
189189 } else if (arg == " -p" || arg == " --port" ) {
190190 if (++i >= argc) {
191191 return false ;
@@ -214,66 +214,53 @@ static bool rpc_server_params_parse(int argc, char ** argv, rpc_server_params &
214214}
215215
216216static ggml_backend_t create_backend (const rpc_server_params & params) {
217- ggml_backend_t backend = NULL ;
218- #ifdef GGML_USE_CUDA
219- fprintf (stderr, " %s: using CUDA backend\n " , __func__);
220- backend = ggml_backend_cuda_init (0 ); // init device 0
221- if (!backend) {
222- fprintf (stderr, " %s: ggml_backend_cuda_init() failed\n " , __func__);
223- }
224- #elif GGML_USE_METAL
225- fprintf (stderr, " %s: using Metal backend\n " , __func__);
226- backend = ggml_backend_metal_init ();
227- if (!backend) {
228- fprintf (stderr, " %s: ggml_backend_metal_init() failed\n " , __func__);
217+ ggml_backend_t backend = nullptr ;
218+
219+ if (!params.device .empty ()) {
220+ ggml_backend_dev_t dev = ggml_backend_dev_by_name (params.device .c_str ());
221+ if (dev) {
222+ backend = ggml_backend_dev_init (dev, nullptr );
223+ if (!backend) {
224+ fprintf (stderr, " Failed to create backend for device %s\n " , params.device .c_str ());
225+ return nullptr ;
226+ }
227+ }
229228 }
230- #elif GGML_USE_VULKAN
231- fprintf (stderr, " %s: using Vulkan backend\n " , __func__);
232- backend = ggml_backend_vk_init (0 ); // init device 0
229+
230+ // try to initialize a GPU backend first
233231 if (!backend) {
234- fprintf (stderr, " %s: ggml_backend_vulkan_init() failed \n " , __func__ );
232+ backend = ggml_backend_init_by_type (GGML_BACKEND_DEVICE_TYPE_GPU, nullptr );
235233 }
236- #elif GGML_USE_SYCL
237- fprintf (stderr, " %s: using SYCL backend\n " , __func__);
238- backend = ggml_backend_sycl_init (0 ); // init device 0
234+
235+ // if there aren't GPU backends fallback to CPU backend
239236 if (!backend) {
240- fprintf (stderr, " %s: ggml_backend_sycl_init() failed \n " , __func__ );
237+ backend = ggml_backend_init_by_type (GGML_BACKEND_DEVICE_TYPE_CPU, nullptr );
241238 }
242- #endif
243239
244- // if there aren't GPU Backends fallback to CPU backend
245- if (!backend) {
246- fprintf (stderr, " %s: using CPU backend\n " , __func__);
247- backend = ggml_backend_cpu_init ();
248- ggml_backend_cpu_set_n_threads (backend, params.n_threads );
240+ fprintf (stderr, " %s: using %s backend\n " , __func__, ggml_backend_name (backend));
241+
242+ // set the number of threads
243+ ggml_backend_dev_t dev = ggml_backend_get_device (backend);
244+ ggml_backend_reg_t reg = dev ? ggml_backend_dev_backend_reg (dev) : nullptr ;
245+ if (reg) {
246+ auto ggml_backend_set_n_threads_fn = (ggml_backend_set_n_threads_t ) ggml_backend_reg_get_proc_address (reg, " ggml_backend_set_n_threads" );
247+ if (ggml_backend_set_n_threads_fn) {
248+ ggml_backend_set_n_threads_fn (backend, params.n_threads );
249+ }
249250 }
251+
250252 return backend;
251253}
252254
253- static void get_backend_memory (size_t * free_mem, size_t * total_mem) {
254- #ifdef GGML_USE_CUDA
255- ggml_backend_cuda_get_device_memory (0 , free_mem, total_mem);
256- #elif GGML_USE_VULKAN
257- ggml_backend_vk_get_device_memory (0 , free_mem, total_mem);
258- #elif GGML_USE_SYCL
259- ggml_backend_sycl_get_device_memory (0 , free_mem, total_mem);
260- #else
261- #ifdef _WIN32
262- MEMORYSTATUSEX status;
263- status.dwLength = sizeof (status);
264- GlobalMemoryStatusEx (&status);
265- *total_mem = status.ullTotalPhys ;
266- *free_mem = status.ullAvailPhys ;
267- #else
268- long pages = sysconf (_SC_PHYS_PAGES);
269- long page_size = sysconf (_SC_PAGE_SIZE);
270- *total_mem = pages * page_size;
271- *free_mem = *total_mem;
272- #endif
273- #endif
255+ static void get_backend_memory (ggml_backend_t backend, size_t * free_mem, size_t * total_mem) {
256+ ggml_backend_dev_t dev = ggml_backend_get_device (backend);
257+ GGML_ASSERT (dev != nullptr );
258+ ggml_backend_dev_memory (dev, free_mem, total_mem);
274259}
275260
276261int main (int argc, char * argv[]) {
262+ ggml_backend_load_all ();
263+
277264 rpc_server_params params;
278265 if (!rpc_server_params_parse (argc, argv, params)) {
279266 fprintf (stderr, " Invalid parameters\n " );
@@ -301,7 +288,7 @@ int main(int argc, char * argv[]) {
301288 free_mem = params.backend_mem ;
302289 total_mem = params.backend_mem ;
303290 } else {
304- get_backend_memory (&free_mem, &total_mem);
291+ get_backend_memory (backend, &free_mem, &total_mem);
305292 }
306293 const char * cache_dir = nullptr ;
307294 std::string cache_dir_str;
@@ -313,14 +300,21 @@ int main(int argc, char * argv[]) {
313300 }
314301 cache_dir = cache_dir_str.c_str ();
315302 }
316- printf (" Starting RPC server v%d.%d.%d\n " ,
317- RPC_PROTO_MAJOR_VERSION,
318- RPC_PROTO_MINOR_VERSION,
319- RPC_PROTO_PATCH_VERSION);
320- printf (" endpoint : %s\n " , endpoint.c_str ());
321- printf (" local cache : %s\n " , cache_dir ? cache_dir : " n/a" );
322- printf (" backend memory : %zu MB\n " , free_mem / (1024 * 1024 ));
323- ggml_backend_rpc_start_server (backend, endpoint.c_str (), cache_dir, free_mem, total_mem);
303+
304+ ggml_backend_reg_t reg = ggml_backend_reg_by_name (" RPC" );
305+ if (!reg) {
306+ fprintf (stderr, " Failed to find RPC backend\n " );
307+ return 1 ;
308+ }
309+
310+ auto start_server_fn = (decltype (ggml_backend_rpc_start_server)*) ggml_backend_reg_get_proc_address (reg, " ggml_backend_rpc_start_server" );
311+ if (!start_server_fn) {
312+ fprintf (stderr, " Failed to obtain RPC backend start server function\n " );
313+ return 1 ;
314+ }
315+
316+ start_server_fn (backend, endpoint.c_str (), cache_dir, free_mem, total_mem);
317+
324318 ggml_backend_free (backend);
325319 return 0 ;
326320}
0 commit comments