Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 22 additions & 15 deletions ggml/src/ggml-rpc/ggml-rpc.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -631,23 +631,30 @@ static void ggml_backend_rpc_buffer_get_tensor(ggml_backend_buffer_t buffer, con
RPC_STATUS_ASSERT(status);
}

static bool ggml_backend_buffer_is_rpc(ggml_backend_buffer_t buffer) {
return buffer->iface.free_buffer == ggml_backend_rpc_buffer_free_buffer;
}

static bool ggml_backend_rpc_buffer_cpy_tensor(ggml_backend_buffer_t buffer, const ggml_tensor * src, ggml_tensor * dst) {
// check if src and dst are on the same server
ggml_backend_buffer_t src_buffer = src->buffer;
ggml_backend_rpc_buffer_context * src_ctx = (ggml_backend_rpc_buffer_context *)src_buffer->context;
ggml_backend_buffer_t dst_buffer = dst->buffer;
ggml_backend_rpc_buffer_context * dst_ctx = (ggml_backend_rpc_buffer_context *)dst_buffer->context;
if (src_ctx->sock != dst_ctx->sock) {
return false;
if (ggml_backend_buffer_is_rpc(src->buffer)) {
// check if src and dst are on the same server
ggml_backend_buffer_t src_buffer = src->buffer;
ggml_backend_rpc_buffer_context * src_ctx = (ggml_backend_rpc_buffer_context *)src_buffer->context;
ggml_backend_buffer_t dst_buffer = dst->buffer;
ggml_backend_rpc_buffer_context * dst_ctx = (ggml_backend_rpc_buffer_context *)dst_buffer->context;
if (src_ctx->sock != dst_ctx->sock) {
return false;
}
ggml_backend_rpc_buffer_context * ctx = (ggml_backend_rpc_buffer_context *)buffer->context;
rpc_msg_copy_tensor_req request;
request.src = serialize_tensor(src);
request.dst = serialize_tensor(dst);
rpc_msg_copy_tensor_rsp response;
bool status = send_rpc_cmd(ctx->sock, RPC_CMD_COPY_TENSOR, &request, sizeof(request), &response, sizeof(response));
RPC_STATUS_ASSERT(status);
return response.result;
}
ggml_backend_rpc_buffer_context * ctx = (ggml_backend_rpc_buffer_context *)buffer->context;
rpc_msg_copy_tensor_req request;
request.src = serialize_tensor(src);
request.dst = serialize_tensor(dst);
rpc_msg_copy_tensor_rsp response;
bool status = send_rpc_cmd(ctx->sock, RPC_CMD_COPY_TENSOR, &request, sizeof(request), &response, sizeof(response));
RPC_STATUS_ASSERT(status);
return response.result;
return false;
}

static void ggml_backend_rpc_buffer_clear(ggml_backend_buffer_t buffer, uint8_t value) {
Expand Down
Loading