@@ -53,6 +53,9 @@ struct socket_t {
5353 }
5454};
5555
56+ // macro for nicer error messages on server crash
57+ #define RPC_STATUS_ASSERT (x ) if (!(x)) GGML_ABORT(" Remote RPC server crashed or returned malformed response" )
58+
5659// all RPC structures must be packed
5760#pragma pack(push, 1)
5861// ggml_tensor is serialized into rpc_tensor
@@ -425,7 +428,7 @@ static bool send_rpc_cmd(const std::shared_ptr<socket_t> & sock, enum rpc_cmd cm
425428static bool check_server_version (const std::shared_ptr<socket_t > & sock) {
426429 rpc_msg_hello_rsp response;
427430 bool status = send_rpc_cmd (sock, RPC_CMD_HELLO, nullptr , 0 , &response, sizeof (response));
428- GGML_ASSERT (status);
431+ RPC_STATUS_ASSERT (status);
429432 if (response.major != RPC_PROTO_MAJOR_VERSION || response.minor > RPC_PROTO_MINOR_VERSION) {
430433 fprintf (stderr, " RPC server version mismatch: %d.%d.%d\n " , response.major , response.minor , response.patch );
431434 return false ;
@@ -481,7 +484,7 @@ static void ggml_backend_rpc_buffer_free_buffer(ggml_backend_buffer_t buffer) {
481484 ggml_backend_rpc_buffer_context * ctx = (ggml_backend_rpc_buffer_context *)buffer->context ;
482485 rpc_msg_free_buffer_req request = {ctx->remote_ptr };
483486 bool status = send_rpc_cmd (ctx->sock , RPC_CMD_FREE_BUFFER, &request, sizeof (request), nullptr , 0 );
484- GGML_ASSERT (status);
487+ RPC_STATUS_ASSERT (status);
485488 delete ctx;
486489}
487490
@@ -493,7 +496,7 @@ static void * ggml_backend_rpc_buffer_get_base(ggml_backend_buffer_t buffer) {
493496 rpc_msg_buffer_get_base_req request = {ctx->remote_ptr };
494497 rpc_msg_buffer_get_base_rsp response;
495498 bool status = send_rpc_cmd (ctx->sock , RPC_CMD_BUFFER_GET_BASE, &request, sizeof (request), &response, sizeof (response));
496- GGML_ASSERT (status);
499+ RPC_STATUS_ASSERT (status);
497500 ctx->base_ptr = reinterpret_cast <void *>(response.base_ptr );
498501 return ctx->base_ptr ;
499502}
@@ -545,7 +548,7 @@ static enum ggml_status ggml_backend_rpc_buffer_init_tensor(ggml_backend_buffer_
545548 request.tensor = serialize_tensor (tensor);
546549
547550 bool status = send_rpc_cmd (ctx->sock , RPC_CMD_INIT_TENSOR, &request, sizeof (request), nullptr , 0 );
548- GGML_ASSERT (status);
551+ RPC_STATUS_ASSERT (status);
549552 }
550553 return GGML_STATUS_SUCCESS;
551554}
@@ -560,7 +563,7 @@ static void ggml_backend_rpc_buffer_set_tensor(ggml_backend_buffer_t buffer, ggm
560563 request.hash = fnv_hash ((const uint8_t *)data, size);
561564 rpc_msg_set_tensor_hash_rsp response;
562565 bool status = send_rpc_cmd (ctx->sock , RPC_CMD_SET_TENSOR_HASH, &request, sizeof (request), &response, sizeof (response));
563- GGML_ASSERT (status);
566+ RPC_STATUS_ASSERT (status);
564567 if (response.result ) {
565568 // the server has the same data, no need to send it
566569 return ;
@@ -573,7 +576,7 @@ static void ggml_backend_rpc_buffer_set_tensor(ggml_backend_buffer_t buffer, ggm
573576 memcpy (input.data () + sizeof (rpc_tensor), &offset, sizeof (offset));
574577 memcpy (input.data () + sizeof (rpc_tensor) + sizeof (offset), data, size);
575578 bool status = send_rpc_cmd (ctx->sock , RPC_CMD_SET_TENSOR, input.data (), input.size ());
576- GGML_ASSERT (status);
579+ RPC_STATUS_ASSERT (status);
577580}
578581
579582static void ggml_backend_rpc_buffer_get_tensor (ggml_backend_buffer_t buffer, const ggml_tensor * tensor, void * data, size_t offset, size_t size) {
@@ -583,7 +586,7 @@ static void ggml_backend_rpc_buffer_get_tensor(ggml_backend_buffer_t buffer, con
583586 request.offset = offset;
584587 request.size = size;
585588 bool status = send_rpc_cmd (ctx->sock , RPC_CMD_GET_TENSOR, &request, sizeof (request), data, size);
586- GGML_ASSERT (status);
589+ RPC_STATUS_ASSERT (status);
587590}
588591
589592static bool ggml_backend_rpc_buffer_cpy_tensor (ggml_backend_buffer_t buffer, const ggml_tensor * src, ggml_tensor * dst) {
@@ -601,15 +604,15 @@ static bool ggml_backend_rpc_buffer_cpy_tensor(ggml_backend_buffer_t buffer, con
601604 request.dst = serialize_tensor (dst);
602605 rpc_msg_copy_tensor_rsp response;
603606 bool status = send_rpc_cmd (ctx->sock , RPC_CMD_COPY_TENSOR, &request, sizeof (request), &response, sizeof (response));
604- GGML_ASSERT (status);
607+ RPC_STATUS_ASSERT (status);
605608 return response.result ;
606609}
607610
608611static void ggml_backend_rpc_buffer_clear (ggml_backend_buffer_t buffer, uint8_t value) {
609612 ggml_backend_rpc_buffer_context * ctx = (ggml_backend_rpc_buffer_context *)buffer->context ;
610613 rpc_msg_buffer_clear_req request = {ctx->remote_ptr , value};
611614 bool status = send_rpc_cmd (ctx->sock , RPC_CMD_BUFFER_CLEAR, &request, sizeof (request), nullptr , 0 );
612- GGML_ASSERT (status);
615+ RPC_STATUS_ASSERT (status);
613616}
614617
615618static ggml_backend_buffer_i ggml_backend_rpc_buffer_interface = {
@@ -635,7 +638,7 @@ static ggml_backend_buffer_t ggml_backend_rpc_buffer_type_alloc_buffer(ggml_back
635638 rpc_msg_alloc_buffer_rsp response;
636639 auto sock = get_socket (buft_ctx->endpoint );
637640 bool status = send_rpc_cmd (sock, RPC_CMD_ALLOC_BUFFER, &request, sizeof (request), &response, sizeof (response));
638- GGML_ASSERT (status);
641+ RPC_STATUS_ASSERT (status);
639642 if (response.remote_ptr != 0 ) {
640643 ggml_backend_buffer_t buffer = ggml_backend_buffer_init (buft,
641644 ggml_backend_rpc_buffer_interface,
@@ -650,7 +653,7 @@ static ggml_backend_buffer_t ggml_backend_rpc_buffer_type_alloc_buffer(ggml_back
650653static size_t get_alignment (const std::shared_ptr<socket_t > & sock) {
651654 rpc_msg_get_alignment_rsp response;
652655 bool status = send_rpc_cmd (sock, RPC_CMD_GET_ALIGNMENT, nullptr , 0 , &response, sizeof (response));
653- GGML_ASSERT (status);
656+ RPC_STATUS_ASSERT (status);
654657 return response.alignment ;
655658}
656659
@@ -662,7 +665,7 @@ static size_t ggml_backend_rpc_buffer_type_get_alignment(ggml_backend_buffer_typ
662665static size_t get_max_size (const std::shared_ptr<socket_t > & sock) {
663666 rpc_msg_get_max_size_rsp response;
664667 bool status = send_rpc_cmd (sock, RPC_CMD_GET_MAX_SIZE, nullptr , 0 , &response, sizeof (response));
665- GGML_ASSERT (status);
668+ RPC_STATUS_ASSERT (status);
666669 return response.max_size ;
667670}
668671
@@ -683,7 +686,7 @@ static size_t ggml_backend_rpc_buffer_type_get_alloc_size(ggml_backend_buffer_ty
683686
684687 rpc_msg_get_alloc_size_rsp response;
685688 bool status = send_rpc_cmd (sock, RPC_CMD_GET_ALLOC_SIZE, &request, sizeof (request), &response, sizeof (response));
686- GGML_ASSERT (status);
689+ RPC_STATUS_ASSERT (status);
687690
688691 return response.alloc_size ;
689692 } else {
@@ -761,7 +764,7 @@ static enum ggml_status ggml_backend_rpc_graph_compute(ggml_backend_t backend, g
761764 rpc_msg_graph_compute_rsp response;
762765 auto sock = get_socket (rpc_ctx->endpoint );
763766 bool status = send_rpc_cmd (sock, RPC_CMD_GRAPH_COMPUTE, input.data (), input.size (), &response, sizeof (response));
764- GGML_ASSERT (status);
767+ RPC_STATUS_ASSERT (status);
765768 return (enum ggml_status)response.result ;
766769}
767770
@@ -835,7 +838,7 @@ bool ggml_backend_is_rpc(ggml_backend_t backend) {
835838static void get_device_memory (const std::shared_ptr<socket_t > & sock, size_t * free, size_t * total) {
836839 rpc_msg_get_device_memory_rsp response;
837840 bool status = send_rpc_cmd (sock, RPC_CMD_GET_DEVICE_MEMORY, nullptr , 0 , &response, sizeof (response));
838- GGML_ASSERT (status);
841+ RPC_STATUS_ASSERT (status);
839842 *free = response.free_mem ;
840843 *total = response.total_mem ;
841844}
0 commit comments