@@ -631,23 +631,30 @@ static void ggml_backend_rpc_buffer_get_tensor(ggml_backend_buffer_t buffer, con
631
631
RPC_STATUS_ASSERT (status);
632
632
}
633
633
634
+ static bool ggml_backend_buffer_is_rpc (ggml_backend_buffer_t buffer) {
635
+ return buffer->iface .free_buffer == ggml_backend_rpc_buffer_free_buffer;
636
+ }
637
+
634
638
static bool ggml_backend_rpc_buffer_cpy_tensor (ggml_backend_buffer_t buffer, const ggml_tensor * src, ggml_tensor * dst) {
635
- // check if src and dst are on the same server
636
- ggml_backend_buffer_t src_buffer = src->buffer ;
637
- ggml_backend_rpc_buffer_context * src_ctx = (ggml_backend_rpc_buffer_context *)src_buffer->context ;
638
- ggml_backend_buffer_t dst_buffer = dst->buffer ;
639
- ggml_backend_rpc_buffer_context * dst_ctx = (ggml_backend_rpc_buffer_context *)dst_buffer->context ;
640
- if (src_ctx->sock != dst_ctx->sock ) {
641
- return false ;
639
+ if (ggml_backend_buffer_is_rpc (src->buffer )) {
640
+ // check if src and dst are on the same server
641
+ ggml_backend_buffer_t src_buffer = src->buffer ;
642
+ ggml_backend_rpc_buffer_context * src_ctx = (ggml_backend_rpc_buffer_context *)src_buffer->context ;
643
+ ggml_backend_buffer_t dst_buffer = dst->buffer ;
644
+ ggml_backend_rpc_buffer_context * dst_ctx = (ggml_backend_rpc_buffer_context *)dst_buffer->context ;
645
+ if (src_ctx->sock != dst_ctx->sock ) {
646
+ return false ;
647
+ }
648
+ ggml_backend_rpc_buffer_context * ctx = (ggml_backend_rpc_buffer_context *)buffer->context ;
649
+ rpc_msg_copy_tensor_req request;
650
+ request.src = serialize_tensor (src);
651
+ request.dst = serialize_tensor (dst);
652
+ rpc_msg_copy_tensor_rsp response;
653
+ bool status = send_rpc_cmd (ctx->sock , RPC_CMD_COPY_TENSOR, &request, sizeof (request), &response, sizeof (response));
654
+ RPC_STATUS_ASSERT (status);
655
+ return response.result ;
642
656
}
643
- ggml_backend_rpc_buffer_context * ctx = (ggml_backend_rpc_buffer_context *)buffer->context ;
644
- rpc_msg_copy_tensor_req request;
645
- request.src = serialize_tensor (src);
646
- request.dst = serialize_tensor (dst);
647
- rpc_msg_copy_tensor_rsp response;
648
- bool status = send_rpc_cmd (ctx->sock , RPC_CMD_COPY_TENSOR, &request, sizeof (request), &response, sizeof (response));
649
- RPC_STATUS_ASSERT (status);
650
- return response.result ;
657
+ return false ;
651
658
}
652
659
653
660
static void ggml_backend_rpc_buffer_clear (ggml_backend_buffer_t buffer, uint8_t value) {
0 commit comments