2626# include < unistd.h>
2727#endif
2828#include < cstring>
29+ #include < fstream>
30+ #include < filesystem>
31+
32+ namespace fs = std::filesystem;
2933
3034#ifdef _WIN32
3135typedef SOCKET sockfd_t ;
@@ -80,6 +84,7 @@ enum rpc_cmd {
8084 RPC_CMD_FREE_BUFFER,
8185 RPC_CMD_BUFFER_CLEAR,
8286 RPC_CMD_SET_TENSOR,
87+ RPC_CMD_SET_TENSOR_HASH,
8388 RPC_CMD_GET_TENSOR,
8489 RPC_CMD_COPY_TENSOR,
8590 RPC_CMD_GRAPH_COMPUTE,
@@ -89,6 +94,9 @@ enum rpc_cmd {
8994 RPC_CMD_COUNT,
9095};
9196
97+ // Try RPC_CMD_SET_TENSOR_HASH first when data size is larger than this threshold
98+ const size_t HASH_THRESHOLD = 10 * 1024 * 1024 ;
99+
92100struct rpc_msg_get_alloc_size_req {
93101 rpc_tensor tensor;
94102};
@@ -135,6 +143,10 @@ struct rpc_msg_buffer_clear_req {
135143 uint8_t value;
136144};
137145
146+ struct rpc_msg_set_tensor_hash_rsp {
147+ uint8_t result;
148+ };
149+
138150struct rpc_msg_get_tensor_req {
139151 rpc_tensor tensor;
140152 uint64_t offset;
@@ -187,6 +199,18 @@ struct ggml_backend_rpc_buffer_context {
187199
188200// RPC helper functions
189201
202+ // Computes FNV-1a hash of the data
203+ static uint64_t fnv_hash (const uint8_t * data, size_t len) {
204+ const uint64_t fnv_prime = 0x100000001b3ULL ;
205+ uint64_t hash = 0xcbf29ce484222325ULL ;
206+
207+ for (size_t i = 0 ; i < len; ++i) {
208+ hash ^= data[i];
209+ hash *= fnv_prime;
210+ }
211+ return hash;
212+ }
213+
190214static std::shared_ptr<socket_t > make_socket (sockfd_t fd) {
191215#ifdef _WIN32
192216 if (fd == INVALID_SOCKET) {
@@ -483,10 +507,26 @@ static enum ggml_status ggml_backend_rpc_buffer_init_tensor(ggml_backend_buffer_
483507
484508static void ggml_backend_rpc_buffer_set_tensor (ggml_backend_buffer_t buffer, ggml_tensor * tensor, const void * data, size_t offset, size_t size) {
485509 ggml_backend_rpc_buffer_context * ctx = (ggml_backend_rpc_buffer_context *)buffer->context ;
486- // input serialization format: | rpc_tensor | offset (8 bytes) | data (size bytes) |
510+ rpc_tensor rpc_tensor = serialize_tensor (tensor);
511+ if (size > HASH_THRESHOLD) {
512+ // input serialization format: | rpc_tensor | offset (8 bytes) | hash (8 bytes)
513+ size_t input_size = sizeof (rpc_tensor) + sizeof (uint64_t ) + sizeof (uint64_t );
514+ std::vector<uint8_t > input (input_size, 0 );
515+ uint64_t hash = fnv_hash ((const uint8_t *)data, size);
516+ memcpy (input.data (), &rpc_tensor, sizeof (rpc_tensor));
517+ memcpy (input.data () + sizeof (rpc_tensor), &offset, sizeof (offset));
518+ memcpy (input.data () + sizeof (rpc_tensor) + sizeof (offset), &hash, sizeof (hash));
519+ rpc_msg_set_tensor_hash_rsp response;
520+ bool status = send_rpc_cmd (ctx->sock , RPC_CMD_SET_TENSOR_HASH, input.data (), input.size (), &response, sizeof (response));
521+ GGML_ASSERT (status);
522+ if (response.result ) {
523+ // the server has the same data, no need to send it
524+ return ;
525+ }
526+ }
527+ // input serialization format: | rpc_tensor | offset (8 bytes) | data (size bytes)
487528 size_t input_size = sizeof (rpc_tensor) + sizeof (uint64_t ) + size;
488529 std::vector<uint8_t > input (input_size, 0 );
489- rpc_tensor rpc_tensor = serialize_tensor (tensor);
490530 memcpy (input.data (), &rpc_tensor, sizeof (rpc_tensor));
491531 memcpy (input.data () + sizeof (rpc_tensor), &offset, sizeof (offset));
492532 memcpy (input.data () + sizeof (rpc_tensor) + sizeof (offset), data, size);
@@ -772,7 +812,9 @@ void ggml_backend_rpc_get_device_memory(const char * endpoint, size_t * free, si
772812
773813class rpc_server {
774814public:
775- rpc_server (ggml_backend_t backend) : backend(backend) {}
815+ rpc_server (ggml_backend_t backend, const char * cache_dir)
816+ : backend(backend), cache_dir(cache_dir) {
817+ }
776818 ~rpc_server ();
777819
778820 void alloc_buffer (const rpc_msg_alloc_buffer_req & request, rpc_msg_alloc_buffer_rsp & response);
@@ -782,13 +824,15 @@ class rpc_server {
782824 bool free_buffer (const rpc_msg_free_buffer_req & request);
783825 bool buffer_clear (const rpc_msg_buffer_clear_req & request);
784826 bool set_tensor (const std::vector<uint8_t > & input);
827+ bool set_tensor_hash (const std::vector<uint8_t > & input, rpc_msg_set_tensor_hash_rsp & response);
785828 bool get_tensor (const rpc_msg_get_tensor_req & request, std::vector<uint8_t > & response);
786829 bool copy_tensor (const rpc_msg_copy_tensor_req & request, rpc_msg_copy_tensor_rsp & response);
787830 bool graph_compute (const std::vector<uint8_t > & input, rpc_msg_graph_compute_rsp & response);
788831 bool init_tensor (const rpc_msg_init_tensor_req & request);
789832 bool get_alloc_size (const rpc_msg_get_alloc_size_req & request, rpc_msg_get_alloc_size_rsp & response);
790833
791834private:
835+ bool get_cached_file (uint64_t hash, std::vector<uint8_t > & data);
792836 ggml_tensor * deserialize_tensor (struct ggml_context * ctx, const rpc_tensor * tensor);
793837 ggml_tensor * create_node (uint64_t id,
794838 struct ggml_context * ctx,
@@ -797,6 +841,7 @@ class rpc_server {
797841
798842
799843 ggml_backend_t backend;
844+ const char * cache_dir;
800845 std::unordered_set<ggml_backend_buffer_t > buffers;
801846};
802847
@@ -960,11 +1005,85 @@ bool rpc_server::set_tensor(const std::vector<uint8_t> & input) {
9601005 }
9611006
9621007 const void * data = input.data () + sizeof (rpc_tensor) + sizeof (offset);
1008+ if (cache_dir && size > HASH_THRESHOLD) {
1009+ uint64_t hash = fnv_hash ((const uint8_t *)data, size);
1010+ char hash_str[17 ];
1011+ snprintf (hash_str, sizeof (hash_str), " %016" PRIx64, hash);
1012+ // save to cache_dir/hash_str
1013+ fs::path cache_file = fs::path (cache_dir) / hash_str;
1014+ std::ofstream ofs (cache_file, std::ios::binary);
1015+ ofs.write ((const char *)data, size);
1016+ printf (" [%s] saved to '%s'\n " , __func__, cache_file.c_str ());
1017+ }
9631018 ggml_backend_tensor_set (tensor, data, offset, size);
9641019 ggml_free (ctx);
9651020 return true ;
9661021}
9671022
1023+ bool rpc_server::get_cached_file (uint64_t hash, std::vector<uint8_t > & data) {
1024+ if (!cache_dir) {
1025+ return false ;
1026+ }
1027+ char hash_str[17 ];
1028+ snprintf (hash_str, sizeof (hash_str), " %016" PRIx64, hash);
1029+ fs::path cache_file = fs::path (cache_dir) / hash_str;
1030+ if (!fs::exists (cache_file)) {
1031+ return false ;
1032+ }
1033+ std::ifstream ifs (cache_file, std::ios::binary);
1034+ ifs.seekg (0 , std::ios::end);
1035+ size_t size = ifs.tellg ();
1036+ ifs.seekg (0 , std::ios::beg);
1037+ data.resize (size);
1038+ ifs.read ((char *)data.data (), size);
1039+ return true ;
1040+ }
1041+
1042+ bool rpc_server::set_tensor_hash (const std::vector<uint8_t > & input, rpc_msg_set_tensor_hash_rsp & response)
1043+ {
1044+ // serialization format: | rpc_tensor | offset (8 bytes) | hash (8 bytes) |
1045+ if (input.size () != sizeof (rpc_tensor) + 16 ) {
1046+ return false ;
1047+ }
1048+ const rpc_tensor * in_tensor = (const rpc_tensor *)input.data ();
1049+ uint64_t offset;
1050+ memcpy (&offset, input.data () + sizeof (rpc_tensor), sizeof (offset));
1051+ const uint64_t * hash = (const uint64_t *)(input.data () + sizeof (rpc_tensor) + sizeof (offset));
1052+ std::vector<uint8_t > cached_file;
1053+ if (!get_cached_file (*hash, cached_file)) {
1054+ response.result = 0 ;
1055+ return true ;
1056+ }
1057+ size_t size = cached_file.size ();
1058+ struct ggml_init_params params {
1059+ /* .mem_size =*/ ggml_tensor_overhead(),
1060+ /* .mem_buffer =*/ NULL ,
1061+ /* .no_alloc =*/ true ,
1062+ };
1063+ struct ggml_context * ctx = ggml_init (params);
1064+ ggml_tensor * tensor = deserialize_tensor (ctx, in_tensor);
1065+ if (tensor == nullptr ) {
1066+ GGML_LOG_ERROR (" [%s] error deserializing tensor\n " , __func__);
1067+ ggml_free (ctx);
1068+ return false ;
1069+ }
1070+ GGML_PRINT_DEBUG (" [%s] buffer: %p, data: %p, offset: %" PRIu64 " , size: %zu, hash: %" PRIx64 " \n " , __func__, (void *)tensor->buffer , tensor->data , offset, size, *hash);
1071+
1072+ // sanitize tensor->data
1073+ {
1074+ const size_t p0 = (size_t ) ggml_backend_buffer_get_base (tensor->buffer );
1075+ const size_t p1 = p0 + ggml_backend_buffer_get_size (tensor->buffer );
1076+
1077+ if (in_tensor->data + offset < p0 || in_tensor->data + offset >= p1 || size > (p1 - in_tensor->data - offset)) {
1078+ GGML_ABORT (" [%s] tensor->data out of bounds\n " , __func__);
1079+ }
1080+ }
1081+ ggml_backend_tensor_set (tensor, cached_file.data (), offset, size);
1082+ response.result = 1 ;
1083+ ggml_free (ctx);
1084+ return true ;
1085+ }
1086+
9681087bool rpc_server::init_tensor (const rpc_msg_init_tensor_req & request) {
9691088 struct ggml_init_params params {
9701089 /* .mem_size =*/ ggml_tensor_overhead(),
@@ -1148,8 +1267,9 @@ rpc_server::~rpc_server() {
11481267 }
11491268}
11501269
1151- static void rpc_serve_client (ggml_backend_t backend, sockfd_t sockfd, size_t free_mem, size_t total_mem) {
1152- rpc_server server (backend);
1270+ static void rpc_serve_client (ggml_backend_t backend, const char * cache_dir,
1271+ sockfd_t sockfd, size_t free_mem, size_t total_mem) {
1272+ rpc_server server (backend, cache_dir);
11531273 while (true ) {
11541274 uint8_t cmd;
11551275 if (!recv_data (sockfd, &cmd, 1 )) {
@@ -1260,6 +1380,20 @@ static void rpc_serve_client(ggml_backend_t backend, sockfd_t sockfd, size_t fre
12601380 }
12611381 break ;
12621382 }
1383+ case RPC_CMD_SET_TENSOR_HASH: {
1384+ std::vector<uint8_t > input;
1385+ if (!recv_msg (sockfd, input)) {
1386+ return ;
1387+ }
1388+ rpc_msg_set_tensor_hash_rsp response;
1389+ if (!server.set_tensor_hash (input, response)) {
1390+ return ;
1391+ }
1392+ if (!send_msg (sockfd, &response, sizeof (response))) {
1393+ return ;
1394+ }
1395+ break ;
1396+ }
12631397 case RPC_CMD_INIT_TENSOR: {
12641398 rpc_msg_init_tensor_req request;
12651399 if (!recv_msg (sockfd, &request,sizeof (request))) {
@@ -1335,7 +1469,9 @@ static void rpc_serve_client(ggml_backend_t backend, sockfd_t sockfd, size_t fre
13351469 }
13361470}
13371471
1338- void ggml_backend_rpc_start_server (ggml_backend_t backend, const char * endpoint, size_t free_mem, size_t total_mem) {
1472+ void ggml_backend_rpc_start_server (ggml_backend_t backend, const char * endpoint,
1473+ const char * cache_dir,
1474+ size_t free_mem, size_t total_mem) {
13391475 std::string host;
13401476 int port;
13411477 if (!parse_endpoint (endpoint, host, port)) {
@@ -1364,7 +1500,7 @@ void ggml_backend_rpc_start_server(ggml_backend_t backend, const char * endpoint
13641500 }
13651501 printf (" Accepted client connection, free_mem=%zu, total_mem=%zu\n " , free_mem, total_mem);
13661502 fflush (stdout);
1367- rpc_serve_client (backend, client_socket->fd , free_mem, total_mem);
1503+ rpc_serve_client (backend, cache_dir, client_socket->fd , free_mem, total_mem);
13681504 printf (" Client connection closed\n " );
13691505 fflush (stdout);
13701506 }
0 commit comments