@@ -86,6 +86,7 @@ enum rpc_cmd {
8686 RPC_CMD_GET_DEVICE_MEMORY,
8787 RPC_CMD_INIT_TENSOR,
8888 RPC_CMD_GET_ALLOC_SIZE,
89+ RPC_CMD_SUPPORT_OP,
8990 RPC_CMD_COUNT,
9091};
9192
@@ -158,6 +159,11 @@ struct rpc_msg_get_device_memory_rsp {
158159 uint64_t free_mem;
159160 uint64_t total_mem;
160161};
162+
163+ struct rpc_msg_support_op_rsp {
164+ uint8_t result;
165+ };
166+
161167#pragma pack(pop)
162168
163169// RPC data structures
@@ -438,7 +444,7 @@ static rpc_tensor serialize_tensor(const ggml_tensor * tensor) {
438444 rpc_tensor result;
439445 result.id = reinterpret_cast <uint64_t >(tensor);
440446 result.type = tensor->type ;
441- if (tensor->buffer ) {
447+ if (tensor->buffer && tensor-> buffer -> context ) {
442448 ggml_backend_buffer_t buffer = tensor->buffer ;
443449 ggml_backend_rpc_buffer_context * ctx = (ggml_backend_rpc_buffer_context *)buffer->context ;
444450 result.buffer = ctx->remote_ptr ;
@@ -767,6 +773,31 @@ void ggml_backend_rpc_get_device_memory(const char * endpoint, size_t * free, si
767773 get_device_memory (sock, free, total);
768774}
769775
776+ static bool ggml_backend_rpc_support_op (const char * endpoint, const ggml_tensor * tensor) {
777+ std::vector<uint8_t > input;
778+ {
779+ std::vector<rpc_tensor> tensors;
780+ for (int i = 0 ; i < GGML_MAX_SRC; i++) {
781+ if (tensor->src [i] == nullptr ) {
782+ break ;
783+ }
784+ tensors.push_back (serialize_tensor (tensor->src [i]));
785+ }
786+ tensors.push_back (serialize_tensor (tensor));
787+ // serialization format: | n_tensors (4 bytes) | tensors (n_tensors * sizeof(rpc_tensor)) |
788+ uint32_t n_tensors = tensors.size ();
789+ int input_size = sizeof (uint32_t ) + n_tensors * sizeof (rpc_tensor);
790+ input.resize (input_size, 0 );
791+ memcpy (input.data (), &n_tensors, sizeof (n_tensors));
792+ memcpy (input.data () + sizeof (n_tensors), tensors.data (), n_tensors * sizeof (rpc_tensor));
793+ }
794+ rpc_msg_support_op_rsp response;
795+ auto sock = get_socket (endpoint);
796+ bool status = send_rpc_cmd (sock, RPC_CMD_SUPPORT_OP, input.data (), input.size (), &response, sizeof (response));
797+ GGML_ASSERT (status);
798+ return response.result ;
799+ }
800+
770801// RPC server-side implementation
771802
772803class rpc_server {
@@ -786,6 +817,7 @@ class rpc_server {
786817 bool graph_compute (const std::vector<uint8_t > & input, rpc_msg_graph_compute_rsp & response);
787818 bool init_tensor (const rpc_msg_init_tensor_req & request);
788819 bool get_alloc_size (const rpc_msg_get_alloc_size_req & request, rpc_msg_get_alloc_size_rsp & response);
820+ bool support_op (const std::vector<uint8_t > & input, rpc_msg_support_op_rsp & response);
789821
790822private:
791823 ggml_tensor * deserialize_tensor (struct ggml_context * ctx, const rpc_tensor * tensor);
@@ -829,6 +861,42 @@ bool rpc_server::get_alloc_size(const rpc_msg_get_alloc_size_req & request, rpc_
829861 return true ;
830862}
831863
864+ bool rpc_server::support_op (const std::vector<uint8_t > & input, rpc_msg_support_op_rsp & response) {
865+ // serialization format: | n_tensors (4 bytes) | tensors (n_tensors * sizeof(rpc_tensor)) |
866+ if (input.size () < sizeof (uint32_t )) {
867+ GGML_LOG_ERROR (" [%s] invalid input size\n " , __func__);
868+ return false ;
869+ }
870+ uint32_t n_tensors;
871+ memcpy (&n_tensors, input.data (), sizeof (n_tensors));
872+ if (input.size () < sizeof (uint32_t ) + n_tensors * sizeof (rpc_tensor)) {
873+ GGML_LOG_ERROR (" [%s] invalid input size\n " , __func__);
874+ return false ;
875+ }
876+ const rpc_tensor * tensors = (const rpc_tensor *)(input.data () + sizeof (uint32_t ));
877+ GGML_PRINT_DEBUG (" [%s] n_tensors: %u\n " , __func__, n_tensors);
878+
879+ size_t buf_size = ggml_tensor_overhead ()*n_tensors;
880+ struct ggml_init_params params {
881+ /* .mem_size =*/ buf_size,
882+ /* .mem_buffer =*/ NULL ,
883+ /* .no_alloc =*/ true ,
884+ };
885+ struct ggml_context * ctx = ggml_init (params);
886+ ggml_tensor * tensor = deserialize_tensor (ctx, &tensors[n_tensors-1 ]);
887+ for (uint32_t i = 0 ; i < n_tensors-1 ; i++) {
888+ ggml_tensor * src = deserialize_tensor (ctx, &tensors[i]);
889+ tensor->src [i] = src;
890+ }
891+ response.result = true ;
892+ if (backend->device ->iface .supports_op ) {
893+ response.result = backend->device ->iface .supports_op (backend->device , tensor);
894+ }
895+ ggml_free (ctx);
896+
897+ return true ;
898+ }
899+
832900void rpc_server::alloc_buffer (const rpc_msg_alloc_buffer_req & request, rpc_msg_alloc_buffer_rsp & response) {
833901 ggml_backend_buffer_type_t buft = ggml_backend_get_default_buffer_type (backend);
834902 ggml_backend_buffer_t buffer = ggml_backend_buft_alloc_buffer (buft, request.size );
@@ -1326,6 +1394,20 @@ static void rpc_serve_client(ggml_backend_t backend, sockfd_t sockfd, size_t fre
13261394 }
13271395 break ;
13281396 }
1397+ case RPC_CMD_SUPPORT_OP: {
1398+ std::vector<uint8_t > input;
1399+ if (!recv_msg (sockfd, input)) {
1400+ return ;
1401+ }
1402+ rpc_msg_support_op_rsp response;
1403+ if (!server.support_op (input, response)) {
1404+ return ;
1405+ }
1406+ if (!send_msg (sockfd, &response, sizeof (response))) {
1407+ return ;
1408+ }
1409+ break ;
1410+ }
13291411 default : {
13301412 fprintf (stderr, " Unknown command: %d\n " , cmd);
13311413 return ;
@@ -1436,10 +1518,26 @@ static ggml_backend_buffer_type_t ggml_backend_rpc_device_get_buffer_type(ggml_b
14361518}
14371519
14381520static bool ggml_backend_rpc_device_supports_op (ggml_backend_dev_t dev, const struct ggml_tensor * op) {
1521+ static std::unordered_map<std::string, std::unordered_map<std::string, bool >> caches;
1522+ ggml_backend_rpc_device_context * ctx = (ggml_backend_rpc_device_context *)dev->context ;
1523+
1524+ auto &cache = caches[ctx->endpoint ];
1525+ std::string key = op->name ;
1526+ key += std::to_string (op->type );
1527+ for (int i = 0 ; i < GGML_MAX_DIMS; i++) {
1528+ key += std::to_string (op->ne [i]);
1529+ }
1530+ key += std::to_string (op->op );
1531+
1532+ auto it = cache.find (key);
1533+ if (it != cache.end ()) {
1534+ return it->second ;
1535+ }
1536+ bool result = ggml_backend_rpc_support_op (ctx->endpoint .c_str (), op);
1537+ cache[key] = result;
1538+ return result;
1539+
14391540 GGML_UNUSED (dev);
1440- GGML_UNUSED (op);
1441- // TODO: call the remote backend and cache the results
1442- return true ;
14431541}
14441542
14451543static bool ggml_backend_rpc_device_supports_buft (ggml_backend_dev_t dev, ggml_backend_buffer_type_t buft) {
0 commit comments