@@ -939,6 +939,7 @@ class rpc_server {
939939 bool graph_compute (const std::vector<uint8_t > & input, rpc_msg_graph_compute_rsp & response);
940940 bool init_tensor (const rpc_msg_init_tensor_req & request);
941941 bool get_alloc_size (const rpc_msg_get_alloc_size_req & request, rpc_msg_get_alloc_size_rsp & response);
942+ bool get_device_memory (const rpc_msg_get_device_memory_req & request, rpc_msg_get_device_memory_rsp & response);
942943
943944private:
944945 bool get_cached_file (uint64_t hash, std::vector<uint8_t > & data);
@@ -1458,14 +1459,28 @@ bool rpc_server::graph_compute(const std::vector<uint8_t> & input, rpc_msg_graph
14581459 return true ;
14591460}
14601461
1462+ bool rpc_server::get_device_memory (const rpc_msg_get_device_memory_req & request, rpc_msg_get_device_memory_rsp & response) {
1463+ uint32_t dev_id = request.device ;
1464+ if (dev_id >= backends.size ()) {
1465+ return false ;
1466+ }
1467+ size_t free, total;
1468+ ggml_backend_dev_t dev = ggml_backend_get_device (backends[dev_id]);
1469+ ggml_backend_dev_memory (dev, &free, &total);
1470+ response.free_mem = free;
1471+ response.total_mem = total;
1472+ LOG_DBG (" [%s] device: %u, free_mem: %" PRIu64 " , total_mem: %" PRIu64 " \n " , __func__, dev_id, response.free_mem , response.total_mem );
1473+ return true ;
1474+ }
1475+
14611476rpc_server::~rpc_server () {
14621477 for (auto buffer : buffers) {
14631478 ggml_backend_buffer_free (buffer);
14641479 }
14651480}
14661481
14671482static void rpc_serve_client (const std::vector<ggml_backend_t > & backends, const char * cache_dir,
1468- sockfd_t sockfd, const std::vector< size_t > & free_mem, const std::vector< size_t > & total_mem ) {
1483+ sockfd_t sockfd) {
14691484 rpc_server server (backends, cache_dir);
14701485 uint8_t cmd;
14711486 if (!recv_data (sockfd, &cmd, 1 )) {
@@ -1689,15 +1704,10 @@ static void rpc_serve_client(const std::vector<ggml_backend_t> & backends, const
16891704 if (!recv_msg (sockfd, &request, sizeof (request))) {
16901705 return ;
16911706 }
1692- auto dev_id = request. device ;
1693- if (dev_id >= backends. size ( )) {
1707+ rpc_msg_get_device_memory_rsp response ;
1708+ if (!server. get_device_memory (request, response )) {
16941709 return ;
16951710 }
1696- rpc_msg_get_device_memory_rsp response;
1697- response.free_mem = free_mem[dev_id];
1698- response.total_mem = total_mem[dev_id];
1699- LOG_DBG (" [get_device_mem] device: %u, free_mem: %" PRIu64 " , total_mem: %" PRIu64 " \n " , dev_id,
1700- response.free_mem , response.total_mem );
17011711 if (!send_msg (sockfd, &response, sizeof (response))) {
17021712 return ;
17031713 }
@@ -1712,15 +1722,12 @@ static void rpc_serve_client(const std::vector<ggml_backend_t> & backends, const
17121722}
17131723
17141724void ggml_backend_rpc_start_server (const char * endpoint, const char * cache_dir,
1715- size_t n_threads, size_t n_devices,
1716- ggml_backend_dev_t * devices, size_t * free_mem, size_t * total_mem) {
1717- if (n_devices == 0 || devices == nullptr || free_mem == nullptr || total_mem == nullptr ) {
1725+ size_t n_threads, size_t n_devices, ggml_backend_dev_t * devices) {
1726+ if (n_devices == 0 || devices == nullptr ) {
17181727 fprintf (stderr, " Invalid arguments to ggml_backend_rpc_start_server\n " );
17191728 return ;
17201729 }
17211730 std::vector<ggml_backend_t > backends;
1722- std::vector<size_t > free_mem_vec (free_mem, free_mem + n_devices);
1723- std::vector<size_t > total_mem_vec (total_mem, total_mem + n_devices);
17241731 printf (" Starting RPC server v%d.%d.%d\n " ,
17251732 RPC_PROTO_MAJOR_VERSION,
17261733 RPC_PROTO_MINOR_VERSION,
@@ -1730,8 +1737,10 @@ void ggml_backend_rpc_start_server(const char * endpoint, const char * cache_dir
17301737 printf (" Devices:\n " );
17311738 for (size_t i = 0 ; i < n_devices; i++) {
17321739 auto dev = devices[i];
1740+ size_t free, total;
1741+ ggml_backend_dev_memory (dev, &free, &total);
17331742 printf (" %s: %s (%zu MiB, %zu MiB free)\n " , ggml_backend_dev_name (dev), ggml_backend_dev_description (dev),
1734- total_mem[i] / 1024 / 1024 , free_mem[i] / 1024 / 1024 );
1743+ total / 1024 / 1024 , free / 1024 / 1024 );
17351744 auto backend = ggml_backend_dev_init (dev, nullptr );
17361745 if (!backend) {
17371746 fprintf (stderr, " Failed to create backend for device %s\n " , dev->iface .get_name (dev));
@@ -1775,7 +1784,7 @@ void ggml_backend_rpc_start_server(const char * endpoint, const char * cache_dir
17751784 }
17761785 printf (" Accepted client connection\n " );
17771786 fflush (stdout);
1778- rpc_serve_client (backends, cache_dir, client_socket->fd , free_mem_vec, total_mem_vec );
1787+ rpc_serve_client (backends, cache_dir, client_socket->fd );
17791788 printf (" Client connection closed\n " );
17801789 fflush (stdout);
17811790 }
0 commit comments