Skip to content

Commit f504dc3

Browse files
committed
rpc : add RPC_CMD_HELLO
Add RPC_CMD_HELLO for getting the version of the protocol implemend by the server. Follow the semantic versioning rules at https://semver.org Hopefully this bring better user experience when we make breaking at the protocol level and avoid issues like #12465
1 parent daa4228 commit f504dc3

File tree

3 files changed

+51
-2
lines changed

3 files changed

+51
-2
lines changed

examples/rpc/rpc-server.cpp

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -297,7 +297,10 @@ int main(int argc, char * argv[]) {
297297
}
298298
cache_dir = cache_dir_str.c_str();
299299
}
300-
printf("Starting RPC server\n");
300+
printf("Starting RPC server v%d.%d.%d\n",
301+
RPC_PROTO_MAJOR_VERSION,
302+
RPC_PROTO_MINOR_VERSION,
303+
RPC_PROTO_PATCH_VERSION);
301304
printf(" endpoint : %s\n", endpoint.c_str());
302305
printf(" local cache : %s\n", cache_dir ? cache_dir : "n/a");
303306
printf(" backend memory : %zu MB\n", free_mem / (1024 * 1024));

ggml/include/ggml-rpc.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,9 @@
77
extern "C" {
88
#endif
99

10+
#define RPC_PROTO_MAJOR_VERSION 1
11+
#define RPC_PROTO_MINOR_VERSION 0
12+
#define RPC_PROTO_PATCH_VERSION 0
1013
#define GGML_RPC_MAX_SERVERS 16
1114

1215
// backend API

ggml/src/ggml-rpc/ggml-rpc.cpp

Lines changed: 44 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,8 @@ static_assert(sizeof(rpc_tensor) % 8 == 0, "rpc_tensor size must be multiple of
7878

7979
// RPC commands
8080
enum rpc_cmd {
81-
RPC_CMD_ALLOC_BUFFER = 0,
81+
RPC_CMD_HELLO = 0,
82+
RPC_CMD_ALLOC_BUFFER,
8283
RPC_CMD_GET_ALIGNMENT,
8384
RPC_CMD_GET_MAX_SIZE,
8485
RPC_CMD_BUFFER_GET_BASE,
@@ -98,6 +99,12 @@ enum rpc_cmd {
9899
// Try RPC_CMD_SET_TENSOR_HASH first when data size is larger than this threshold
99100
const size_t HASH_THRESHOLD = 10 * 1024 * 1024;
100101

102+
struct rpc_msg_hello_rsp {
103+
uint8_t major;
104+
uint8_t minor;
105+
uint8_t patch;
106+
};
107+
101108
struct rpc_msg_get_alloc_size_req {
102109
rpc_tensor tensor;
103110
};
@@ -606,6 +613,20 @@ static ggml_backend_buffer_t ggml_backend_rpc_buffer_type_alloc_buffer(ggml_back
606613
}
607614
}
608615

616+
static bool check_server_version(const std::shared_ptr<socket_t> & sock) {
617+
rpc_msg_hello_rsp response;
618+
bool status = send_rpc_cmd(sock, RPC_CMD_HELLO, nullptr, 0, &response, sizeof(response));
619+
GGML_ASSERT(status);
620+
if (response.major != RPC_PROTO_MAJOR_VERSION || response.minor > RPC_PROTO_MINOR_VERSION) {
621+
fprintf(stderr, "RPC server version mismatch: %d.%d.%d\n", response.major, response.minor, response.patch);
622+
return false;
623+
}
624+
if (response.minor != RPC_PROTO_MINOR_VERSION || response.patch != RPC_PROTO_PATCH_VERSION) {
625+
fprintf(stderr, "WARNING: RPC server version mismatch: %d.%d.%d\n", response.major, response.minor, response.patch);
626+
}
627+
return true;
628+
}
629+
609630
static size_t get_alignment(const std::shared_ptr<socket_t> & sock) {
610631
rpc_msg_get_alignment_rsp response;
611632
bool status = send_rpc_cmd(sock, RPC_CMD_GET_ALIGNMENT, nullptr, 0, &response, sizeof(response));
@@ -754,6 +775,9 @@ ggml_backend_buffer_type_t ggml_backend_rpc_buffer_type(const char * endpoint) {
754775
fprintf(stderr, "Failed to connect to %s\n", endpoint);
755776
return nullptr;
756777
}
778+
if (!check_server_version(sock)) {
779+
return nullptr;
780+
}
757781
size_t alignment = get_alignment(sock);
758782
size_t max_size = get_max_size(sock);
759783
ggml_backend_rpc_buffer_type_context * buft_ctx = new ggml_backend_rpc_buffer_type_context {
@@ -818,6 +842,7 @@ class rpc_server {
818842
}
819843
~rpc_server();
820844

845+
void hello(rpc_msg_hello_rsp & response);
821846
void alloc_buffer(const rpc_msg_alloc_buffer_req & request, rpc_msg_alloc_buffer_rsp & response);
822847
void get_alignment(rpc_msg_get_alignment_rsp & response);
823848
void get_max_size(rpc_msg_get_max_size_rsp & response);
@@ -846,6 +871,13 @@ class rpc_server {
846871
std::unordered_set<ggml_backend_buffer_t> buffers;
847872
};
848873

874+
void rpc_server::hello(rpc_msg_hello_rsp & response) {
875+
response.major = RPC_PROTO_MAJOR_VERSION;
876+
response.minor = RPC_PROTO_MINOR_VERSION;
877+
response.patch = RPC_PROTO_PATCH_VERSION;
878+
GGML_PRINT_DEBUG("[%s] version: %d.%d.%d\n", __func__, response.major, response.minor, response.patch);
879+
}
880+
849881
bool rpc_server::get_alloc_size(const rpc_msg_get_alloc_size_req & request, rpc_msg_get_alloc_size_rsp & response) {
850882
ggml_backend_buffer_type_t buft;
851883
struct ggml_init_params params {
@@ -1282,6 +1314,17 @@ static void rpc_serve_client(ggml_backend_t backend, const char * cache_dir,
12821314
break;
12831315
}
12841316
switch (cmd) {
1317+
case RPC_CMD_HELLO: {
1318+
if (!recv_msg(sockfd, nullptr, 0)) {
1319+
return;
1320+
}
1321+
rpc_msg_hello_rsp response;
1322+
server.hello(response);
1323+
if (!send_msg(sockfd, &response, sizeof(response))) {
1324+
return;
1325+
}
1326+
break;
1327+
}
12851328
case RPC_CMD_ALLOC_BUFFER: {
12861329
rpc_msg_alloc_buffer_req request;
12871330
if (!recv_msg(sockfd, &request, sizeof(request))) {

0 commit comments

Comments
 (0)