diff --git a/ggml/src/ggml-rpc/ggml-rpc.cpp b/ggml/src/ggml-rpc/ggml-rpc.cpp index 1a8739e788e76..aad48d62a850c 100644 --- a/ggml/src/ggml-rpc/ggml-rpc.cpp +++ b/ggml/src/ggml-rpc/ggml-rpc.cpp @@ -631,23 +631,30 @@ static void ggml_backend_rpc_buffer_get_tensor(ggml_backend_buffer_t buffer, con RPC_STATUS_ASSERT(status); } +static bool ggml_backend_buffer_is_rpc(ggml_backend_buffer_t buffer) { + return buffer->iface.free_buffer == ggml_backend_rpc_buffer_free_buffer; +} + static bool ggml_backend_rpc_buffer_cpy_tensor(ggml_backend_buffer_t buffer, const ggml_tensor * src, ggml_tensor * dst) { - // check if src and dst are on the same server - ggml_backend_buffer_t src_buffer = src->buffer; - ggml_backend_rpc_buffer_context * src_ctx = (ggml_backend_rpc_buffer_context *)src_buffer->context; - ggml_backend_buffer_t dst_buffer = dst->buffer; - ggml_backend_rpc_buffer_context * dst_ctx = (ggml_backend_rpc_buffer_context *)dst_buffer->context; - if (src_ctx->sock != dst_ctx->sock) { - return false; + if (ggml_backend_buffer_is_rpc(src->buffer)) { + // check if src and dst are on the same server + ggml_backend_buffer_t src_buffer = src->buffer; + ggml_backend_rpc_buffer_context * src_ctx = (ggml_backend_rpc_buffer_context *)src_buffer->context; + ggml_backend_buffer_t dst_buffer = dst->buffer; + ggml_backend_rpc_buffer_context * dst_ctx = (ggml_backend_rpc_buffer_context *)dst_buffer->context; + if (src_ctx->sock != dst_ctx->sock) { + return false; + } + ggml_backend_rpc_buffer_context * ctx = (ggml_backend_rpc_buffer_context *)buffer->context; + rpc_msg_copy_tensor_req request; + request.src = serialize_tensor(src); + request.dst = serialize_tensor(dst); + rpc_msg_copy_tensor_rsp response; + bool status = send_rpc_cmd(ctx->sock, RPC_CMD_COPY_TENSOR, &request, sizeof(request), &response, sizeof(response)); + RPC_STATUS_ASSERT(status); + return response.result; } - ggml_backend_rpc_buffer_context * ctx = (ggml_backend_rpc_buffer_context *)buffer->context; - rpc_msg_copy_tensor_req request; - request.src = serialize_tensor(src); - request.dst = serialize_tensor(dst); - rpc_msg_copy_tensor_rsp response; - bool status = send_rpc_cmd(ctx->sock, RPC_CMD_COPY_TENSOR, &request, sizeof(request), &response, sizeof(response)); - RPC_STATUS_ASSERT(status); - return response.result; + return false; } static void ggml_backend_rpc_buffer_clear(ggml_backend_buffer_t buffer, uint8_t value) {