@@ -78,7 +78,8 @@ static_assert(sizeof(rpc_tensor) % 8 == 0, "rpc_tensor size must be multiple of
7878
7979// RPC commands
8080enum rpc_cmd {
81- RPC_CMD_ALLOC_BUFFER = 0 ,
81+ RPC_CMD_HELLO = 0 ,
82+ RPC_CMD_ALLOC_BUFFER,
8283 RPC_CMD_GET_ALIGNMENT,
8384 RPC_CMD_GET_MAX_SIZE,
8485 RPC_CMD_BUFFER_GET_BASE,
@@ -98,6 +99,12 @@ enum rpc_cmd {
9899// Try RPC_CMD_SET_TENSOR_HASH first when data size is larger than this threshold
99100const size_t HASH_THRESHOLD = 10 * 1024 * 1024 ;
100101
102+ struct rpc_msg_hello_rsp {
103+ uint8_t major;
104+ uint8_t minor;
105+ uint8_t patch;
106+ };
107+
101108struct rpc_msg_get_alloc_size_req {
102109 rpc_tensor tensor;
103110};
@@ -606,6 +613,20 @@ static ggml_backend_buffer_t ggml_backend_rpc_buffer_type_alloc_buffer(ggml_back
606613 }
607614}
608615
616+ static bool check_server_version (const std::shared_ptr<socket_t > & sock) {
617+ rpc_msg_hello_rsp response;
618+ bool status = send_rpc_cmd (sock, RPC_CMD_HELLO, nullptr , 0 , &response, sizeof (response));
619+ GGML_ASSERT (status);
620+ if (response.major != RPC_PROTO_MAJOR_VERSION || response.minor > RPC_PROTO_MINOR_VERSION) {
621+ fprintf (stderr, " RPC server version mismatch: %d.%d.%d\n " , response.major , response.minor , response.patch );
622+ return false ;
623+ }
624+ if (response.minor != RPC_PROTO_MINOR_VERSION || response.patch != RPC_PROTO_PATCH_VERSION) {
625+ fprintf (stderr, " WARNING: RPC server version mismatch: %d.%d.%d\n " , response.major , response.minor , response.patch );
626+ }
627+ return true ;
628+ }
629+
609630static size_t get_alignment (const std::shared_ptr<socket_t > & sock) {
610631 rpc_msg_get_alignment_rsp response;
611632 bool status = send_rpc_cmd (sock, RPC_CMD_GET_ALIGNMENT, nullptr , 0 , &response, sizeof (response));
@@ -754,6 +775,9 @@ ggml_backend_buffer_type_t ggml_backend_rpc_buffer_type(const char * endpoint) {
754775 fprintf (stderr, " Failed to connect to %s\n " , endpoint);
755776 return nullptr ;
756777 }
778+ if (!check_server_version (sock)) {
779+ return nullptr ;
780+ }
757781 size_t alignment = get_alignment (sock);
758782 size_t max_size = get_max_size (sock);
759783 ggml_backend_rpc_buffer_type_context * buft_ctx = new ggml_backend_rpc_buffer_type_context {
@@ -818,6 +842,7 @@ class rpc_server {
818842 }
819843 ~rpc_server ();
820844
845+ void hello (rpc_msg_hello_rsp & response);
821846 void alloc_buffer (const rpc_msg_alloc_buffer_req & request, rpc_msg_alloc_buffer_rsp & response);
822847 void get_alignment (rpc_msg_get_alignment_rsp & response);
823848 void get_max_size (rpc_msg_get_max_size_rsp & response);
@@ -846,6 +871,13 @@ class rpc_server {
846871 std::unordered_set<ggml_backend_buffer_t > buffers;
847872};
848873
874+ void rpc_server::hello (rpc_msg_hello_rsp & response) {
875+ response.major = RPC_PROTO_MAJOR_VERSION;
876+ response.minor = RPC_PROTO_MINOR_VERSION;
877+ response.patch = RPC_PROTO_PATCH_VERSION;
878+ GGML_PRINT_DEBUG (" [%s] version: %d.%d.%d\n " , __func__, response.major , response.minor , response.patch );
879+ }
880+
849881bool rpc_server::get_alloc_size (const rpc_msg_get_alloc_size_req & request, rpc_msg_get_alloc_size_rsp & response) {
850882 ggml_backend_buffer_type_t buft;
851883 struct ggml_init_params params {
@@ -1282,6 +1314,17 @@ static void rpc_serve_client(ggml_backend_t backend, const char * cache_dir,
12821314 break ;
12831315 }
12841316 switch (cmd) {
1317+ case RPC_CMD_HELLO: {
1318+ if (!recv_msg (sockfd, nullptr , 0 )) {
1319+ return ;
1320+ }
1321+ rpc_msg_hello_rsp response;
1322+ server.hello (response);
1323+ if (!send_msg (sockfd, &response, sizeof (response))) {
1324+ return ;
1325+ }
1326+ break ;
1327+ }
12851328 case RPC_CMD_ALLOC_BUFFER: {
12861329 rpc_msg_alloc_buffer_req request;
12871330 if (!recv_msg (sockfd, &request, sizeof (request))) {
0 commit comments