Skip to content

Commit 17aa550

Browse files
committed
Fix edge cases for cudaMemcpyBatchAsync
1 parent f034a26 commit 17aa550

File tree

1 file changed

+29
-26
lines changed

1 file changed

+29
-26
lines changed

include/cuco/detail/utility/memcpy_async.cuh

Lines changed: 29 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -20,49 +20,52 @@
2020

2121
#include <cuda/stream_ref>
2222

23+
#include <cstring>
24+
2325
namespace cuco::detail {
2426

2527
/**
26-
* @brief Asynchronous memory copy utility that works around cudaMemcpyAsync bugs
28+
* @brief Asynchronous memory copy utility using cudaMemcpyBatchAsync when possible
2729
*
28-
* This function provides a drop-in replacement for cudaMemcpyAsync that uses
29-
* cudaMemcpyBatchAsync internally to work around known issues with cudaMemcpyAsync
30-
* when available (CUDA 12.8+). For older CUDA versions, it falls back to the
31-
* original cudaMemcpyAsync. The function automatically handles the different API
32-
* signatures between CUDA runtime versions.
30+
* Uses cudaMemcpyBatchAsync for CUDA 12.8+ with proper edge case handling.
31+
* Falls back to cudaMemcpyAsync for older CUDA versions or edge cases.
3332
*
3433
* @param dst Destination memory address
3534
* @param src Source memory address
3635
* @param count Number of bytes to copy
37-
* @param kind Type of memory copy (cudaMemcpyHostToDevice, cudaMemcpyDeviceToHost, etc.)
38-
* @param stream CUDA stream for the asynchronous operation
36+
* @param kind Memory copy direction
37+
* @param stream CUDA stream for the operation
3938
*/
4039
inline void memcpy_async(
41-
void* dst, const void* src, size_t count, cudaMemcpyKind kind, cuda::stream_ref stream)
40+
void* dst, void const* src, size_t count, cudaMemcpyKind kind, cuda::stream_ref stream)
4241
{
42+
if (dst == nullptr || src == nullptr || count == 0) { return; }
43+
4344
#if CUDART_VERSION >= 12080
44-
// CUDA 12.8+ - Use cudaMemcpyBatchAsync as a workaround for cudaMemcpyAsync bugs
45-
void* dsts[1] = {dst};
46-
void* srcs[1] = {const_cast<void*>(src)};
47-
size_t sizes[1] = {count};
48-
cudaMemcpyAttributes attrs[1] = {{.srcAccessOrder = cudaMemcpySrcAccessOrderStream}};
49-
size_t attrsIdxs[1] = {0};
45+
if (stream.get() == 0) {
46+
CUCO_CUDA_TRY(cudaMemcpyAsync(dst, src, count, kind, stream.get()));
47+
return;
48+
}
49+
50+
void* dsts[1] = {dst};
51+
void* srcs[1] = {const_cast<void*>(src)};
52+
std::size_t sizes[1] = {count};
53+
std::size_t attrs_idxs[1] = {0};
54+
55+
cudaMemcpyAttributes attrs[1] = {};
56+
attrs[0].srcAccessOrder = cudaMemcpySrcAccessOrderStream;
57+
attrs[0].flags = cudaMemcpyFlagPreferOverlapWithCompute;
5058

5159
#if CUDART_VERSION >= 13000
52-
// CUDA 13.0+ API - no failIdx parameter
53-
CUCO_CUDA_TRY(cudaMemcpyBatchAsync(dsts, srcs, sizes, 1, attrs, attrsIdxs, 1, stream.get()));
60+
CUCO_CUDA_TRY(cudaMemcpyBatchAsync(dsts, srcs, sizes, 1, attrs, attrs_idxs, 1, stream.get()));
5461
#else
55-
// CUDA 12.8-12.x API - requires failIdx parameter
56-
size_t failIdx;
62+
std::size_t fail_idx;
5763
CUCO_CUDA_TRY(
58-
cudaMemcpyBatchAsync(dsts, srcs, sizes, 1, attrs, attrsIdxs, 1, &failIdx, stream.get()));
59-
#endif
60-
64+
cudaMemcpyBatchAsync(dsts, srcs, sizes, 1, attrs, attrs_idxs, 1, &fail_idx, stream.get()));
65+
#endif // CUDART_VERSION >= 13000
6166
#else
62-
// CUDA 12.0-12.7 - Fall back to original cudaMemcpyAsync
63-
// Note: This may still have the original bugs that cudaMemcpyBatchAsync was designed to fix
67+
// CUDA < 12.8 - use regular cudaMemcpyAsync
6468
CUCO_CUDA_TRY(cudaMemcpyAsync(dst, src, count, kind, stream.get()));
65-
#endif
69+
#endif // CUDART_VERSION >= 12080
6670
}
67-
6871
} // namespace cuco::detail

0 commit comments

Comments
 (0)