@@ -93,9 +93,23 @@ enum rpc_cmd {
9393    RPC_CMD_COPY_TENSOR,
9494    RPC_CMD_GRAPH_COMPUTE,
9595    RPC_CMD_GET_DEVICE_MEMORY,
96+     RPC_CMD_INIT_TENSOR,
97+     RPC_CMD_GET_ALLOC_SIZE,
9698    RPC_CMD_COUNT,
9799};
98100
101+ struct  rpc_msg_get_alloc_size_req  {
102+     rpc_tensor tensor;
103+ };
104+ 
105+ struct  rpc_msg_get_alloc_size_rsp  {
106+     uint64_t  alloc_size;
107+ };
108+ 
109+ struct  rpc_msg_init_tensor_req  {
110+     rpc_tensor tensor;
111+ };
112+ 
99113struct  rpc_msg_alloc_buffer_req  {
100114    uint64_t  size;
101115};
@@ -461,10 +475,18 @@ static rpc_tensor serialize_tensor(const ggml_tensor * tensor) {
461475}
462476
463477static  void  ggml_backend_rpc_buffer_init_tensor (ggml_backend_buffer_t  buffer, ggml_tensor * tensor) {
464-     UNUSED (buffer);
465-     if  (ggml_is_quantized (tensor->type )) {
466-         //  TODO: this check is due to MATRIX_ROW_PADDING in CUDA and should be generalized
467-         GGML_ASSERT (tensor->ne [0 ] % 512  == 0  && " unsupported quantized tensor"  );
478+     ggml_backend_rpc_buffer_context * ctx = (ggml_backend_rpc_buffer_context *)buffer->context ;
479+ 
480+     //  CUDA backend on the server pads everything to 512 due to CUDA limitations.
481+     //  Due to bandwidth constraints, we only call the server init tensor functions if necessary.
482+     //  In particular, only quantized tensors need padding
483+     if  (ggml_is_quantized (tensor->type ) && (tensor->ne [0 ] % 512  != 0 ) && (tensor->view_src  == nullptr )) {
484+         rpc_msg_init_tensor_req request;
485+ 
486+         request.tensor  = serialize_tensor (tensor);
487+ 
488+         bool  status = send_rpc_cmd (ctx->sock , RPC_CMD_INIT_TENSOR, &request, sizeof (request), nullptr , 0 );
489+         GGML_ASSERT (status);
468490    }
469491}
470492
@@ -577,8 +599,23 @@ static size_t ggml_backend_rpc_get_max_size(ggml_backend_buffer_type_t buft) {
577599}
578600
579601static  size_t  ggml_backend_rpc_buffer_type_get_alloc_size (ggml_backend_buffer_type_t  buft, const  ggml_tensor * tensor) {
580-     UNUSED (buft);
581-     return  ggml_nbytes (tensor);
602+     //  See comments in init_tensor.
603+     if  (ggml_is_quantized (tensor->type ) && (tensor->ne [0 ] % 512  != 0 ) && (tensor->view_src  == nullptr )) {
604+         ggml_backend_rpc_buffer_type_context * buft_ctx = (ggml_backend_rpc_buffer_type_context *)buft->context ;
605+         auto  sock = get_socket (buft_ctx->endpoint );
606+ 
607+         rpc_msg_get_alloc_size_req request;
608+ 
609+         request.tensor  = serialize_tensor (tensor);
610+ 
611+         rpc_msg_get_alloc_size_rsp response;
612+         bool  status = send_rpc_cmd (sock, RPC_CMD_GET_ALLOC_SIZE, &request, sizeof (request), &response, sizeof (response));
613+         GGML_ASSERT (status);
614+ 
615+         return  response.alloc_size ;
616+     } else  {
617+         return  ggml_nbytes (tensor);
618+     }
582619}
583620
584621static  ggml_backend_buffer_type_i ggml_backend_rpc_buffer_type_interface = {
@@ -757,6 +794,8 @@ class rpc_server {
757794    bool  get_tensor (const  rpc_msg_get_tensor_req & request, std::vector<uint8_t > & response);
758795    bool  copy_tensor (const  rpc_msg_copy_tensor_req & request, rpc_msg_copy_tensor_rsp & response);
759796    bool  graph_compute (const  std::vector<uint8_t > & input, rpc_msg_graph_compute_rsp & response);
797+     bool  init_tensor (const  rpc_msg_init_tensor_req & request);
798+     bool  get_alloc_size (const  rpc_msg_get_alloc_size_req & request, rpc_msg_get_alloc_size_rsp & response);
760799
761800private: 
762801    ggml_tensor * deserialize_tensor (struct  ggml_context  * ctx, const  rpc_tensor * tensor);
@@ -770,6 +809,36 @@ class rpc_server {
770809    std::unordered_set<ggml_backend_buffer_t > buffers;
771810};
772811
812+ bool  rpc_server::get_alloc_size (const  rpc_msg_get_alloc_size_req & request, rpc_msg_get_alloc_size_rsp & response) {
813+     ggml_backend_buffer_type_t  buft;
814+     struct  ggml_init_params  params {
815+         /* .mem_size   =*/   ggml_tensor_overhead(),
816+         /* .mem_buffer =*/   NULL ,
817+         /* .no_alloc   =*/   true ,
818+     };
819+ 
820+     struct  ggml_context  * ctx = ggml_init (params);
821+     ggml_tensor * tensor = deserialize_tensor (ctx, &request.tensor );
822+ 
823+     if  (tensor == nullptr ) {
824+         GGML_LOG_ERROR (" Null tensor pointer passed to server get_alloc_size function.\n "  );
825+         ggml_free (ctx);
826+         return  false ;
827+     }
828+ 
829+     if  (tensor->buffer  == nullptr ) {
830+         // No buffer allocated.
831+         buft = ggml_backend_get_default_buffer_type (backend);
832+     } else  {
833+         buft = tensor->buffer ->buft ;
834+     }
835+ 
836+     response.alloc_size  = ggml_backend_buft_get_alloc_size (buft,tensor);
837+ 
838+     ggml_free (ctx);
839+     return  true ;
840+ }
841+ 
773842void  rpc_server::alloc_buffer (const  rpc_msg_alloc_buffer_req & request, rpc_msg_alloc_buffer_rsp & response) {
774843    ggml_backend_buffer_type_t  buft = ggml_backend_get_default_buffer_type (backend);
775844    ggml_backend_buffer_t  buffer = ggml_backend_buft_alloc_buffer (buft, request.size );
@@ -905,6 +974,40 @@ bool rpc_server::set_tensor(const std::vector<uint8_t> & input) {
905974    return  true ;
906975}
907976
977+ bool  rpc_server::init_tensor (const  rpc_msg_init_tensor_req & request) {
978+     struct  ggml_init_params  params {
979+         /* .mem_size   =*/   ggml_tensor_overhead(),
980+         /* .mem_buffer =*/   NULL ,
981+         /* .no_alloc   =*/   true ,
982+     };
983+     struct  ggml_context  * ctx = ggml_init (params);
984+     ggml_tensor * tensor = deserialize_tensor (ctx, &request.tensor );
985+     if  (tensor == nullptr ) {
986+         GGML_LOG_ERROR (" Null tensor pointer passed to server init_tensor function.\n "  );
987+         ggml_free (ctx);
988+         return  false ;
989+     }
990+ 
991+     //  Call the backend's buffer_init_tensor function
992+     ggml_backend_buffer_t  buffer = tensor->buffer ;
993+     if  (buffer && buffer->iface .init_tensor ) {
994+         buffer->iface .init_tensor (buffer, tensor);
995+     } else  {
996+         GGML_LOG_ERROR (" Null buffer for tensor passed to init_tensor function\n "  );
997+     }
998+ 
999+     if  (tensor->extra  != nullptr ) {
1000+         //  This pointer can either be passed around client/server, or probably better stored server-side and kept track of.
1001+         //  Currently unimplemented.
1002+         GGML_LOG_ERROR (" tensor->extra populated by the backend, this is currently unsupported.\n "  );
1003+         ggml_free (ctx);
1004+         return  false ;
1005+     }
1006+ 
1007+     ggml_free (ctx);
1008+     return  true ;
1009+ }
1010+ 
9081011bool  rpc_server::get_tensor (const  rpc_msg_get_tensor_req & request, std::vector<uint8_t > & response) {
9091012    struct  ggml_init_params  params {
9101013        /* .mem_size   =*/   ggml_tensor_overhead(),
@@ -1058,6 +1161,18 @@ static void rpc_serve_client(ggml_backend_t backend, sockfd_t sockfd, size_t fre
10581161                }
10591162                break ;
10601163            }
1164+             case  RPC_CMD_GET_ALLOC_SIZE: {
1165+                 rpc_msg_get_alloc_size_req request;
1166+                 if  (!recv_msg (sockfd, &request, sizeof (request))) {
1167+                     return ;
1168+                 }
1169+                 rpc_msg_get_alloc_size_rsp response;
1170+                 server.get_alloc_size (request, response);
1171+                 if  (!send_msg (sockfd, &response, sizeof (response))) {
1172+                     return ;
1173+                 }
1174+                 break ;
1175+             }
10611176            case  RPC_CMD_GET_ALIGNMENT: {
10621177                if  (!recv_msg (sockfd, nullptr , 0 )) {
10631178                    return ;
@@ -1133,6 +1248,19 @@ static void rpc_serve_client(ggml_backend_t backend, sockfd_t sockfd, size_t fre
11331248                }
11341249                break ;
11351250            }
1251+             case  RPC_CMD_INIT_TENSOR: {
1252+                 rpc_msg_init_tensor_req request;
1253+                 if  (!recv_msg (sockfd, &request,sizeof (request))) {
1254+                     return ;
1255+                 }
1256+                 if  (!server.init_tensor (request)) {
1257+                     return ;
1258+                 }
1259+                 if  (!send_msg (sockfd, nullptr , 0 )) {
1260+                     return ;
1261+                 }
1262+                 break ;
1263+             }
11361264            case  RPC_CMD_GET_TENSOR: {
11371265                rpc_msg_get_tensor_req request;
11381266                if  (!recv_msg (sockfd, &request, sizeof (request))) {
0 commit comments