Skip to content

Commit 2272e04

Browse files
committed
rpc : reuse compute graphs
Store compute graphs on the server side and reuse them when possible. Compute graphs are kept in a ring buffer with fixed size, so we can avoid serializing and deserializing the same graph every time. Add two new commands: * RPC_CMD_GRAPH_COMPUTE_AND_STORE -- store the graph, compute it and return its ID * RPC_CMD_GRAPH_RECOMPUTE -- recompute the graph with the given ID Currently there is no good way to associate an ID with `ggml_cgraph`, so we abuse `tensor->extra` of the first node for this purpose.
1 parent 0bcb40b commit 2272e04

File tree

2 files changed

+136
-15
lines changed

2 files changed

+136
-15
lines changed

ggml/include/ggml-rpc.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ extern "C" {
88
#endif
99

1010
#define RPC_PROTO_MAJOR_VERSION 3
11-
#define RPC_PROTO_MINOR_VERSION 0
11+
#define RPC_PROTO_MINOR_VERSION 5
1212
#define RPC_PROTO_PATCH_VERSION 0
1313
#define GGML_RPC_MAX_SERVERS 16
1414

ggml/src/ggml-rpc/ggml-rpc.cpp

Lines changed: 135 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -106,13 +106,17 @@ enum rpc_cmd {
106106
RPC_CMD_GET_ALLOC_SIZE,
107107
RPC_CMD_HELLO,
108108
RPC_CMD_DEVICE_COUNT,
109+
RPC_CMD_GRAPH_COMPUTE_AND_STORE,
110+
RPC_CMD_GRAPH_RECOMPUTE,
109111
RPC_CMD_COUNT,
110112
};
111113

112114
static_assert(RPC_CMD_HELLO == 14, "RPC_CMD_HELLO must be always 14");
113115

114116
// Try RPC_CMD_SET_TENSOR_HASH first when data size is larger than this threshold
115-
const size_t HASH_THRESHOLD = 10 * 1024 * 1024;
117+
const size_t HASH_THRESHOLD = 10 * 1024 * 1024;
118+
const int MAX_STORED_GRAPHS = 64;
119+
const uint64_t INVALID_GRAPH_ID = UINT64_MAX;
116120

117121
struct rpc_msg_hello_rsp {
118122
uint8_t major;
@@ -217,6 +221,20 @@ struct rpc_msg_get_device_memory_rsp {
217221
uint64_t free_mem;
218222
uint64_t total_mem;
219223
};
224+
225+
struct rpc_msg_graph_compute_and_store_rsp {
226+
uint8_t result;
227+
uint64_t graph_id;
228+
};
229+
230+
struct rpc_msg_graph_recompute_req {
231+
uint64_t graph_id;
232+
};
233+
234+
struct rpc_msg_graph_recompute_rsp {
235+
uint8_t result;
236+
};
237+
220238
#pragma pack(pop)
221239

222240
// RPC data structures
@@ -238,6 +256,7 @@ struct ggml_backend_rpc_context {
238256
std::string endpoint;
239257
uint32_t device;
240258
std::string name;
259+
uint64_t curr_graph_id;
241260
};
242261

243262
struct ggml_backend_rpc_buffer_context {
@@ -592,6 +611,8 @@ static enum ggml_status ggml_backend_rpc_buffer_init_tensor(ggml_backend_buffer_
592611
bool status = send_rpc_cmd(ctx->sock, RPC_CMD_INIT_TENSOR, &request, sizeof(request), nullptr, 0);
593612
RPC_STATUS_ASSERT(status);
594613
}
614+
// HACK: use the extra field for storing the graph ID
615+
tensor->extra = reinterpret_cast<void*>(INVALID_GRAPH_ID);
595616
return GGML_STATUS_SUCCESS;
596617
}
597618

@@ -815,13 +836,30 @@ static void serialize_graph(uint32_t device, const ggml_cgraph * cgraph, std::ve
815836

816837
static enum ggml_status ggml_backend_rpc_graph_compute(ggml_backend_t backend, ggml_cgraph * cgraph) {
817838
ggml_backend_rpc_context * rpc_ctx = (ggml_backend_rpc_context *)backend->context;
818-
std::vector<uint8_t> input;
819-
serialize_graph(rpc_ctx->device, cgraph, input);
820-
rpc_msg_graph_compute_rsp response;
821-
auto sock = get_socket(rpc_ctx->endpoint);
822-
bool status = send_rpc_cmd(sock, RPC_CMD_GRAPH_COMPUTE, input.data(), input.size(), &response, sizeof(response));
823-
RPC_STATUS_ASSERT(status);
824-
return (enum ggml_status)response.result;
839+
840+
GGML_ASSERT(cgraph->n_nodes > 0);
841+
// HACK: we store the graph ID in the first node's extra field
842+
uint64_t stored_graph_id = reinterpret_cast<uint64_t>(cgraph->nodes[0]->extra);
843+
bool reuse_graph = stored_graph_id != INVALID_GRAPH_ID && (stored_graph_id + MAX_STORED_GRAPHS > rpc_ctx->curr_graph_id);
844+
if (reuse_graph) {
845+
rpc_msg_graph_recompute_req request;
846+
request.graph_id = stored_graph_id;
847+
rpc_msg_graph_recompute_rsp response;
848+
auto sock = get_socket(rpc_ctx->endpoint);
849+
bool status = send_rpc_cmd(sock, RPC_CMD_GRAPH_RECOMPUTE, &request, sizeof(request), &response, sizeof(response));
850+
RPC_STATUS_ASSERT(status);
851+
return (enum ggml_status)response.result;
852+
} else {
853+
std::vector<uint8_t> input;
854+
serialize_graph(rpc_ctx->device, cgraph, input);
855+
rpc_msg_graph_compute_and_store_rsp response;
856+
auto sock = get_socket(rpc_ctx->endpoint);
857+
bool status = send_rpc_cmd(sock, RPC_CMD_GRAPH_COMPUTE_AND_STORE, input.data(), input.size(), &response, sizeof(response));
858+
RPC_STATUS_ASSERT(status);
859+
rpc_ctx->curr_graph_id = response.graph_id;
860+
cgraph->nodes[0]->extra = reinterpret_cast<void*>(response.graph_id);
861+
return (enum ggml_status)response.result;
862+
}
825863
}
826864

827865
static ggml_backend_i ggml_backend_rpc_interface = {
@@ -878,9 +916,10 @@ ggml_backend_buffer_type_t ggml_backend_rpc_buffer_type(const char * endpoint, u
878916
ggml_backend_t ggml_backend_rpc_init(const char * endpoint, uint32_t device) {
879917
std::string dev_name = "RPC" + std::to_string(device) + "[" + std::string(endpoint) + "]";
880918
ggml_backend_rpc_context * ctx = new ggml_backend_rpc_context {
881-
/* .endpoint = */ endpoint,
882-
/* .device = */ device,
883-
/* .name = */ dev_name
919+
/* .endpoint = */ endpoint,
920+
/* .device = */ device,
921+
/* .name = */ dev_name,
922+
/* .curr_graph_id = */ 0,
884923
};
885924
auto reg = ggml_backend_rpc_add_server(endpoint);
886925
ggml_backend_t backend = new ggml_backend {
@@ -921,7 +960,8 @@ void ggml_backend_rpc_get_device_memory(const char * endpoint, uint32_t device,
921960
class rpc_server {
922961
public:
923962
rpc_server(std::vector<ggml_backend_t> backends, const char * cache_dir)
924-
: backends(std::move(backends)), cache_dir(cache_dir) {
963+
: backends(std::move(backends)), cache_dir(cache_dir), curr_graph_id(0) {
964+
stored_graphs.resize(MAX_STORED_GRAPHS);
925965
}
926966
~rpc_server();
927967

@@ -937,22 +977,34 @@ class rpc_server {
937977
bool get_tensor(const rpc_msg_get_tensor_req & request, std::vector<uint8_t> & response);
938978
bool copy_tensor(const rpc_msg_copy_tensor_req & request, rpc_msg_copy_tensor_rsp & response);
939979
bool graph_compute(const std::vector<uint8_t> & input, rpc_msg_graph_compute_rsp & response);
980+
bool graph_compute_and_store(const std::vector<uint8_t> & input, rpc_msg_graph_compute_and_store_rsp & response);
981+
bool graph_recompute(const rpc_msg_graph_recompute_req & request, rpc_msg_graph_recompute_rsp & response);
940982
bool init_tensor(const rpc_msg_init_tensor_req & request);
941983
bool get_alloc_size(const rpc_msg_get_alloc_size_req & request, rpc_msg_get_alloc_size_rsp & response);
942984
bool get_device_memory(const rpc_msg_get_device_memory_req & request, rpc_msg_get_device_memory_rsp & response);
943985

986+
struct stored_graph {
987+
uint32_t device;
988+
ggml_context_ptr ctx_ptr;
989+
ggml_cgraph * graph;
990+
};
991+
944992
private:
945993
bool get_cached_file(uint64_t hash, std::vector<uint8_t> & data);
946994
ggml_tensor * deserialize_tensor(struct ggml_context * ctx, const rpc_tensor * tensor);
947995
ggml_tensor * create_node(uint64_t id,
948996
struct ggml_context * ctx,
949997
const std::unordered_map<uint64_t, const rpc_tensor*> & tensor_ptrs,
950998
std::unordered_map<uint64_t, struct ggml_tensor*> & tensor_map);
999+
bool store_graph(const std::vector<uint8_t> & input, stored_graph & sg);
9511000

9521001

9531002
std::vector<ggml_backend_t> backends;
9541003
const char * cache_dir;
9551004
std::unordered_set<ggml_backend_buffer_t> buffers;
1005+
uint64_t curr_graph_id;
1006+
// ring buffer for storing graphs
1007+
std::vector<stored_graph> stored_graphs;
9561008
};
9571009

9581010
void rpc_server::hello(rpc_msg_hello_rsp & response) {
@@ -1394,7 +1446,7 @@ ggml_tensor * rpc_server::create_node(uint64_t id,
13941446
return result;
13951447
}
13961448

1397-
bool rpc_server::graph_compute(const std::vector<uint8_t> & input, rpc_msg_graph_compute_rsp & response) {
1449+
bool rpc_server::store_graph(const std::vector<uint8_t> & input, stored_graph & sg) {
13981450
// serialization format:
13991451
// | device (4 bytes) | n_nodes (4 bytes) | nodes (n_nodes * sizeof(uint64_t) | n_tensors (4 bytes) | tensors (n_tensors * sizeof(rpc_tensor)) |
14001452
if (input.size() < 2*sizeof(uint32_t)) {
@@ -1422,7 +1474,6 @@ bool rpc_server::graph_compute(const std::vector<uint8_t> & input, rpc_msg_graph
14221474
return false;
14231475
}
14241476
const rpc_tensor * tensors = (const rpc_tensor *)src;
1425-
LOG_DBG("[%s] device: %u, n_nodes: %u, n_tensors: %u\n", __func__, device, n_nodes, n_tensors);
14261477

14271478
size_t buf_size = ggml_tensor_overhead()*(n_nodes + n_tensors) + ggml_graph_overhead_custom(n_nodes, false);
14281479

@@ -1454,6 +1505,47 @@ bool rpc_server::graph_compute(const std::vector<uint8_t> & input, rpc_msg_graph
14541505
return false;
14551506
}
14561507
}
1508+
sg.ctx_ptr.swap(ctx_ptr);
1509+
sg.graph = graph;
1510+
sg.device = device;
1511+
return true;
1512+
}
1513+
1514+
bool rpc_server::graph_compute(const std::vector<uint8_t> & input, rpc_msg_graph_compute_rsp & response) {
1515+
stored_graph sg;
1516+
if (!store_graph(input, sg)) {
1517+
return false;
1518+
}
1519+
uint32_t device = sg.device;
1520+
LOG_DBG("[%s] device: %u, input: %zu bytes\n", __func__, device, input.size());
1521+
ggml_status status = ggml_backend_graph_compute(backends[device], sg.graph);
1522+
response.result = status;
1523+
return true;
1524+
}
1525+
1526+
bool rpc_server::graph_compute_and_store(const std::vector<uint8_t> & input, rpc_msg_graph_compute_and_store_rsp & response) {
1527+
int graph_slot = curr_graph_id % MAX_STORED_GRAPHS;
1528+
if (!store_graph(input, stored_graphs[graph_slot])) {
1529+
return false;
1530+
}
1531+
ggml_cgraph * graph = stored_graphs[graph_slot].graph;
1532+
uint32_t device = stored_graphs[graph_slot].device;
1533+
LOG_DBG("[%s] device: %u, input: %zu bytes, graph_id: %" PRIu64 "\n", __func__, device, input.size(), curr_graph_id);
1534+
ggml_status status = ggml_backend_graph_compute(backends[device], graph);
1535+
response.result = status;
1536+
response.graph_id = curr_graph_id;
1537+
curr_graph_id++;
1538+
return true;
1539+
}
1540+
1541+
bool rpc_server::graph_recompute(const rpc_msg_graph_recompute_req & request, rpc_msg_graph_recompute_rsp & response) {
1542+
int graph_slot = request.graph_id % MAX_STORED_GRAPHS;
1543+
if (stored_graphs[graph_slot].graph == nullptr) {
1544+
return false;
1545+
}
1546+
ggml_cgraph * graph = stored_graphs[graph_slot].graph;
1547+
uint32_t device = stored_graphs[graph_slot].device;
1548+
LOG_DBG("[%s] device: %u, graph_id: %" PRIu64 "\n", __func__, device, request.graph_id);
14571549
ggml_status status = ggml_backend_graph_compute(backends[device], graph);
14581550
response.result = status;
14591551
return true;
@@ -1699,6 +1791,35 @@ static void rpc_serve_client(const std::vector<ggml_backend_t> & backends, const
16991791
}
17001792
break;
17011793
}
1794+
1795+
case RPC_CMD_GRAPH_COMPUTE_AND_STORE: {
1796+
std::vector<uint8_t> input;
1797+
if (!recv_msg(sockfd, input)) {
1798+
return;
1799+
}
1800+
rpc_msg_graph_compute_and_store_rsp response;
1801+
if (!server.graph_compute_and_store(input, response)) {
1802+
return;
1803+
}
1804+
if (!send_msg(sockfd, &response, sizeof(response))) {
1805+
return;
1806+
}
1807+
break;
1808+
}
1809+
case RPC_CMD_GRAPH_RECOMPUTE: {
1810+
rpc_msg_graph_recompute_req request;
1811+
if (!recv_msg(sockfd, &request, sizeof(request))) {
1812+
return;
1813+
}
1814+
rpc_msg_graph_recompute_rsp response;
1815+
if (!server.graph_recompute(request, response)) {
1816+
return;
1817+
}
1818+
if (!send_msg(sockfd, &response, sizeof(response))) {
1819+
return;
1820+
}
1821+
break;
1822+
}
17021823
case RPC_CMD_GET_DEVICE_MEMORY: {
17031824
rpc_msg_get_device_memory_req request;
17041825
if (!recv_msg(sockfd, &request, sizeof(request))) {

0 commit comments

Comments
 (0)