diff --git a/ggml/src/ggml-rpc/ggml-rpc.cpp b/ggml/src/ggml-rpc/ggml-rpc.cpp index 97873acc77dee..9d0d4833ae1a8 100644 --- a/ggml/src/ggml-rpc/ggml-rpc.cpp +++ b/ggml/src/ggml-rpc/ggml-rpc.cpp @@ -86,6 +86,7 @@ enum rpc_cmd { RPC_CMD_GET_DEVICE_MEMORY, RPC_CMD_INIT_TENSOR, RPC_CMD_GET_ALLOC_SIZE, + RPC_CMD_SUPPORT_OP, RPC_CMD_COUNT, }; @@ -158,6 +159,11 @@ struct rpc_msg_get_device_memory_rsp { uint64_t free_mem; uint64_t total_mem; }; + +struct rpc_msg_support_op_rsp { + uint8_t result; +}; + #pragma pack(pop) // RPC data structures @@ -438,7 +444,7 @@ static rpc_tensor serialize_tensor(const ggml_tensor * tensor) { rpc_tensor result; result.id = reinterpret_cast(tensor); result.type = tensor->type; - if (tensor->buffer) { + if (tensor->buffer && tensor->buffer->context) { ggml_backend_buffer_t buffer = tensor->buffer; ggml_backend_rpc_buffer_context * ctx = (ggml_backend_rpc_buffer_context *)buffer->context; result.buffer = ctx->remote_ptr; @@ -767,6 +773,31 @@ void ggml_backend_rpc_get_device_memory(const char * endpoint, size_t * free, si get_device_memory(sock, free, total); } +static bool ggml_backend_rpc_support_op(const char * endpoint, const ggml_tensor * tensor) { + std::vector input; + { + std::vector tensors; + for (int i = 0; i < GGML_MAX_SRC; i++) { + if (tensor->src[i] == nullptr) { + break; + } + tensors.push_back(serialize_tensor(tensor->src[i])); + } + tensors.push_back(serialize_tensor(tensor)); + // serialization format: | n_tensors (4 bytes) | tensors (n_tensors * sizeof(rpc_tensor)) | + uint32_t n_tensors = tensors.size(); + int input_size = sizeof(uint32_t) + n_tensors * sizeof(rpc_tensor); + input.resize(input_size, 0); + memcpy(input.data(), &n_tensors, sizeof(n_tensors)); + memcpy(input.data() + sizeof(n_tensors), tensors.data(), n_tensors * sizeof(rpc_tensor)); + } + rpc_msg_support_op_rsp response; + auto sock = get_socket(endpoint); + bool status = send_rpc_cmd(sock, RPC_CMD_SUPPORT_OP, input.data(), input.size(), &response, sizeof(response)); + GGML_ASSERT(status); + return response.result; +} + // RPC server-side implementation class rpc_server { @@ -786,6 +817,7 @@ class rpc_server { bool graph_compute(const std::vector & input, rpc_msg_graph_compute_rsp & response); bool init_tensor(const rpc_msg_init_tensor_req & request); bool get_alloc_size(const rpc_msg_get_alloc_size_req & request, rpc_msg_get_alloc_size_rsp & response); + bool support_op(const std::vector & input, rpc_msg_support_op_rsp & response); private: ggml_tensor * deserialize_tensor(struct ggml_context * ctx, const rpc_tensor * tensor); @@ -829,6 +861,42 @@ bool rpc_server::get_alloc_size(const rpc_msg_get_alloc_size_req & request, rpc_ return true; } +bool rpc_server::support_op(const std::vector & input, rpc_msg_support_op_rsp & response) { + // serialization format: | n_tensors (4 bytes) | tensors (n_tensors * sizeof(rpc_tensor)) | + if (input.size() < sizeof(uint32_t)) { + GGML_LOG_ERROR("[%s] invalid input size\n", __func__); + return false; + } + uint32_t n_tensors; + memcpy(&n_tensors, input.data(), sizeof(n_tensors)); + if (input.size() < sizeof(uint32_t) + n_tensors * sizeof(rpc_tensor)) { + GGML_LOG_ERROR("[%s] invalid input size\n", __func__); + return false; + } + const rpc_tensor * tensors = (const rpc_tensor *)(input.data() + sizeof(uint32_t)); + GGML_PRINT_DEBUG("[%s] n_tensors: %u\n", __func__, n_tensors); + + size_t buf_size = ggml_tensor_overhead()*n_tensors; + struct ggml_init_params params { + /*.mem_size =*/ buf_size, + /*.mem_buffer =*/ NULL, + /*.no_alloc =*/ true, + }; + struct ggml_context * ctx = ggml_init(params); + ggml_tensor * tensor = deserialize_tensor(ctx, &tensors[n_tensors-1]); + for (uint32_t i = 0; i < n_tensors-1; i++) { + ggml_tensor * src = deserialize_tensor(ctx, &tensors[i]); + tensor->src[i] = src; + } + response.result = true; + if (backend->device->iface.supports_op) { + response.result = backend->device->iface.supports_op(backend->device, tensor); + } + ggml_free(ctx); + + return true; +} + void rpc_server::alloc_buffer(const rpc_msg_alloc_buffer_req & request, rpc_msg_alloc_buffer_rsp & response) { ggml_backend_buffer_type_t buft = ggml_backend_get_default_buffer_type(backend); ggml_backend_buffer_t buffer = ggml_backend_buft_alloc_buffer(buft, request.size); @@ -1326,6 +1394,20 @@ static void rpc_serve_client(ggml_backend_t backend, sockfd_t sockfd, size_t fre } break; } + case RPC_CMD_SUPPORT_OP: { + std::vector input; + if (!recv_msg(sockfd, input)) { + return; + } + rpc_msg_support_op_rsp response; + if (!server.support_op(input, response)) { + return; + } + if (!send_msg(sockfd, &response, sizeof(response))) { + return; + } + break; + } default: { fprintf(stderr, "Unknown command: %d\n", cmd); return; @@ -1436,10 +1518,26 @@ static ggml_backend_buffer_type_t ggml_backend_rpc_device_get_buffer_type(ggml_b } static bool ggml_backend_rpc_device_supports_op(ggml_backend_dev_t dev, const struct ggml_tensor * op) { + static std::unordered_map> caches; + ggml_backend_rpc_device_context * ctx = (ggml_backend_rpc_device_context *)dev->context; + + auto &cache = caches[ctx->endpoint]; + std::string key = op->name; + key += std::to_string(op->type); + for (int i = 0; i < GGML_MAX_DIMS; i++) { + key += std::to_string(op->ne[i]); + } + key += std::to_string(op->op); + + auto it = cache.find(key); + if (it != cache.end()) { + return it->second; + } + bool result = ggml_backend_rpc_support_op(ctx->endpoint.c_str(), op); + cache[key] = result; + return result; + GGML_UNUSED(dev); - GGML_UNUSED(op); - //TODO: call the remote backend and cache the results - return true; } static bool ggml_backend_rpc_device_supports_buft(ggml_backend_dev_t dev, ggml_backend_buffer_type_t buft) {