@@ -106,13 +106,17 @@ enum rpc_cmd {
106106 RPC_CMD_GET_ALLOC_SIZE,
107107 RPC_CMD_HELLO,
108108 RPC_CMD_DEVICE_COUNT,
109+ RPC_CMD_GRAPH_COMPUTE_AND_STORE,
110+ RPC_CMD_GRAPH_RECOMPUTE,
109111 RPC_CMD_COUNT,
110112};
111113
112114static_assert (RPC_CMD_HELLO == 14 , " RPC_CMD_HELLO must be always 14" );
113115
114116// Try RPC_CMD_SET_TENSOR_HASH first when data size is larger than this threshold
115- const size_t HASH_THRESHOLD = 10 * 1024 * 1024 ;
117+ const size_t HASH_THRESHOLD = 10 * 1024 * 1024 ;
118+ const int MAX_STORED_GRAPHS = 64 ;
119+ const uint64_t INVALID_GRAPH_ID = UINT64_MAX;
116120
117121struct rpc_msg_hello_rsp {
118122 uint8_t major;
@@ -217,6 +221,20 @@ struct rpc_msg_get_device_memory_rsp {
217221 uint64_t free_mem;
218222 uint64_t total_mem;
219223};
224+
225+ struct rpc_msg_graph_compute_and_store_rsp {
226+ uint8_t result;
227+ uint64_t graph_id;
228+ };
229+
230+ struct rpc_msg_graph_recompute_req {
231+ uint64_t graph_id;
232+ };
233+
234+ struct rpc_msg_graph_recompute_rsp {
235+ uint8_t result;
236+ };
237+
220238#pragma pack(pop)
221239
222240// RPC data structures
@@ -238,6 +256,7 @@ struct ggml_backend_rpc_context {
238256 std::string endpoint;
239257 uint32_t device;
240258 std::string name;
259+ uint64_t curr_graph_id;
241260};
242261
243262struct ggml_backend_rpc_buffer_context {
@@ -592,6 +611,8 @@ static enum ggml_status ggml_backend_rpc_buffer_init_tensor(ggml_backend_buffer_
592611 bool status = send_rpc_cmd (ctx->sock , RPC_CMD_INIT_TENSOR, &request, sizeof (request), nullptr , 0 );
593612 RPC_STATUS_ASSERT (status);
594613 }
614+ // HACK: use the extra field for storing the graph ID
615+ tensor->extra = reinterpret_cast <void *>(INVALID_GRAPH_ID);
595616 return GGML_STATUS_SUCCESS;
596617}
597618
@@ -815,13 +836,30 @@ static void serialize_graph(uint32_t device, const ggml_cgraph * cgraph, std::ve
815836
816837static enum ggml_status ggml_backend_rpc_graph_compute (ggml_backend_t backend, ggml_cgraph * cgraph) {
817838 ggml_backend_rpc_context * rpc_ctx = (ggml_backend_rpc_context *)backend->context ;
818- std::vector<uint8_t > input;
819- serialize_graph (rpc_ctx->device , cgraph, input);
820- rpc_msg_graph_compute_rsp response;
821- auto sock = get_socket (rpc_ctx->endpoint );
822- bool status = send_rpc_cmd (sock, RPC_CMD_GRAPH_COMPUTE, input.data (), input.size (), &response, sizeof (response));
823- RPC_STATUS_ASSERT (status);
824- return (enum ggml_status)response.result ;
839+
840+ GGML_ASSERT (cgraph->n_nodes > 0 );
841+ // HACK: we store the graph ID in the first node's extra field
842+ uint64_t stored_graph_id = reinterpret_cast <uint64_t >(cgraph->nodes [0 ]->extra );
843+ bool reuse_graph = stored_graph_id != INVALID_GRAPH_ID && (stored_graph_id + MAX_STORED_GRAPHS > rpc_ctx->curr_graph_id );
844+ if (reuse_graph) {
845+ rpc_msg_graph_recompute_req request;
846+ request.graph_id = stored_graph_id;
847+ rpc_msg_graph_recompute_rsp response;
848+ auto sock = get_socket (rpc_ctx->endpoint );
849+ bool status = send_rpc_cmd (sock, RPC_CMD_GRAPH_RECOMPUTE, &request, sizeof (request), &response, sizeof (response));
850+ RPC_STATUS_ASSERT (status);
851+ return (enum ggml_status)response.result ;
852+ } else {
853+ std::vector<uint8_t > input;
854+ serialize_graph (rpc_ctx->device , cgraph, input);
855+ rpc_msg_graph_compute_and_store_rsp response;
856+ auto sock = get_socket (rpc_ctx->endpoint );
857+ bool status = send_rpc_cmd (sock, RPC_CMD_GRAPH_COMPUTE_AND_STORE, input.data (), input.size (), &response, sizeof (response));
858+ RPC_STATUS_ASSERT (status);
859+ rpc_ctx->curr_graph_id = response.graph_id ;
860+ cgraph->nodes [0 ]->extra = reinterpret_cast <void *>(response.graph_id );
861+ return (enum ggml_status)response.result ;
862+ }
825863}
826864
827865static ggml_backend_i ggml_backend_rpc_interface = {
@@ -878,9 +916,10 @@ ggml_backend_buffer_type_t ggml_backend_rpc_buffer_type(const char * endpoint, u
878916ggml_backend_t ggml_backend_rpc_init (const char * endpoint, uint32_t device) {
879917 std::string dev_name = " RPC" + std::to_string (device) + " [" + std::string (endpoint) + " ]" ;
880918 ggml_backend_rpc_context * ctx = new ggml_backend_rpc_context {
881- /* .endpoint = */ endpoint,
882- /* .device = */ device,
883- /* .name = */ dev_name
919+ /* .endpoint = */ endpoint,
920+ /* .device = */ device,
921+ /* .name = */ dev_name,
922+ /* .curr_graph_id = */ 0 ,
884923 };
885924 auto reg = ggml_backend_rpc_add_server (endpoint);
886925 ggml_backend_t backend = new ggml_backend {
@@ -921,7 +960,8 @@ void ggml_backend_rpc_get_device_memory(const char * endpoint, uint32_t device,
921960class rpc_server {
922961public:
923962 rpc_server (std::vector<ggml_backend_t > backends, const char * cache_dir)
924- : backends(std::move(backends)), cache_dir(cache_dir) {
963+ : backends(std::move(backends)), cache_dir(cache_dir), curr_graph_id(0 ) {
964+ stored_graphs.resize (MAX_STORED_GRAPHS);
925965 }
926966 ~rpc_server ();
927967
@@ -937,22 +977,34 @@ class rpc_server {
937977 bool get_tensor (const rpc_msg_get_tensor_req & request, std::vector<uint8_t > & response);
938978 bool copy_tensor (const rpc_msg_copy_tensor_req & request, rpc_msg_copy_tensor_rsp & response);
939979 bool graph_compute (const std::vector<uint8_t > & input, rpc_msg_graph_compute_rsp & response);
980+ bool graph_compute_and_store (const std::vector<uint8_t > & input, rpc_msg_graph_compute_and_store_rsp & response);
981+ bool graph_recompute (const rpc_msg_graph_recompute_req & request, rpc_msg_graph_recompute_rsp & response);
940982 bool init_tensor (const rpc_msg_init_tensor_req & request);
941983 bool get_alloc_size (const rpc_msg_get_alloc_size_req & request, rpc_msg_get_alloc_size_rsp & response);
942984 bool get_device_memory (const rpc_msg_get_device_memory_req & request, rpc_msg_get_device_memory_rsp & response);
943985
986+ struct stored_graph {
987+ uint32_t device;
988+ ggml_context_ptr ctx_ptr;
989+ ggml_cgraph * graph;
990+ };
991+
944992private:
945993 bool get_cached_file (uint64_t hash, std::vector<uint8_t > & data);
946994 ggml_tensor * deserialize_tensor (struct ggml_context * ctx, const rpc_tensor * tensor);
947995 ggml_tensor * create_node (uint64_t id,
948996 struct ggml_context * ctx,
949997 const std::unordered_map<uint64_t , const rpc_tensor*> & tensor_ptrs,
950998 std::unordered_map<uint64_t , struct ggml_tensor *> & tensor_map);
999+ bool store_graph (const std::vector<uint8_t > & input, stored_graph & sg);
9511000
9521001
9531002 std::vector<ggml_backend_t > backends;
9541003 const char * cache_dir;
9551004 std::unordered_set<ggml_backend_buffer_t > buffers;
1005+ uint64_t curr_graph_id;
1006+ // ring buffer for storing graphs
1007+ std::vector<stored_graph> stored_graphs;
9561008};
9571009
9581010void rpc_server::hello (rpc_msg_hello_rsp & response) {
@@ -1394,7 +1446,7 @@ ggml_tensor * rpc_server::create_node(uint64_t id,
13941446 return result;
13951447}
13961448
1397- bool rpc_server::graph_compute (const std::vector<uint8_t > & input, rpc_msg_graph_compute_rsp & response ) {
1449+ bool rpc_server::store_graph (const std::vector<uint8_t > & input, stored_graph & sg ) {
13981450 // serialization format:
13991451 // | device (4 bytes) | n_nodes (4 bytes) | nodes (n_nodes * sizeof(uint64_t) | n_tensors (4 bytes) | tensors (n_tensors * sizeof(rpc_tensor)) |
14001452 if (input.size () < 2 *sizeof (uint32_t )) {
@@ -1422,7 +1474,6 @@ bool rpc_server::graph_compute(const std::vector<uint8_t> & input, rpc_msg_graph
14221474 return false ;
14231475 }
14241476 const rpc_tensor * tensors = (const rpc_tensor *)src;
1425- LOG_DBG (" [%s] device: %u, n_nodes: %u, n_tensors: %u\n " , __func__, device, n_nodes, n_tensors);
14261477
14271478 size_t buf_size = ggml_tensor_overhead ()*(n_nodes + n_tensors) + ggml_graph_overhead_custom (n_nodes, false );
14281479
@@ -1454,6 +1505,47 @@ bool rpc_server::graph_compute(const std::vector<uint8_t> & input, rpc_msg_graph
14541505 return false ;
14551506 }
14561507 }
1508+ sg.ctx_ptr .swap (ctx_ptr);
1509+ sg.graph = graph;
1510+ sg.device = device;
1511+ return true ;
1512+ }
1513+
1514+ bool rpc_server::graph_compute (const std::vector<uint8_t > & input, rpc_msg_graph_compute_rsp & response) {
1515+ stored_graph sg;
1516+ if (!store_graph (input, sg)) {
1517+ return false ;
1518+ }
1519+ uint32_t device = sg.device ;
1520+ LOG_DBG (" [%s] device: %u, input: %zu bytes\n " , __func__, device, input.size ());
1521+ ggml_status status = ggml_backend_graph_compute (backends[device], sg.graph );
1522+ response.result = status;
1523+ return true ;
1524+ }
1525+
1526+ bool rpc_server::graph_compute_and_store (const std::vector<uint8_t > & input, rpc_msg_graph_compute_and_store_rsp & response) {
1527+ int graph_slot = curr_graph_id % MAX_STORED_GRAPHS;
1528+ if (!store_graph (input, stored_graphs[graph_slot])) {
1529+ return false ;
1530+ }
1531+ ggml_cgraph * graph = stored_graphs[graph_slot].graph ;
1532+ uint32_t device = stored_graphs[graph_slot].device ;
1533+ LOG_DBG (" [%s] device: %u, input: %zu bytes, graph_id: %" PRIu64 " \n " , __func__, device, input.size (), curr_graph_id);
1534+ ggml_status status = ggml_backend_graph_compute (backends[device], graph);
1535+ response.result = status;
1536+ response.graph_id = curr_graph_id;
1537+ curr_graph_id++;
1538+ return true ;
1539+ }
1540+
1541+ bool rpc_server::graph_recompute (const rpc_msg_graph_recompute_req & request, rpc_msg_graph_recompute_rsp & response) {
1542+ int graph_slot = request.graph_id % MAX_STORED_GRAPHS;
1543+ if (stored_graphs[graph_slot].graph == nullptr ) {
1544+ return false ;
1545+ }
1546+ ggml_cgraph * graph = stored_graphs[graph_slot].graph ;
1547+ uint32_t device = stored_graphs[graph_slot].device ;
1548+ LOG_DBG (" [%s] device: %u, graph_id: %" PRIu64 " \n " , __func__, device, request.graph_id );
14571549 ggml_status status = ggml_backend_graph_compute (backends[device], graph);
14581550 response.result = status;
14591551 return true ;
@@ -1699,6 +1791,35 @@ static void rpc_serve_client(const std::vector<ggml_backend_t> & backends, const
16991791 }
17001792 break ;
17011793 }
1794+
1795+ case RPC_CMD_GRAPH_COMPUTE_AND_STORE: {
1796+ std::vector<uint8_t > input;
1797+ if (!recv_msg (sockfd, input)) {
1798+ return ;
1799+ }
1800+ rpc_msg_graph_compute_and_store_rsp response;
1801+ if (!server.graph_compute_and_store (input, response)) {
1802+ return ;
1803+ }
1804+ if (!send_msg (sockfd, &response, sizeof (response))) {
1805+ return ;
1806+ }
1807+ break ;
1808+ }
1809+ case RPC_CMD_GRAPH_RECOMPUTE: {
1810+ rpc_msg_graph_recompute_req request;
1811+ if (!recv_msg (sockfd, &request, sizeof (request))) {
1812+ return ;
1813+ }
1814+ rpc_msg_graph_recompute_rsp response;
1815+ if (!server.graph_recompute (request, response)) {
1816+ return ;
1817+ }
1818+ if (!send_msg (sockfd, &response, sizeof (response))) {
1819+ return ;
1820+ }
1821+ break ;
1822+ }
17021823 case RPC_CMD_GET_DEVICE_MEMORY: {
17031824 rpc_msg_get_device_memory_req request;
17041825 if (!recv_msg (sockfd, &request, sizeof (request))) {
0 commit comments