Skip to content

Commit 537b237

Browse files
committed
rpc : reuse compute graphs
Store compute graphs on the server side and reuse them when possible. Compute graphs are kept in a ring buffer with fixed size, so we can avoid serializing and deserializing the same graph every time. Add two new commands: * RPC_CMD_GRAPH_COMPUTE_AND_STORE -- store the graph, compute it and return its ID * RPC_CMD_GRAPH_RECOMPUTE -- recompute the graph with the given ID Currently there is no good way to associate an ID with `ggml_cgraph`, so we abuse `tensor->extra` of the first node for this purpose.
1 parent 6d7f111 commit 537b237

File tree

2 files changed

+126
-14
lines changed

2 files changed

+126
-14
lines changed

ggml/include/ggml-rpc.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ extern "C" {
88
#endif
99

1010
#define RPC_PROTO_MAJOR_VERSION 2
11-
#define RPC_PROTO_MINOR_VERSION 0
11+
#define RPC_PROTO_MINOR_VERSION 5
1212
#define RPC_PROTO_PATCH_VERSION 0
1313
#define GGML_RPC_MAX_SERVERS 16
1414

ggml/src/ggml-rpc/ggml-rpc.cpp

Lines changed: 125 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -99,12 +99,17 @@ enum rpc_cmd {
9999
RPC_CMD_INIT_TENSOR,
100100
RPC_CMD_GET_ALLOC_SIZE,
101101
RPC_CMD_HELLO,
102+
RPC_CMD_GRAPH_COMPUTE_AND_STORE,
103+
RPC_CMD_GRAPH_RECOMPUTE,
102104
RPC_CMD_COUNT,
103105
};
104106

105107
// Try RPC_CMD_SET_TENSOR_HASH first when data size is larger than this threshold
106108
const size_t HASH_THRESHOLD = 10 * 1024 * 1024;
107109

110+
const int MAX_STORED_GRAPHS = 64;
111+
const int64_t INVALID_GRAPH_ID = -1;
112+
108113
struct rpc_msg_hello_rsp {
109114
uint8_t major;
110115
uint8_t minor;
@@ -186,6 +191,19 @@ struct rpc_msg_graph_compute_rsp {
186191
uint8_t result;
187192
};
188193

194+
struct rpc_msg_graph_compute_and_store_rsp {
195+
uint8_t result;
196+
int32_t graph_id;
197+
};
198+
199+
struct rpc_msg_graph_recompute_req {
200+
int32_t graph_id;
201+
};
202+
203+
struct rpc_msg_graph_recompute_rsp {
204+
uint8_t result;
205+
};
206+
189207
struct rpc_msg_get_device_memory_rsp {
190208
uint64_t free_mem;
191209
uint64_t total_mem;
@@ -209,6 +227,7 @@ struct ggml_backend_rpc_buffer_type_context {
209227
struct ggml_backend_rpc_context {
210228
std::string endpoint;
211229
std::string name;
230+
int32_t curr_graph_id;
212231
};
213232

214233
struct ggml_backend_rpc_buffer_context {
@@ -563,6 +582,8 @@ static enum ggml_status ggml_backend_rpc_buffer_init_tensor(ggml_backend_buffer_
563582
bool status = send_rpc_cmd(ctx->sock, RPC_CMD_INIT_TENSOR, &request, sizeof(request), nullptr, 0);
564583
RPC_STATUS_ASSERT(status);
565584
}
585+
// HACK: use the extra field for storing the graph ID
586+
tensor->extra = reinterpret_cast<void*>(INVALID_GRAPH_ID);
566587
return GGML_STATUS_SUCCESS;
567588
}
568589

@@ -772,13 +793,29 @@ static void serialize_graph(const ggml_cgraph * cgraph, std::vector<uint8_t> & o
772793

773794
static enum ggml_status ggml_backend_rpc_graph_compute(ggml_backend_t backend, ggml_cgraph * cgraph) {
774795
ggml_backend_rpc_context * rpc_ctx = (ggml_backend_rpc_context *)backend->context;
775-
std::vector<uint8_t> input;
776-
serialize_graph(cgraph, input);
777-
rpc_msg_graph_compute_rsp response;
778-
auto sock = get_socket(rpc_ctx->endpoint);
779-
bool status = send_rpc_cmd(sock, RPC_CMD_GRAPH_COMPUTE, input.data(), input.size(), &response, sizeof(response));
780-
RPC_STATUS_ASSERT(status);
781-
return (enum ggml_status)response.result;
796+
GGML_ASSERT(cgraph->n_nodes > 0);
797+
// HACK: we store the graph ID in the first node's extra field
798+
int64_t stored_graph_id = reinterpret_cast<int64_t>(cgraph->nodes[0]->extra);
799+
bool reuse_graph = stored_graph_id != INVALID_GRAPH_ID && (stored_graph_id + MAX_STORED_GRAPHS > rpc_ctx->curr_graph_id);
800+
if (reuse_graph) {
801+
rpc_msg_graph_recompute_req request;
802+
request.graph_id = stored_graph_id;
803+
rpc_msg_graph_recompute_rsp response;
804+
auto sock = get_socket(rpc_ctx->endpoint);
805+
bool status = send_rpc_cmd(sock, RPC_CMD_GRAPH_RECOMPUTE, &request, sizeof(request), &response, sizeof(response));
806+
RPC_STATUS_ASSERT(status);
807+
return (enum ggml_status)response.result;
808+
} else {
809+
std::vector<uint8_t> input;
810+
serialize_graph(cgraph, input);
811+
rpc_msg_graph_compute_and_store_rsp response;
812+
auto sock = get_socket(rpc_ctx->endpoint);
813+
bool status = send_rpc_cmd(sock, RPC_CMD_GRAPH_COMPUTE_AND_STORE, input.data(), input.size(), &response, sizeof(response));
814+
RPC_STATUS_ASSERT(status);
815+
rpc_ctx->curr_graph_id = response.graph_id;
816+
cgraph->nodes[0]->extra = reinterpret_cast<void*>(response.graph_id);
817+
return (enum ggml_status)response.result;
818+
}
782819
}
783820

784821
static ggml_backend_i ggml_backend_rpc_interface = {
@@ -831,8 +868,9 @@ ggml_backend_buffer_type_t ggml_backend_rpc_buffer_type(const char * endpoint) {
831868

832869
ggml_backend_t ggml_backend_rpc_init(const char * endpoint) {
833870
ggml_backend_rpc_context * ctx = new ggml_backend_rpc_context {
834-
/* .endpoint = */ endpoint,
835-
/* .name = */ "RPC[" + std::string(endpoint) + "]",
871+
/* .endpoint = */ endpoint,
872+
/* .name = */ "RPC[" + std::string(endpoint) + "]",
873+
/*. curr_graph_id = */ 0,
836874
};
837875

838876
ggml_backend_t backend = new ggml_backend {
@@ -871,7 +909,8 @@ void ggml_backend_rpc_get_device_memory(const char * endpoint, size_t * free, si
871909
class rpc_server {
872910
public:
873911
rpc_server(ggml_backend_t backend, const char * cache_dir)
874-
: backend(backend), cache_dir(cache_dir) {
912+
: backend(backend), cache_dir(cache_dir), curr_graph_id(0) {
913+
stored_graphs.resize(MAX_STORED_GRAPHS);
875914
}
876915
~rpc_server();
877916

@@ -887,21 +926,31 @@ class rpc_server {
887926
bool get_tensor(const rpc_msg_get_tensor_req & request, std::vector<uint8_t> & response);
888927
bool copy_tensor(const rpc_msg_copy_tensor_req & request, rpc_msg_copy_tensor_rsp & response);
889928
bool graph_compute(const std::vector<uint8_t> & input, rpc_msg_graph_compute_rsp & response);
929+
bool graph_compute_and_store(const std::vector<uint8_t> & input, rpc_msg_graph_compute_and_store_rsp & response);
930+
bool graph_recompute(const rpc_msg_graph_recompute_req & request, rpc_msg_graph_recompute_rsp & response);
890931
bool init_tensor(const rpc_msg_init_tensor_req & request);
891932
bool get_alloc_size(const rpc_msg_get_alloc_size_req & request, rpc_msg_get_alloc_size_rsp & response);
892933

934+
struct stored_graph {
935+
ggml_context_ptr ctx_ptr;
936+
ggml_cgraph * graph;
937+
};
938+
893939
private:
894940
bool get_cached_file(uint64_t hash, std::vector<uint8_t> & data);
895941
ggml_tensor * deserialize_tensor(struct ggml_context * ctx, const rpc_tensor * tensor);
896942
ggml_tensor * create_node(uint64_t id,
897943
struct ggml_context * ctx,
898944
const std::unordered_map<uint64_t, const rpc_tensor*> & tensor_ptrs,
899945
std::unordered_map<uint64_t, struct ggml_tensor*> & tensor_map);
900-
946+
bool store_graph(const std::vector<uint8_t> & input, stored_graph & sg);
901947

902948
ggml_backend_t backend;
903949
const char * cache_dir;
904950
std::unordered_set<ggml_backend_buffer_t> buffers;
951+
int64_t curr_graph_id;
952+
// ring buffer for storing graphs
953+
std::vector<stored_graph> stored_graphs;
905954
};
906955

907956
void rpc_server::hello(rpc_msg_hello_rsp & response) {
@@ -1323,7 +1372,7 @@ ggml_tensor * rpc_server::create_node(uint64_t id,
13231372
return result;
13241373
}
13251374

1326-
bool rpc_server::graph_compute(const std::vector<uint8_t> & input, rpc_msg_graph_compute_rsp & response) {
1375+
bool rpc_server::store_graph(const std::vector<uint8_t> & input, stored_graph & sg) {
13271376
// serialization format:
13281377
// | n_nodes (4 bytes) | nodes (n_nodes * sizeof(uint64_t) | n_tensors (4 bytes) | tensors (n_tensors * sizeof(rpc_tensor)) |
13291378
if (input.size() < sizeof(uint32_t)) {
@@ -1373,7 +1422,42 @@ bool rpc_server::graph_compute(const std::vector<uint8_t> & input, rpc_msg_graph
13731422
return false;
13741423
}
13751424
}
1376-
ggml_status status = ggml_backend_graph_compute(backend, graph);
1425+
sg.ctx_ptr.swap(ctx_ptr);
1426+
sg.graph = graph;
1427+
return true;
1428+
}
1429+
1430+
bool rpc_server::graph_compute(const std::vector<uint8_t> & input, rpc_msg_graph_compute_rsp & response) {
1431+
stored_graph sg;
1432+
if (!store_graph(input, sg)) {
1433+
return false;
1434+
}
1435+
ggml_status status = ggml_backend_graph_compute(backend, sg.graph);
1436+
response.result = status;
1437+
return true;
1438+
}
1439+
1440+
bool rpc_server::graph_compute_and_store(const std::vector<uint8_t> & input, rpc_msg_graph_compute_and_store_rsp & response) {
1441+
int graph_slot = curr_graph_id % MAX_STORED_GRAPHS;
1442+
if (!store_graph(input, stored_graphs[graph_slot])) {
1443+
return false;
1444+
}
1445+
ggml_status status = ggml_backend_graph_compute(backend, stored_graphs[graph_slot].graph);
1446+
response.result = status;
1447+
response.graph_id = curr_graph_id;
1448+
curr_graph_id++;
1449+
return true;
1450+
}
1451+
1452+
bool rpc_server::graph_recompute(const rpc_msg_graph_recompute_req & request, rpc_msg_graph_recompute_rsp & response) {
1453+
if (request.graph_id < 0) {
1454+
return false;
1455+
}
1456+
int graph_slot = request.graph_id % MAX_STORED_GRAPHS;
1457+
if (stored_graphs[graph_slot].graph == nullptr) {
1458+
return false;
1459+
}
1460+
ggml_status status = ggml_backend_graph_compute(backend, stored_graphs[graph_slot].graph);
13771461
response.result = status;
13781462
return true;
13791463
}
@@ -1585,6 +1669,34 @@ static void rpc_serve_client(ggml_backend_t backend, const char * cache_dir,
15851669
}
15861670
break;
15871671
}
1672+
case RPC_CMD_GRAPH_COMPUTE_AND_STORE: {
1673+
std::vector<uint8_t> input;
1674+
if (!recv_msg(sockfd, input)) {
1675+
return;
1676+
}
1677+
rpc_msg_graph_compute_and_store_rsp response;
1678+
if (!server.graph_compute_and_store(input, response)) {
1679+
return;
1680+
}
1681+
if (!send_msg(sockfd, &response, sizeof(response))) {
1682+
return;
1683+
}
1684+
break;
1685+
}
1686+
case RPC_CMD_GRAPH_RECOMPUTE: {
1687+
rpc_msg_graph_recompute_req request;
1688+
if (!recv_msg(sockfd, &request, sizeof(request))) {
1689+
return;
1690+
}
1691+
rpc_msg_graph_recompute_rsp response;
1692+
if (!server.graph_recompute(request, response)) {
1693+
return;
1694+
}
1695+
if (!send_msg(sockfd, &response, sizeof(response))) {
1696+
return;
1697+
}
1698+
break;
1699+
}
15881700
case RPC_CMD_GET_DEVICE_MEMORY: {
15891701
if (!recv_msg(sockfd, nullptr, 0)) {
15901702
return;

0 commit comments

Comments
 (0)