Skip to content

Commit 9519c7a

Browse files
committed
rpc : send hash when tensor data is above some fixed threshold
ref #10095
1 parent 960e726 commit 9519c7a

File tree

6 files changed

+7270
-9
lines changed

6 files changed

+7270
-9
lines changed

examples/rpc/rpc-server.cpp

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ struct rpc_server_params {
2929
std::string host = "127.0.0.1";
3030
int port = 50052;
3131
size_t backend_mem = 0;
32+
std::string gguf_path = "";
3233
};
3334

3435
static void print_usage(int /*argc*/, char ** argv, rpc_server_params params) {
@@ -37,6 +38,7 @@ static void print_usage(int /*argc*/, char ** argv, rpc_server_params params) {
3738
fprintf(stderr, " -h, --help show this help message and exit\n");
3839
fprintf(stderr, " -H HOST, --host HOST host to bind to (default: %s)\n", params.host.c_str());
3940
fprintf(stderr, " -p PORT, --port PORT port to bind to (default: %d)\n", params.port);
41+
fprintf(stderr, " -f PATH, --gguf PATH path to GGUF file\n");
4042
fprintf(stderr, " -m MEM, --mem MEM backend memory size (in MB)\n");
4143
fprintf(stderr, "\n");
4244
}
@@ -58,6 +60,11 @@ static bool rpc_server_params_parse(int argc, char ** argv, rpc_server_params &
5860
if (params.port <= 0 || params.port > 65535) {
5961
return false;
6062
}
63+
} else if (arg == "-f" || arg == "--gguf") {
64+
if (++i >= argc) {
65+
return false;
66+
}
67+
params.gguf_path = argv[i];
6168
} else if (arg == "-m" || arg == "--mem") {
6269
if (++i >= argc) {
6370
return false;
@@ -164,8 +171,9 @@ int main(int argc, char * argv[]) {
164171
} else {
165172
get_backend_memory(&free_mem, &total_mem);
166173
}
174+
const char * gguf_path = params.gguf_path.empty() ? nullptr : params.gguf_path.c_str();
167175
printf("Starting RPC server on %s, backend memory: %zu MB\n", endpoint.c_str(), free_mem / (1024 * 1024));
168-
ggml_backend_rpc_start_server(backend, endpoint.c_str(), free_mem, total_mem);
176+
ggml_backend_rpc_start_server(backend, endpoint.c_str(), gguf_path, free_mem, total_mem);
169177
ggml_backend_free(backend);
170178
return 0;
171179
}

ggml/include/ggml-rpc.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,8 @@ GGML_BACKEND_API ggml_backend_buffer_type_t ggml_backend_rpc_buffer_type(const c
1717

1818
GGML_BACKEND_API void ggml_backend_rpc_get_device_memory(const char * endpoint, size_t * free, size_t * total);
1919

20-
GGML_BACKEND_API void ggml_backend_rpc_start_server(ggml_backend_t backend, const char * endpoint, size_t free_mem, size_t total_mem);
20+
GGML_BACKEND_API void ggml_backend_rpc_start_server(ggml_backend_t backend, const char * endpoint,
21+
const char * gguf_path, size_t free_mem, size_t total_mem);
2122

2223
GGML_BACKEND_API ggml_backend_reg_t ggml_backend_rpc_reg(void);
2324

ggml/src/ggml-rpc/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ message(STATUS "Using RPC backend")
22

33
ggml_add_backend_library(ggml-rpc
44
ggml-rpc.cpp
5+
xxhash.c
56
)
67

78
if (WIN32)

ggml/src/ggml-rpc/ggml-rpc.cpp

Lines changed: 123 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
#include "ggml-rpc.h"
22
#include "ggml-impl.h"
33
#include "ggml-backend-impl.h"
4+
#include "xxhash.h"
45

56
#include <cinttypes>
67
#include <string>
@@ -80,6 +81,7 @@ enum rpc_cmd {
8081
RPC_CMD_FREE_BUFFER,
8182
RPC_CMD_BUFFER_CLEAR,
8283
RPC_CMD_SET_TENSOR,
84+
RPC_CMD_SET_TENSOR_HASH,
8385
RPC_CMD_GET_TENSOR,
8486
RPC_CMD_COPY_TENSOR,
8587
RPC_CMD_GRAPH_COMPUTE,
@@ -89,6 +91,9 @@ enum rpc_cmd {
8991
RPC_CMD_COUNT,
9092
};
9193

94+
// Try RPC_CMD_SET_TENSOR_HASH first when data size is larger than this threshold
95+
const size_t HASH_THRESHOLD = 1024 * 1024;
96+
9297
struct rpc_msg_get_alloc_size_req {
9398
rpc_tensor tensor;
9499
};
@@ -135,6 +140,10 @@ struct rpc_msg_buffer_clear_req {
135140
uint8_t value;
136141
};
137142

143+
struct rpc_msg_set_tensor_hash_rsp {
144+
uint8_t result;
145+
};
146+
138147
struct rpc_msg_get_tensor_req {
139148
rpc_tensor tensor;
140149
uint64_t offset;
@@ -483,10 +492,26 @@ static enum ggml_status ggml_backend_rpc_buffer_init_tensor(ggml_backend_buffer_
483492

484493
static void ggml_backend_rpc_buffer_set_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor, const void * data, size_t offset, size_t size) {
485494
ggml_backend_rpc_buffer_context * ctx = (ggml_backend_rpc_buffer_context *)buffer->context;
486-
// input serialization format: | rpc_tensor | offset (8 bytes) | data (size bytes) |
495+
rpc_tensor rpc_tensor = serialize_tensor(tensor);
496+
if (size > HASH_THRESHOLD) {
497+
// input serialization format: | rpc_tensor | offset (8 bytes) | hash (8 bytes)
498+
size_t input_size = sizeof(rpc_tensor) + sizeof(uint64_t) + sizeof(XXH64_hash_t);
499+
std::vector<uint8_t> input(input_size, 0);
500+
XXH64_hash_t hash = XXH64(data, size, 0);
501+
memcpy(input.data(), &rpc_tensor, sizeof(rpc_tensor));
502+
memcpy(input.data() + sizeof(rpc_tensor), &offset, sizeof(offset));
503+
memcpy(input.data() + sizeof(rpc_tensor) + sizeof(offset), &hash, sizeof(hash));
504+
rpc_msg_set_tensor_hash_rsp response;
505+
bool status = send_rpc_cmd(ctx->sock, RPC_CMD_SET_TENSOR_HASH, input.data(), input.size(), &response, sizeof(response));
506+
GGML_ASSERT(status);
507+
if (response.result) {
508+
// the server has the same data, no need to send it
509+
return;
510+
}
511+
}
512+
// input serialization format: | rpc_tensor | offset (8 bytes) | data (size bytes)
487513
size_t input_size = sizeof(rpc_tensor) + sizeof(uint64_t) + size;
488514
std::vector<uint8_t> input(input_size, 0);
489-
rpc_tensor rpc_tensor = serialize_tensor(tensor);
490515
memcpy(input.data(), &rpc_tensor, sizeof(rpc_tensor));
491516
memcpy(input.data() + sizeof(rpc_tensor), &offset, sizeof(offset));
492517
memcpy(input.data() + sizeof(rpc_tensor) + sizeof(offset), data, size);
@@ -772,7 +797,9 @@ void ggml_backend_rpc_get_device_memory(const char * endpoint, size_t * free, si
772797

773798
class rpc_server {
774799
public:
775-
rpc_server(ggml_backend_t backend) : backend(backend) {}
800+
rpc_server(ggml_backend_t backend, const std::unordered_map<XXH64_hash_t, ggml_tensor *> & tensor_hashes)
801+
: backend(backend), tensor_hashes(tensor_hashes) {
802+
}
776803
~rpc_server();
777804

778805
void alloc_buffer(const rpc_msg_alloc_buffer_req & request, rpc_msg_alloc_buffer_rsp & response);
@@ -782,6 +809,7 @@ class rpc_server {
782809
bool free_buffer(const rpc_msg_free_buffer_req & request);
783810
bool buffer_clear(const rpc_msg_buffer_clear_req & request);
784811
bool set_tensor(const std::vector<uint8_t> & input);
812+
bool set_tensor_hash(const std::vector<uint8_t> & input, rpc_msg_set_tensor_hash_rsp & response);
785813
bool get_tensor(const rpc_msg_get_tensor_req & request, std::vector<uint8_t> & response);
786814
bool copy_tensor(const rpc_msg_copy_tensor_req & request, rpc_msg_copy_tensor_rsp & response);
787815
bool graph_compute(const std::vector<uint8_t> & input, rpc_msg_graph_compute_rsp & response);
@@ -798,6 +826,7 @@ class rpc_server {
798826

799827
ggml_backend_t backend;
800828
std::unordered_set<ggml_backend_buffer_t> buffers;
829+
const std::unordered_map<XXH64_hash_t, ggml_tensor *> & tensor_hashes;
801830
};
802831

803832
bool rpc_server::get_alloc_size(const rpc_msg_get_alloc_size_req & request, rpc_msg_get_alloc_size_rsp & response) {
@@ -965,6 +994,52 @@ bool rpc_server::set_tensor(const std::vector<uint8_t> & input) {
965994
return true;
966995
}
967996

997+
bool rpc_server::set_tensor_hash(const std::vector<uint8_t> & input, rpc_msg_set_tensor_hash_rsp & response)
998+
{
999+
// serialization format: | rpc_tensor | offset (8 bytes) | hash (8 bytes) |
1000+
if (input.size() != sizeof(rpc_tensor) + 16) {
1001+
return false;
1002+
}
1003+
const rpc_tensor * in_tensor = (const rpc_tensor *)input.data();
1004+
uint64_t offset;
1005+
memcpy(&offset, input.data() + sizeof(rpc_tensor), sizeof(offset));
1006+
const XXH64_hash_t * hash = (const XXH64_hash_t *)(input.data() + sizeof(rpc_tensor) + sizeof(offset));
1007+
if (tensor_hashes.find(*hash) == tensor_hashes.end()) {
1008+
response.result = 0;
1009+
return true;
1010+
}
1011+
ggml_tensor * cached_tensor = tensor_hashes.at(*hash);
1012+
const size_t size = ggml_nbytes(cached_tensor);
1013+
1014+
struct ggml_init_params params {
1015+
/*.mem_size =*/ ggml_tensor_overhead(),
1016+
/*.mem_buffer =*/ NULL,
1017+
/*.no_alloc =*/ true,
1018+
};
1019+
struct ggml_context * ctx = ggml_init(params);
1020+
ggml_tensor * tensor = deserialize_tensor(ctx, in_tensor);
1021+
if (tensor == nullptr) {
1022+
GGML_LOG_ERROR("[%s] error deserializing tensor\n", __func__);
1023+
ggml_free(ctx);
1024+
return false;
1025+
}
1026+
GGML_PRINT_DEBUG("[%s] buffer: %p, data: %p, offset: %" PRIu64 ", size: %zu, hash: %" PRIx64 "\n", __func__, (void*)tensor->buffer, tensor->data, offset, size, *hash);
1027+
1028+
// sanitize tensor->data
1029+
{
1030+
const size_t p0 = (size_t) ggml_backend_buffer_get_base(tensor->buffer);
1031+
const size_t p1 = p0 + ggml_backend_buffer_get_size(tensor->buffer);
1032+
1033+
if (in_tensor->data + offset < p0 || in_tensor->data + offset >= p1 || size > (p1 - in_tensor->data - offset)) {
1034+
GGML_ABORT("[%s] tensor->data out of bounds\n", __func__);
1035+
}
1036+
}
1037+
ggml_backend_tensor_set(tensor, cached_tensor->data, offset, size);
1038+
response.result = 1;
1039+
ggml_free(ctx);
1040+
return true;
1041+
}
1042+
9681043
bool rpc_server::init_tensor(const rpc_msg_init_tensor_req & request) {
9691044
struct ggml_init_params params {
9701045
/*.mem_size =*/ ggml_tensor_overhead(),
@@ -1148,8 +1223,9 @@ rpc_server::~rpc_server() {
11481223
}
11491224
}
11501225

1151-
static void rpc_serve_client(ggml_backend_t backend, sockfd_t sockfd, size_t free_mem, size_t total_mem) {
1152-
rpc_server server(backend);
1226+
static void rpc_serve_client(ggml_backend_t backend, const std::unordered_map<XXH64_hash_t, ggml_tensor *> & tensor_hashes,
1227+
sockfd_t sockfd, size_t free_mem, size_t total_mem) {
1228+
rpc_server server(backend, tensor_hashes);
11531229
while (true) {
11541230
uint8_t cmd;
11551231
if (!recv_data(sockfd, &cmd, 1)) {
@@ -1260,6 +1336,20 @@ static void rpc_serve_client(ggml_backend_t backend, sockfd_t sockfd, size_t fre
12601336
}
12611337
break;
12621338
}
1339+
case RPC_CMD_SET_TENSOR_HASH: {
1340+
std::vector<uint8_t> input;
1341+
if (!recv_msg(sockfd, input)) {
1342+
return;
1343+
}
1344+
rpc_msg_set_tensor_hash_rsp response;
1345+
if (!server.set_tensor_hash(input, response)) {
1346+
return;
1347+
}
1348+
if (!send_msg(sockfd, &response, sizeof(response))) {
1349+
return;
1350+
}
1351+
break;
1352+
}
12631353
case RPC_CMD_INIT_TENSOR: {
12641354
rpc_msg_init_tensor_req request;
12651355
if (!recv_msg(sockfd, &request,sizeof(request))) {
@@ -1335,7 +1425,8 @@ static void rpc_serve_client(ggml_backend_t backend, sockfd_t sockfd, size_t fre
13351425
}
13361426
}
13371427

1338-
void ggml_backend_rpc_start_server(ggml_backend_t backend, const char * endpoint, size_t free_mem, size_t total_mem) {
1428+
void ggml_backend_rpc_start_server(ggml_backend_t backend, const char * endpoint, const char * gguf_path,
1429+
size_t free_mem, size_t total_mem) {
13391430
std::string host;
13401431
int port;
13411432
if (!parse_endpoint(endpoint, host, port)) {
@@ -1351,6 +1442,28 @@ void ggml_backend_rpc_start_server(ggml_backend_t backend, const char * endpoint
13511442
}
13521443
}
13531444
#endif
1445+
gguf_context * ctx = nullptr;
1446+
std::unordered_map<XXH64_hash_t, ggml_tensor *> tensor_hashes;
1447+
if (gguf_path != nullptr) {
1448+
struct ggml_context * ctx_data = NULL;
1449+
struct gguf_init_params params = {
1450+
/*.no_alloc = */ false,
1451+
/*.ctx = */ &ctx_data,
1452+
};
1453+
ctx = gguf_init_from_file(gguf_path, params);
1454+
if (ctx == nullptr) {
1455+
fprintf(stderr, "Failed to load GGUF file: %s\n", gguf_path);
1456+
return;
1457+
}
1458+
const int n_tensors = gguf_get_n_tensors(ctx);
1459+
for (int i = 0; i < n_tensors; ++i) {
1460+
const char * name = gguf_get_tensor_name(ctx, i);
1461+
ggml_tensor * cur = ggml_get_tensor(ctx_data, name);
1462+
size_t n_bytes = ggml_nbytes(cur);
1463+
XXH64_hash_t hash = XXH64(cur->data, n_bytes, 0);
1464+
tensor_hashes[hash] = cur;
1465+
}
1466+
}
13541467
auto server_socket = create_server_socket(host.c_str(), port);
13551468
if (server_socket == nullptr) {
13561469
fprintf(stderr, "Failed to create server socket\n");
@@ -1364,10 +1477,13 @@ void ggml_backend_rpc_start_server(ggml_backend_t backend, const char * endpoint
13641477
}
13651478
printf("Accepted client connection, free_mem=%zu, total_mem=%zu\n", free_mem, total_mem);
13661479
fflush(stdout);
1367-
rpc_serve_client(backend, client_socket->fd, free_mem, total_mem);
1480+
rpc_serve_client(backend, tensor_hashes, client_socket->fd, free_mem, total_mem);
13681481
printf("Client connection closed\n");
13691482
fflush(stdout);
13701483
}
1484+
if (ctx != nullptr) {
1485+
gguf_free(ctx);
1486+
}
13711487
#ifdef _WIN32
13721488
WSACleanup();
13731489
#endif

ggml/src/ggml-rpc/xxhash.c

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
/*
2+
* xxHash - Extremely Fast Hash algorithm
3+
* Copyright (C) 2012-2023 Yann Collet
4+
*
5+
* BSD 2-Clause License (https://www.opensource.org/licenses/bsd-license.php)
6+
*
7+
* Redistribution and use in source and binary forms, with or without
8+
* modification, are permitted provided that the following conditions are
9+
* met:
10+
*
11+
* * Redistributions of source code must retain the above copyright
12+
* notice, this list of conditions and the following disclaimer.
13+
* * Redistributions in binary form must reproduce the above
14+
* copyright notice, this list of conditions and the following disclaimer
15+
* in the documentation and/or other materials provided with the
16+
* distribution.
17+
*
18+
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
19+
* "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
20+
* LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
21+
* A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
22+
* OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
23+
* SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
24+
* LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
25+
* DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
26+
* THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
27+
* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
28+
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
29+
*
30+
* You can contact the author at:
31+
* - xxHash homepage: https://www.xxhash.com
32+
* - xxHash source repository: https://github.com/Cyan4973/xxHash
33+
*/
34+
35+
/*
36+
* xxhash.c instantiates functions defined in xxhash.h
37+
*/
38+
39+
#define XXH_STATIC_LINKING_ONLY /* access advanced declarations */
40+
#define XXH_IMPLEMENTATION /* access definitions */
41+
42+
#include "xxhash.h"

0 commit comments

Comments
 (0)