@@ -99,12 +99,17 @@ enum rpc_cmd {
9999 RPC_CMD_INIT_TENSOR,
100100 RPC_CMD_GET_ALLOC_SIZE,
101101 RPC_CMD_HELLO,
102+ RPC_CMD_GRAPH_COMPUTE_AND_STORE,
103+ RPC_CMD_GRAPH_RECOMPUTE,
102104 RPC_CMD_COUNT,
103105};
104106
105107// Try RPC_CMD_SET_TENSOR_HASH first when data size is larger than this threshold
106108const size_t HASH_THRESHOLD = 10 * 1024 * 1024 ;
107109
110+ const int MAX_STORED_GRAPHS = 64 ;
111+ const int64_t INVALID_GRAPH_ID = -1 ;
112+
108113struct rpc_msg_hello_rsp {
109114 uint8_t major;
110115 uint8_t minor;
@@ -186,6 +191,19 @@ struct rpc_msg_graph_compute_rsp {
186191 uint8_t result;
187192};
188193
194+ struct rpc_msg_graph_compute_and_store_rsp {
195+ uint8_t result;
196+ int32_t graph_id;
197+ };
198+
199+ struct rpc_msg_graph_recompute_req {
200+ int32_t graph_id;
201+ };
202+
203+ struct rpc_msg_graph_recompute_rsp {
204+ uint8_t result;
205+ };
206+
189207struct rpc_msg_get_device_memory_rsp {
190208 uint64_t free_mem;
191209 uint64_t total_mem;
@@ -209,6 +227,7 @@ struct ggml_backend_rpc_buffer_type_context {
209227struct ggml_backend_rpc_context {
210228 std::string endpoint;
211229 std::string name;
230+ int32_t curr_graph_id;
212231};
213232
214233struct ggml_backend_rpc_buffer_context {
@@ -563,6 +582,8 @@ static enum ggml_status ggml_backend_rpc_buffer_init_tensor(ggml_backend_buffer_
563582 bool status = send_rpc_cmd (ctx->sock , RPC_CMD_INIT_TENSOR, &request, sizeof (request), nullptr , 0 );
564583 RPC_STATUS_ASSERT (status);
565584 }
585+ // HACK: use the extra field for storing the graph ID
586+ tensor->extra = reinterpret_cast <void *>(INVALID_GRAPH_ID);
566587 return GGML_STATUS_SUCCESS;
567588}
568589
@@ -772,13 +793,29 @@ static void serialize_graph(const ggml_cgraph * cgraph, std::vector<uint8_t> & o
772793
773794static enum ggml_status ggml_backend_rpc_graph_compute (ggml_backend_t backend, ggml_cgraph * cgraph) {
774795 ggml_backend_rpc_context * rpc_ctx = (ggml_backend_rpc_context *)backend->context ;
775- std::vector<uint8_t > input;
776- serialize_graph (cgraph, input);
777- rpc_msg_graph_compute_rsp response;
778- auto sock = get_socket (rpc_ctx->endpoint );
779- bool status = send_rpc_cmd (sock, RPC_CMD_GRAPH_COMPUTE, input.data (), input.size (), &response, sizeof (response));
780- RPC_STATUS_ASSERT (status);
781- return (enum ggml_status)response.result ;
796+ GGML_ASSERT (cgraph->n_nodes > 0 );
797+ // HACK: we store the graph ID in the first node's extra field
798+ int64_t stored_graph_id = reinterpret_cast <int64_t >(cgraph->nodes [0 ]->extra );
799+ bool reuse_graph = stored_graph_id != INVALID_GRAPH_ID && (stored_graph_id + MAX_STORED_GRAPHS > rpc_ctx->curr_graph_id );
800+ if (reuse_graph) {
801+ rpc_msg_graph_recompute_req request;
802+ request.graph_id = stored_graph_id;
803+ rpc_msg_graph_recompute_rsp response;
804+ auto sock = get_socket (rpc_ctx->endpoint );
805+ bool status = send_rpc_cmd (sock, RPC_CMD_GRAPH_RECOMPUTE, &request, sizeof (request), &response, sizeof (response));
806+ RPC_STATUS_ASSERT (status);
807+ return (enum ggml_status)response.result ;
808+ } else {
809+ std::vector<uint8_t > input;
810+ serialize_graph (cgraph, input);
811+ rpc_msg_graph_compute_and_store_rsp response;
812+ auto sock = get_socket (rpc_ctx->endpoint );
813+ bool status = send_rpc_cmd (sock, RPC_CMD_GRAPH_COMPUTE_AND_STORE, input.data (), input.size (), &response, sizeof (response));
814+ RPC_STATUS_ASSERT (status);
815+ rpc_ctx->curr_graph_id = response.graph_id ;
816+ cgraph->nodes [0 ]->extra = reinterpret_cast <void *>(response.graph_id );
817+ return (enum ggml_status)response.result ;
818+ }
782819}
783820
784821static ggml_backend_i ggml_backend_rpc_interface = {
@@ -831,8 +868,9 @@ ggml_backend_buffer_type_t ggml_backend_rpc_buffer_type(const char * endpoint) {
831868
832869ggml_backend_t ggml_backend_rpc_init (const char * endpoint) {
833870 ggml_backend_rpc_context * ctx = new ggml_backend_rpc_context {
834- /* .endpoint = */ endpoint,
835- /* .name = */ " RPC[" + std::string (endpoint) + " ]" ,
871+ /* .endpoint = */ endpoint,
872+ /* .name = */ " RPC[" + std::string (endpoint) + " ]" ,
873+ /* . curr_graph_id = */ 0 ,
836874 };
837875
838876 ggml_backend_t backend = new ggml_backend {
@@ -871,7 +909,8 @@ void ggml_backend_rpc_get_device_memory(const char * endpoint, size_t * free, si
871909class rpc_server {
872910public:
873911 rpc_server (ggml_backend_t backend, const char * cache_dir)
874- : backend(backend), cache_dir(cache_dir) {
912+ : backend(backend), cache_dir(cache_dir), curr_graph_id(0 ) {
913+ stored_graphs.resize (MAX_STORED_GRAPHS);
875914 }
876915 ~rpc_server ();
877916
@@ -887,21 +926,31 @@ class rpc_server {
887926 bool get_tensor (const rpc_msg_get_tensor_req & request, std::vector<uint8_t > & response);
888927 bool copy_tensor (const rpc_msg_copy_tensor_req & request, rpc_msg_copy_tensor_rsp & response);
889928 bool graph_compute (const std::vector<uint8_t > & input, rpc_msg_graph_compute_rsp & response);
929+ bool graph_compute_and_store (const std::vector<uint8_t > & input, rpc_msg_graph_compute_and_store_rsp & response);
930+ bool graph_recompute (const rpc_msg_graph_recompute_req & request, rpc_msg_graph_recompute_rsp & response);
890931 bool init_tensor (const rpc_msg_init_tensor_req & request);
891932 bool get_alloc_size (const rpc_msg_get_alloc_size_req & request, rpc_msg_get_alloc_size_rsp & response);
892933
934+ struct stored_graph {
935+ ggml_context_ptr ctx_ptr;
936+ ggml_cgraph * graph;
937+ };
938+
893939private:
894940 bool get_cached_file (uint64_t hash, std::vector<uint8_t > & data);
895941 ggml_tensor * deserialize_tensor (struct ggml_context * ctx, const rpc_tensor * tensor);
896942 ggml_tensor * create_node (uint64_t id,
897943 struct ggml_context * ctx,
898944 const std::unordered_map<uint64_t , const rpc_tensor*> & tensor_ptrs,
899945 std::unordered_map<uint64_t , struct ggml_tensor *> & tensor_map);
900-
946+ bool store_graph ( const std::vector< uint8_t > & input, stored_graph & sg);
901947
902948 ggml_backend_t backend;
903949 const char * cache_dir;
904950 std::unordered_set<ggml_backend_buffer_t > buffers;
951+ int64_t curr_graph_id;
952+ // ring buffer for storing graphs
953+ std::vector<stored_graph> stored_graphs;
905954};
906955
907956void rpc_server::hello (rpc_msg_hello_rsp & response) {
@@ -1323,7 +1372,7 @@ ggml_tensor * rpc_server::create_node(uint64_t id,
13231372 return result;
13241373}
13251374
1326- bool rpc_server::graph_compute (const std::vector<uint8_t > & input, rpc_msg_graph_compute_rsp & response ) {
1375+ bool rpc_server::store_graph (const std::vector<uint8_t > & input, stored_graph & sg ) {
13271376 // serialization format:
13281377 // | n_nodes (4 bytes) | nodes (n_nodes * sizeof(uint64_t) | n_tensors (4 bytes) | tensors (n_tensors * sizeof(rpc_tensor)) |
13291378 if (input.size () < sizeof (uint32_t )) {
@@ -1373,7 +1422,42 @@ bool rpc_server::graph_compute(const std::vector<uint8_t> & input, rpc_msg_graph
13731422 return false ;
13741423 }
13751424 }
1376- ggml_status status = ggml_backend_graph_compute (backend, graph);
1425+ sg.ctx_ptr .swap (ctx_ptr);
1426+ sg.graph = graph;
1427+ return true ;
1428+ }
1429+
1430+ bool rpc_server::graph_compute (const std::vector<uint8_t > & input, rpc_msg_graph_compute_rsp & response) {
1431+ stored_graph sg;
1432+ if (!store_graph (input, sg)) {
1433+ return false ;
1434+ }
1435+ ggml_status status = ggml_backend_graph_compute (backend, sg.graph );
1436+ response.result = status;
1437+ return true ;
1438+ }
1439+
1440+ bool rpc_server::graph_compute_and_store (const std::vector<uint8_t > & input, rpc_msg_graph_compute_and_store_rsp & response) {
1441+ int graph_slot = curr_graph_id % MAX_STORED_GRAPHS;
1442+ if (!store_graph (input, stored_graphs[graph_slot])) {
1443+ return false ;
1444+ }
1445+ ggml_status status = ggml_backend_graph_compute (backend, stored_graphs[graph_slot].graph );
1446+ response.result = status;
1447+ response.graph_id = curr_graph_id;
1448+ curr_graph_id++;
1449+ return true ;
1450+ }
1451+
1452+ bool rpc_server::graph_recompute (const rpc_msg_graph_recompute_req & request, rpc_msg_graph_recompute_rsp & response) {
1453+ if (request.graph_id < 0 ) {
1454+ return false ;
1455+ }
1456+ int graph_slot = request.graph_id % MAX_STORED_GRAPHS;
1457+ if (stored_graphs[graph_slot].graph == nullptr ) {
1458+ return false ;
1459+ }
1460+ ggml_status status = ggml_backend_graph_compute (backend, stored_graphs[graph_slot].graph );
13771461 response.result = status;
13781462 return true ;
13791463}
@@ -1585,6 +1669,34 @@ static void rpc_serve_client(ggml_backend_t backend, const char * cache_dir,
15851669 }
15861670 break ;
15871671 }
1672+ case RPC_CMD_GRAPH_COMPUTE_AND_STORE: {
1673+ std::vector<uint8_t > input;
1674+ if (!recv_msg (sockfd, input)) {
1675+ return ;
1676+ }
1677+ rpc_msg_graph_compute_and_store_rsp response;
1678+ if (!server.graph_compute_and_store (input, response)) {
1679+ return ;
1680+ }
1681+ if (!send_msg (sockfd, &response, sizeof (response))) {
1682+ return ;
1683+ }
1684+ break ;
1685+ }
1686+ case RPC_CMD_GRAPH_RECOMPUTE: {
1687+ rpc_msg_graph_recompute_req request;
1688+ if (!recv_msg (sockfd, &request, sizeof (request))) {
1689+ return ;
1690+ }
1691+ rpc_msg_graph_recompute_rsp response;
1692+ if (!server.graph_recompute (request, response)) {
1693+ return ;
1694+ }
1695+ if (!send_msg (sockfd, &response, sizeof (response))) {
1696+ return ;
1697+ }
1698+ break ;
1699+ }
15881700 case RPC_CMD_GET_DEVICE_MEMORY: {
15891701 if (!recv_msg (sockfd, nullptr , 0 )) {
15901702 return ;
0 commit comments