@@ -58,7 +58,7 @@ struct socket_t {
5858};
5959
6060// ggml_tensor is serialized into rpc_tensor
61- #pragma pack(push, 1)
61+ #pragma pack(1)
6262struct rpc_tensor {
6363 uint64_t id;
6464 uint32_t type;
@@ -76,7 +76,6 @@ struct rpc_tensor {
7676
7777 char padding[4 ];
7878};
79- #pragma pack(pop)
8079
8180static_assert (sizeof (rpc_tensor) % 8 == 0 , " rpc_tensor size must be multiple of 8" );
8281
@@ -96,6 +95,77 @@ enum rpc_cmd {
9695 RPC_CMD_COUNT,
9796};
9897
98+ #pragma pack(1)
99+ struct rpc_msg_alloc_buffer_req {
100+ uint64_t size;
101+ };
102+
103+ #pragma pack(1)
104+ struct rpc_msg_alloc_buffer_rsp {
105+ uint64_t remote_ptr;
106+ uint64_t remote_size;
107+ };
108+
109+ #pragma pack(1)
110+ struct rpc_msg_get_alignment_rsp {
111+ uint64_t alignment;
112+ };
113+
114+ #pragma pack(1)
115+ struct rpc_msg_get_max_size_rsp {
116+ uint64_t max_size;
117+ };
118+
119+ #pragma pack(1)
120+ struct rpc_msg_buffer_get_base_req {
121+ uint64_t remote_ptr;
122+ };
123+
124+ #pragma pack(1)
125+ struct rpc_msg_buffer_get_base_rsp {
126+ uint64_t base_ptr;
127+ };
128+
129+ #pragma pack(1)
130+ struct rpc_msg_free_buffer_req {
131+ uint64_t remote_ptr;
132+ };
133+
134+ #pragma pack(1)
135+ struct rpc_msg_buffer_clear_req {
136+ uint64_t remote_ptr;
137+ uint8_t value;
138+ };
139+
140+ #pragma pack(1)
141+ struct rpc_msg_get_tensor_req {
142+ rpc_tensor tensor;
143+ uint64_t offset;
144+ uint64_t size;
145+ };
146+
147+ #pragma pack(1)
148+ struct rpc_msg_copy_tensor_req {
149+ rpc_tensor src;
150+ rpc_tensor dst;
151+ };
152+
153+ #pragma pack(1)
154+ struct rpc_msg_copy_tensor_rsp {
155+ uint8_t result;
156+ };
157+
158+ #pragma pack(1)
159+ struct rpc_msg_graph_compute_rsp {
160+ uint8_t result;
161+ };
162+
163+ #pragma pack(1)
164+ struct rpc_msg_get_device_memory_rsp {
165+ uint64_t free_mem;
166+ uint64_t total_mem;
167+ };
168+
99169// RPC data structures
100170
101171static ggml_guid_t ggml_backend_rpc_guid () {
@@ -252,28 +322,27 @@ static bool parse_endpoint(const std::string & endpoint, std::string & host, int
252322
253323// RPC request : | rpc_cmd (1 byte) | request_size (8 bytes) | request_data (request_size bytes) |
254324// RPC response: | response_size (8 bytes) | response_data (response_size bytes) |
255- static bool send_rpc_cmd (const std::shared_ptr<socket_t > & sock, enum rpc_cmd cmd, const std::vector< uint8_t > & input, std::vector< uint8_t > & output) {
325+ static bool send_rpc_cmd (const std::shared_ptr<socket_t > & sock, enum rpc_cmd cmd, const void * input, size_t input_size, void * output, size_t output_size ) {
256326 uint8_t cmd_byte = cmd;
257327 if (!send_data (sock->fd , &cmd_byte, sizeof (cmd_byte))) {
258328 return false ;
259329 }
260- uint64_t input_size = input.size ();
261330 if (!send_data (sock->fd , &input_size, sizeof (input_size))) {
262331 return false ;
263332 }
264- if (!send_data (sock->fd , input. data (), input. size () )) {
333+ if (!send_data (sock->fd , input, input_size )) {
265334 return false ;
266335 }
267- uint64_t output_size;
268- if (!recv_data (sock->fd , &output_size, sizeof (output_size))) {
336+ // TODO: currently the output_size is always known, do we need support for commands with variable output size?
337+ // even if we do, we can skip sending output_size from the server for commands with known output size
338+ uint64_t out_size;
339+ if (!recv_data (sock->fd , &out_size, sizeof (out_size))) {
269340 return false ;
270341 }
271- if (output_size == 0 ) {
272- output.clear ();
273- return true ;
342+ if (out_size != output_size) {
343+ return false ;
274344 }
275- output.resize (output_size);
276- if (!recv_data (sock->fd , output.data (), output_size)) {
345+ if (!recv_data (sock->fd , output, output_size)) {
277346 return false ;
278347 }
279348 return true ;
@@ -326,14 +395,9 @@ static const char * ggml_backend_rpc_buffer_get_name(ggml_backend_buffer_t buffe
326395
327396static void ggml_backend_rpc_buffer_free_buffer (ggml_backend_buffer_t buffer) {
328397 ggml_backend_rpc_buffer_context * ctx = (ggml_backend_rpc_buffer_context *)buffer->context ;
329- // input serialization format: | remote_ptr (8 bytes) |
330- std::vector<uint8_t > input (sizeof (uint64_t ), 0 );
331- uint64_t remote_ptr = ctx->remote_ptr ;
332- memcpy (input.data (), &remote_ptr, sizeof (remote_ptr));
333- std::vector<uint8_t > output;
334- bool status = send_rpc_cmd (ctx->sock , RPC_CMD_FREE_BUFFER, input, output);
398+ rpc_msg_free_buffer_req request = {ctx->remote_ptr };
399+ bool status = send_rpc_cmd (ctx->sock , RPC_CMD_FREE_BUFFER, &request, sizeof (request), nullptr , 0 );
335400 GGML_ASSERT (status);
336- GGML_ASSERT (output.empty ());
337401 delete ctx;
338402}
339403
@@ -342,20 +406,13 @@ static void * ggml_backend_rpc_buffer_get_base(ggml_backend_buffer_t buffer) {
342406 if (ctx->base_cache .find (buffer) != ctx->base_cache .end ()) {
343407 return ctx->base_cache [buffer];
344408 }
345- // input serialization format: | remote_ptr (8 bytes) |
346- std::vector<uint8_t > input (sizeof (uint64_t ), 0 );
347- uint64_t remote_ptr = ctx->remote_ptr ;
348- memcpy (input.data (), &remote_ptr, sizeof (remote_ptr));
349- std::vector<uint8_t > output;
350- bool status = send_rpc_cmd (ctx->sock , RPC_CMD_BUFFER_GET_BASE, input, output);
409+ rpc_msg_buffer_get_base_req request = {ctx->remote_ptr };
410+ rpc_msg_buffer_get_base_rsp response;
411+ bool status = send_rpc_cmd (ctx->sock , RPC_CMD_BUFFER_GET_BASE, &request, sizeof (request), &response, sizeof (response));
351412 GGML_ASSERT (status);
352- GGML_ASSERT (output.size () == sizeof (uint64_t ));
353- // output serialization format: | base_ptr (8 bytes) |
354- uint64_t base_ptr;
355- memcpy (&base_ptr, output.data (), sizeof (base_ptr));
356- void * base = reinterpret_cast <void *>(base_ptr);
357- ctx->base_cache [buffer] = base;
358- return base;
413+ void * base_ptr = reinterpret_cast <void *>(response.base_ptr );
414+ ctx->base_cache [buffer] = base_ptr;
415+ return base_ptr;
359416}
360417
361418static rpc_tensor serialize_tensor (const ggml_tensor * tensor) {
@@ -405,26 +462,18 @@ static void ggml_backend_rpc_buffer_set_tensor(ggml_backend_buffer_t buffer, ggm
405462 memcpy (input.data (), &rpc_tensor, sizeof (rpc_tensor));
406463 memcpy (input.data () + sizeof (rpc_tensor), &offset, sizeof (offset));
407464 memcpy (input.data () + sizeof (rpc_tensor) + sizeof (offset), data, size);
408- std::vector<uint8_t > output;
409- bool status = send_rpc_cmd (ctx->sock , RPC_CMD_SET_TENSOR, input, output);
465+ bool status = send_rpc_cmd (ctx->sock , RPC_CMD_SET_TENSOR, input.data (), input.size (), nullptr , 0 );
410466 GGML_ASSERT (status);
411467}
412468
413469static void ggml_backend_rpc_buffer_get_tensor (ggml_backend_buffer_t buffer, const ggml_tensor * tensor, void * data, size_t offset, size_t size) {
414470 ggml_backend_rpc_buffer_context * ctx = (ggml_backend_rpc_buffer_context *)buffer->context ;
415- // input serialization format: | rpc_tensor | offset (8 bytes) | size (8 bytes) |
416- int input_size = sizeof (rpc_tensor) + 2 *sizeof (uint64_t );
417- std::vector<uint8_t > input (input_size, 0 );
418- rpc_tensor rpc_tensor = serialize_tensor (tensor);
419- memcpy (input.data (), &rpc_tensor, sizeof (rpc_tensor));
420- memcpy (input.data () + sizeof (rpc_tensor), &offset, sizeof (offset));
421- memcpy (input.data () + sizeof (rpc_tensor) + sizeof (offset), &size, sizeof (size));
422- std::vector<uint8_t > output;
423- bool status = send_rpc_cmd (ctx->sock , RPC_CMD_GET_TENSOR, input, output);
471+ rpc_msg_get_tensor_req request;
472+ request.tensor = serialize_tensor (tensor);
473+ request.offset = offset;
474+ request.size = size;
475+ bool status = send_rpc_cmd (ctx->sock , RPC_CMD_GET_TENSOR, &request, sizeof (request), data, size);
424476 GGML_ASSERT (status);
425- GGML_ASSERT (output.size () == size);
426- // output serialization format: | data (size bytes) |
427- memcpy (data, output.data (), size);
428477}
429478
430479static bool ggml_backend_rpc_buffer_cpy_tensor (ggml_backend_buffer_t buffer, const ggml_tensor * src, ggml_tensor * dst) {
@@ -437,30 +486,19 @@ static bool ggml_backend_rpc_buffer_cpy_tensor(ggml_backend_buffer_t buffer, con
437486 return false ;
438487 }
439488 ggml_backend_rpc_buffer_context * ctx = (ggml_backend_rpc_buffer_context *)buffer->context ;
440- // input serialization format: | rpc_tensor src | rpc_tensor dst |
441- int input_size = 2 *sizeof (rpc_tensor);
442- std::vector<uint8_t > input (input_size, 0 );
443- rpc_tensor rpc_src = serialize_tensor (src);
444- rpc_tensor rpc_dst = serialize_tensor (dst);
445- memcpy (input.data (), &rpc_src, sizeof (rpc_src));
446- memcpy (input.data () + sizeof (rpc_src), &rpc_dst, sizeof (rpc_dst));
447- std::vector<uint8_t > output;
448- bool status = send_rpc_cmd (ctx->sock , RPC_CMD_COPY_TENSOR, input, output);
489+ rpc_msg_copy_tensor_req request;
490+ request.src = serialize_tensor (src);
491+ request.dst = serialize_tensor (dst);
492+ rpc_msg_copy_tensor_rsp response;
493+ bool status = send_rpc_cmd (ctx->sock , RPC_CMD_COPY_TENSOR, &request, sizeof (request), &response, sizeof (response));
449494 GGML_ASSERT (status);
450- // output serialization format: | result (1 byte) |
451- GGML_ASSERT (output.size () == 1 );
452- return output[0 ];
495+ return response.result ;
453496}
454497
455498static void ggml_backend_rpc_buffer_clear (ggml_backend_buffer_t buffer, uint8_t value) {
456499 ggml_backend_rpc_buffer_context * ctx = (ggml_backend_rpc_buffer_context *)buffer->context ;
457- // serialization format: | bufptr (8 bytes) | value (1 byte) |
458- int input_size = sizeof (uint64_t ) + sizeof (uint8_t );
459- std::vector<uint8_t > input (input_size, 0 );
460- memcpy (input.data (), &ctx->remote_ptr , sizeof (ctx->remote_ptr ));
461- memcpy (input.data () + sizeof (ctx->remote_ptr ), &value, sizeof (value));
462- std::vector<uint8_t > output;
463- bool status = send_rpc_cmd (ctx->sock , RPC_CMD_BUFFER_CLEAR, input, output);
500+ rpc_msg_buffer_clear_req request = {ctx->remote_ptr , value};
501+ bool status = send_rpc_cmd (ctx->sock , RPC_CMD_BUFFER_CLEAR, &request, sizeof (request), nullptr , 0 );
464502 GGML_ASSERT (status);
465503}
466504
@@ -484,42 +522,27 @@ static const char * ggml_backend_rpc_buffer_type_name(ggml_backend_buffer_type_t
484522
485523static ggml_backend_buffer_t ggml_backend_rpc_buffer_type_alloc_buffer (ggml_backend_buffer_type_t buft, size_t size) {
486524 ggml_backend_rpc_buffer_type_context * buft_ctx = (ggml_backend_rpc_buffer_type_context *)buft->context ;
487- // input serialization format: | size (8 bytes) |
488- int input_size = sizeof (uint64_t );
489- std::vector<uint8_t > input (input_size, 0 );
490- memcpy (input.data (), &size, sizeof (size));
491- std::vector<uint8_t > output;
525+ rpc_msg_alloc_buffer_req request = {size};
526+ rpc_msg_alloc_buffer_rsp response;
492527 auto sock = get_socket (buft_ctx->endpoint );
493- bool status = send_rpc_cmd (sock, RPC_CMD_ALLOC_BUFFER, input, output );
528+ bool status = send_rpc_cmd (sock, RPC_CMD_ALLOC_BUFFER, &request, sizeof (request), &response, sizeof (response) );
494529 GGML_ASSERT (status);
495- GGML_ASSERT (output.size () == 2 *sizeof (uint64_t ));
496- // output serialization format: | remote_ptr (8 bytes) | remote_size (8 bytes) |
497- uint64_t remote_ptr;
498- memcpy (&remote_ptr, output.data (), sizeof (remote_ptr));
499- size_t remote_size;
500- memcpy (&remote_size, output.data () + sizeof (uint64_t ), sizeof (remote_size));
501- if (remote_ptr != 0 ) {
530+ if (response.remote_ptr != 0 ) {
502531 ggml_backend_buffer_t buffer = ggml_backend_buffer_init (buft,
503532 ggml_backend_rpc_buffer_interface,
504- new ggml_backend_rpc_buffer_context{sock, {}, remote_ptr, " RPC[" + std::string (buft_ctx->endpoint ) + " ]" },
505- remote_size);
533+ new ggml_backend_rpc_buffer_context{sock, {}, response. remote_ptr , " RPC[" + std::string (buft_ctx->endpoint ) + " ]" },
534+ response. remote_size );
506535 return buffer;
507536 } else {
508537 return nullptr ;
509538 }
510539}
511540
512541static size_t get_alignment (const std::shared_ptr<socket_t > & sock) {
513- // input serialization format: | 0 bytes |
514- std::vector<uint8_t > input;
515- std::vector<uint8_t > output;
516- bool status = send_rpc_cmd (sock, RPC_CMD_GET_ALIGNMENT, input, output);
542+ rpc_msg_get_alignment_rsp response;
543+ bool status = send_rpc_cmd (sock, RPC_CMD_GET_ALIGNMENT, nullptr , 0 , &response, sizeof (response));
517544 GGML_ASSERT (status);
518- GGML_ASSERT (output.size () == sizeof (uint64_t ));
519- // output serialization format: | alignment (8 bytes) |
520- uint64_t alignment;
521- memcpy (&alignment, output.data (), sizeof (alignment));
522- return alignment;
545+ return response.alignment ;
523546}
524547
525548static size_t ggml_backend_rpc_buffer_type_get_alignment (ggml_backend_buffer_type_t buft) {
@@ -528,16 +551,10 @@ static size_t ggml_backend_rpc_buffer_type_get_alignment(ggml_backend_buffer_typ
528551}
529552
530553static size_t get_max_size (const std::shared_ptr<socket_t > & sock) {
531- // input serialization format: | 0 bytes |
532- std::vector<uint8_t > input;
533- std::vector<uint8_t > output;
534- bool status = send_rpc_cmd (sock, RPC_CMD_GET_MAX_SIZE, input, output);
554+ rpc_msg_get_max_size_rsp response;
555+ bool status = send_rpc_cmd (sock, RPC_CMD_GET_MAX_SIZE, nullptr , 0 , &response, sizeof (response));
535556 GGML_ASSERT (status);
536- GGML_ASSERT (output.size () == sizeof (uint64_t ));
537- // output serialization format: | max_size (8 bytes) |
538- uint64_t max_size;
539- memcpy (&max_size, output.data (), sizeof (max_size));
540- return max_size;
557+ return response.max_size ;
541558}
542559
543560static size_t ggml_backend_rpc_get_max_size (ggml_backend_buffer_type_t buft) {
@@ -622,12 +639,11 @@ static enum ggml_status ggml_backend_rpc_graph_compute(ggml_backend_t backend, g
622639 ggml_backend_rpc_context * rpc_ctx = (ggml_backend_rpc_context *)backend->context ;
623640 std::vector<uint8_t > input;
624641 serialize_graph (cgraph, input);
625- std::vector< uint8_t > output ;
642+ rpc_msg_graph_compute_rsp response ;
626643 auto sock = get_socket (rpc_ctx->endpoint );
627- bool status = send_rpc_cmd (sock, RPC_CMD_GRAPH_COMPUTE, input, output );
644+ bool status = send_rpc_cmd (sock, RPC_CMD_GRAPH_COMPUTE, input. data (), input. size (), &response, sizeof (response) );
628645 GGML_ASSERT (status);
629- GGML_ASSERT (output.size () == 1 );
630- return (enum ggml_status)output[0 ];
646+ return (enum ggml_status)response.result ;
631647}
632648
633649static ggml_backend_i ggml_backend_rpc_interface = {
@@ -702,19 +718,11 @@ GGML_API bool ggml_backend_is_rpc(ggml_backend_t backend) {
702718}
703719
704720static void get_device_memory (const std::shared_ptr<socket_t > & sock, size_t * free, size_t * total) {
705- // input serialization format: | 0 bytes |
706- std::vector<uint8_t > input;
707- std::vector<uint8_t > output;
708- bool status = send_rpc_cmd (sock, RPC_CMD_GET_DEVICE_MEMORY, input, output);
721+ rpc_msg_get_device_memory_rsp response;
722+ bool status = send_rpc_cmd (sock, RPC_CMD_GET_DEVICE_MEMORY, nullptr , 0 , &response, sizeof (response));
709723 GGML_ASSERT (status);
710- GGML_ASSERT (output.size () == 2 *sizeof (uint64_t ));
711- // output serialization format: | free (8 bytes) | total (8 bytes) |
712- uint64_t free_mem;
713- memcpy (&free_mem, output.data (), sizeof (free_mem));
714- uint64_t total_mem;
715- memcpy (&total_mem, output.data () + sizeof (uint64_t ), sizeof (total_mem));
716- *free = free_mem;
717- *total = total_mem;
724+ *free = response.free_mem ;
725+ *total = response.total_mem ;
718726}
719727
720728GGML_API void ggml_backend_rpc_get_device_memory (const char * endpoint, size_t * free, size_t * total) {
0 commit comments