From 7966d054d2ed3445228a80e04c82ffaec5f222c4 Mon Sep 17 00:00:00 2001 From: Koushik Dutta Date: Mon, 26 May 2025 20:50:52 -0700 Subject: [PATCH] ggml: improve ggml_backend_cuda_cpy_tensor_async Make device to device actually async; right now it syncs on dst. Implement host to/from device async. --- ggml/src/ggml-backend.cpp | 19 +++++++++-- ggml/src/ggml-cuda/ggml-cuda.cu | 56 +++++++++++++++++++++++---------- 2 files changed, 56 insertions(+), 19 deletions(-) diff --git a/ggml/src/ggml-backend.cpp b/ggml/src/ggml-backend.cpp index 1f40f10e87622..bfaeb1131351f 100644 --- a/ggml/src/ggml-backend.cpp +++ b/ggml/src/ggml-backend.cpp @@ -1361,6 +1361,11 @@ static enum ggml_status ggml_backend_sched_compute_splits(ggml_backend_sched_t s struct ggml_backend_sched_split * split = &splits[i]; int split_backend_id = split->backend_id; ggml_backend_t split_backend = sched->backends[split_backend_id]; + bool needs_synchronize[GGML_SCHED_MAX_BACKENDS] = { false }; + auto queue_synchronize = [&](ggml_backend_t backend) { + auto backend_id = ggml_backend_sched_backend_id(sched, backend); + needs_synchronize[backend_id] = true; + }; // copy the input tensors to the split backend for (int j = 0; j < split->n_inputs; j++) { @@ -1383,9 +1388,10 @@ static enum ggml_status ggml_backend_sched_compute_splits(ggml_backend_sched_t s } else { ggml_backend_synchronize(split_backend); } - // try async copy, but if not possible, we can still use a sync copy without synchronizing the dst backend, since we handle the synchronization here with multiple copies and events - // TODO: add public function to facilitate this, since applications do not have direct access to the backend interface - if (!split_backend->iface.cpy_tensor_async || !split_backend->iface.cpy_tensor_async(input_backend, split_backend, input, input_cpy)) { + if (split_backend->iface.cpy_tensor_async && split_backend->iface.cpy_tensor_async(input_backend, split_backend, input, input_cpy)) { + // async tensor copy occurs on the source stream, queue up a synchronize after all the copies are done to ensure all inputs are ready + queue_synchronize(input_backend); + } else { ggml_backend_synchronize(input_backend); if (sched->events[split_backend_id][sched->cur_copy] != NULL) { ggml_backend_event_synchronize(sched->events[split_backend_id][sched->cur_copy]); @@ -1397,6 +1403,13 @@ static enum ggml_status ggml_backend_sched_compute_splits(ggml_backend_sched_t s } } + for (int i = 0; i < GGML_SCHED_MAX_BACKENDS; i++) { + if (needs_synchronize[i]) { + auto backend = sched->backends[i]; + ggml_backend_synchronize(backend); + } + } + if (!sched->callback_eval) { enum ggml_status ec = ggml_backend_graph_compute_async(split_backend, &split->graph); if (ec != GGML_STATUS_SUCCESS) { diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu index c442a64924303..08dbb64aa4d27 100644 --- a/ggml/src/ggml-cuda/ggml-cuda.cu +++ b/ggml/src/ggml-cuda/ggml-cuda.cu @@ -2404,22 +2404,44 @@ static bool ggml_backend_cuda_cpy_tensor_async(ggml_backend_t backend_src, ggml_ ggml_backend_buffer_t buf_src = src->view_src ? src->view_src->buffer : src->buffer; ggml_backend_buffer_t buf_dst = dst->view_src ? dst->view_src->buffer : dst->buffer; - if (!ggml_backend_is_cuda(backend_src) || !ggml_backend_is_cuda(backend_dst)) { - return false; + bool src_is_cuda = ggml_backend_is_cuda(backend_src); + if (src_is_cuda) { + if (!ggml_backend_buffer_is_cuda(buf_src)) { + return false; + } + } + else { + if (!ggml_backend_buffer_is_host(buf_src)) { + return false; + } + } + + bool dst_is_cuda = ggml_backend_is_cuda(backend_dst); + if (dst_is_cuda) { + if (!ggml_backend_buffer_is_cuda(buf_dst)) { + return false; + } + } + else { + if (!ggml_backend_buffer_is_host(buf_dst)) { + return false; + } } - if (!ggml_backend_buffer_is_cuda(src->buffer) || !ggml_backend_buffer_is_cuda(dst->buffer)) { + // async copy supports cuda to cuda, cuda to host, and host to cuda. + if (!src_is_cuda && !dst_is_cuda) { + // ignore host to host copy return false; } // device -> device copy - ggml_backend_cuda_context * cuda_ctx_src = (ggml_backend_cuda_context *)backend_src->context; - ggml_backend_cuda_context * cuda_ctx_dst = (ggml_backend_cuda_context *)backend_dst->context; + ggml_backend_cuda_context * cuda_ctx_src = src_is_cuda ? (ggml_backend_cuda_context *)backend_src->context : nullptr; + ggml_backend_cuda_context * cuda_ctx_dst = dst_is_cuda ? (ggml_backend_cuda_context *)backend_dst->context : nullptr; ggml_backend_cuda_buffer_context * buf_ctx_src = (ggml_backend_cuda_buffer_context *)buf_src->context; ggml_backend_cuda_buffer_context * buf_ctx_dst = (ggml_backend_cuda_buffer_context *)buf_dst->context; - if (cuda_ctx_src->device != buf_ctx_src->device || cuda_ctx_dst->device != buf_ctx_dst->device) { + if ((cuda_ctx_src && cuda_ctx_src->device != buf_ctx_src->device) || (cuda_ctx_dst && cuda_ctx_dst->device != buf_ctx_dst->device)) { #ifndef NDEBUG GGML_LOG_DEBUG("%s: backend and buffer devices do not match\n", __func__); #endif @@ -2428,7 +2450,11 @@ static bool ggml_backend_cuda_cpy_tensor_async(ggml_backend_t backend_src, ggml_ if (backend_src != backend_dst) { // copy on src stream - if (cuda_ctx_src->device == cuda_ctx_dst->device) { + if (!src_is_cuda) { + CUDA_CHECK(cudaMemcpyAsync(dst->data, src->data, ggml_nbytes(dst), cudaMemcpyHostToDevice, cuda_ctx_dst->stream())); + } else if (!dst_is_cuda) { + CUDA_CHECK(cudaMemcpyAsync(dst->data, src->data, ggml_nbytes(dst), cudaMemcpyDeviceToHost, cuda_ctx_src->stream())); + } else if (cuda_ctx_src->device == cuda_ctx_dst->device) { CUDA_CHECK(cudaMemcpyAsync(dst->data, src->data, ggml_nbytes(dst), cudaMemcpyDeviceToDevice, cuda_ctx_src->stream())); } else { #ifdef GGML_CUDA_NO_PEER_COPY @@ -2438,16 +2464,14 @@ static bool ggml_backend_cuda_cpy_tensor_async(ggml_backend_t backend_src, ggml_ #endif } - // record event on src stream after the copy - if (!cuda_ctx_src->copy_event) { - ggml_cuda_set_device(cuda_ctx_src->device); - CUDA_CHECK(cudaEventCreateWithFlags(&cuda_ctx_src->copy_event, cudaEventDisableTiming)); - } - - CUDA_CHECK(cudaEventRecord(cuda_ctx_src->copy_event, cuda_ctx_src->stream())); + if (cuda_ctx_src) { + if (!cuda_ctx_src->copy_event) { + ggml_cuda_set_device(cuda_ctx_src->device); + CUDA_CHECK(cudaEventCreateWithFlags(&cuda_ctx_src->copy_event, cudaEventDisableTiming)); + } - // wait on dst stream for the copy to complete - CUDA_CHECK(cudaStreamWaitEvent(cuda_ctx_dst->stream(), cuda_ctx_src->copy_event, 0)); + CUDA_CHECK(cudaEventRecord(cuda_ctx_src->copy_event, cuda_ctx_src->stream())); + } } else { // src and dst are on the same backend CUDA_CHECK(cudaMemcpyAsync(dst->data, src->data, ggml_nbytes(dst), cudaMemcpyDeviceToDevice, cuda_ctx_src->stream()));