diff --git a/ggml/include/ggml-rpc.h b/ggml/include/ggml-rpc.h index e6dca3f62b09c..832c26c61d3eb 100644 --- a/ggml/include/ggml-rpc.h +++ b/ggml/include/ggml-rpc.h @@ -8,7 +8,7 @@ extern "C" { #endif #define RPC_PROTO_MAJOR_VERSION 3 -#define RPC_PROTO_MINOR_VERSION 0 +#define RPC_PROTO_MINOR_VERSION 5 #define RPC_PROTO_PATCH_VERSION 0 #define GGML_RPC_MAX_SERVERS 16 diff --git a/ggml/src/ggml-rpc/ggml-rpc.cpp b/ggml/src/ggml-rpc/ggml-rpc.cpp index a38df5a97e1f0..054cead39356b 100644 --- a/ggml/src/ggml-rpc/ggml-rpc.cpp +++ b/ggml/src/ggml-rpc/ggml-rpc.cpp @@ -106,13 +106,17 @@ enum rpc_cmd { RPC_CMD_GET_ALLOC_SIZE, RPC_CMD_HELLO, RPC_CMD_DEVICE_COUNT, + RPC_CMD_GRAPH_COMPUTE_AND_STORE, + RPC_CMD_GRAPH_RECOMPUTE, RPC_CMD_COUNT, }; static_assert(RPC_CMD_HELLO == 14, "RPC_CMD_HELLO must be always 14"); // Try RPC_CMD_SET_TENSOR_HASH first when data size is larger than this threshold -const size_t HASH_THRESHOLD = 10 * 1024 * 1024; +const size_t HASH_THRESHOLD = 10 * 1024 * 1024; +const int MAX_STORED_GRAPHS = 64; +const uint64_t INVALID_GRAPH_ID = UINT64_MAX; struct rpc_msg_hello_rsp { uint8_t major; @@ -217,6 +221,20 @@ struct rpc_msg_get_device_memory_rsp { uint64_t free_mem; uint64_t total_mem; }; + +struct rpc_msg_graph_compute_and_store_rsp { + uint8_t result; + uint64_t graph_id; +}; + +struct rpc_msg_graph_recompute_req { + uint64_t graph_id; +}; + +struct rpc_msg_graph_recompute_rsp { + uint8_t result; +}; + #pragma pack(pop) // RPC data structures @@ -238,6 +256,7 @@ struct ggml_backend_rpc_context { std::string endpoint; uint32_t device; std::string name; + uint64_t curr_graph_id; }; struct ggml_backend_rpc_buffer_context { @@ -592,6 +611,8 @@ static enum ggml_status ggml_backend_rpc_buffer_init_tensor(ggml_backend_buffer_ bool status = send_rpc_cmd(ctx->sock, RPC_CMD_INIT_TENSOR, &request, sizeof(request), nullptr, 0); RPC_STATUS_ASSERT(status); } + // HACK: use the extra field for storing the graph ID + tensor->extra = reinterpret_cast(INVALID_GRAPH_ID); return GGML_STATUS_SUCCESS; } @@ -815,13 +836,30 @@ static void serialize_graph(uint32_t device, const ggml_cgraph * cgraph, std::ve static enum ggml_status ggml_backend_rpc_graph_compute(ggml_backend_t backend, ggml_cgraph * cgraph) { ggml_backend_rpc_context * rpc_ctx = (ggml_backend_rpc_context *)backend->context; - std::vector input; - serialize_graph(rpc_ctx->device, cgraph, input); - rpc_msg_graph_compute_rsp response; - auto sock = get_socket(rpc_ctx->endpoint); - bool status = send_rpc_cmd(sock, RPC_CMD_GRAPH_COMPUTE, input.data(), input.size(), &response, sizeof(response)); - RPC_STATUS_ASSERT(status); - return (enum ggml_status)response.result; + + GGML_ASSERT(cgraph->n_nodes > 0); + // HACK: we store the graph ID in the first node's extra field + uint64_t stored_graph_id = reinterpret_cast(cgraph->nodes[0]->extra); + bool reuse_graph = stored_graph_id != INVALID_GRAPH_ID && (stored_graph_id + MAX_STORED_GRAPHS > rpc_ctx->curr_graph_id); + if (reuse_graph) { + rpc_msg_graph_recompute_req request; + request.graph_id = stored_graph_id; + rpc_msg_graph_recompute_rsp response; + auto sock = get_socket(rpc_ctx->endpoint); + bool status = send_rpc_cmd(sock, RPC_CMD_GRAPH_RECOMPUTE, &request, sizeof(request), &response, sizeof(response)); + RPC_STATUS_ASSERT(status); + return (enum ggml_status)response.result; + } else { + std::vector input; + serialize_graph(rpc_ctx->device, cgraph, input); + rpc_msg_graph_compute_and_store_rsp response; + auto sock = get_socket(rpc_ctx->endpoint); + bool status = send_rpc_cmd(sock, RPC_CMD_GRAPH_COMPUTE_AND_STORE, input.data(), input.size(), &response, sizeof(response)); + RPC_STATUS_ASSERT(status); + rpc_ctx->curr_graph_id = response.graph_id; + cgraph->nodes[0]->extra = reinterpret_cast(response.graph_id); + return (enum ggml_status)response.result; + } } static ggml_backend_i ggml_backend_rpc_interface = { @@ -878,9 +916,10 @@ ggml_backend_buffer_type_t ggml_backend_rpc_buffer_type(const char * endpoint, u ggml_backend_t ggml_backend_rpc_init(const char * endpoint, uint32_t device) { std::string dev_name = "RPC" + std::to_string(device) + "[" + std::string(endpoint) + "]"; ggml_backend_rpc_context * ctx = new ggml_backend_rpc_context { - /* .endpoint = */ endpoint, - /* .device = */ device, - /* .name = */ dev_name + /* .endpoint = */ endpoint, + /* .device = */ device, + /* .name = */ dev_name, + /* .curr_graph_id = */ 0, }; auto reg = ggml_backend_rpc_add_server(endpoint); ggml_backend_t backend = new ggml_backend { @@ -921,7 +960,8 @@ void ggml_backend_rpc_get_device_memory(const char * endpoint, uint32_t device, class rpc_server { public: rpc_server(std::vector backends, const char * cache_dir) - : backends(std::move(backends)), cache_dir(cache_dir) { + : backends(std::move(backends)), cache_dir(cache_dir), curr_graph_id(0) { + stored_graphs.resize(MAX_STORED_GRAPHS); } ~rpc_server(); @@ -937,10 +977,18 @@ class rpc_server { bool get_tensor(const rpc_msg_get_tensor_req & request, std::vector & response); bool copy_tensor(const rpc_msg_copy_tensor_req & request, rpc_msg_copy_tensor_rsp & response); bool graph_compute(const std::vector & input, rpc_msg_graph_compute_rsp & response); + bool graph_compute_and_store(const std::vector & input, rpc_msg_graph_compute_and_store_rsp & response); + bool graph_recompute(const rpc_msg_graph_recompute_req & request, rpc_msg_graph_recompute_rsp & response); bool init_tensor(const rpc_msg_init_tensor_req & request); bool get_alloc_size(const rpc_msg_get_alloc_size_req & request, rpc_msg_get_alloc_size_rsp & response); bool get_device_memory(const rpc_msg_get_device_memory_req & request, rpc_msg_get_device_memory_rsp & response); + struct stored_graph { + uint32_t device; + ggml_context_ptr ctx_ptr; + ggml_cgraph * graph; + }; + private: bool get_cached_file(uint64_t hash, std::vector & data); ggml_tensor * deserialize_tensor(struct ggml_context * ctx, const rpc_tensor * tensor); @@ -948,11 +996,15 @@ class rpc_server { struct ggml_context * ctx, const std::unordered_map & tensor_ptrs, std::unordered_map & tensor_map); + bool store_graph(const std::vector & input, stored_graph & sg); std::vector backends; const char * cache_dir; std::unordered_set buffers; + uint64_t curr_graph_id; + // ring buffer for storing graphs + std::vector stored_graphs; }; void rpc_server::hello(rpc_msg_hello_rsp & response) { @@ -1394,7 +1446,7 @@ ggml_tensor * rpc_server::create_node(uint64_t id, return result; } -bool rpc_server::graph_compute(const std::vector & input, rpc_msg_graph_compute_rsp & response) { +bool rpc_server::store_graph(const std::vector & input, stored_graph & sg) { // serialization format: // | device (4 bytes) | n_nodes (4 bytes) | nodes (n_nodes * sizeof(uint64_t) | n_tensors (4 bytes) | tensors (n_tensors * sizeof(rpc_tensor)) | if (input.size() < 2*sizeof(uint32_t)) { @@ -1422,7 +1474,6 @@ bool rpc_server::graph_compute(const std::vector & input, rpc_msg_graph return false; } const rpc_tensor * tensors = (const rpc_tensor *)src; - LOG_DBG("[%s] device: %u, n_nodes: %u, n_tensors: %u\n", __func__, device, n_nodes, n_tensors); size_t buf_size = ggml_tensor_overhead()*(n_nodes + n_tensors) + ggml_graph_overhead_custom(n_nodes, false); @@ -1454,6 +1505,47 @@ bool rpc_server::graph_compute(const std::vector & input, rpc_msg_graph return false; } } + sg.ctx_ptr.swap(ctx_ptr); + sg.graph = graph; + sg.device = device; + return true; +} + +bool rpc_server::graph_compute(const std::vector & input, rpc_msg_graph_compute_rsp & response) { + stored_graph sg; + if (!store_graph(input, sg)) { + return false; + } + uint32_t device = sg.device; + LOG_DBG("[%s] device: %u, input: %zu bytes\n", __func__, device, input.size()); + ggml_status status = ggml_backend_graph_compute(backends[device], sg.graph); + response.result = status; + return true; +} + +bool rpc_server::graph_compute_and_store(const std::vector & input, rpc_msg_graph_compute_and_store_rsp & response) { + int graph_slot = curr_graph_id % MAX_STORED_GRAPHS; + if (!store_graph(input, stored_graphs[graph_slot])) { + return false; + } + ggml_cgraph * graph = stored_graphs[graph_slot].graph; + uint32_t device = stored_graphs[graph_slot].device; + LOG_DBG("[%s] device: %u, input: %zu bytes, graph_id: %" PRIu64 "\n", __func__, device, input.size(), curr_graph_id); + ggml_status status = ggml_backend_graph_compute(backends[device], graph); + response.result = status; + response.graph_id = curr_graph_id; + curr_graph_id++; + return true; +} + +bool rpc_server::graph_recompute(const rpc_msg_graph_recompute_req & request, rpc_msg_graph_recompute_rsp & response) { + int graph_slot = request.graph_id % MAX_STORED_GRAPHS; + if (stored_graphs[graph_slot].graph == nullptr) { + return false; + } + ggml_cgraph * graph = stored_graphs[graph_slot].graph; + uint32_t device = stored_graphs[graph_slot].device; + LOG_DBG("[%s] device: %u, graph_id: %" PRIu64 "\n", __func__, device, request.graph_id); ggml_status status = ggml_backend_graph_compute(backends[device], graph); response.result = status; return true; @@ -1699,6 +1791,35 @@ static void rpc_serve_client(const std::vector & backends, const } break; } + + case RPC_CMD_GRAPH_COMPUTE_AND_STORE: { + std::vector input; + if (!recv_msg(sockfd, input)) { + return; + } + rpc_msg_graph_compute_and_store_rsp response; + if (!server.graph_compute_and_store(input, response)) { + return; + } + if (!send_msg(sockfd, &response, sizeof(response))) { + return; + } + break; + } + case RPC_CMD_GRAPH_RECOMPUTE: { + rpc_msg_graph_recompute_req request; + if (!recv_msg(sockfd, &request, sizeof(request))) { + return; + } + rpc_msg_graph_recompute_rsp response; + if (!server.graph_recompute(request, response)) { + return; + } + if (!send_msg(sockfd, &response, sizeof(response))) { + return; + } + break; + } case RPC_CMD_GET_DEVICE_MEMORY: { rpc_msg_get_device_memory_req request; if (!recv_msg(sockfd, &request, sizeof(request))) {