@@ -631,23 +631,30 @@ static void ggml_backend_rpc_buffer_get_tensor(ggml_backend_buffer_t buffer, con
631631 RPC_STATUS_ASSERT (status);
632632}
633633
634+ static bool ggml_backend_buffer_is_rpc (ggml_backend_buffer_t buffer) {
635+ return buffer->iface .free_buffer == ggml_backend_rpc_buffer_free_buffer;
636+ }
637+
634638static bool ggml_backend_rpc_buffer_cpy_tensor (ggml_backend_buffer_t buffer, const ggml_tensor * src, ggml_tensor * dst) {
635- // check if src and dst are on the same server
636- ggml_backend_buffer_t src_buffer = src->buffer ;
637- ggml_backend_rpc_buffer_context * src_ctx = (ggml_backend_rpc_buffer_context *)src_buffer->context ;
638- ggml_backend_buffer_t dst_buffer = dst->buffer ;
639- ggml_backend_rpc_buffer_context * dst_ctx = (ggml_backend_rpc_buffer_context *)dst_buffer->context ;
640- if (src_ctx->sock != dst_ctx->sock ) {
641- return false ;
639+ if (ggml_backend_buffer_is_rpc (src->buffer )) {
640+ // check if src and dst are on the same server
641+ ggml_backend_buffer_t src_buffer = src->buffer ;
642+ ggml_backend_rpc_buffer_context * src_ctx = (ggml_backend_rpc_buffer_context *)src_buffer->context ;
643+ ggml_backend_buffer_t dst_buffer = dst->buffer ;
644+ ggml_backend_rpc_buffer_context * dst_ctx = (ggml_backend_rpc_buffer_context *)dst_buffer->context ;
645+ if (src_ctx->sock != dst_ctx->sock ) {
646+ return false ;
647+ }
648+ ggml_backend_rpc_buffer_context * ctx = (ggml_backend_rpc_buffer_context *)buffer->context ;
649+ rpc_msg_copy_tensor_req request;
650+ request.src = serialize_tensor (src);
651+ request.dst = serialize_tensor (dst);
652+ rpc_msg_copy_tensor_rsp response;
653+ bool status = send_rpc_cmd (ctx->sock , RPC_CMD_COPY_TENSOR, &request, sizeof (request), &response, sizeof (response));
654+ RPC_STATUS_ASSERT (status);
655+ return response.result ;
642656 }
643- ggml_backend_rpc_buffer_context * ctx = (ggml_backend_rpc_buffer_context *)buffer->context ;
644- rpc_msg_copy_tensor_req request;
645- request.src = serialize_tensor (src);
646- request.dst = serialize_tensor (dst);
647- rpc_msg_copy_tensor_rsp response;
648- bool status = send_rpc_cmd (ctx->sock , RPC_CMD_COPY_TENSOR, &request, sizeof (request), &response, sizeof (response));
649- RPC_STATUS_ASSERT (status);
650- return response.result ;
657+ return false ;
651658}
652659
653660static void ggml_backend_rpc_buffer_clear (ggml_backend_buffer_t buffer, uint8_t value) {
0 commit comments