@@ -82,17 +82,18 @@ static_assert(sizeof(rpc_tensor) % 8 == 0, "rpc_tensor size must be multiple of
8282
8383// RPC commands
8484enum rpc_cmd {
85- ALLOC_BUFFER = 0 ,
86- GET_ALIGNMENT,
87- GET_MAX_SIZE,
88- BUFFER_GET_BASE,
89- FREE_BUFFER,
90- BUFFER_CLEAR,
91- SET_TENSOR,
92- GET_TENSOR,
93- COPY_TENSOR,
94- GRAPH_COMPUTE,
95- GET_DEVICE_MEMORY,
85+ RPC_CMD_ALLOC_BUFFER = 0 ,
86+ RPC_CMD_GET_ALIGNMENT,
87+ RPC_CMD_GET_MAX_SIZE,
88+ RPC_CMD_BUFFER_GET_BASE,
89+ RPC_CMD_FREE_BUFFER,
90+ RPC_CMD_BUFFER_CLEAR,
91+ RPC_CMD_SET_TENSOR,
92+ RPC_CMD_GET_TENSOR,
93+ RPC_CMD_COPY_TENSOR,
94+ RPC_CMD_GRAPH_COMPUTE,
95+ RPC_CMD_GET_DEVICE_MEMORY,
96+ RPC_CMD_COUNT,
9697};
9798
9899// RPC data structures
@@ -330,7 +331,7 @@ GGML_CALL static void ggml_backend_rpc_buffer_free_buffer(ggml_backend_buffer_t
330331 uint64_t remote_ptr = ctx->remote_ptr ;
331332 memcpy (input.data (), &remote_ptr, sizeof (remote_ptr));
332333 std::vector<uint8_t > output;
333- bool status = send_rpc_cmd (ctx->sock , FREE_BUFFER , input, output);
334+ bool status = send_rpc_cmd (ctx->sock , RPC_CMD_FREE_BUFFER , input, output);
334335 GGML_ASSERT (status);
335336 GGML_ASSERT (output.empty ());
336337 delete ctx;
@@ -346,7 +347,7 @@ GGML_CALL static void * ggml_backend_rpc_buffer_get_base(ggml_backend_buffer_t b
346347 uint64_t remote_ptr = ctx->remote_ptr ;
347348 memcpy (input.data (), &remote_ptr, sizeof (remote_ptr));
348349 std::vector<uint8_t > output;
349- bool status = send_rpc_cmd (ctx->sock , BUFFER_GET_BASE , input, output);
350+ bool status = send_rpc_cmd (ctx->sock , RPC_CMD_BUFFER_GET_BASE , input, output);
350351 GGML_ASSERT (status);
351352 GGML_ASSERT (output.size () == sizeof (uint64_t ));
352353 // output serialization format: | base_ptr (8 bytes) |
@@ -405,7 +406,7 @@ GGML_CALL static void ggml_backend_rpc_buffer_set_tensor(ggml_backend_buffer_t b
405406 memcpy (input.data () + sizeof (rpc_tensor), &offset, sizeof (offset));
406407 memcpy (input.data () + sizeof (rpc_tensor) + sizeof (offset), data, size);
407408 std::vector<uint8_t > output;
408- bool status = send_rpc_cmd (ctx->sock , SET_TENSOR , input, output);
409+ bool status = send_rpc_cmd (ctx->sock , RPC_CMD_SET_TENSOR , input, output);
409410 GGML_ASSERT (status);
410411}
411412
@@ -419,7 +420,7 @@ GGML_CALL static void ggml_backend_rpc_buffer_get_tensor(ggml_backend_buffer_t b
419420 memcpy (input.data () + sizeof (rpc_tensor), &offset, sizeof (offset));
420421 memcpy (input.data () + sizeof (rpc_tensor) + sizeof (offset), &size, sizeof (size));
421422 std::vector<uint8_t > output;
422- bool status = send_rpc_cmd (ctx->sock , GET_TENSOR , input, output);
423+ bool status = send_rpc_cmd (ctx->sock , RPC_CMD_GET_TENSOR , input, output);
423424 GGML_ASSERT (status);
424425 GGML_ASSERT (output.size () == size);
425426 // output serialization format: | data (size bytes) |
@@ -444,7 +445,7 @@ GGML_CALL static bool ggml_backend_rpc_buffer_cpy_tensor(ggml_backend_buffer_t b
444445 memcpy (input.data (), &rpc_src, sizeof (rpc_src));
445446 memcpy (input.data () + sizeof (rpc_src), &rpc_dst, sizeof (rpc_dst));
446447 std::vector<uint8_t > output;
447- bool status = send_rpc_cmd (ctx->sock , COPY_TENSOR , input, output);
448+ bool status = send_rpc_cmd (ctx->sock , RPC_CMD_COPY_TENSOR , input, output);
448449 GGML_ASSERT (status);
449450 // output serialization format: | result (1 byte) |
450451 GGML_ASSERT (output.size () == 1 );
@@ -459,7 +460,7 @@ GGML_CALL static void ggml_backend_rpc_buffer_clear(ggml_backend_buffer_t buffer
459460 memcpy (input.data (), &ctx->remote_ptr , sizeof (ctx->remote_ptr ));
460461 memcpy (input.data () + sizeof (ctx->remote_ptr ), &value, sizeof (value));
461462 std::vector<uint8_t > output;
462- bool status = send_rpc_cmd (ctx->sock , BUFFER_CLEAR , input, output);
463+ bool status = send_rpc_cmd (ctx->sock , RPC_CMD_BUFFER_CLEAR , input, output);
463464 GGML_ASSERT (status);
464465}
465466
@@ -488,7 +489,7 @@ GGML_CALL static ggml_backend_buffer_t ggml_backend_rpc_buffer_type_alloc_buffer
488489 memcpy (input.data (), &size, sizeof (size));
489490 std::vector<uint8_t > output;
490491 auto sock = get_socket (buft_ctx->endpoint );
491- bool status = send_rpc_cmd (sock, ALLOC_BUFFER , input, output);
492+ bool status = send_rpc_cmd (sock, RPC_CMD_ALLOC_BUFFER , input, output);
492493 GGML_ASSERT (status);
493494 GGML_ASSERT (output.size () == 2 *sizeof (uint64_t ));
494495 // output serialization format: | remote_ptr (8 bytes) | remote_size (8 bytes) |
@@ -511,7 +512,7 @@ static size_t get_alignment(const std::shared_ptr<socket_t> & sock) {
511512 // input serialization format: | 0 bytes |
512513 std::vector<uint8_t > input;
513514 std::vector<uint8_t > output;
514- bool status = send_rpc_cmd (sock, GET_ALIGNMENT , input, output);
515+ bool status = send_rpc_cmd (sock, RPC_CMD_GET_ALIGNMENT , input, output);
515516 GGML_ASSERT (status);
516517 GGML_ASSERT (output.size () == sizeof (uint64_t ));
517518 // output serialization format: | alignment (8 bytes) |
@@ -529,7 +530,7 @@ static size_t get_max_size(const std::shared_ptr<socket_t> & sock) {
529530 // input serialization format: | 0 bytes |
530531 std::vector<uint8_t > input;
531532 std::vector<uint8_t > output;
532- bool status = send_rpc_cmd (sock, GET_MAX_SIZE , input, output);
533+ bool status = send_rpc_cmd (sock, RPC_CMD_GET_MAX_SIZE , input, output);
533534 GGML_ASSERT (status);
534535 GGML_ASSERT (output.size () == sizeof (uint64_t ));
535536 // output serialization format: | max_size (8 bytes) |
@@ -622,7 +623,7 @@ GGML_CALL static enum ggml_status ggml_backend_rpc_graph_compute(ggml_backend_t
622623 serialize_graph (cgraph, input);
623624 std::vector<uint8_t > output;
624625 auto sock = get_socket (rpc_ctx->endpoint );
625- bool status = send_rpc_cmd (sock, GRAPH_COMPUTE , input, output);
626+ bool status = send_rpc_cmd (sock, RPC_CMD_GRAPH_COMPUTE , input, output);
626627 GGML_ASSERT (status);
627628 GGML_ASSERT (output.size () == 1 );
628629 return (enum ggml_status)output[0 ];
@@ -719,7 +720,7 @@ static void get_device_memory(const std::shared_ptr<socket_t> & sock, size_t * f
719720 // input serialization format: | 0 bytes |
720721 std::vector<uint8_t > input;
721722 std::vector<uint8_t > output;
722- bool status = send_rpc_cmd (sock, GET_DEVICE_MEMORY , input, output);
723+ bool status = send_rpc_cmd (sock, RPC_CMD_GET_DEVICE_MEMORY , input, output);
723724 GGML_ASSERT (status);
724725 GGML_ASSERT (output.size () == 2 *sizeof (uint64_t ));
725726 // output serialization format: | free (8 bytes) | total (8 bytes) |
@@ -1098,59 +1099,69 @@ static void rpc_serve_client(ggml_backend_t backend, sockfd_t sockfd, size_t fre
10981099 if (!recv_data (sockfd, &cmd, 1 )) {
10991100 break ;
11001101 }
1102+ if (cmd >= RPC_CMD_COUNT) {
1103+ // fail fast if the command is invalid
1104+ fprintf (stderr, " Unknown command: %d\n " , cmd);
1105+ break ;
1106+ }
11011107 std::vector<uint8_t > input;
11021108 std::vector<uint8_t > output;
11031109 uint64_t input_size;
11041110 if (!recv_data (sockfd, &input_size, sizeof (input_size))) {
11051111 break ;
11061112 }
1107- input.resize (input_size);
1113+ try {
1114+ input.resize (input_size);
1115+ } catch (const std::bad_alloc & e) {
1116+ fprintf (stderr, " Failed to allocate input buffer of size %" PRIu64 " \n " , input_size);
1117+ break ;
1118+ }
11081119 if (!recv_data (sockfd, input.data (), input_size)) {
11091120 break ;
11101121 }
11111122 bool ok = true ;
11121123 switch (cmd) {
1113- case ALLOC_BUFFER : {
1124+ case RPC_CMD_ALLOC_BUFFER : {
11141125 ok = server.alloc_buffer (input, output);
11151126 break ;
11161127 }
1117- case GET_ALIGNMENT : {
1128+ case RPC_CMD_GET_ALIGNMENT : {
11181129 server.get_alignment (output);
11191130 break ;
11201131 }
1121- case GET_MAX_SIZE : {
1132+ case RPC_CMD_GET_MAX_SIZE : {
11221133 server.get_max_size (output);
11231134 break ;
11241135 }
1125- case BUFFER_GET_BASE : {
1136+ case RPC_CMD_BUFFER_GET_BASE : {
11261137 ok = server.buffer_get_base (input, output);
11271138 break ;
11281139 }
1129- case FREE_BUFFER : {
1140+ case RPC_CMD_FREE_BUFFER : {
11301141 ok = server.free_buffer (input);
11311142 break ;
11321143 }
1133- case BUFFER_CLEAR : {
1144+ case RPC_CMD_BUFFER_CLEAR : {
11341145 ok = server.buffer_clear (input);
11351146 break ;
11361147 }
1137- case SET_TENSOR : {
1148+ case RPC_CMD_SET_TENSOR : {
11381149 ok = server.set_tensor (input);
11391150 break ;
11401151 }
1141- case GET_TENSOR : {
1152+ case RPC_CMD_GET_TENSOR : {
11421153 ok = server.get_tensor (input, output);
11431154 break ;
11441155 }
1145- case COPY_TENSOR : {
1156+ case RPC_CMD_COPY_TENSOR : {
11461157 ok = server.copy_tensor (input, output);
11471158 break ;
11481159 }
1149- case GRAPH_COMPUTE : {
1160+ case RPC_CMD_GRAPH_COMPUTE : {
11501161 ok = server.graph_compute (input, output);
11511162 break ;
11521163 }
1153- case GET_DEVICE_MEMORY : {
1164+ case RPC_CMD_GET_DEVICE_MEMORY : {
11541165 // output serialization format: | free (8 bytes) | total (8 bytes) |
11551166 output.resize (2 *sizeof (uint64_t ), 0 );
11561167 memcpy (output.data (), &free_mem, sizeof (free_mem));
@@ -1203,8 +1214,10 @@ void start_rpc_server(ggml_backend_t backend, const char * endpoint, size_t free
12031214 return ;
12041215 }
12051216 printf (" Accepted client connection, free_mem=%zu, total_mem=%zu\n " , free_mem, total_mem);
1217+ fflush (stdout);
12061218 rpc_serve_client (backend, client_socket->fd , free_mem, total_mem);
12071219 printf (" Client connection closed\n " );
1220+ fflush (stdout);
12081221 }
12091222#ifdef _WIN32
12101223 WSACleanup ();
0 commit comments