@@ -58,7 +58,7 @@ struct socket_t {
5858};
5959
6060// ggml_tensor is serialized into rpc_tensor
61- #pragma pack(push, 1)
61+ #pragma pack(1)
6262struct rpc_tensor {
6363 uint64_t id;
6464 uint32_t type;
@@ -96,6 +96,17 @@ enum rpc_cmd {
9696 RPC_CMD_COUNT,
9797};
9898
99+ #pragma pack(1)
100+ struct request_alloc_buffer {
101+ uint64_t size;
102+ };
103+
104+ #pragma pack(1)
105+ struct response_alloc_buffer {
106+ uint64_t remote_ptr;
107+ uint64_t remote_size;
108+ };
109+
99110// RPC data structures
100111
101112static ggml_guid_t ggml_backend_rpc_guid () {
@@ -252,30 +263,31 @@ static bool parse_endpoint(const std::string & endpoint, std::string & host, int
252263
253264// RPC request : | rpc_cmd (1 byte) | request_size (8 bytes) | request_data (request_size bytes) |
254265// RPC response: | response_size (8 bytes) | response_data (response_size bytes) |
255- static bool send_rpc_cmd (const std::shared_ptr<socket_t > & sock, enum rpc_cmd cmd, const std::vector< uint8_t > & input, std::vector< uint8_t > & output) {
266+ static bool send_rpc_cmd (const std::shared_ptr<socket_t > & sock, enum rpc_cmd cmd, const void * input, size_t input_size, void * output, size_t output_size ) {
256267 uint8_t cmd_byte = cmd;
257268 if (!send_data (sock->fd , &cmd_byte, sizeof (cmd_byte))) {
258269 return false ;
259270 }
260- uint64_t input_size = input.size ();
261271 if (!send_data (sock->fd , &input_size, sizeof (input_size))) {
262272 return false ;
263273 }
264- if (!send_data (sock->fd , input. data (), input. size () )) {
274+ if (!send_data (sock->fd , input, input_size )) {
265275 return false ;
266276 }
267- uint64_t output_size;
268- if (!recv_data (sock->fd , &output_size, sizeof (output_size))) {
277+ // TODO: currently the output_size is always known, do we need support for commands with variable output size?
278+ // even if we do, we can skip sending output_size from the server for commands with known output size
279+ uint64_t out_size;
280+ if (!recv_data (sock->fd , &out_size, sizeof (out_size))) {
269281 return false ;
270282 }
271- if (output_size == 0 ) {
272- output.clear ();
273- return true ;
274- }
275- output.resize (output_size);
276- if (!recv_data (sock->fd , output.data (), output_size)) {
283+ if (out_size != output_size) {
277284 return false ;
278285 }
286+ if (output_size > 0 ) {
287+ if (!recv_data (sock->fd , output, output_size)) {
288+ return false ;
289+ }
290+ }
279291 return true ;
280292}
281293
@@ -484,25 +496,15 @@ static const char * ggml_backend_rpc_buffer_type_name(ggml_backend_buffer_type_t
484496
485497static ggml_backend_buffer_t ggml_backend_rpc_buffer_type_alloc_buffer (ggml_backend_buffer_type_t buft, size_t size) {
486498 ggml_backend_rpc_buffer_type_context * buft_ctx = (ggml_backend_rpc_buffer_type_context *)buft->context ;
487- // input serialization format: | size (8 bytes) |
488- int input_size = sizeof (uint64_t );
489- std::vector<uint8_t > input (input_size, 0 );
490- memcpy (input.data (), &size, sizeof (size));
491- std::vector<uint8_t > output;
499+ request_alloc_buffer request = {size};
500+ response_alloc_buffer response;
492501 auto sock = get_socket (buft_ctx->endpoint );
493- bool status = send_rpc_cmd (sock, RPC_CMD_ALLOC_BUFFER, input, output);
494- GGML_ASSERT (status);
495- GGML_ASSERT (output.size () == 2 *sizeof (uint64_t ));
496- // output serialization format: | remote_ptr (8 bytes) | remote_size (8 bytes) |
497- uint64_t remote_ptr;
498- memcpy (&remote_ptr, output.data (), sizeof (remote_ptr));
499- size_t remote_size;
500- memcpy (&remote_size, output.data () + sizeof (uint64_t ), sizeof (remote_size));
501- if (remote_ptr != 0 ) {
502+ bool status = send_rpc_cmd (sock, RPC_CMD_ALLOC_BUFFER, &request, sizeof (request), &response, sizeof (response));
503+ if (response.remote_ptr != 0 ) {
502504 ggml_backend_buffer_t buffer = ggml_backend_buffer_init (buft,
503505 ggml_backend_rpc_buffer_interface,
504- new ggml_backend_rpc_buffer_context{sock, {}, remote_ptr, " RPC[" + std::string (buft_ctx->endpoint ) + " ]" },
505- remote_size);
506+ new ggml_backend_rpc_buffer_context{sock, {}, response. remote_ptr , " RPC[" + std::string (buft_ctx->endpoint ) + " ]" },
507+ response. remote_size );
506508 return buffer;
507509 } else {
508510 return nullptr ;
0 commit comments