@@ -58,7 +58,7 @@ struct socket_t {
5858};
5959
6060//  ggml_tensor is serialized into rpc_tensor
61- #pragma  pack(push,  1)
61+ #pragma  pack(1)
6262struct  rpc_tensor  {
6363    uint64_t  id;
6464    uint32_t  type;
@@ -76,7 +76,6 @@ struct rpc_tensor {
7676
7777    char  padding[4 ];
7878};
79- #pragma  pack(pop)
8079
8180static_assert (sizeof (rpc_tensor) % 8  == 0 , " rpc_tensor size must be multiple of 8"  );
8281
@@ -96,6 +95,77 @@ enum rpc_cmd {
9695    RPC_CMD_COUNT,
9796};
9897
98+ #pragma  pack(1)
99+ struct  rpc_msg_alloc_buffer_req  {
100+     uint64_t  size;
101+ };
102+ 
103+ #pragma  pack(1)
104+ struct  rpc_msg_alloc_buffer_rsp  {
105+     uint64_t  remote_ptr;
106+     uint64_t  remote_size;
107+ };
108+ 
109+ #pragma  pack(1)
110+ struct  rpc_msg_get_alignment_rsp  {
111+     uint64_t  alignment;
112+ };
113+ 
114+ #pragma  pack(1)
115+ struct  rpc_msg_get_max_size_rsp  {
116+     uint64_t  max_size;
117+ };
118+ 
119+ #pragma  pack(1)
120+ struct  rpc_msg_buffer_get_base_req  {
121+     uint64_t  remote_ptr;
122+ };
123+ 
124+ #pragma  pack(1)
125+ struct  rpc_msg_buffer_get_base_rsp  {
126+     uint64_t  base_ptr;
127+ };
128+ 
129+ #pragma  pack(1)
130+ struct  rpc_msg_free_buffer_req  {
131+     uint64_t  remote_ptr;
132+ };
133+ 
134+ #pragma  pack(1)
135+ struct  rpc_msg_buffer_clear_req  {
136+     uint64_t  remote_ptr;
137+     uint8_t  value;
138+ };
139+ 
140+ #pragma  pack(1)
141+ struct  rpc_msg_get_tensor_req  {
142+     rpc_tensor tensor;
143+     uint64_t  offset;
144+     uint64_t  size;
145+ };
146+ 
147+ #pragma  pack(1)
148+ struct  rpc_msg_copy_tensor_req  {
149+     rpc_tensor src;
150+     rpc_tensor dst;
151+ };
152+ 
153+ #pragma  pack(1)
154+ struct  rpc_msg_copy_tensor_rsp  {
155+     uint8_t  result;
156+ };
157+ 
158+ #pragma  pack(1)
159+ struct  rpc_msg_graph_compute_rsp  {
160+     uint8_t  result;
161+ };
162+ 
163+ #pragma  pack(1)
164+ struct  rpc_msg_get_device_memory_rsp  {
165+     uint64_t  free_mem;
166+     uint64_t  total_mem;
167+ };
168+ 
99169//  RPC data structures
100170
101171static  ggml_guid_t  ggml_backend_rpc_guid () {
@@ -252,28 +322,27 @@ static bool parse_endpoint(const std::string & endpoint, std::string & host, int
252322
253323//  RPC request : | rpc_cmd (1 byte) | request_size (8 bytes) | request_data (request_size bytes) |
254324//  RPC response: | response_size (8 bytes) | response_data (response_size bytes) |
255- static  bool  send_rpc_cmd (const  std::shared_ptr<socket_t > & sock, enum  rpc_cmd cmd, const  std::vector< uint8_t > &  input, std::vector< uint8_t > &  output) {
325+ static  bool  send_rpc_cmd (const  std::shared_ptr<socket_t > & sock, enum  rpc_cmd cmd, const  void  *  input, size_t  input_size,  void  *  output,  size_t  output_size ) {
256326    uint8_t  cmd_byte = cmd;
257327    if  (!send_data (sock->fd , &cmd_byte, sizeof (cmd_byte))) {
258328        return  false ;
259329    }
260-     uint64_t  input_size = input.size ();
261330    if  (!send_data (sock->fd , &input_size, sizeof (input_size))) {
262331        return  false ;
263332    }
264-     if  (!send_data (sock->fd , input. data (), input. size () )) {
333+     if  (!send_data (sock->fd , input, input_size )) {
265334        return  false ;
266335    }
267-     uint64_t  output_size;
268-     if  (!recv_data (sock->fd , &output_size, sizeof (output_size))) {
336+     //  TODO: currently the output_size is always known, do we need support for commands with variable output size?
337+     //  even if we do, we can skip sending output_size from the server for commands with known output size
338+     uint64_t  out_size;
339+     if  (!recv_data (sock->fd , &out_size, sizeof (out_size))) {
269340        return  false ;
270341    }
271-     if  (output_size == 0 ) {
272-         output.clear ();
273-         return  true ;
342+     if  (out_size != output_size) {
343+         return  false ;
274344    }
275-     output.resize (output_size);
276-     if  (!recv_data (sock->fd , output.data (), output_size)) {
345+     if  (!recv_data (sock->fd , output, output_size)) {
277346        return  false ;
278347    }
279348    return  true ;
@@ -326,14 +395,9 @@ static const char * ggml_backend_rpc_buffer_get_name(ggml_backend_buffer_t buffe
326395
327396static  void  ggml_backend_rpc_buffer_free_buffer (ggml_backend_buffer_t  buffer) {
328397    ggml_backend_rpc_buffer_context * ctx = (ggml_backend_rpc_buffer_context *)buffer->context ;
329-     //  input serialization format: | remote_ptr (8 bytes) |
330-     std::vector<uint8_t > input (sizeof (uint64_t ), 0 );
331-     uint64_t  remote_ptr = ctx->remote_ptr ;
332-     memcpy (input.data (), &remote_ptr, sizeof (remote_ptr));
333-     std::vector<uint8_t > output;
334-     bool  status = send_rpc_cmd (ctx->sock , RPC_CMD_FREE_BUFFER, input, output);
398+     rpc_msg_free_buffer_req request = {ctx->remote_ptr };
399+     bool  status = send_rpc_cmd (ctx->sock , RPC_CMD_FREE_BUFFER, &request, sizeof (request), nullptr , 0 );
335400    GGML_ASSERT (status);
336-     GGML_ASSERT (output.empty ());
337401    delete  ctx;
338402}
339403
@@ -342,20 +406,13 @@ static void * ggml_backend_rpc_buffer_get_base(ggml_backend_buffer_t buffer) {
342406    if  (ctx->base_cache .find (buffer) != ctx->base_cache .end ()) {
343407        return  ctx->base_cache [buffer];
344408    }
345-     //  input serialization format: | remote_ptr (8 bytes) |
346-     std::vector<uint8_t > input (sizeof (uint64_t ), 0 );
347-     uint64_t  remote_ptr = ctx->remote_ptr ;
348-     memcpy (input.data (), &remote_ptr, sizeof (remote_ptr));
349-     std::vector<uint8_t > output;
350-     bool  status = send_rpc_cmd (ctx->sock , RPC_CMD_BUFFER_GET_BASE, input, output);
409+     rpc_msg_buffer_get_base_req request = {ctx->remote_ptr };
410+     rpc_msg_buffer_get_base_rsp response;
411+     bool  status = send_rpc_cmd (ctx->sock , RPC_CMD_BUFFER_GET_BASE, &request, sizeof (request), &response, sizeof (response));
351412    GGML_ASSERT (status);
352-     GGML_ASSERT (output.size () == sizeof (uint64_t ));
353-     //  output serialization format: | base_ptr (8 bytes) |
354-     uint64_t  base_ptr;
355-     memcpy (&base_ptr, output.data (), sizeof (base_ptr));
356-     void  * base = reinterpret_cast <void  *>(base_ptr);
357-     ctx->base_cache [buffer] = base;
358-     return  base;
413+     void  * base_ptr = reinterpret_cast <void  *>(response.base_ptr );
414+     ctx->base_cache [buffer] = base_ptr;
415+     return  base_ptr;
359416}
360417
361418static  rpc_tensor serialize_tensor (const  ggml_tensor * tensor) {
@@ -405,26 +462,18 @@ static void ggml_backend_rpc_buffer_set_tensor(ggml_backend_buffer_t buffer, ggm
405462    memcpy (input.data (), &rpc_tensor, sizeof (rpc_tensor));
406463    memcpy (input.data () + sizeof (rpc_tensor), &offset, sizeof (offset));
407464    memcpy (input.data () + sizeof (rpc_tensor) + sizeof (offset), data, size);
408-     std::vector<uint8_t > output;
409-     bool  status = send_rpc_cmd (ctx->sock , RPC_CMD_SET_TENSOR, input, output);
465+     bool  status = send_rpc_cmd (ctx->sock , RPC_CMD_SET_TENSOR, input.data (), input.size (), nullptr , 0 );
410466    GGML_ASSERT (status);
411467}
412468
413469static  void  ggml_backend_rpc_buffer_get_tensor (ggml_backend_buffer_t  buffer, const  ggml_tensor * tensor, void  * data, size_t  offset, size_t  size) {
414470    ggml_backend_rpc_buffer_context * ctx = (ggml_backend_rpc_buffer_context *)buffer->context ;
415-     //  input serialization format: | rpc_tensor | offset (8 bytes) | size (8 bytes) |
416-     int  input_size = sizeof (rpc_tensor) + 2 *sizeof (uint64_t );
417-     std::vector<uint8_t > input (input_size, 0 );
418-     rpc_tensor rpc_tensor = serialize_tensor (tensor);
419-     memcpy (input.data (), &rpc_tensor, sizeof (rpc_tensor));
420-     memcpy (input.data () + sizeof (rpc_tensor), &offset, sizeof (offset));
421-     memcpy (input.data () + sizeof (rpc_tensor) + sizeof (offset), &size, sizeof (size));
422-     std::vector<uint8_t > output;
423-     bool  status = send_rpc_cmd (ctx->sock , RPC_CMD_GET_TENSOR, input, output);
471+     rpc_msg_get_tensor_req request;
472+     request.tensor  = serialize_tensor (tensor);
473+     request.offset  = offset;
474+     request.size  = size;
475+     bool  status = send_rpc_cmd (ctx->sock , RPC_CMD_GET_TENSOR, &request, sizeof (request), data, size);
424476    GGML_ASSERT (status);
425-     GGML_ASSERT (output.size () == size);
426-     //  output serialization format: | data (size bytes) |
427-     memcpy (data, output.data (), size);
428477}
429478
430479static  bool  ggml_backend_rpc_buffer_cpy_tensor (ggml_backend_buffer_t  buffer, const  ggml_tensor * src, ggml_tensor * dst) {
@@ -437,30 +486,19 @@ static bool ggml_backend_rpc_buffer_cpy_tensor(ggml_backend_buffer_t buffer, con
437486        return  false ;
438487    }
439488    ggml_backend_rpc_buffer_context * ctx = (ggml_backend_rpc_buffer_context *)buffer->context ;
440-     //  input serialization format: | rpc_tensor src | rpc_tensor dst |
441-     int  input_size = 2 *sizeof (rpc_tensor);
442-     std::vector<uint8_t > input (input_size, 0 );
443-     rpc_tensor rpc_src = serialize_tensor (src);
444-     rpc_tensor rpc_dst = serialize_tensor (dst);
445-     memcpy (input.data (), &rpc_src, sizeof (rpc_src));
446-     memcpy (input.data () + sizeof (rpc_src), &rpc_dst, sizeof (rpc_dst));
447-     std::vector<uint8_t > output;
448-     bool  status = send_rpc_cmd (ctx->sock , RPC_CMD_COPY_TENSOR, input, output);
489+     rpc_msg_copy_tensor_req request;
490+     request.src  = serialize_tensor (src);
491+     request.dst  = serialize_tensor (dst);
492+     rpc_msg_copy_tensor_rsp response;
493+     bool  status = send_rpc_cmd (ctx->sock , RPC_CMD_COPY_TENSOR, &request, sizeof (request), &response, sizeof (response));
449494    GGML_ASSERT (status);
450-     //  output serialization format: | result (1 byte) |
451-     GGML_ASSERT (output.size () == 1 );
452-     return  output[0 ];
495+     return  response.result ;
453496}
454497
455498static  void  ggml_backend_rpc_buffer_clear (ggml_backend_buffer_t  buffer, uint8_t  value) {
456499    ggml_backend_rpc_buffer_context * ctx = (ggml_backend_rpc_buffer_context *)buffer->context ;
457-     //  serialization format: | bufptr (8 bytes) | value (1 byte) |
458-     int  input_size = sizeof (uint64_t ) + sizeof (uint8_t );
459-     std::vector<uint8_t > input (input_size, 0 );
460-     memcpy (input.data (), &ctx->remote_ptr , sizeof (ctx->remote_ptr ));
461-     memcpy (input.data () + sizeof (ctx->remote_ptr ), &value, sizeof (value));
462-     std::vector<uint8_t > output;
463-     bool  status = send_rpc_cmd (ctx->sock , RPC_CMD_BUFFER_CLEAR, input, output);
500+     rpc_msg_buffer_clear_req request = {ctx->remote_ptr , value};
501+     bool  status = send_rpc_cmd (ctx->sock , RPC_CMD_BUFFER_CLEAR, &request, sizeof (request), nullptr , 0 );
464502    GGML_ASSERT (status);
465503}
466504
@@ -484,42 +522,27 @@ static const char * ggml_backend_rpc_buffer_type_name(ggml_backend_buffer_type_t
484522
485523static  ggml_backend_buffer_t  ggml_backend_rpc_buffer_type_alloc_buffer (ggml_backend_buffer_type_t  buft, size_t  size) {
486524    ggml_backend_rpc_buffer_type_context * buft_ctx = (ggml_backend_rpc_buffer_type_context *)buft->context ;
487-     //  input serialization format: | size (8 bytes) |
488-     int  input_size = sizeof (uint64_t );
489-     std::vector<uint8_t > input (input_size, 0 );
490-     memcpy (input.data (), &size, sizeof (size));
491-     std::vector<uint8_t > output;
525+     rpc_msg_alloc_buffer_req request = {size};
526+     rpc_msg_alloc_buffer_rsp response;
492527    auto  sock = get_socket (buft_ctx->endpoint );
493-     bool  status = send_rpc_cmd (sock, RPC_CMD_ALLOC_BUFFER, input, output );
528+     bool  status = send_rpc_cmd (sock, RPC_CMD_ALLOC_BUFFER, &request,  sizeof (request), &response,  sizeof (response) );
494529    GGML_ASSERT (status);
495-     GGML_ASSERT (output.size () == 2 *sizeof (uint64_t ));
496-     //  output serialization format: | remote_ptr (8 bytes) | remote_size (8 bytes) |
497-     uint64_t  remote_ptr;
498-     memcpy (&remote_ptr, output.data (), sizeof (remote_ptr));
499-     size_t  remote_size;
500-     memcpy (&remote_size, output.data () + sizeof (uint64_t ), sizeof (remote_size));
501-     if  (remote_ptr != 0 ) {
530+     if  (response.remote_ptr  != 0 ) {
502531        ggml_backend_buffer_t  buffer = ggml_backend_buffer_init (buft,
503532            ggml_backend_rpc_buffer_interface,
504-             new  ggml_backend_rpc_buffer_context{sock, {}, remote_ptr, " RPC["   + std::string (buft_ctx->endpoint ) + " ]"  },
505-             remote_size);
533+             new  ggml_backend_rpc_buffer_context{sock, {}, response. remote_ptr , " RPC["   + std::string (buft_ctx->endpoint ) + " ]"  },
534+             response. remote_size );
506535        return  buffer;
507536    } else  {
508537        return  nullptr ;
509538    }
510539}
511540
512541static  size_t  get_alignment (const  std::shared_ptr<socket_t > & sock) {
513-     //  input serialization format: | 0 bytes |
514-     std::vector<uint8_t > input;
515-     std::vector<uint8_t > output;
516-     bool  status = send_rpc_cmd (sock, RPC_CMD_GET_ALIGNMENT, input, output);
542+     rpc_msg_get_alignment_rsp response;
543+     bool  status = send_rpc_cmd (sock, RPC_CMD_GET_ALIGNMENT, nullptr , 0 , &response, sizeof (response));
517544    GGML_ASSERT (status);
518-     GGML_ASSERT (output.size () == sizeof (uint64_t ));
519-     //  output serialization format: | alignment (8 bytes) |
520-     uint64_t  alignment;
521-     memcpy (&alignment, output.data (), sizeof (alignment));
522-     return  alignment;
545+     return  response.alignment ;
523546}
524547
525548static  size_t  ggml_backend_rpc_buffer_type_get_alignment (ggml_backend_buffer_type_t  buft) {
@@ -528,16 +551,10 @@ static size_t ggml_backend_rpc_buffer_type_get_alignment(ggml_backend_buffer_typ
528551}
529552
530553static  size_t  get_max_size (const  std::shared_ptr<socket_t > & sock) {
531-     //  input serialization format: | 0 bytes |
532-     std::vector<uint8_t > input;
533-     std::vector<uint8_t > output;
534-     bool  status = send_rpc_cmd (sock, RPC_CMD_GET_MAX_SIZE, input, output);
554+     rpc_msg_get_max_size_rsp response;
555+     bool  status = send_rpc_cmd (sock, RPC_CMD_GET_MAX_SIZE, nullptr , 0 , &response, sizeof (response));
535556    GGML_ASSERT (status);
536-     GGML_ASSERT (output.size () == sizeof (uint64_t ));
537-     //  output serialization format: | max_size (8 bytes) |
538-     uint64_t  max_size;
539-     memcpy (&max_size, output.data (), sizeof (max_size));
540-     return  max_size;
557+     return  response.max_size ;
541558}
542559
543560static  size_t  ggml_backend_rpc_get_max_size (ggml_backend_buffer_type_t  buft) {
@@ -622,12 +639,11 @@ static enum ggml_status ggml_backend_rpc_graph_compute(ggml_backend_t backend, g
622639    ggml_backend_rpc_context * rpc_ctx = (ggml_backend_rpc_context *)backend->context ;
623640    std::vector<uint8_t > input;
624641    serialize_graph (cgraph, input);
625-     std::vector< uint8_t > output ;
642+     rpc_msg_graph_compute_rsp response ;
626643    auto  sock = get_socket (rpc_ctx->endpoint );
627-     bool  status = send_rpc_cmd (sock, RPC_CMD_GRAPH_COMPUTE, input, output );
644+     bool  status = send_rpc_cmd (sock, RPC_CMD_GRAPH_COMPUTE, input. data (), input. size (), &response,  sizeof (response) );
628645    GGML_ASSERT (status);
629-     GGML_ASSERT (output.size () == 1 );
630-     return  (enum  ggml_status)output[0 ];
646+     return  (enum  ggml_status)response.result ;
631647}
632648
633649static  ggml_backend_i ggml_backend_rpc_interface = {
@@ -702,19 +718,11 @@ GGML_API bool ggml_backend_is_rpc(ggml_backend_t backend) {
702718}
703719
704720static  void  get_device_memory (const  std::shared_ptr<socket_t > & sock, size_t  * free, size_t  * total) {
705-     //  input serialization format: | 0 bytes |
706-     std::vector<uint8_t > input;
707-     std::vector<uint8_t > output;
708-     bool  status = send_rpc_cmd (sock, RPC_CMD_GET_DEVICE_MEMORY, input, output);
721+     rpc_msg_get_device_memory_rsp response;
722+     bool  status = send_rpc_cmd (sock, RPC_CMD_GET_DEVICE_MEMORY, nullptr , 0 , &response, sizeof (response));
709723    GGML_ASSERT (status);
710-     GGML_ASSERT (output.size () == 2 *sizeof (uint64_t ));
711-     //  output serialization format: | free (8 bytes) | total (8 bytes) |
712-     uint64_t  free_mem;
713-     memcpy (&free_mem, output.data (), sizeof (free_mem));
714-     uint64_t  total_mem;
715-     memcpy (&total_mem, output.data () + sizeof (uint64_t ), sizeof (total_mem));
716-     *free = free_mem;
717-     *total = total_mem;
724+     *free = response.free_mem ;
725+     *total = response.total_mem ;
718726}
719727
720728GGML_API void  ggml_backend_rpc_get_device_memory (const  char  * endpoint, size_t  * free, size_t  * total) {
0 commit comments