@@ -92,12 +92,19 @@ enum rpc_cmd {
9292    RPC_CMD_GET_DEVICE_MEMORY,
9393    RPC_CMD_INIT_TENSOR,
9494    RPC_CMD_GET_ALLOC_SIZE,
95+     RPC_CMD_HELLO,
9596    RPC_CMD_COUNT,
9697};
9798
9899//  Try RPC_CMD_SET_TENSOR_HASH first when data size is larger than this threshold
99100const  size_t  HASH_THRESHOLD = 10  * 1024  * 1024 ;
100101
102+ struct  rpc_msg_hello_rsp  {
103+     uint8_t  major;
104+     uint8_t  minor;
105+     uint8_t  patch;
106+ };
107+ 
101108struct  rpc_msg_get_alloc_size_req  {
102109    rpc_tensor tensor;
103110};
@@ -400,6 +407,20 @@ static bool send_rpc_cmd(const std::shared_ptr<socket_t> & sock, enum rpc_cmd cm
400407
401408//  RPC client-side implementation
402409
410+ static  bool  check_server_version (const  std::shared_ptr<socket_t > & sock) {
411+     rpc_msg_hello_rsp response;
412+     bool  status = send_rpc_cmd (sock, RPC_CMD_HELLO, nullptr , 0 , &response, sizeof (response));
413+     GGML_ASSERT (status);
414+     if  (response.major  != RPC_PROTO_MAJOR_VERSION || response.minor  > RPC_PROTO_MINOR_VERSION) {
415+         fprintf (stderr, " RPC server version mismatch: %d.%d.%d\n "  , response.major , response.minor , response.patch );
416+         return  false ;
417+     }
418+     if  (response.minor  != RPC_PROTO_MINOR_VERSION || response.patch  != RPC_PROTO_PATCH_VERSION) {
419+         fprintf (stderr, " WARNING: RPC server version mismatch: %d.%d.%d\n "  , response.major , response.minor , response.patch );
420+     }
421+     return  true ;
422+ }
423+ 
403424static  std::shared_ptr<socket_t > get_socket (const  std::string & endpoint) {
404425    static  std::mutex mutex;
405426    std::lock_guard<std::mutex> lock (mutex);
@@ -433,6 +454,9 @@ static std::shared_ptr<socket_t> get_socket(const std::string & endpoint) {
433454    if  (sock == nullptr ) {
434455        return  nullptr ;
435456    }
457+     if  (!check_server_version (sock)) {
458+         return  nullptr ;
459+     }
436460    GGML_PRINT_DEBUG (" [%s] connected to %s, sockfd=%d\n "  , __func__, endpoint.c_str (), sock->fd );
437461    sockets[endpoint] = sock;
438462    return  sock;
@@ -818,6 +842,7 @@ class rpc_server {
818842    }
819843    ~rpc_server ();
820844
845+     void  hello (rpc_msg_hello_rsp & response);
821846    void  alloc_buffer (const  rpc_msg_alloc_buffer_req & request, rpc_msg_alloc_buffer_rsp & response);
822847    void  get_alignment (rpc_msg_get_alignment_rsp & response);
823848    void  get_max_size (rpc_msg_get_max_size_rsp & response);
@@ -846,6 +871,13 @@ class rpc_server {
846871    std::unordered_set<ggml_backend_buffer_t > buffers;
847872};
848873
874+ void  rpc_server::hello (rpc_msg_hello_rsp & response) {
875+     response.major  = RPC_PROTO_MAJOR_VERSION;
876+     response.minor  = RPC_PROTO_MINOR_VERSION;
877+     response.patch  = RPC_PROTO_PATCH_VERSION;
878+     GGML_PRINT_DEBUG (" [%s] version: %d.%d.%d\n "  , __func__, response.major , response.minor , response.patch );
879+ }
880+ 
849881bool  rpc_server::get_alloc_size (const  rpc_msg_get_alloc_size_req & request, rpc_msg_get_alloc_size_rsp & response) {
850882    ggml_backend_buffer_type_t  buft;
851883    struct  ggml_init_params  params {
@@ -1271,8 +1303,24 @@ rpc_server::~rpc_server() {
12711303static  void  rpc_serve_client (ggml_backend_t  backend, const  char  * cache_dir,
12721304                             sockfd_t  sockfd, size_t  free_mem, size_t  total_mem) {
12731305    rpc_server server (backend, cache_dir);
1306+     uint8_t  cmd;
1307+     if  (!recv_data (sockfd, &cmd, 1 )) {
1308+         return ;
1309+     }
1310+     //  the first command sent by the client must be HELLO
1311+     if  (cmd != RPC_CMD_HELLO) {
1312+         fprintf (stderr, " Expected HELLO command, update client\n "  );
1313+         return ;
1314+     }
1315+     if  (!recv_msg (sockfd, nullptr , 0 )) {
1316+         return ;
1317+     }
1318+     rpc_msg_hello_rsp response;
1319+     server.hello (response);
1320+     if  (!send_msg (sockfd, &response, sizeof (response))) {
1321+         return ;
1322+     }
12741323    while  (true ) {
1275-         uint8_t  cmd;
12761324        if  (!recv_data (sockfd, &cmd, 1 )) {
12771325            break ;
12781326        }
@@ -1282,6 +1330,10 @@ static void rpc_serve_client(ggml_backend_t backend, const char * cache_dir,
12821330            break ;
12831331        }
12841332        switch  (cmd) {
1333+             case  RPC_CMD_HELLO: {
1334+                 //  HELLO command is handled above
1335+                 return ;
1336+             }
12851337            case  RPC_CMD_ALLOC_BUFFER: {
12861338                rpc_msg_alloc_buffer_req request;
12871339                if  (!recv_msg (sockfd, &request, sizeof (request))) {
0 commit comments