Skip to content

Commit f392839

Browse files
authored
rpc : check src buffer when copying tensor (#16421)
Only dst buffer is guaranteed to be an RPC buffer. Add check for the src one.
1 parent 898acba commit f392839

File tree

1 file changed

+22
-15
lines changed

1 file changed

+22
-15
lines changed

ggml/src/ggml-rpc/ggml-rpc.cpp

Lines changed: 22 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -631,23 +631,30 @@ static void ggml_backend_rpc_buffer_get_tensor(ggml_backend_buffer_t buffer, con
631631
RPC_STATUS_ASSERT(status);
632632
}
633633

634+
static bool ggml_backend_buffer_is_rpc(ggml_backend_buffer_t buffer) {
635+
return buffer->iface.free_buffer == ggml_backend_rpc_buffer_free_buffer;
636+
}
637+
634638
static bool ggml_backend_rpc_buffer_cpy_tensor(ggml_backend_buffer_t buffer, const ggml_tensor * src, ggml_tensor * dst) {
635-
// check if src and dst are on the same server
636-
ggml_backend_buffer_t src_buffer = src->buffer;
637-
ggml_backend_rpc_buffer_context * src_ctx = (ggml_backend_rpc_buffer_context *)src_buffer->context;
638-
ggml_backend_buffer_t dst_buffer = dst->buffer;
639-
ggml_backend_rpc_buffer_context * dst_ctx = (ggml_backend_rpc_buffer_context *)dst_buffer->context;
640-
if (src_ctx->sock != dst_ctx->sock) {
641-
return false;
639+
if (ggml_backend_buffer_is_rpc(src->buffer)) {
640+
// check if src and dst are on the same server
641+
ggml_backend_buffer_t src_buffer = src->buffer;
642+
ggml_backend_rpc_buffer_context * src_ctx = (ggml_backend_rpc_buffer_context *)src_buffer->context;
643+
ggml_backend_buffer_t dst_buffer = dst->buffer;
644+
ggml_backend_rpc_buffer_context * dst_ctx = (ggml_backend_rpc_buffer_context *)dst_buffer->context;
645+
if (src_ctx->sock != dst_ctx->sock) {
646+
return false;
647+
}
648+
ggml_backend_rpc_buffer_context * ctx = (ggml_backend_rpc_buffer_context *)buffer->context;
649+
rpc_msg_copy_tensor_req request;
650+
request.src = serialize_tensor(src);
651+
request.dst = serialize_tensor(dst);
652+
rpc_msg_copy_tensor_rsp response;
653+
bool status = send_rpc_cmd(ctx->sock, RPC_CMD_COPY_TENSOR, &request, sizeof(request), &response, sizeof(response));
654+
RPC_STATUS_ASSERT(status);
655+
return response.result;
642656
}
643-
ggml_backend_rpc_buffer_context * ctx = (ggml_backend_rpc_buffer_context *)buffer->context;
644-
rpc_msg_copy_tensor_req request;
645-
request.src = serialize_tensor(src);
646-
request.dst = serialize_tensor(dst);
647-
rpc_msg_copy_tensor_rsp response;
648-
bool status = send_rpc_cmd(ctx->sock, RPC_CMD_COPY_TENSOR, &request, sizeof(request), &response, sizeof(response));
649-
RPC_STATUS_ASSERT(status);
650-
return response.result;
657+
return false;
651658
}
652659

653660
static void ggml_backend_rpc_buffer_clear(ggml_backend_buffer_t buffer, uint8_t value) {

0 commit comments

Comments
 (0)