@@ -2358,33 +2358,35 @@ GGML_CALL static void ggml_backend_cuda_get_tensor_async(ggml_backend_t backend,
23582358}
23592359
23602360GGML_CALL static bool ggml_backend_cuda_cpy_tensor_async (ggml_backend_t backend_src, ggml_backend_t backend_dst, const ggml_tensor * src, ggml_tensor * dst) {
2361- GGML_ASSERT (ggml_backend_is_cuda (backend_src) || ggml_backend_is_cuda (backend_dst));
2362-
23632361 ggml_backend_buffer_t buf_src = src->view_src ? src->view_src ->buffer : src->buffer ;
23642362 ggml_backend_buffer_t buf_dst = dst->view_src ? dst->view_src ->buffer : dst->buffer ;
23652363
2366- if (!ggml_backend_buffer_is_cuda (src-> buffer )) {
2364+ if (!ggml_backend_is_cuda (backend_src) || ! ggml_backend_is_cuda (backend_dst )) {
23672365 return false ;
23682366 }
23692367
2370- if (!ggml_backend_buffer_is_cuda (dst->buffer )) {
2368+ if (!ggml_backend_buffer_is_cuda (src-> buffer ) || ! ggml_backend_buffer_is_cuda ( dst->buffer )) {
23712369 return false ;
23722370 }
23732371
2374- // device -> device
2372+ // device -> device copy
23752373 ggml_backend_cuda_context * cuda_ctx_src = (ggml_backend_cuda_context *)backend_src->context ;
23762374 ggml_backend_cuda_context * cuda_ctx_dst = (ggml_backend_cuda_context *)backend_dst->context ;
23772375
2378- if (backend_src != backend_dst) {
2379- ggml_backend_cuda_buffer_context * buf_ctx_src = (ggml_backend_cuda_buffer_context *)buf_src->context ;
2380- ggml_backend_cuda_buffer_context * buf_ctx_dst = (ggml_backend_cuda_buffer_context *)buf_dst->context ;
2376+ ggml_backend_cuda_buffer_context * buf_ctx_src = (ggml_backend_cuda_buffer_context *)buf_src->context ;
2377+ ggml_backend_cuda_buffer_context * buf_ctx_dst = (ggml_backend_cuda_buffer_context *)buf_dst->context ;
23812378
2382- GGML_ASSERT (cuda_ctx_src->device == buf_ctx_src->device );
2383- GGML_ASSERT (cuda_ctx_dst->device == buf_ctx_dst->device );
2379+ if (cuda_ctx_src->device != buf_ctx_src->device || cuda_ctx_dst->device != buf_ctx_dst->device ) {
2380+ #ifndef NDEBUG
2381+ GGML_CUDA_LOG_WARN (" %s: backend and buffer devices do not match\n " , __func__);
2382+ #endif
2383+ return false ;
2384+ }
23842385
2386+ if (backend_src != backend_dst) {
23852387 // copy on src stream
23862388 if (cuda_ctx_src->device == cuda_ctx_dst->device ) {
2387- CUDA_CHECK (cudaMemcpyAsync (dst->data , src->data , ggml_nbytes (dst), cudaMemcpyDeviceToDevice, cuda_ctx_dst ->stream ()));
2389+ CUDA_CHECK (cudaMemcpyAsync (dst->data , src->data , ggml_nbytes (dst), cudaMemcpyDeviceToDevice, cuda_ctx_src ->stream ()));
23882390 } else {
23892391#ifdef GGML_CUDA_NO_PEER_COPY
23902392 return false ;
@@ -2393,7 +2395,7 @@ GGML_CALL static bool ggml_backend_cuda_cpy_tensor_async(ggml_backend_t backend_
23932395#endif
23942396 }
23952397
2396- // record event on src stream
2398+ // record event on src stream after the copy
23972399 if (!cuda_ctx_src->copy_event ) {
23982400 ggml_cuda_set_device (cuda_ctx_src->device );
23992401 CUDA_CHECK (cudaEventCreateWithFlags (&cuda_ctx_src->copy_event , cudaEventDisableTiming));
@@ -2405,7 +2407,7 @@ GGML_CALL static bool ggml_backend_cuda_cpy_tensor_async(ggml_backend_t backend_
24052407 CUDA_CHECK (cudaStreamWaitEvent (cuda_ctx_dst->stream (), cuda_ctx_src->copy_event , 0 ));
24062408 } else {
24072409 // src and dst are on the same backend
2408- CUDA_CHECK (cudaMemcpyAsync (dst->data , src->data , ggml_nbytes (dst), cudaMemcpyDeviceToDevice, cuda_ctx_dst ->stream ()));
2410+ CUDA_CHECK (cudaMemcpyAsync (dst->data , src->data , ggml_nbytes (dst), cudaMemcpyDeviceToDevice, cuda_ctx_src ->stream ()));
24092411 }
24102412 return true ;
24112413}
0 commit comments