22#define _SILENCE_CXX17_CODECVT_HEADER_DEPRECATION_WARNING
33#endif
44
5- #include " ggml-cpu.h"
6-
7- #ifdef GGML_USE_CUDA
8- #include " ggml-cuda.h"
9- #endif
10-
11- #ifdef GGML_USE_METAL
12- #include " ggml-metal.h"
13- #endif
14-
15- #ifdef GGML_USE_VULKAN
16- #include " ggml-vulkan.h"
17- #endif
18-
19- #ifdef GGML_USE_SYCL
20- #include " ggml-sycl.h"
21- #endif
22-
235#include " ggml-rpc.h"
246#ifdef _WIN32
257# define NOMINMAX
@@ -154,13 +136,15 @@ struct rpc_server_params {
154136 size_t backend_mem = 0 ;
155137 bool use_cache = false ;
156138 int n_threads = std::max(1U , std::thread::hardware_concurrency()/2 );
139+ std::string device;
157140};
158141
159142static void print_usage (int /* argc*/ , char ** argv, rpc_server_params params) {
160143 fprintf (stderr, " Usage: %s [options]\n\n " , argv[0 ]);
161144 fprintf (stderr, " options:\n " );
162145 fprintf (stderr, " -h, --help show this help message and exit\n " );
163146 fprintf (stderr, " -t, --threads number of threads for the CPU backend (default: %d)\n " , params.n_threads );
147+ fprintf (stderr, " -d DEV, --device device to use\n " );
164148 fprintf (stderr, " -H HOST, --host HOST host to bind to (default: %s)\n " , params.host .c_str ());
165149 fprintf (stderr, " -p PORT, --port PORT port to bind to (default: %d)\n " , params.port );
166150 fprintf (stderr, " -m MEM, --mem MEM backend memory size (in MB)\n " );
@@ -186,6 +170,22 @@ static bool rpc_server_params_parse(int argc, char ** argv, rpc_server_params &
186170 fprintf (stderr, " error: invalid number of threads: %d\n " , params.n_threads );
187171 return false ;
188172 }
173+ } else if (arg == " -d" || arg == " --device" ) {
174+ if (++i >= argc) {
175+ return false ;
176+ }
177+ params.device = argv[i];
178+ if (ggml_backend_dev_by_name (params.device .c_str ()) == nullptr ) {
179+ fprintf (stderr, " error: unknown device: %s\n " , params.device .c_str ());
180+ fprintf (stderr, " available devices:\n " );
181+ for (size_t i = 0 ; i < ggml_backend_dev_count (); i++) {
182+ auto * dev = ggml_backend_dev_get (i);
183+ size_t free, total;
184+ ggml_backend_dev_memory (dev, &free, &total);
185+ printf (" %s: %s (%zu MiB, %zu MiB free)\n " , ggml_backend_dev_name (dev), ggml_backend_dev_description (dev), total / 1024 / 1024 , free / 1024 / 1024 );
186+ }
187+ return false ;
188+ }
189189 } else if (arg == " -p" || arg == " --port" ) {
190190 if (++i >= argc) {
191191 return false ;
@@ -214,66 +214,37 @@ static bool rpc_server_params_parse(int argc, char ** argv, rpc_server_params &
214214}
215215
216216static ggml_backend_t create_backend (const rpc_server_params & params) {
217- ggml_backend_t backend = NULL ;
218- #ifdef GGML_USE_CUDA
219- fprintf (stderr, " %s: using CUDA backend\n " , __func__);
220- backend = ggml_backend_cuda_init (0 ); // init device 0
221- if (!backend) {
222- fprintf (stderr, " %s: ggml_backend_cuda_init() failed\n " , __func__);
223- }
224- #elif GGML_USE_METAL
225- fprintf (stderr, " %s: using Metal backend\n " , __func__);
226- backend = ggml_backend_metal_init ();
227- if (!backend) {
228- fprintf (stderr, " %s: ggml_backend_metal_init() failed\n " , __func__);
229- }
230- #elif GGML_USE_VULKAN
231- fprintf (stderr, " %s: using Vulkan backend\n " , __func__);
232- backend = ggml_backend_vk_init (0 ); // init device 0
233- if (!backend) {
234- fprintf (stderr, " %s: ggml_backend_vulkan_init() failed\n " , __func__);
235- }
236- #elif GGML_USE_SYCL
237- fprintf (stderr, " %s: using SYCL backend\n " , __func__);
238- backend = ggml_backend_sycl_init (0 ); // init device 0
239- if (!backend) {
240- fprintf (stderr, " %s: ggml_backend_sycl_init() failed\n " , __func__);
241- }
242- #endif
217+ ggml_backend_t backend = ggml_backend_init_by_type (GGML_BACKEND_DEVICE_TYPE_GPU, nullptr );
243218
244219 // if there aren't GPU Backends fallback to CPU backend
245220 if (!backend) {
246- fprintf (stderr, " %s: using CPU backend\n " , __func__);
247- backend = ggml_backend_cpu_init ();
248- ggml_backend_cpu_set_n_threads (backend, params.n_threads );
221+ backend = ggml_backend_init_by_type (GGML_BACKEND_DEVICE_TYPE_CPU, nullptr );
222+ }
223+
224+ fprintf (stderr, " %s: using %s backend\n " , __func__, ggml_backend_name (backend));
225+
226+ // set the number of threads
227+ ggml_backend_dev_t dev = ggml_backend_get_device (backend);
228+ ggml_backend_reg_t reg = dev ? ggml_backend_dev_backend_reg (dev) : nullptr ;
229+ if (reg) {
230+ auto ggml_backend_set_n_threads_fn = (ggml_backend_set_n_threads_t ) ggml_backend_reg_get_proc_address (reg, " ggml_backend_set_n_threads" );
231+ if (ggml_backend_set_n_threads_fn) {
232+ ggml_backend_set_n_threads_fn (backend, params.n_threads );
233+ }
249234 }
235+
250236 return backend;
251237}
252238
253- static void get_backend_memory (size_t * free_mem, size_t * total_mem) {
254- #ifdef GGML_USE_CUDA
255- ggml_backend_cuda_get_device_memory (0 , free_mem, total_mem);
256- #elif GGML_USE_VULKAN
257- ggml_backend_vk_get_device_memory (0 , free_mem, total_mem);
258- #elif GGML_USE_SYCL
259- ggml_backend_sycl_get_device_memory (0 , free_mem, total_mem);
260- #else
261- #ifdef _WIN32
262- MEMORYSTATUSEX status;
263- status.dwLength = sizeof (status);
264- GlobalMemoryStatusEx (&status);
265- *total_mem = status.ullTotalPhys ;
266- *free_mem = status.ullAvailPhys ;
267- #else
268- long pages = sysconf (_SC_PHYS_PAGES);
269- long page_size = sysconf (_SC_PAGE_SIZE);
270- *total_mem = pages * page_size;
271- *free_mem = *total_mem;
272- #endif
273- #endif
239+ static void get_backend_memory (ggml_backend_t backend, size_t * free_mem, size_t * total_mem) {
240+ ggml_backend_dev_t dev = ggml_backend_get_device (backend);
241+ GGML_ASSERT (dev != nullptr );
242+ ggml_backend_dev_memory (dev, free_mem, total_mem);
274243}
275244
276245int main (int argc, char * argv[]) {
246+ ggml_backend_load_all ();
247+
277248 rpc_server_params params;
278249 if (!rpc_server_params_parse (argc, argv, params)) {
279250 fprintf (stderr, " Invalid parameters\n " );
@@ -301,7 +272,7 @@ int main(int argc, char * argv[]) {
301272 free_mem = params.backend_mem ;
302273 total_mem = params.backend_mem ;
303274 } else {
304- get_backend_memory (&free_mem, &total_mem);
275+ get_backend_memory (backend, &free_mem, &total_mem);
305276 }
306277 const char * cache_dir = nullptr ;
307278 std::string cache_dir_str;
@@ -320,7 +291,21 @@ int main(int argc, char * argv[]) {
320291 printf (" endpoint : %s\n " , endpoint.c_str ());
321292 printf (" local cache : %s\n " , cache_dir ? cache_dir : " n/a" );
322293 printf (" backend memory : %zu MB\n " , free_mem / (1024 * 1024 ));
323- ggml_backend_rpc_start_server (backend, endpoint.c_str (), cache_dir, free_mem, total_mem);
294+
295+ ggml_backend_reg_t reg = ggml_backend_reg_by_name (" RPC" );
296+ if (!reg) {
297+ fprintf (stderr, " Failed to find RPC backend\n " );
298+ return 1 ;
299+ }
300+
301+ auto start_server_fn = (decltype (ggml_backend_rpc_start_server)*) ggml_backend_reg_get_proc_address (reg, " ggml_backend_rpc_start_server" );
302+ if (!start_server_fn) {
303+ fprintf (stderr, " Failed to obtain RPC backend start server function\n " );
304+ return 1 ;
305+ }
306+
307+ start_server_fn (backend, endpoint.c_str (), cache_dir, free_mem, total_mem);
308+
324309 ggml_backend_free (backend);
325310 return 0 ;
326311}
0 commit comments