@@ -2404,22 +2404,44 @@ static bool ggml_backend_cuda_cpy_tensor_async(ggml_backend_t backend_src, ggml_
24042404 ggml_backend_buffer_t buf_src = src->view_src ? src->view_src ->buffer : src->buffer ;
24052405 ggml_backend_buffer_t buf_dst = dst->view_src ? dst->view_src ->buffer : dst->buffer ;
24062406
2407- if (!ggml_backend_is_cuda (backend_src) || !ggml_backend_is_cuda (backend_dst)) {
2408- return false ;
2407+ bool src_is_cuda = ggml_backend_is_cuda (backend_src);
2408+ if (src_is_cuda) {
2409+ if (!ggml_backend_buffer_is_cuda (buf_src)) {
2410+ return false ;
2411+ }
2412+ }
2413+ else {
2414+ if (!ggml_backend_buffer_is_host (buf_src)) {
2415+ return false ;
2416+ }
2417+ }
2418+
2419+ bool dst_is_cuda = ggml_backend_is_cuda (backend_dst);
2420+ if (dst_is_cuda) {
2421+ if (!ggml_backend_buffer_is_cuda (buf_dst)) {
2422+ return false ;
2423+ }
2424+ }
2425+ else {
2426+ if (!ggml_backend_buffer_is_host (buf_dst)) {
2427+ return false ;
2428+ }
24092429 }
24102430
2411- if (!ggml_backend_buffer_is_cuda (src->buffer ) || !ggml_backend_buffer_is_cuda (dst->buffer )) {
2431+ // async copy supports cuda to cuda, cuda to host, and host to cuda.
2432+ if (!src_is_cuda && !dst_is_cuda) {
2433+ // ignore host to host copy
24122434 return false ;
24132435 }
24142436
24152437 // device -> device copy
2416- ggml_backend_cuda_context * cuda_ctx_src = (ggml_backend_cuda_context *)backend_src->context ;
2417- ggml_backend_cuda_context * cuda_ctx_dst = (ggml_backend_cuda_context *)backend_dst->context ;
2438+ ggml_backend_cuda_context * cuda_ctx_src = src_is_cuda ? (ggml_backend_cuda_context *)backend_src->context : nullptr ;
2439+ ggml_backend_cuda_context * cuda_ctx_dst = dst_is_cuda ? (ggml_backend_cuda_context *)backend_dst->context : nullptr ;
24182440
24192441 ggml_backend_cuda_buffer_context * buf_ctx_src = (ggml_backend_cuda_buffer_context *)buf_src->context ;
24202442 ggml_backend_cuda_buffer_context * buf_ctx_dst = (ggml_backend_cuda_buffer_context *)buf_dst->context ;
24212443
2422- if (cuda_ctx_src->device != buf_ctx_src->device || cuda_ctx_dst->device != buf_ctx_dst->device ) {
2444+ if (( cuda_ctx_src && cuda_ctx_src ->device != buf_ctx_src->device ) || ( cuda_ctx_dst && cuda_ctx_dst ->device != buf_ctx_dst->device ) ) {
24232445#ifndef NDEBUG
24242446 GGML_LOG_DEBUG (" %s: backend and buffer devices do not match\n " , __func__);
24252447#endif
@@ -2428,7 +2450,11 @@ static bool ggml_backend_cuda_cpy_tensor_async(ggml_backend_t backend_src, ggml_
24282450
24292451 if (backend_src != backend_dst) {
24302452 // copy on src stream
2431- if (cuda_ctx_src->device == cuda_ctx_dst->device ) {
2453+ if (!src_is_cuda) {
2454+ CUDA_CHECK (cudaMemcpyAsync (dst->data , src->data , ggml_nbytes (dst), cudaMemcpyHostToDevice, cuda_ctx_dst->stream ()));
2455+ } else if (!dst_is_cuda) {
2456+ CUDA_CHECK (cudaMemcpyAsync (dst->data , src->data , ggml_nbytes (dst), cudaMemcpyDeviceToHost, cuda_ctx_src->stream ()));
2457+ } else if (cuda_ctx_src->device == cuda_ctx_dst->device ) {
24322458 CUDA_CHECK (cudaMemcpyAsync (dst->data , src->data , ggml_nbytes (dst), cudaMemcpyDeviceToDevice, cuda_ctx_src->stream ()));
24332459 } else {
24342460#ifdef GGML_CUDA_NO_PEER_COPY
@@ -2438,16 +2464,14 @@ static bool ggml_backend_cuda_cpy_tensor_async(ggml_backend_t backend_src, ggml_
24382464#endif
24392465 }
24402466
2441- // record event on src stream after the copy
2442- if (!cuda_ctx_src->copy_event ) {
2443- ggml_cuda_set_device (cuda_ctx_src->device );
2444- CUDA_CHECK (cudaEventCreateWithFlags (&cuda_ctx_src->copy_event , cudaEventDisableTiming));
2445- }
2446-
2447- CUDA_CHECK (cudaEventRecord (cuda_ctx_src->copy_event , cuda_ctx_src->stream ()));
2467+ if (cuda_ctx_src) {
2468+ if (!cuda_ctx_src->copy_event ) {
2469+ ggml_cuda_set_device (cuda_ctx_src->device );
2470+ CUDA_CHECK (cudaEventCreateWithFlags (&cuda_ctx_src->copy_event , cudaEventDisableTiming));
2471+ }
24482472
2449- // wait on dst stream for the copy to complete
2450- CUDA_CHECK ( cudaStreamWaitEvent (cuda_ctx_dst-> stream (), cuda_ctx_src-> copy_event , 0 ));
2473+ CUDA_CHECK ( cudaEventRecord (cuda_ctx_src-> copy_event , cuda_ctx_src-> stream ()));
2474+ }
24512475 } else {
24522476 // src and dst are on the same backend
24532477 CUDA_CHECK (cudaMemcpyAsync (dst->data , src->data , ggml_nbytes (dst), cudaMemcpyDeviceToDevice, cuda_ctx_src->stream ()));
0 commit comments