Skip to content

Commit 5225eaa

Browse files
committed
ggml: improve ggml_backend_cuda_cpy_tensor_async
Make device to device actually async; right now it syncs on dst. Implement host to device async.
1 parent cdf94a1 commit 5225eaa

File tree

2 files changed

+44
-16
lines changed

2 files changed

+44
-16
lines changed

ggml/src/ggml-backend.cpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1394,6 +1394,12 @@ static enum ggml_status ggml_backend_sched_compute_splits(ggml_backend_sched_t s
13941394
}
13951395
ggml_backend_tensor_copy(input, input_cpy);
13961396
}
1397+
else {
1398+
if (input_backend->iface.synchronize) {
1399+
// async copy succeeded, need to synchronize the input backend to ensure the copy is done before the split backend uses it
1400+
input_backend->iface.synchronize(input_backend);
1401+
}
1402+
}
13971403
}
13981404
}
13991405

ggml/src/ggml-cuda/ggml-cuda.cu

Lines changed: 38 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -2404,22 +2404,42 @@ static bool ggml_backend_cuda_cpy_tensor_async(ggml_backend_t backend_src, ggml_
24042404
ggml_backend_buffer_t buf_src = src->view_src ? src->view_src->buffer : src->buffer;
24052405
ggml_backend_buffer_t buf_dst = dst->view_src ? dst->view_src->buffer : dst->buffer;
24062406

2407-
if (!ggml_backend_is_cuda(backend_src) || !ggml_backend_is_cuda(backend_dst)) {
2408-
return false;
2407+
bool src_is_cuda = ggml_backend_is_cuda(backend_src);
2408+
if (src_is_cuda) {
2409+
if (!ggml_backend_buffer_is_cuda(buf_src)) {
2410+
return false;
2411+
}
2412+
}
2413+
2414+
bool dst_is_cuda = ggml_backend_is_cuda(backend_dst);
2415+
if (dst_is_cuda) {
2416+
if (!ggml_backend_buffer_is_cuda(buf_dst)) {
2417+
return false;
2418+
}
24092419
}
24102420

2411-
if (!ggml_backend_buffer_is_cuda(src->buffer) || !ggml_backend_buffer_is_cuda(dst->buffer)) {
2421+
bool src_is_host = !src_is_cuda && ggml_backend_buffer_is_host(buf_src);
2422+
bool dst_is_host = !dst_is_cuda && ggml_backend_buffer_is_host(buf_dst);
2423+
2424+
// async copy supports cuda to cuda, cuda to host, and host to cuda.
2425+
if (!src_is_cuda && !dst_is_cuda) {
2426+
return false;
2427+
}
2428+
else if (src_is_host && !dst_is_cuda) {
2429+
return false;
2430+
}
2431+
else if (dst_is_host && !src_is_cuda) {
24122432
return false;
24132433
}
24142434

24152435
// device -> device copy
2416-
ggml_backend_cuda_context * cuda_ctx_src = (ggml_backend_cuda_context *)backend_src->context;
2417-
ggml_backend_cuda_context * cuda_ctx_dst = (ggml_backend_cuda_context *)backend_dst->context;
2436+
ggml_backend_cuda_context * cuda_ctx_src = src_is_cuda ? (ggml_backend_cuda_context *)backend_src->context : nullptr;
2437+
ggml_backend_cuda_context * cuda_ctx_dst = dst_is_cuda ? (ggml_backend_cuda_context *)backend_dst->context : nullptr;
24182438

24192439
ggml_backend_cuda_buffer_context * buf_ctx_src = (ggml_backend_cuda_buffer_context *)buf_src->context;
24202440
ggml_backend_cuda_buffer_context * buf_ctx_dst = (ggml_backend_cuda_buffer_context *)buf_dst->context;
24212441

2422-
if (cuda_ctx_src->device != buf_ctx_src->device || cuda_ctx_dst->device != buf_ctx_dst->device) {
2442+
if ((cuda_ctx_src && cuda_ctx_src->device != buf_ctx_src->device) || (cuda_ctx_dst && cuda_ctx_dst->device != buf_ctx_dst->device)) {
24232443
#ifndef NDEBUG
24242444
GGML_LOG_DEBUG("%s: backend and buffer devices do not match\n", __func__);
24252445
#endif
@@ -2428,7 +2448,11 @@ static bool ggml_backend_cuda_cpy_tensor_async(ggml_backend_t backend_src, ggml_
24282448

24292449
if (backend_src != backend_dst) {
24302450
// copy on src stream
2431-
if (cuda_ctx_src->device == cuda_ctx_dst->device) {
2451+
if (src_is_host) {
2452+
CUDA_CHECK(cudaMemcpyAsync(dst->data, src->data, ggml_nbytes(dst), cudaMemcpyHostToDevice, cuda_ctx_dst->stream()));
2453+
} else if (dst_is_host) {
2454+
CUDA_CHECK(cudaMemcpyAsync(dst->data, src->data, ggml_nbytes(dst), cudaMemcpyDeviceToHost, cuda_ctx_src->stream()));
2455+
} else if (cuda_ctx_src->device == cuda_ctx_dst->device) {
24322456
CUDA_CHECK(cudaMemcpyAsync(dst->data, src->data, ggml_nbytes(dst), cudaMemcpyDeviceToDevice, cuda_ctx_src->stream()));
24332457
} else {
24342458
#ifdef GGML_CUDA_NO_PEER_COPY
@@ -2438,16 +2462,14 @@ static bool ggml_backend_cuda_cpy_tensor_async(ggml_backend_t backend_src, ggml_
24382462
#endif
24392463
}
24402464

2441-
// record event on src stream after the copy
2442-
if (!cuda_ctx_src->copy_event) {
2443-
ggml_cuda_set_device(cuda_ctx_src->device);
2444-
CUDA_CHECK(cudaEventCreateWithFlags(&cuda_ctx_src->copy_event, cudaEventDisableTiming));
2445-
}
2446-
2447-
CUDA_CHECK(cudaEventRecord(cuda_ctx_src->copy_event, cuda_ctx_src->stream()));
2465+
if (cuda_ctx_src) {
2466+
if (!cuda_ctx_src->copy_event) {
2467+
ggml_cuda_set_device(cuda_ctx_src->device);
2468+
CUDA_CHECK(cudaEventCreateWithFlags(&cuda_ctx_src->copy_event, cudaEventDisableTiming));
2469+
}
24482470

2449-
// wait on dst stream for the copy to complete
2450-
CUDA_CHECK(cudaStreamWaitEvent(cuda_ctx_dst->stream(), cuda_ctx_src->copy_event, 0));
2471+
CUDA_CHECK(cudaEventRecord(cuda_ctx_src->copy_event, cuda_ctx_src->stream()));
2472+
}
24512473
} else {
24522474
// src and dst are on the same backend
24532475
CUDA_CHECK(cudaMemcpyAsync(dst->data, src->data, ggml_nbytes(dst), cudaMemcpyDeviceToDevice, cuda_ctx_src->stream()));

0 commit comments

Comments
 (0)