@@ -2404,22 +2404,42 @@ static bool ggml_backend_cuda_cpy_tensor_async(ggml_backend_t backend_src, ggml_
24042404 ggml_backend_buffer_t buf_src = src->view_src ? src->view_src ->buffer : src->buffer ;
24052405 ggml_backend_buffer_t buf_dst = dst->view_src ? dst->view_src ->buffer : dst->buffer ;
24062406
2407- if (!ggml_backend_is_cuda (backend_src) || !ggml_backend_is_cuda (backend_dst)) {
2408- return false ;
2407+ bool src_is_cuda = ggml_backend_is_cuda (backend_src);
2408+ if (src_is_cuda) {
2409+ if (!ggml_backend_buffer_is_cuda (buf_src)) {
2410+ return false ;
2411+ }
2412+ }
2413+
2414+ bool dst_is_cuda = ggml_backend_is_cuda (backend_dst);
2415+ if (dst_is_cuda) {
2416+ if (!ggml_backend_buffer_is_cuda (buf_dst)) {
2417+ return false ;
2418+ }
24092419 }
24102420
2411- if (!ggml_backend_buffer_is_cuda (src->buffer ) || !ggml_backend_buffer_is_cuda (dst->buffer )) {
2421+ bool src_is_host = !src_is_cuda && ggml_backend_buffer_is_host (buf_src);
2422+ bool dst_is_host = !dst_is_cuda && ggml_backend_buffer_is_host (buf_dst);
2423+
2424+ // async copy supports cuda to cuda, cuda to host, and host to cuda.
2425+ if (!src_is_cuda && !dst_is_cuda) {
2426+ return false ;
2427+ }
2428+ else if (src_is_host && !dst_is_cuda) {
2429+ return false ;
2430+ }
2431+ else if (dst_is_host && !src_is_cuda) {
24122432 return false ;
24132433 }
24142434
24152435 // device -> device copy
2416- ggml_backend_cuda_context * cuda_ctx_src = (ggml_backend_cuda_context *)backend_src->context ;
2417- ggml_backend_cuda_context * cuda_ctx_dst = (ggml_backend_cuda_context *)backend_dst->context ;
2436+ ggml_backend_cuda_context * cuda_ctx_src = src_is_cuda ? (ggml_backend_cuda_context *)backend_src->context : nullptr ;
2437+ ggml_backend_cuda_context * cuda_ctx_dst = dst_is_cuda ? (ggml_backend_cuda_context *)backend_dst->context : nullptr ;
24182438
24192439 ggml_backend_cuda_buffer_context * buf_ctx_src = (ggml_backend_cuda_buffer_context *)buf_src->context ;
24202440 ggml_backend_cuda_buffer_context * buf_ctx_dst = (ggml_backend_cuda_buffer_context *)buf_dst->context ;
24212441
2422- if (cuda_ctx_src->device != buf_ctx_src->device || cuda_ctx_dst->device != buf_ctx_dst->device ) {
2442+ if (( cuda_ctx_src && cuda_ctx_src ->device != buf_ctx_src->device ) || ( cuda_ctx_dst && cuda_ctx_dst ->device != buf_ctx_dst->device ) ) {
24232443#ifndef NDEBUG
24242444 GGML_LOG_DEBUG (" %s: backend and buffer devices do not match\n " , __func__);
24252445#endif
@@ -2428,7 +2448,11 @@ static bool ggml_backend_cuda_cpy_tensor_async(ggml_backend_t backend_src, ggml_
24282448
24292449 if (backend_src != backend_dst) {
24302450 // copy on src stream
2431- if (cuda_ctx_src->device == cuda_ctx_dst->device ) {
2451+ if (src_is_host) {
2452+ CUDA_CHECK (cudaMemcpyAsync (dst->data , src->data , ggml_nbytes (dst), cudaMemcpyHostToDevice, cuda_ctx_dst->stream ()));
2453+ } else if (dst_is_host) {
2454+ CUDA_CHECK (cudaMemcpyAsync (dst->data , src->data , ggml_nbytes (dst), cudaMemcpyDeviceToHost, cuda_ctx_src->stream ()));
2455+ } else if (cuda_ctx_src->device == cuda_ctx_dst->device ) {
24322456 CUDA_CHECK (cudaMemcpyAsync (dst->data , src->data , ggml_nbytes (dst), cudaMemcpyDeviceToDevice, cuda_ctx_src->stream ()));
24332457 } else {
24342458#ifdef GGML_CUDA_NO_PEER_COPY
@@ -2438,16 +2462,14 @@ static bool ggml_backend_cuda_cpy_tensor_async(ggml_backend_t backend_src, ggml_
24382462#endif
24392463 }
24402464
2441- // record event on src stream after the copy
2442- if (!cuda_ctx_src->copy_event ) {
2443- ggml_cuda_set_device (cuda_ctx_src->device );
2444- CUDA_CHECK (cudaEventCreateWithFlags (&cuda_ctx_src->copy_event , cudaEventDisableTiming));
2445- }
2446-
2447- CUDA_CHECK (cudaEventRecord (cuda_ctx_src->copy_event , cuda_ctx_src->stream ()));
2465+ if (cuda_ctx_src) {
2466+ if (!cuda_ctx_src->copy_event ) {
2467+ ggml_cuda_set_device (cuda_ctx_src->device );
2468+ CUDA_CHECK (cudaEventCreateWithFlags (&cuda_ctx_src->copy_event , cudaEventDisableTiming));
2469+ }
24482470
2449- // wait on dst stream for the copy to complete
2450- CUDA_CHECK ( cudaStreamWaitEvent (cuda_ctx_dst-> stream (), cuda_ctx_src-> copy_event , 0 ));
2471+ CUDA_CHECK ( cudaEventRecord (cuda_ctx_src-> copy_event , cuda_ctx_src-> stream ()));
2472+ }
24512473 } else {
24522474 // src and dst are on the same backend
24532475 CUDA_CHECK (cudaMemcpyAsync (dst->data , src->data , ggml_nbytes (dst), cudaMemcpyDeviceToDevice, cuda_ctx_src->stream ()));
0 commit comments