@@ -92,12 +92,19 @@ enum rpc_cmd {
9292 RPC_CMD_GET_DEVICE_MEMORY,
9393 RPC_CMD_INIT_TENSOR,
9494 RPC_CMD_GET_ALLOC_SIZE,
95+ RPC_CMD_HELLO,
9596 RPC_CMD_COUNT,
9697};
9798
9899// Try RPC_CMD_SET_TENSOR_HASH first when data size is larger than this threshold
99100const size_t HASH_THRESHOLD = 10 * 1024 * 1024 ;
100101
102+ struct rpc_msg_hello_rsp {
103+ uint8_t major;
104+ uint8_t minor;
105+ uint8_t patch;
106+ };
107+
101108struct rpc_msg_get_alloc_size_req {
102109 rpc_tensor tensor;
103110};
@@ -400,6 +407,20 @@ static bool send_rpc_cmd(const std::shared_ptr<socket_t> & sock, enum rpc_cmd cm
400407
401408// RPC client-side implementation
402409
410+ static bool check_server_version (const std::shared_ptr<socket_t > & sock) {
411+ rpc_msg_hello_rsp response;
412+ bool status = send_rpc_cmd (sock, RPC_CMD_HELLO, nullptr , 0 , &response, sizeof (response));
413+ GGML_ASSERT (status);
414+ if (response.major != RPC_PROTO_MAJOR_VERSION || response.minor > RPC_PROTO_MINOR_VERSION) {
415+ fprintf (stderr, " RPC server version mismatch: %d.%d.%d\n " , response.major , response.minor , response.patch );
416+ return false ;
417+ }
418+ if (response.minor != RPC_PROTO_MINOR_VERSION || response.patch != RPC_PROTO_PATCH_VERSION) {
419+ fprintf (stderr, " WARNING: RPC server version mismatch: %d.%d.%d\n " , response.major , response.minor , response.patch );
420+ }
421+ return true ;
422+ }
423+
403424static std::shared_ptr<socket_t > get_socket (const std::string & endpoint) {
404425 static std::mutex mutex;
405426 std::lock_guard<std::mutex> lock (mutex);
@@ -433,6 +454,9 @@ static std::shared_ptr<socket_t> get_socket(const std::string & endpoint) {
433454 if (sock == nullptr ) {
434455 return nullptr ;
435456 }
457+ if (!check_server_version (sock)) {
458+ return nullptr ;
459+ }
436460 GGML_PRINT_DEBUG (" [%s] connected to %s, sockfd=%d\n " , __func__, endpoint.c_str (), sock->fd );
437461 sockets[endpoint] = sock;
438462 return sock;
@@ -818,6 +842,7 @@ class rpc_server {
818842 }
819843 ~rpc_server ();
820844
845+ void hello (rpc_msg_hello_rsp & response);
821846 void alloc_buffer (const rpc_msg_alloc_buffer_req & request, rpc_msg_alloc_buffer_rsp & response);
822847 void get_alignment (rpc_msg_get_alignment_rsp & response);
823848 void get_max_size (rpc_msg_get_max_size_rsp & response);
@@ -846,6 +871,13 @@ class rpc_server {
846871 std::unordered_set<ggml_backend_buffer_t > buffers;
847872};
848873
874+ void rpc_server::hello (rpc_msg_hello_rsp & response) {
875+ response.major = RPC_PROTO_MAJOR_VERSION;
876+ response.minor = RPC_PROTO_MINOR_VERSION;
877+ response.patch = RPC_PROTO_PATCH_VERSION;
878+ GGML_PRINT_DEBUG (" [%s] version: %d.%d.%d\n " , __func__, response.major , response.minor , response.patch );
879+ }
880+
849881bool rpc_server::get_alloc_size (const rpc_msg_get_alloc_size_req & request, rpc_msg_get_alloc_size_rsp & response) {
850882 ggml_backend_buffer_type_t buft;
851883 struct ggml_init_params params {
@@ -1271,8 +1303,24 @@ rpc_server::~rpc_server() {
12711303static void rpc_serve_client (ggml_backend_t backend, const char * cache_dir,
12721304 sockfd_t sockfd, size_t free_mem, size_t total_mem) {
12731305 rpc_server server (backend, cache_dir);
1306+ uint8_t cmd;
1307+ if (!recv_data (sockfd, &cmd, 1 )) {
1308+ return ;
1309+ }
1310+ // the first command sent by the client must be HELLO
1311+ if (cmd != RPC_CMD_HELLO) {
1312+ fprintf (stderr, " Expected HELLO command, update client\n " );
1313+ return ;
1314+ }
1315+ if (!recv_msg (sockfd, nullptr , 0 )) {
1316+ return ;
1317+ }
1318+ rpc_msg_hello_rsp response;
1319+ server.hello (response);
1320+ if (!send_msg (sockfd, &response, sizeof (response))) {
1321+ return ;
1322+ }
12741323 while (true ) {
1275- uint8_t cmd;
12761324 if (!recv_data (sockfd, &cmd, 1 )) {
12771325 break ;
12781326 }
@@ -1282,6 +1330,10 @@ static void rpc_serve_client(ggml_backend_t backend, const char * cache_dir,
12821330 break ;
12831331 }
12841332 switch (cmd) {
1333+ case RPC_CMD_HELLO: {
1334+ // HELLO command is handled above
1335+ return ;
1336+ }
12851337 case RPC_CMD_ALLOC_BUFFER: {
12861338 rpc_msg_alloc_buffer_req request;
12871339 if (!recv_msg (sockfd, &request, sizeof (request))) {
0 commit comments