2626# include < unistd.h>
2727#endif
2828#include < cstring>
29+ #include < fstream>
30+ #include < filesystem>
31+
32+ namespace fs = std::filesystem;
2933
3034#ifdef _WIN32
3135typedef SOCKET sockfd_t ;
@@ -80,6 +84,7 @@ enum rpc_cmd {
8084 RPC_CMD_FREE_BUFFER,
8185 RPC_CMD_BUFFER_CLEAR,
8286 RPC_CMD_SET_TENSOR,
87+ RPC_CMD_SET_TENSOR_HASH,
8388 RPC_CMD_GET_TENSOR,
8489 RPC_CMD_COPY_TENSOR,
8590 RPC_CMD_GRAPH_COMPUTE,
@@ -89,6 +94,9 @@ enum rpc_cmd {
8994 RPC_CMD_COUNT,
9095};
9196
97+ // Try RPC_CMD_SET_TENSOR_HASH first when data size is larger than this threshold
98+ const size_t HASH_THRESHOLD = 1024 * 1024 ;
99+
92100struct rpc_msg_get_alloc_size_req {
93101 rpc_tensor tensor;
94102};
@@ -135,6 +143,10 @@ struct rpc_msg_buffer_clear_req {
135143 uint8_t value;
136144};
137145
146+ struct rpc_msg_set_tensor_hash_rsp {
147+ uint8_t result;
148+ };
149+
138150struct rpc_msg_get_tensor_req {
139151 rpc_tensor tensor;
140152 uint64_t offset;
@@ -187,6 +199,18 @@ struct ggml_backend_rpc_buffer_context {
187199
188200// RPC helper functions
189201
202+ // Computes FNV-1a hash of the data
203+ static uint64_t fnv_hash (const uint8_t * data, size_t len) {
204+ const uint64_t fnv_prime = 0x100000001b3ULL ;
205+ uint64_t hash = 0xcbf29ce484222325ULL ;
206+
207+ for (size_t i = 0 ; i < len; ++i) {
208+ hash ^= data[i];
209+ hash *= fnv_prime;
210+ }
211+ return hash;
212+ }
213+
190214static std::shared_ptr<socket_t > make_socket (sockfd_t fd) {
191215#ifdef _WIN32
192216 if (fd == INVALID_SOCKET) {
@@ -483,10 +507,26 @@ static enum ggml_status ggml_backend_rpc_buffer_init_tensor(ggml_backend_buffer_
483507
484508static void ggml_backend_rpc_buffer_set_tensor (ggml_backend_buffer_t buffer, ggml_tensor * tensor, const void * data, size_t offset, size_t size) {
485509 ggml_backend_rpc_buffer_context * ctx = (ggml_backend_rpc_buffer_context *)buffer->context ;
486- // input serialization format: | rpc_tensor | offset (8 bytes) | data (size bytes) |
510+ rpc_tensor rpc_tensor = serialize_tensor (tensor);
511+ if (size > HASH_THRESHOLD) {
512+ // input serialization format: | rpc_tensor | offset (8 bytes) | hash (8 bytes)
513+ size_t input_size = sizeof (rpc_tensor) + sizeof (uint64_t ) + sizeof (uint64_t );
514+ std::vector<uint8_t > input (input_size, 0 );
515+ uint64_t hash = fnv_hash ((const uint8_t *)data, size);
516+ memcpy (input.data (), &rpc_tensor, sizeof (rpc_tensor));
517+ memcpy (input.data () + sizeof (rpc_tensor), &offset, sizeof (offset));
518+ memcpy (input.data () + sizeof (rpc_tensor) + sizeof (offset), &hash, sizeof (hash));
519+ rpc_msg_set_tensor_hash_rsp response;
520+ bool status = send_rpc_cmd (ctx->sock , RPC_CMD_SET_TENSOR_HASH, input.data (), input.size (), &response, sizeof (response));
521+ GGML_ASSERT (status);
522+ if (response.result ) {
523+ // the server has the same data, no need to send it
524+ return ;
525+ }
526+ }
527+ // input serialization format: | rpc_tensor | offset (8 bytes) | data (size bytes)
487528 size_t input_size = sizeof (rpc_tensor) + sizeof (uint64_t ) + size;
488529 std::vector<uint8_t > input (input_size, 0 );
489- rpc_tensor rpc_tensor = serialize_tensor (tensor);
490530 memcpy (input.data (), &rpc_tensor, sizeof (rpc_tensor));
491531 memcpy (input.data () + sizeof (rpc_tensor), &offset, sizeof (offset));
492532 memcpy (input.data () + sizeof (rpc_tensor) + sizeof (offset), data, size);
@@ -772,7 +812,10 @@ void ggml_backend_rpc_get_device_memory(const char * endpoint, size_t * free, si
772812
773813class rpc_server {
774814public:
775- rpc_server (ggml_backend_t backend) : backend(backend) {}
815+ rpc_server (ggml_backend_t backend, const fs::path & cache_dir,
816+ const std::unordered_map<uint64_t , ggml_tensor *> & tensor_cache)
817+ : backend(backend), cache_dir(cache_dir), tensor_cache(tensor_cache) {
818+ }
776819 ~rpc_server ();
777820
778821 void alloc_buffer (const rpc_msg_alloc_buffer_req & request, rpc_msg_alloc_buffer_rsp & response);
@@ -782,13 +825,15 @@ class rpc_server {
782825 bool free_buffer (const rpc_msg_free_buffer_req & request);
783826 bool buffer_clear (const rpc_msg_buffer_clear_req & request);
784827 bool set_tensor (const std::vector<uint8_t > & input);
828+ bool set_tensor_hash (const std::vector<uint8_t > & input, rpc_msg_set_tensor_hash_rsp & response);
785829 bool get_tensor (const rpc_msg_get_tensor_req & request, std::vector<uint8_t > & response);
786830 bool copy_tensor (const rpc_msg_copy_tensor_req & request, rpc_msg_copy_tensor_rsp & response);
787831 bool graph_compute (const std::vector<uint8_t > & input, rpc_msg_graph_compute_rsp & response);
788832 bool init_tensor (const rpc_msg_init_tensor_req & request);
789833 bool get_alloc_size (const rpc_msg_get_alloc_size_req & request, rpc_msg_get_alloc_size_rsp & response);
790834
791835private:
836+ bool get_cached_file (uint64_t hash, std::vector<uint8_t > & data);
792837 ggml_tensor * deserialize_tensor (struct ggml_context * ctx, const rpc_tensor * tensor);
793838 ggml_tensor * create_node (uint64_t id,
794839 struct ggml_context * ctx,
@@ -797,7 +842,9 @@ class rpc_server {
797842
798843
799844 ggml_backend_t backend;
845+ fs::path cache_dir;
800846 std::unordered_set<ggml_backend_buffer_t > buffers;
847+ const std::unordered_map<uint64_t , ggml_tensor *> & tensor_cache;
801848};
802849
803850bool rpc_server::get_alloc_size (const rpc_msg_get_alloc_size_req & request, rpc_msg_get_alloc_size_rsp & response) {
@@ -960,11 +1007,97 @@ bool rpc_server::set_tensor(const std::vector<uint8_t> & input) {
9601007 }
9611008
9621009 const void * data = input.data () + sizeof (rpc_tensor) + sizeof (offset);
1010+ if (size > HASH_THRESHOLD) {
1011+ uint64_t hash = fnv_hash ((const uint8_t *)data, size);
1012+ char hash_str[17 ];
1013+ snprintf (hash_str, sizeof (hash_str), " %016" PRIx64, hash);
1014+ // save to cache_dir/hash_str
1015+ fs::path cache_file = cache_dir / hash_str;
1016+ std::ofstream ofs (cache_file, std::ios::binary);
1017+ ofs.write ((const char *)data, size);
1018+ }
9631019 ggml_backend_tensor_set (tensor, data, offset, size);
9641020 ggml_free (ctx);
9651021 return true ;
9661022}
9671023
1024+ bool rpc_server::get_cached_file (uint64_t hash, std::vector<uint8_t > & data) {
1025+ char hash_str[17 ];
1026+ snprintf (hash_str, sizeof (hash_str), " %016" PRIx64, hash);
1027+ fs::path cache_file = cache_dir / hash_str;
1028+ if (!fs::exists (cache_file)) {
1029+ return false ;
1030+ }
1031+ std::ifstream ifs (cache_file, std::ios::binary);
1032+ ifs.seekg (0 , std::ios::end);
1033+ size_t size = ifs.tellg ();
1034+ ifs.seekg (0 , std::ios::beg);
1035+ data.resize (size);
1036+ ifs.read ((char *)data.data (), size);
1037+ return true ;
1038+ }
1039+
1040+ bool rpc_server::set_tensor_hash (const std::vector<uint8_t > & input, rpc_msg_set_tensor_hash_rsp & response)
1041+ {
1042+ // serialization format: | rpc_tensor | offset (8 bytes) | hash (8 bytes) |
1043+ if (input.size () != sizeof (rpc_tensor) + 16 ) {
1044+ return false ;
1045+ }
1046+ const rpc_tensor * in_tensor = (const rpc_tensor *)input.data ();
1047+ uint64_t offset;
1048+ memcpy (&offset, input.data () + sizeof (rpc_tensor), sizeof (offset));
1049+ const uint64_t * hash = (const uint64_t *)(input.data () + sizeof (rpc_tensor) + sizeof (offset));
1050+ bool have_cached_tensor = false ;
1051+ ggml_tensor * cached_tensor = nullptr ;
1052+ bool have_cached_file = false ;
1053+ std::vector<uint8_t > cached_file;
1054+ size_t size = 0 ;
1055+ if (tensor_cache.find (*hash) != tensor_cache.end ()) {
1056+ have_cached_tensor = true ;
1057+ cached_tensor = tensor_cache.at (*hash);
1058+ size = ggml_nbytes (cached_tensor);
1059+ } else if (get_cached_file (*hash, cached_file)) {
1060+ have_cached_file = true ;
1061+ size = cached_file.size ();
1062+ } else {
1063+ response.result = 0 ;
1064+ return true ;
1065+ }
1066+ struct ggml_init_params params {
1067+ /* .mem_size =*/ ggml_tensor_overhead(),
1068+ /* .mem_buffer =*/ NULL ,
1069+ /* .no_alloc =*/ true ,
1070+ };
1071+ struct ggml_context * ctx = ggml_init (params);
1072+ ggml_tensor * tensor = deserialize_tensor (ctx, in_tensor);
1073+ if (tensor == nullptr ) {
1074+ GGML_LOG_ERROR (" [%s] error deserializing tensor\n " , __func__);
1075+ ggml_free (ctx);
1076+ return false ;
1077+ }
1078+ GGML_PRINT_DEBUG (" [%s] buffer: %p, data: %p, offset: %" PRIu64 " , size: %zu, hash: %" PRIx64 " \n " , __func__, (void *)tensor->buffer , tensor->data , offset, size, *hash);
1079+
1080+ // sanitize tensor->data
1081+ {
1082+ const size_t p0 = (size_t ) ggml_backend_buffer_get_base (tensor->buffer );
1083+ const size_t p1 = p0 + ggml_backend_buffer_get_size (tensor->buffer );
1084+
1085+ if (in_tensor->data + offset < p0 || in_tensor->data + offset >= p1 || size > (p1 - in_tensor->data - offset)) {
1086+ GGML_ABORT (" [%s] tensor->data out of bounds\n " , __func__);
1087+ }
1088+ }
1089+ if (have_cached_tensor) {
1090+ ggml_backend_tensor_set (tensor, cached_tensor->data , offset, size);
1091+ } else if (have_cached_file) {
1092+ ggml_backend_tensor_set (tensor, cached_file.data (), offset, size);
1093+ } else {
1094+ GGML_ABORT (" [%s] no cached tensor or file\n " , __func__);
1095+ }
1096+ response.result = 1 ;
1097+ ggml_free (ctx);
1098+ return true ;
1099+ }
1100+
9681101bool rpc_server::init_tensor (const rpc_msg_init_tensor_req & request) {
9691102 struct ggml_init_params params {
9701103 /* .mem_size =*/ ggml_tensor_overhead(),
@@ -1148,8 +1281,10 @@ rpc_server::~rpc_server() {
11481281 }
11491282}
11501283
1151- static void rpc_serve_client (ggml_backend_t backend, sockfd_t sockfd, size_t free_mem, size_t total_mem) {
1152- rpc_server server (backend);
1284+ static void rpc_serve_client (ggml_backend_t backend, fs::path cache_dir,
1285+ const std::unordered_map<uint64_t , ggml_tensor *> & tensor_cache,
1286+ sockfd_t sockfd, size_t free_mem, size_t total_mem) {
1287+ rpc_server server (backend, cache_dir, tensor_cache);
11531288 while (true ) {
11541289 uint8_t cmd;
11551290 if (!recv_data (sockfd, &cmd, 1 )) {
@@ -1260,6 +1395,20 @@ static void rpc_serve_client(ggml_backend_t backend, sockfd_t sockfd, size_t fre
12601395 }
12611396 break ;
12621397 }
1398+ case RPC_CMD_SET_TENSOR_HASH: {
1399+ std::vector<uint8_t > input;
1400+ if (!recv_msg (sockfd, input)) {
1401+ return ;
1402+ }
1403+ rpc_msg_set_tensor_hash_rsp response;
1404+ if (!server.set_tensor_hash (input, response)) {
1405+ return ;
1406+ }
1407+ if (!send_msg (sockfd, &response, sizeof (response))) {
1408+ return ;
1409+ }
1410+ break ;
1411+ }
12631412 case RPC_CMD_INIT_TENSOR: {
12641413 rpc_msg_init_tensor_req request;
12651414 if (!recv_msg (sockfd, &request,sizeof (request))) {
@@ -1335,7 +1484,9 @@ static void rpc_serve_client(ggml_backend_t backend, sockfd_t sockfd, size_t fre
13351484 }
13361485}
13371486
1338- void ggml_backend_rpc_start_server (ggml_backend_t backend, const char * endpoint, size_t free_mem, size_t total_mem) {
1487+ void ggml_backend_rpc_start_server (ggml_backend_t backend, const char * endpoint,
1488+ const char * gguf_path, const char * cache_dir,
1489+ size_t free_mem, size_t total_mem) {
13391490 std::string host;
13401491 int port;
13411492 if (!parse_endpoint (endpoint, host, port)) {
@@ -1351,6 +1502,28 @@ void ggml_backend_rpc_start_server(ggml_backend_t backend, const char * endpoint
13511502 }
13521503 }
13531504#endif
1505+ gguf_context * ctx = nullptr ;
1506+ std::unordered_map<uint64_t , ggml_tensor *> tensor_cache;
1507+ if (gguf_path != nullptr ) {
1508+ struct ggml_context * ctx_data = NULL ;
1509+ struct gguf_init_params params = {
1510+ /* .no_alloc = */ false ,
1511+ /* .ctx = */ &ctx_data,
1512+ };
1513+ ctx = gguf_init_from_file (gguf_path, params);
1514+ if (ctx == nullptr ) {
1515+ fprintf (stderr, " Failed to load GGUF file: %s\n " , gguf_path);
1516+ return ;
1517+ }
1518+ const int n_tensors = gguf_get_n_tensors (ctx);
1519+ for (int i = 0 ; i < n_tensors; ++i) {
1520+ const char * name = gguf_get_tensor_name (ctx, i);
1521+ ggml_tensor * cur = ggml_get_tensor (ctx_data, name);
1522+ size_t n_bytes = ggml_nbytes (cur);
1523+ uint64_t hash = fnv_hash ((const uint8_t *)cur->data , n_bytes);
1524+ tensor_cache[hash] = cur;
1525+ }
1526+ }
13541527 auto server_socket = create_server_socket (host.c_str (), port);
13551528 if (server_socket == nullptr ) {
13561529 fprintf (stderr, " Failed to create server socket\n " );
@@ -1364,10 +1537,13 @@ void ggml_backend_rpc_start_server(ggml_backend_t backend, const char * endpoint
13641537 }
13651538 printf (" Accepted client connection, free_mem=%zu, total_mem=%zu\n " , free_mem, total_mem);
13661539 fflush (stdout);
1367- rpc_serve_client (backend, client_socket->fd , free_mem, total_mem);
1540+ rpc_serve_client (backend, cache_dir, tensor_cache, client_socket->fd , free_mem, total_mem);
13681541 printf (" Client connection closed\n " );
13691542 fflush (stdout);
13701543 }
1544+ if (ctx != nullptr ) {
1545+ gguf_free (ctx);
1546+ }
13711547#ifdef _WIN32
13721548 WSACleanup ();
13731549#endif
0 commit comments