Skip to content

Commit b7bda76

Browse files
committed
rpc : send hash when tensor data is above some fixed threshold
ref #10095
1 parent 960e726 commit b7bda76

File tree

3 files changed

+215
-13
lines changed

3 files changed

+215
-13
lines changed

examples/rpc/rpc-server.cpp

Lines changed: 29 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -24,20 +24,27 @@
2424
#endif
2525
#include <string>
2626
#include <stdio.h>
27+
#include <filesystem>
28+
29+
namespace fs = std::filesystem;
2730

2831
struct rpc_server_params {
2932
std::string host = "127.0.0.1";
3033
int port = 50052;
3134
size_t backend_mem = 0;
35+
std::string gguf_path = "";
36+
std::string cache_dir = "";
3237
};
3338

3439
static void print_usage(int /*argc*/, char ** argv, rpc_server_params params) {
3540
fprintf(stderr, "Usage: %s [options]\n\n", argv[0]);
3641
fprintf(stderr, "options:\n");
37-
fprintf(stderr, " -h, --help show this help message and exit\n");
38-
fprintf(stderr, " -H HOST, --host HOST host to bind to (default: %s)\n", params.host.c_str());
39-
fprintf(stderr, " -p PORT, --port PORT port to bind to (default: %d)\n", params.port);
40-
fprintf(stderr, " -m MEM, --mem MEM backend memory size (in MB)\n");
42+
fprintf(stderr, " -h, --help show this help message and exit\n");
43+
fprintf(stderr, " -H HOST, --host HOST host to bind to (default: %s)\n", params.host.c_str());
44+
fprintf(stderr, " -p PORT, --port PORT port to bind to (default: %d)\n", params.port);
45+
fprintf(stderr, " -f PATH, --gguf PATH path to GGUF file\n");
46+
fprintf(stderr, " -d DIR, --cache-dir DIR local cache dir\n");
47+
fprintf(stderr, " -m MEM, --mem MEM backend memory size (in MB)\n");
4148
fprintf(stderr, "\n");
4249
}
4350

@@ -58,6 +65,21 @@ static bool rpc_server_params_parse(int argc, char ** argv, rpc_server_params &
5865
if (params.port <= 0 || params.port > 65535) {
5966
return false;
6067
}
68+
} else if (arg == "-f" || arg == "--gguf") {
69+
if (++i >= argc) {
70+
return false;
71+
}
72+
params.gguf_path = argv[i];
73+
} else if (arg == "-d" || arg == "--cache-dir") {
74+
if (++i >= argc) {
75+
return false;
76+
}
77+
fs::path cache_dir(argv[i]);
78+
if (!fs::is_directory(cache_dir)) {
79+
fprintf(stderr, "error: cache dir does not exist: %s\n", cache_dir.c_str());
80+
return false;
81+
}
82+
params.cache_dir = argv[i];
6183
} else if (arg == "-m" || arg == "--mem") {
6284
if (++i >= argc) {
6385
return false;
@@ -164,8 +186,10 @@ int main(int argc, char * argv[]) {
164186
} else {
165187
get_backend_memory(&free_mem, &total_mem);
166188
}
189+
const char * gguf_path = params.gguf_path.empty() ? nullptr : params.gguf_path.c_str();
190+
const char * cache_dir = params.cache_dir.empty() ? nullptr : params.cache_dir.c_str();
167191
printf("Starting RPC server on %s, backend memory: %zu MB\n", endpoint.c_str(), free_mem / (1024 * 1024));
168-
ggml_backend_rpc_start_server(backend, endpoint.c_str(), free_mem, total_mem);
192+
ggml_backend_rpc_start_server(backend, endpoint.c_str(), gguf_path, cache_dir, free_mem, total_mem);
169193
ggml_backend_free(backend);
170194
return 0;
171195
}

ggml/include/ggml-rpc.h

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,9 @@ GGML_BACKEND_API ggml_backend_buffer_type_t ggml_backend_rpc_buffer_type(const c
1717

1818
GGML_BACKEND_API void ggml_backend_rpc_get_device_memory(const char * endpoint, size_t * free, size_t * total);
1919

20-
GGML_BACKEND_API void ggml_backend_rpc_start_server(ggml_backend_t backend, const char * endpoint, size_t free_mem, size_t total_mem);
20+
GGML_BACKEND_API void ggml_backend_rpc_start_server(ggml_backend_t backend, const char * endpoint,
21+
const char * gguf_path, const char * cache_dir,
22+
size_t free_mem, size_t total_mem);
2123

2224
GGML_BACKEND_API ggml_backend_reg_t ggml_backend_rpc_reg(void);
2325

ggml/src/ggml-rpc/ggml-rpc.cpp

Lines changed: 183 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,10 @@
2626
# include <unistd.h>
2727
#endif
2828
#include <cstring>
29+
#include <fstream>
30+
#include <filesystem>
31+
32+
namespace fs = std::filesystem;
2933

3034
#ifdef _WIN32
3135
typedef SOCKET sockfd_t;
@@ -80,6 +84,7 @@ enum rpc_cmd {
8084
RPC_CMD_FREE_BUFFER,
8185
RPC_CMD_BUFFER_CLEAR,
8286
RPC_CMD_SET_TENSOR,
87+
RPC_CMD_SET_TENSOR_HASH,
8388
RPC_CMD_GET_TENSOR,
8489
RPC_CMD_COPY_TENSOR,
8590
RPC_CMD_GRAPH_COMPUTE,
@@ -89,6 +94,9 @@ enum rpc_cmd {
8994
RPC_CMD_COUNT,
9095
};
9196

97+
// Try RPC_CMD_SET_TENSOR_HASH first when data size is larger than this threshold
98+
const size_t HASH_THRESHOLD = 1024 * 1024;
99+
92100
struct rpc_msg_get_alloc_size_req {
93101
rpc_tensor tensor;
94102
};
@@ -135,6 +143,10 @@ struct rpc_msg_buffer_clear_req {
135143
uint8_t value;
136144
};
137145

146+
struct rpc_msg_set_tensor_hash_rsp {
147+
uint8_t result;
148+
};
149+
138150
struct rpc_msg_get_tensor_req {
139151
rpc_tensor tensor;
140152
uint64_t offset;
@@ -187,6 +199,18 @@ struct ggml_backend_rpc_buffer_context {
187199

188200
// RPC helper functions
189201

202+
// Computes FNV-1a hash of the data
203+
static uint64_t fnv_hash(const uint8_t * data, size_t len) {
204+
const uint64_t fnv_prime = 0x100000001b3ULL;
205+
uint64_t hash = 0xcbf29ce484222325ULL;
206+
207+
for (size_t i = 0; i < len; ++i) {
208+
hash ^= data[i];
209+
hash *= fnv_prime;
210+
}
211+
return hash;
212+
}
213+
190214
static std::shared_ptr<socket_t> make_socket(sockfd_t fd) {
191215
#ifdef _WIN32
192216
if (fd == INVALID_SOCKET) {
@@ -483,10 +507,26 @@ static enum ggml_status ggml_backend_rpc_buffer_init_tensor(ggml_backend_buffer_
483507

484508
static void ggml_backend_rpc_buffer_set_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor, const void * data, size_t offset, size_t size) {
485509
ggml_backend_rpc_buffer_context * ctx = (ggml_backend_rpc_buffer_context *)buffer->context;
486-
// input serialization format: | rpc_tensor | offset (8 bytes) | data (size bytes) |
510+
rpc_tensor rpc_tensor = serialize_tensor(tensor);
511+
if (size > HASH_THRESHOLD) {
512+
// input serialization format: | rpc_tensor | offset (8 bytes) | hash (8 bytes)
513+
size_t input_size = sizeof(rpc_tensor) + sizeof(uint64_t) + sizeof(uint64_t);
514+
std::vector<uint8_t> input(input_size, 0);
515+
uint64_t hash = fnv_hash((const uint8_t*)data, size);
516+
memcpy(input.data(), &rpc_tensor, sizeof(rpc_tensor));
517+
memcpy(input.data() + sizeof(rpc_tensor), &offset, sizeof(offset));
518+
memcpy(input.data() + sizeof(rpc_tensor) + sizeof(offset), &hash, sizeof(hash));
519+
rpc_msg_set_tensor_hash_rsp response;
520+
bool status = send_rpc_cmd(ctx->sock, RPC_CMD_SET_TENSOR_HASH, input.data(), input.size(), &response, sizeof(response));
521+
GGML_ASSERT(status);
522+
if (response.result) {
523+
// the server has the same data, no need to send it
524+
return;
525+
}
526+
}
527+
// input serialization format: | rpc_tensor | offset (8 bytes) | data (size bytes)
487528
size_t input_size = sizeof(rpc_tensor) + sizeof(uint64_t) + size;
488529
std::vector<uint8_t> input(input_size, 0);
489-
rpc_tensor rpc_tensor = serialize_tensor(tensor);
490530
memcpy(input.data(), &rpc_tensor, sizeof(rpc_tensor));
491531
memcpy(input.data() + sizeof(rpc_tensor), &offset, sizeof(offset));
492532
memcpy(input.data() + sizeof(rpc_tensor) + sizeof(offset), data, size);
@@ -772,7 +812,10 @@ void ggml_backend_rpc_get_device_memory(const char * endpoint, size_t * free, si
772812

773813
class rpc_server {
774814
public:
775-
rpc_server(ggml_backend_t backend) : backend(backend) {}
815+
rpc_server(ggml_backend_t backend, const fs::path & cache_dir,
816+
const std::unordered_map<uint64_t, ggml_tensor *> & tensor_cache)
817+
: backend(backend), cache_dir(cache_dir), tensor_cache(tensor_cache) {
818+
}
776819
~rpc_server();
777820

778821
void alloc_buffer(const rpc_msg_alloc_buffer_req & request, rpc_msg_alloc_buffer_rsp & response);
@@ -782,13 +825,15 @@ class rpc_server {
782825
bool free_buffer(const rpc_msg_free_buffer_req & request);
783826
bool buffer_clear(const rpc_msg_buffer_clear_req & request);
784827
bool set_tensor(const std::vector<uint8_t> & input);
828+
bool set_tensor_hash(const std::vector<uint8_t> & input, rpc_msg_set_tensor_hash_rsp & response);
785829
bool get_tensor(const rpc_msg_get_tensor_req & request, std::vector<uint8_t> & response);
786830
bool copy_tensor(const rpc_msg_copy_tensor_req & request, rpc_msg_copy_tensor_rsp & response);
787831
bool graph_compute(const std::vector<uint8_t> & input, rpc_msg_graph_compute_rsp & response);
788832
bool init_tensor(const rpc_msg_init_tensor_req & request);
789833
bool get_alloc_size(const rpc_msg_get_alloc_size_req & request, rpc_msg_get_alloc_size_rsp & response);
790834

791835
private:
836+
bool get_cached_file(uint64_t hash, std::vector<uint8_t> & data);
792837
ggml_tensor * deserialize_tensor(struct ggml_context * ctx, const rpc_tensor * tensor);
793838
ggml_tensor * create_node(uint64_t id,
794839
struct ggml_context * ctx,
@@ -797,7 +842,9 @@ class rpc_server {
797842

798843

799844
ggml_backend_t backend;
845+
fs::path cache_dir;
800846
std::unordered_set<ggml_backend_buffer_t> buffers;
847+
const std::unordered_map<uint64_t, ggml_tensor *> & tensor_cache;
801848
};
802849

803850
bool rpc_server::get_alloc_size(const rpc_msg_get_alloc_size_req & request, rpc_msg_get_alloc_size_rsp & response) {
@@ -960,11 +1007,97 @@ bool rpc_server::set_tensor(const std::vector<uint8_t> & input) {
9601007
}
9611008

9621009
const void * data = input.data() + sizeof(rpc_tensor) + sizeof(offset);
1010+
if (size > HASH_THRESHOLD) {
1011+
uint64_t hash = fnv_hash((const uint8_t*)data, size);
1012+
char hash_str[17];
1013+
snprintf(hash_str, sizeof(hash_str), "%016" PRIx64, hash);
1014+
// save to cache_dir/hash_str
1015+
fs::path cache_file = cache_dir / hash_str;
1016+
std::ofstream ofs(cache_file, std::ios::binary);
1017+
ofs.write((const char *)data, size);
1018+
}
9631019
ggml_backend_tensor_set(tensor, data, offset, size);
9641020
ggml_free(ctx);
9651021
return true;
9661022
}
9671023

1024+
bool rpc_server::get_cached_file(uint64_t hash, std::vector<uint8_t> & data) {
1025+
char hash_str[17];
1026+
snprintf(hash_str, sizeof(hash_str), "%016" PRIx64, hash);
1027+
fs::path cache_file = cache_dir / hash_str;
1028+
if (!fs::exists(cache_file)) {
1029+
return false;
1030+
}
1031+
std::ifstream ifs(cache_file, std::ios::binary);
1032+
ifs.seekg(0, std::ios::end);
1033+
size_t size = ifs.tellg();
1034+
ifs.seekg(0, std::ios::beg);
1035+
data.resize(size);
1036+
ifs.read((char *)data.data(), size);
1037+
return true;
1038+
}
1039+
1040+
bool rpc_server::set_tensor_hash(const std::vector<uint8_t> & input, rpc_msg_set_tensor_hash_rsp & response)
1041+
{
1042+
// serialization format: | rpc_tensor | offset (8 bytes) | hash (8 bytes) |
1043+
if (input.size() != sizeof(rpc_tensor) + 16) {
1044+
return false;
1045+
}
1046+
const rpc_tensor * in_tensor = (const rpc_tensor *)input.data();
1047+
uint64_t offset;
1048+
memcpy(&offset, input.data() + sizeof(rpc_tensor), sizeof(offset));
1049+
const uint64_t * hash = (const uint64_t *)(input.data() + sizeof(rpc_tensor) + sizeof(offset));
1050+
bool have_cached_tensor = false;
1051+
ggml_tensor * cached_tensor = nullptr;
1052+
bool have_cached_file = false;
1053+
std::vector<uint8_t> cached_file;
1054+
size_t size = 0;
1055+
if (tensor_cache.find(*hash) != tensor_cache.end()) {
1056+
have_cached_tensor = true;
1057+
cached_tensor = tensor_cache.at(*hash);
1058+
size = ggml_nbytes(cached_tensor);
1059+
} else if (get_cached_file(*hash, cached_file)) {
1060+
have_cached_file = true;
1061+
size = cached_file.size();
1062+
} else {
1063+
response.result = 0;
1064+
return true;
1065+
}
1066+
struct ggml_init_params params {
1067+
/*.mem_size =*/ ggml_tensor_overhead(),
1068+
/*.mem_buffer =*/ NULL,
1069+
/*.no_alloc =*/ true,
1070+
};
1071+
struct ggml_context * ctx = ggml_init(params);
1072+
ggml_tensor * tensor = deserialize_tensor(ctx, in_tensor);
1073+
if (tensor == nullptr) {
1074+
GGML_LOG_ERROR("[%s] error deserializing tensor\n", __func__);
1075+
ggml_free(ctx);
1076+
return false;
1077+
}
1078+
GGML_PRINT_DEBUG("[%s] buffer: %p, data: %p, offset: %" PRIu64 ", size: %zu, hash: %" PRIx64 "\n", __func__, (void*)tensor->buffer, tensor->data, offset, size, *hash);
1079+
1080+
// sanitize tensor->data
1081+
{
1082+
const size_t p0 = (size_t) ggml_backend_buffer_get_base(tensor->buffer);
1083+
const size_t p1 = p0 + ggml_backend_buffer_get_size(tensor->buffer);
1084+
1085+
if (in_tensor->data + offset < p0 || in_tensor->data + offset >= p1 || size > (p1 - in_tensor->data - offset)) {
1086+
GGML_ABORT("[%s] tensor->data out of bounds\n", __func__);
1087+
}
1088+
}
1089+
if (have_cached_tensor) {
1090+
ggml_backend_tensor_set(tensor, cached_tensor->data, offset, size);
1091+
} else if (have_cached_file) {
1092+
ggml_backend_tensor_set(tensor, cached_file.data(), offset, size);
1093+
} else {
1094+
GGML_ABORT("[%s] no cached tensor or file\n", __func__);
1095+
}
1096+
response.result = 1;
1097+
ggml_free(ctx);
1098+
return true;
1099+
}
1100+
9681101
bool rpc_server::init_tensor(const rpc_msg_init_tensor_req & request) {
9691102
struct ggml_init_params params {
9701103
/*.mem_size =*/ ggml_tensor_overhead(),
@@ -1148,8 +1281,10 @@ rpc_server::~rpc_server() {
11481281
}
11491282
}
11501283

1151-
static void rpc_serve_client(ggml_backend_t backend, sockfd_t sockfd, size_t free_mem, size_t total_mem) {
1152-
rpc_server server(backend);
1284+
static void rpc_serve_client(ggml_backend_t backend, fs::path cache_dir,
1285+
const std::unordered_map<uint64_t, ggml_tensor *> & tensor_cache,
1286+
sockfd_t sockfd, size_t free_mem, size_t total_mem) {
1287+
rpc_server server(backend, cache_dir, tensor_cache);
11531288
while (true) {
11541289
uint8_t cmd;
11551290
if (!recv_data(sockfd, &cmd, 1)) {
@@ -1260,6 +1395,20 @@ static void rpc_serve_client(ggml_backend_t backend, sockfd_t sockfd, size_t fre
12601395
}
12611396
break;
12621397
}
1398+
case RPC_CMD_SET_TENSOR_HASH: {
1399+
std::vector<uint8_t> input;
1400+
if (!recv_msg(sockfd, input)) {
1401+
return;
1402+
}
1403+
rpc_msg_set_tensor_hash_rsp response;
1404+
if (!server.set_tensor_hash(input, response)) {
1405+
return;
1406+
}
1407+
if (!send_msg(sockfd, &response, sizeof(response))) {
1408+
return;
1409+
}
1410+
break;
1411+
}
12631412
case RPC_CMD_INIT_TENSOR: {
12641413
rpc_msg_init_tensor_req request;
12651414
if (!recv_msg(sockfd, &request,sizeof(request))) {
@@ -1335,7 +1484,9 @@ static void rpc_serve_client(ggml_backend_t backend, sockfd_t sockfd, size_t fre
13351484
}
13361485
}
13371486

1338-
void ggml_backend_rpc_start_server(ggml_backend_t backend, const char * endpoint, size_t free_mem, size_t total_mem) {
1487+
void ggml_backend_rpc_start_server(ggml_backend_t backend, const char * endpoint,
1488+
const char * gguf_path, const char * cache_dir,
1489+
size_t free_mem, size_t total_mem) {
13391490
std::string host;
13401491
int port;
13411492
if (!parse_endpoint(endpoint, host, port)) {
@@ -1351,6 +1502,28 @@ void ggml_backend_rpc_start_server(ggml_backend_t backend, const char * endpoint
13511502
}
13521503
}
13531504
#endif
1505+
gguf_context * ctx = nullptr;
1506+
std::unordered_map<uint64_t, ggml_tensor *> tensor_cache;
1507+
if (gguf_path != nullptr) {
1508+
struct ggml_context * ctx_data = NULL;
1509+
struct gguf_init_params params = {
1510+
/*.no_alloc = */ false,
1511+
/*.ctx = */ &ctx_data,
1512+
};
1513+
ctx = gguf_init_from_file(gguf_path, params);
1514+
if (ctx == nullptr) {
1515+
fprintf(stderr, "Failed to load GGUF file: %s\n", gguf_path);
1516+
return;
1517+
}
1518+
const int n_tensors = gguf_get_n_tensors(ctx);
1519+
for (int i = 0; i < n_tensors; ++i) {
1520+
const char * name = gguf_get_tensor_name(ctx, i);
1521+
ggml_tensor * cur = ggml_get_tensor(ctx_data, name);
1522+
size_t n_bytes = ggml_nbytes(cur);
1523+
uint64_t hash = fnv_hash((const uint8_t *)cur->data, n_bytes);
1524+
tensor_cache[hash] = cur;
1525+
}
1526+
}
13541527
auto server_socket = create_server_socket(host.c_str(), port);
13551528
if (server_socket == nullptr) {
13561529
fprintf(stderr, "Failed to create server socket\n");
@@ -1364,10 +1537,13 @@ void ggml_backend_rpc_start_server(ggml_backend_t backend, const char * endpoint
13641537
}
13651538
printf("Accepted client connection, free_mem=%zu, total_mem=%zu\n", free_mem, total_mem);
13661539
fflush(stdout);
1367-
rpc_serve_client(backend, client_socket->fd, free_mem, total_mem);
1540+
rpc_serve_client(backend, cache_dir, tensor_cache, client_socket->fd, free_mem, total_mem);
13681541
printf("Client connection closed\n");
13691542
fflush(stdout);
13701543
}
1544+
if (ctx != nullptr) {
1545+
gguf_free(ctx);
1546+
}
13711547
#ifdef _WIN32
13721548
WSACleanup();
13731549
#endif

0 commit comments

Comments
 (0)