diff --git a/ggml/include/ggml-rpc.h b/ggml/include/ggml-rpc.h index 72eff0027351a..e6dca3f62b09c 100644 --- a/ggml/include/ggml-rpc.h +++ b/ggml/include/ggml-rpc.h @@ -21,8 +21,7 @@ GGML_BACKEND_API ggml_backend_buffer_type_t ggml_backend_rpc_buffer_type(const c GGML_BACKEND_API void ggml_backend_rpc_get_device_memory(const char * endpoint, uint32_t device, size_t * free, size_t * total); GGML_BACKEND_API void ggml_backend_rpc_start_server(const char * endpoint, const char * cache_dir, - size_t n_threads, size_t n_devices, - ggml_backend_dev_t * devices, size_t * free_mem, size_t * total_mem); + size_t n_threads, size_t n_devices, ggml_backend_dev_t * devices); GGML_BACKEND_API ggml_backend_reg_t ggml_backend_rpc_reg(void); GGML_BACKEND_API ggml_backend_reg_t ggml_backend_rpc_add_server(const char * endpoint); diff --git a/ggml/src/ggml-rpc/ggml-rpc.cpp b/ggml/src/ggml-rpc/ggml-rpc.cpp index aad48d62a850c..a38df5a97e1f0 100644 --- a/ggml/src/ggml-rpc/ggml-rpc.cpp +++ b/ggml/src/ggml-rpc/ggml-rpc.cpp @@ -939,6 +939,7 @@ class rpc_server { bool graph_compute(const std::vector & input, rpc_msg_graph_compute_rsp & response); bool init_tensor(const rpc_msg_init_tensor_req & request); bool get_alloc_size(const rpc_msg_get_alloc_size_req & request, rpc_msg_get_alloc_size_rsp & response); + bool get_device_memory(const rpc_msg_get_device_memory_req & request, rpc_msg_get_device_memory_rsp & response); private: bool get_cached_file(uint64_t hash, std::vector & data); @@ -1458,6 +1459,20 @@ bool rpc_server::graph_compute(const std::vector & input, rpc_msg_graph return true; } +bool rpc_server::get_device_memory(const rpc_msg_get_device_memory_req & request, rpc_msg_get_device_memory_rsp & response) { + uint32_t dev_id = request.device; + if (dev_id >= backends.size()) { + return false; + } + size_t free, total; + ggml_backend_dev_t dev = ggml_backend_get_device(backends[dev_id]); + ggml_backend_dev_memory(dev, &free, &total); + response.free_mem = free; + response.total_mem = total; + LOG_DBG("[%s] device: %u, free_mem: %" PRIu64 ", total_mem: %" PRIu64 "\n", __func__, dev_id, response.free_mem, response.total_mem); + return true; +} + rpc_server::~rpc_server() { for (auto buffer : buffers) { ggml_backend_buffer_free(buffer); @@ -1465,7 +1480,7 @@ rpc_server::~rpc_server() { } static void rpc_serve_client(const std::vector & backends, const char * cache_dir, - sockfd_t sockfd, const std::vector & free_mem, const std::vector & total_mem) { + sockfd_t sockfd) { rpc_server server(backends, cache_dir); uint8_t cmd; if (!recv_data(sockfd, &cmd, 1)) { @@ -1689,15 +1704,10 @@ static void rpc_serve_client(const std::vector & backends, const if (!recv_msg(sockfd, &request, sizeof(request))) { return; } - auto dev_id = request.device; - if (dev_id >= backends.size()) { + rpc_msg_get_device_memory_rsp response; + if (!server.get_device_memory(request, response)) { return; } - rpc_msg_get_device_memory_rsp response; - response.free_mem = free_mem[dev_id]; - response.total_mem = total_mem[dev_id]; - LOG_DBG("[get_device_mem] device: %u, free_mem: %" PRIu64 ", total_mem: %" PRIu64 "\n", dev_id, - response.free_mem, response.total_mem); if (!send_msg(sockfd, &response, sizeof(response))) { return; } @@ -1712,15 +1722,12 @@ static void rpc_serve_client(const std::vector & backends, const } void ggml_backend_rpc_start_server(const char * endpoint, const char * cache_dir, - size_t n_threads, size_t n_devices, - ggml_backend_dev_t * devices, size_t * free_mem, size_t * total_mem) { - if (n_devices == 0 || devices == nullptr || free_mem == nullptr || total_mem == nullptr) { + size_t n_threads, size_t n_devices, ggml_backend_dev_t * devices) { + if (n_devices == 0 || devices == nullptr) { fprintf(stderr, "Invalid arguments to ggml_backend_rpc_start_server\n"); return; } std::vector backends; - std::vector free_mem_vec(free_mem, free_mem + n_devices); - std::vector total_mem_vec(total_mem, total_mem + n_devices); printf("Starting RPC server v%d.%d.%d\n", RPC_PROTO_MAJOR_VERSION, RPC_PROTO_MINOR_VERSION, @@ -1730,8 +1737,10 @@ void ggml_backend_rpc_start_server(const char * endpoint, const char * cache_dir printf("Devices:\n"); for (size_t i = 0; i < n_devices; i++) { auto dev = devices[i]; + size_t free, total; + ggml_backend_dev_memory(dev, &free, &total); printf(" %s: %s (%zu MiB, %zu MiB free)\n", ggml_backend_dev_name(dev), ggml_backend_dev_description(dev), - total_mem[i] / 1024 / 1024, free_mem[i] / 1024 / 1024); + total / 1024 / 1024, free / 1024 / 1024); auto backend = ggml_backend_dev_init(dev, nullptr); if (!backend) { fprintf(stderr, "Failed to create backend for device %s\n", dev->iface.get_name(dev)); @@ -1775,7 +1784,7 @@ void ggml_backend_rpc_start_server(const char * endpoint, const char * cache_dir } printf("Accepted client connection\n"); fflush(stdout); - rpc_serve_client(backends, cache_dir, client_socket->fd, free_mem_vec, total_mem_vec); + rpc_serve_client(backends, cache_dir, client_socket->fd); printf("Client connection closed\n"); fflush(stdout); } diff --git a/tools/rpc/rpc-server.cpp b/tools/rpc/rpc-server.cpp index 088515612772d..58b93c7468ea3 100644 --- a/tools/rpc/rpc-server.cpp +++ b/tools/rpc/rpc-server.cpp @@ -137,7 +137,6 @@ struct rpc_server_params { bool use_cache = false; int n_threads = std::max(1U, std::thread::hardware_concurrency()/2); std::vector devices; - std::vector dev_mem; }; static void print_usage(int /*argc*/, char ** argv, rpc_server_params params) { @@ -148,7 +147,6 @@ static void print_usage(int /*argc*/, char ** argv, rpc_server_params params) { fprintf(stderr, " -d, --device comma-separated list of devices\n"); fprintf(stderr, " -H, --host HOST host to bind to (default: %s)\n", params.host.c_str()); fprintf(stderr, " -p, --port PORT port to bind to (default: %d)\n", params.port); - fprintf(stderr, " -m, --mem memory size for each device (in MB)\n"); fprintf(stderr, " -c, --cache enable local file cache\n"); fprintf(stderr, "\n"); } @@ -197,23 +195,6 @@ static bool rpc_server_params_parse(int argc, char ** argv, rpc_server_params & } } else if (arg == "-c" || arg == "--cache") { params.use_cache = true; - } else if (arg == "-m" || arg == "--mem") { - if (++i >= argc) { - return false; - } - const std::regex regex{ R"([,/]+)" }; - std::string mem_str = argv[i]; - std::sregex_token_iterator iter(mem_str.begin(), mem_str.end(), regex, -1); - std::sregex_token_iterator end; - for ( ; iter != end; ++iter) { - try { - size_t mem = std::stoul(*iter) * 1024 * 1024; - params.dev_mem.push_back(mem); - } catch (const std::exception & ) { - fprintf(stderr, "error: invalid memory size: %s\n", iter->str().c_str()); - return false; - } - } } else if (arg == "-h" || arg == "--help") { print_usage(argc, argv, params); exit(0); @@ -293,18 +274,6 @@ int main(int argc, char * argv[]) { return 1; } std::string endpoint = params.host + ":" + std::to_string(params.port); - std::vector free_mem, total_mem; - for (size_t i = 0; i < devices.size(); i++) { - if (i < params.dev_mem.size()) { - free_mem.push_back(params.dev_mem[i]); - total_mem.push_back(params.dev_mem[i]); - } else { - size_t free, total; - ggml_backend_dev_memory(devices[i], &free, &total); - free_mem.push_back(free); - total_mem.push_back(total); - } - } const char * cache_dir = nullptr; std::string cache_dir_str; if (params.use_cache) { @@ -328,7 +297,6 @@ int main(int argc, char * argv[]) { return 1; } - start_server_fn(endpoint.c_str(), cache_dir, params.n_threads, devices.size(), - devices.data(), free_mem.data(), total_mem.data()); + start_server_fn(endpoint.c_str(), cache_dir, params.n_threads, devices.size(), devices.data()); return 0; }