@@ -920,8 +920,8 @@ void ggml_backend_rpc_get_device_memory(const char * endpoint, uint32_t device,
920920
921921class rpc_server {
922922public:
923- rpc_server (std::vector<ggml_backend_t > backends, const char * cache_dir)
924- : backends(std::move(backends)), cache_dir(cache_dir) {
923+ rpc_server (std::vector<ggml_backend_t > backends, std::vector< size_t > max_mem, const char * cache_dir)
924+ : backends(std::move(backends)), max_mem(std::move(max_mem)), cache_dir(cache_dir) {
925925 }
926926 ~rpc_server ();
927927
@@ -939,6 +939,7 @@ class rpc_server {
939939 bool graph_compute (const std::vector<uint8_t > & input, rpc_msg_graph_compute_rsp & response);
940940 bool init_tensor (const rpc_msg_init_tensor_req & request);
941941 bool get_alloc_size (const rpc_msg_get_alloc_size_req & request, rpc_msg_get_alloc_size_rsp & response);
942+ bool get_device_memory (const rpc_msg_get_device_memory_req & request, rpc_msg_get_device_memory_rsp & response);
942943
943944private:
944945 bool get_cached_file (uint64_t hash, std::vector<uint8_t > & data);
@@ -950,6 +951,7 @@ class rpc_server {
950951
951952
952953 std::vector<ggml_backend_t > backends;
954+ std::vector<size_t > max_mem;
953955 const char * cache_dir;
954956 std::unordered_set<ggml_backend_buffer_t > buffers;
955957};
@@ -1458,15 +1460,39 @@ bool rpc_server::graph_compute(const std::vector<uint8_t> & input, rpc_msg_graph
14581460 return true ;
14591461}
14601462
1463+ static void rpc_dev_memory (ggml_backend_dev_t dev, size_t max_mem, size_t * free, size_t * total) {
1464+ // cap total and free memory to the user-specified max memory
1465+ size_t free_dev, total_dev;
1466+ ggml_backend_dev_memory (dev, &free_dev, &total_dev);
1467+ total_dev = (total_dev > max_mem) ? max_mem : total_dev;
1468+ free_dev = (free_dev > max_mem) ? max_mem : free_dev;
1469+ *free = free_dev;
1470+ *total = total_dev;
1471+ }
1472+
1473+ bool rpc_server::get_device_memory (const rpc_msg_get_device_memory_req & request, rpc_msg_get_device_memory_rsp & response) {
1474+ uint32_t dev_id = request.device ;
1475+ if (dev_id >= backends.size ()) {
1476+ return false ;
1477+ }
1478+ size_t free, total;
1479+ ggml_backend_dev_t dev = ggml_backend_get_device (backends[dev_id]);
1480+ rpc_dev_memory (dev, max_mem[dev_id], &free, &total);
1481+ response.free_mem = free;
1482+ response.total_mem = total;
1483+ LOG_DBG (" [%s] device: %u, free_mem: %" PRIu64 " , total_mem: %" PRIu64 " \n " , __func__, dev_id, response.free_mem , response.total_mem );
1484+ return true ;
1485+ }
1486+
14611487rpc_server::~rpc_server () {
14621488 for (auto buffer : buffers) {
14631489 ggml_backend_buffer_free (buffer);
14641490 }
14651491}
14661492
14671493static void rpc_serve_client (const std::vector<ggml_backend_t > & backends, const char * cache_dir,
1468- sockfd_t sockfd, const std::vector<size_t > & free_mem, const std::vector< size_t > & total_mem ) {
1469- rpc_server server (backends, cache_dir);
1494+ sockfd_t sockfd, const std::vector<size_t > & max_mem ) {
1495+ rpc_server server (backends, max_mem, cache_dir);
14701496 uint8_t cmd;
14711497 if (!recv_data (sockfd, &cmd, 1 )) {
14721498 return ;
@@ -1689,15 +1715,10 @@ static void rpc_serve_client(const std::vector<ggml_backend_t> & backends, const
16891715 if (!recv_msg (sockfd, &request, sizeof (request))) {
16901716 return ;
16911717 }
1692- auto dev_id = request. device ;
1693- if (dev_id >= backends. size ( )) {
1718+ rpc_msg_get_device_memory_rsp response ;
1719+ if (!server. get_device_memory (request, response )) {
16941720 return ;
16951721 }
1696- rpc_msg_get_device_memory_rsp response;
1697- response.free_mem = free_mem[dev_id];
1698- response.total_mem = total_mem[dev_id];
1699- LOG_DBG (" [get_device_mem] device: %u, free_mem: %" PRIu64 " , total_mem: %" PRIu64 " \n " , dev_id,
1700- response.free_mem , response.total_mem );
17011722 if (!send_msg (sockfd, &response, sizeof (response))) {
17021723 return ;
17031724 }
@@ -1713,14 +1734,13 @@ static void rpc_serve_client(const std::vector<ggml_backend_t> & backends, const
17131734
17141735void ggml_backend_rpc_start_server (const char * endpoint, const char * cache_dir,
17151736 size_t n_threads, size_t n_devices,
1716- ggml_backend_dev_t * devices, size_t * free_mem, size_t * total_mem ) {
1717- if (n_devices == 0 || devices == nullptr || free_mem == nullptr || total_mem == nullptr ) {
1737+ ggml_backend_dev_t * devices, size_t * max_mem ) {
1738+ if (n_devices == 0 || devices == nullptr || max_mem == nullptr ) {
17181739 fprintf (stderr, " Invalid arguments to ggml_backend_rpc_start_server\n " );
17191740 return ;
17201741 }
17211742 std::vector<ggml_backend_t > backends;
1722- std::vector<size_t > free_mem_vec (free_mem, free_mem + n_devices);
1723- std::vector<size_t > total_mem_vec (total_mem, total_mem + n_devices);
1743+ std::vector<size_t > max_mem_vec (max_mem, max_mem + n_devices);
17241744 printf (" Starting RPC server v%d.%d.%d\n " ,
17251745 RPC_PROTO_MAJOR_VERSION,
17261746 RPC_PROTO_MINOR_VERSION,
@@ -1730,8 +1750,10 @@ void ggml_backend_rpc_start_server(const char * endpoint, const char * cache_dir
17301750 printf (" Devices:\n " );
17311751 for (size_t i = 0 ; i < n_devices; i++) {
17321752 auto dev = devices[i];
1753+ size_t free, total;
1754+ rpc_dev_memory (dev, max_mem[i], &free, &total);
17331755 printf (" %s: %s (%zu MiB, %zu MiB free)\n " , ggml_backend_dev_name (dev), ggml_backend_dev_description (dev),
1734- total_mem[i] / 1024 / 1024 , free_mem[i] / 1024 / 1024 );
1756+ total / 1024 / 1024 , free / 1024 / 1024 );
17351757 auto backend = ggml_backend_dev_init (dev, nullptr );
17361758 if (!backend) {
17371759 fprintf (stderr, " Failed to create backend for device %s\n " , dev->iface .get_name (dev));
@@ -1775,7 +1797,7 @@ void ggml_backend_rpc_start_server(const char * endpoint, const char * cache_dir
17751797 }
17761798 printf (" Accepted client connection\n " );
17771799 fflush (stdout);
1778- rpc_serve_client (backends, cache_dir, client_socket->fd , free_mem_vec, total_mem_vec );
1800+ rpc_serve_client (backends, cache_dir, client_socket->fd , max_mem_vec );
17791801 printf (" Client connection closed\n " );
17801802 fflush (stdout);
17811803 }
0 commit comments