2626# include < unistd.h>
2727#endif
2828#include < cstring>
29+ #include < fstream>
30+ #include < filesystem>
31+
32+ namespace fs = std::filesystem;
2933
3034#ifdef _WIN32
3135typedef SOCKET sockfd_t ;
@@ -80,6 +84,7 @@ enum rpc_cmd {
8084 RPC_CMD_FREE_BUFFER,
8185 RPC_CMD_BUFFER_CLEAR,
8286 RPC_CMD_SET_TENSOR,
87+ RPC_CMD_SET_TENSOR_HASH,
8388 RPC_CMD_GET_TENSOR,
8489 RPC_CMD_COPY_TENSOR,
8590 RPC_CMD_GRAPH_COMPUTE,
@@ -89,6 +94,9 @@ enum rpc_cmd {
8994 RPC_CMD_COUNT,
9095};
9196
97+ // Try RPC_CMD_SET_TENSOR_HASH first when data size is larger than this threshold
98+ const size_t HASH_THRESHOLD = 10 * 1024 * 1024 ;
99+
92100struct rpc_msg_get_alloc_size_req {
93101 rpc_tensor tensor;
94102};
@@ -135,6 +143,10 @@ struct rpc_msg_buffer_clear_req {
135143 uint8_t value;
136144};
137145
146+ struct rpc_msg_set_tensor_hash_rsp {
147+ uint8_t result;
148+ };
149+
138150struct rpc_msg_get_tensor_req {
139151 rpc_tensor tensor;
140152 uint64_t offset;
@@ -187,6 +199,18 @@ struct ggml_backend_rpc_buffer_context {
187199
188200// RPC helper functions
189201
202+ // Computes FNV-1a hash of the data
203+ static uint64_t fnv_hash (const uint8_t * data, size_t len) {
204+ const uint64_t fnv_prime = 0x100000001b3ULL ;
205+ uint64_t hash = 0xcbf29ce484222325ULL ;
206+
207+ for (size_t i = 0 ; i < len; ++i) {
208+ hash ^= data[i];
209+ hash *= fnv_prime;
210+ }
211+ return hash;
212+ }
213+
190214static std::shared_ptr<socket_t > make_socket (sockfd_t fd) {
191215#ifdef _WIN32
192216 if (fd == INVALID_SOCKET) {
@@ -483,10 +507,26 @@ static enum ggml_status ggml_backend_rpc_buffer_init_tensor(ggml_backend_buffer_
483507
484508static void ggml_backend_rpc_buffer_set_tensor (ggml_backend_buffer_t buffer, ggml_tensor * tensor, const void * data, size_t offset, size_t size) {
485509 ggml_backend_rpc_buffer_context * ctx = (ggml_backend_rpc_buffer_context *)buffer->context ;
486- // input serialization format: | rpc_tensor | offset (8 bytes) | data (size bytes) |
510+ rpc_tensor rpc_tensor = serialize_tensor (tensor);
511+ if (size > HASH_THRESHOLD) {
512+ // input serialization format: | rpc_tensor | offset (8 bytes) | hash (8 bytes)
513+ size_t input_size = sizeof (rpc_tensor) + sizeof (uint64_t ) + sizeof (uint64_t );
514+ std::vector<uint8_t > input (input_size, 0 );
515+ uint64_t hash = fnv_hash ((const uint8_t *)data, size);
516+ memcpy (input.data (), &rpc_tensor, sizeof (rpc_tensor));
517+ memcpy (input.data () + sizeof (rpc_tensor), &offset, sizeof (offset));
518+ memcpy (input.data () + sizeof (rpc_tensor) + sizeof (offset), &hash, sizeof (hash));
519+ rpc_msg_set_tensor_hash_rsp response;
520+ bool status = send_rpc_cmd (ctx->sock , RPC_CMD_SET_TENSOR_HASH, input.data (), input.size (), &response, sizeof (response));
521+ GGML_ASSERT (status);
522+ if (response.result ) {
523+ // the server has the same data, no need to send it
524+ return ;
525+ }
526+ }
527+ // input serialization format: | rpc_tensor | offset (8 bytes) | data (size bytes)
487528 size_t input_size = sizeof (rpc_tensor) + sizeof (uint64_t ) + size;
488529 std::vector<uint8_t > input (input_size, 0 );
489- rpc_tensor rpc_tensor = serialize_tensor (tensor);
490530 memcpy (input.data (), &rpc_tensor, sizeof (rpc_tensor));
491531 memcpy (input.data () + sizeof (rpc_tensor), &offset, sizeof (offset));
492532 memcpy (input.data () + sizeof (rpc_tensor) + sizeof (offset), data, size);
@@ -772,7 +812,9 @@ void ggml_backend_rpc_get_device_memory(const char * endpoint, size_t * free, si
772812
773813class rpc_server {
774814public:
775- rpc_server (ggml_backend_t backend) : backend(backend) {}
815+ rpc_server (ggml_backend_t backend, const char * cache_dir)
816+ : backend(backend), cache_dir(cache_dir) {
817+ }
776818 ~rpc_server ();
777819
778820 void alloc_buffer (const rpc_msg_alloc_buffer_req & request, rpc_msg_alloc_buffer_rsp & response);
@@ -782,13 +824,15 @@ class rpc_server {
782824 bool free_buffer (const rpc_msg_free_buffer_req & request);
783825 bool buffer_clear (const rpc_msg_buffer_clear_req & request);
784826 bool set_tensor (const std::vector<uint8_t > & input);
827+ bool set_tensor_hash (const std::vector<uint8_t > & input, rpc_msg_set_tensor_hash_rsp & response);
785828 bool get_tensor (const rpc_msg_get_tensor_req & request, std::vector<uint8_t > & response);
786829 bool copy_tensor (const rpc_msg_copy_tensor_req & request, rpc_msg_copy_tensor_rsp & response);
787830 bool graph_compute (const std::vector<uint8_t > & input, rpc_msg_graph_compute_rsp & response);
788831 bool init_tensor (const rpc_msg_init_tensor_req & request);
789832 bool get_alloc_size (const rpc_msg_get_alloc_size_req & request, rpc_msg_get_alloc_size_rsp & response);
790833
791834private:
835+ bool get_cached_file (uint64_t hash, std::vector<uint8_t > & data);
792836 ggml_tensor * deserialize_tensor (struct ggml_context * ctx, const rpc_tensor * tensor);
793837 ggml_tensor * create_node (uint64_t id,
794838 struct ggml_context * ctx,
@@ -797,6 +841,7 @@ class rpc_server {
797841
798842
799843 ggml_backend_t backend;
844+ const char * cache_dir;
800845 std::unordered_set<ggml_backend_buffer_t > buffers;
801846};
802847
@@ -960,11 +1005,84 @@ bool rpc_server::set_tensor(const std::vector<uint8_t> & input) {
9601005 }
9611006
9621007 const void * data = input.data () + sizeof (rpc_tensor) + sizeof (offset);
1008+ if (cache_dir && size > HASH_THRESHOLD) {
1009+ uint64_t hash = fnv_hash ((const uint8_t *)data, size);
1010+ char hash_str[17 ];
1011+ snprintf (hash_str, sizeof (hash_str), " %016" PRIx64, hash);
1012+ // save to cache_dir/hash_str
1013+ fs::path cache_file = fs::path (cache_dir) / hash_str;
1014+ std::ofstream ofs (cache_file, std::ios::binary);
1015+ ofs.write ((const char *)data, size);
1016+ }
9631017 ggml_backend_tensor_set (tensor, data, offset, size);
9641018 ggml_free (ctx);
9651019 return true ;
9661020}
9671021
1022+ bool rpc_server::get_cached_file (uint64_t hash, std::vector<uint8_t > & data) {
1023+ if (!cache_dir) {
1024+ return false ;
1025+ }
1026+ char hash_str[17 ];
1027+ snprintf (hash_str, sizeof (hash_str), " %016" PRIx64, hash);
1028+ fs::path cache_file = fs::path (cache_dir) / hash_str;
1029+ if (!fs::exists (cache_file)) {
1030+ return false ;
1031+ }
1032+ std::ifstream ifs (cache_file, std::ios::binary);
1033+ ifs.seekg (0 , std::ios::end);
1034+ size_t size = ifs.tellg ();
1035+ ifs.seekg (0 , std::ios::beg);
1036+ data.resize (size);
1037+ ifs.read ((char *)data.data (), size);
1038+ return true ;
1039+ }
1040+
1041+ bool rpc_server::set_tensor_hash (const std::vector<uint8_t > & input, rpc_msg_set_tensor_hash_rsp & response)
1042+ {
1043+ // serialization format: | rpc_tensor | offset (8 bytes) | hash (8 bytes) |
1044+ if (input.size () != sizeof (rpc_tensor) + 16 ) {
1045+ return false ;
1046+ }
1047+ const rpc_tensor * in_tensor = (const rpc_tensor *)input.data ();
1048+ uint64_t offset;
1049+ memcpy (&offset, input.data () + sizeof (rpc_tensor), sizeof (offset));
1050+ const uint64_t * hash = (const uint64_t *)(input.data () + sizeof (rpc_tensor) + sizeof (offset));
1051+ std::vector<uint8_t > cached_file;
1052+ if (!get_cached_file (*hash, cached_file)) {
1053+ response.result = 0 ;
1054+ return true ;
1055+ }
1056+ size_t size = cached_file.size ();
1057+ struct ggml_init_params params {
1058+ /* .mem_size =*/ ggml_tensor_overhead(),
1059+ /* .mem_buffer =*/ NULL ,
1060+ /* .no_alloc =*/ true ,
1061+ };
1062+ struct ggml_context * ctx = ggml_init (params);
1063+ ggml_tensor * tensor = deserialize_tensor (ctx, in_tensor);
1064+ if (tensor == nullptr ) {
1065+ GGML_LOG_ERROR (" [%s] error deserializing tensor\n " , __func__);
1066+ ggml_free (ctx);
1067+ return false ;
1068+ }
1069+ GGML_PRINT_DEBUG (" [%s] buffer: %p, data: %p, offset: %" PRIu64 " , size: %zu, hash: %" PRIx64 " \n " , __func__, (void *)tensor->buffer , tensor->data , offset, size, *hash);
1070+
1071+ // sanitize tensor->data
1072+ {
1073+ const size_t p0 = (size_t ) ggml_backend_buffer_get_base (tensor->buffer );
1074+ const size_t p1 = p0 + ggml_backend_buffer_get_size (tensor->buffer );
1075+
1076+ if (in_tensor->data + offset < p0 || in_tensor->data + offset >= p1 || size > (p1 - in_tensor->data - offset)) {
1077+ GGML_ABORT (" [%s] tensor->data out of bounds\n " , __func__);
1078+ }
1079+ }
1080+ ggml_backend_tensor_set (tensor, cached_file.data (), offset, size);
1081+ response.result = 1 ;
1082+ ggml_free (ctx);
1083+ return true ;
1084+ }
1085+
9681086bool rpc_server::init_tensor (const rpc_msg_init_tensor_req & request) {
9691087 struct ggml_init_params params {
9701088 /* .mem_size =*/ ggml_tensor_overhead(),
@@ -1148,8 +1266,9 @@ rpc_server::~rpc_server() {
11481266 }
11491267}
11501268
1151- static void rpc_serve_client (ggml_backend_t backend, sockfd_t sockfd, size_t free_mem, size_t total_mem) {
1152- rpc_server server (backend);
1269+ static void rpc_serve_client (ggml_backend_t backend, const char * cache_dir,
1270+ sockfd_t sockfd, size_t free_mem, size_t total_mem) {
1271+ rpc_server server (backend, cache_dir);
11531272 while (true ) {
11541273 uint8_t cmd;
11551274 if (!recv_data (sockfd, &cmd, 1 )) {
@@ -1260,6 +1379,20 @@ static void rpc_serve_client(ggml_backend_t backend, sockfd_t sockfd, size_t fre
12601379 }
12611380 break ;
12621381 }
1382+ case RPC_CMD_SET_TENSOR_HASH: {
1383+ std::vector<uint8_t > input;
1384+ if (!recv_msg (sockfd, input)) {
1385+ return ;
1386+ }
1387+ rpc_msg_set_tensor_hash_rsp response;
1388+ if (!server.set_tensor_hash (input, response)) {
1389+ return ;
1390+ }
1391+ if (!send_msg (sockfd, &response, sizeof (response))) {
1392+ return ;
1393+ }
1394+ break ;
1395+ }
12631396 case RPC_CMD_INIT_TENSOR: {
12641397 rpc_msg_init_tensor_req request;
12651398 if (!recv_msg (sockfd, &request,sizeof (request))) {
@@ -1335,7 +1468,9 @@ static void rpc_serve_client(ggml_backend_t backend, sockfd_t sockfd, size_t fre
13351468 }
13361469}
13371470
1338- void ggml_backend_rpc_start_server (ggml_backend_t backend, const char * endpoint, size_t free_mem, size_t total_mem) {
1471+ void ggml_backend_rpc_start_server (ggml_backend_t backend, const char * endpoint,
1472+ const char * cache_dir,
1473+ size_t free_mem, size_t total_mem) {
13391474 std::string host;
13401475 int port;
13411476 if (!parse_endpoint (endpoint, host, port)) {
@@ -1364,7 +1499,7 @@ void ggml_backend_rpc_start_server(ggml_backend_t backend, const char * endpoint
13641499 }
13651500 printf (" Accepted client connection, free_mem=%zu, total_mem=%zu\n " , free_mem, total_mem);
13661501 fflush (stdout);
1367- rpc_serve_client (backend, client_socket->fd , free_mem, total_mem);
1502+ rpc_serve_client (backend, cache_dir, client_socket->fd , free_mem, total_mem);
13681503 printf (" Client connection closed\n " );
13691504 fflush (stdout);
13701505 }
0 commit comments