|
20 | 20 |
|
21 | 21 | #include <cuda/stream_ref> |
22 | 22 |
|
| 23 | +#include <cstring> |
| 24 | + |
23 | 25 | namespace cuco::detail { |
24 | 26 |
|
25 | 27 | /** |
26 | | - * @brief Asynchronous memory copy utility that works around cudaMemcpyAsync bugs |
| 28 | + * @brief Asynchronous memory copy utility using cudaMemcpyBatchAsync when possible |
27 | 29 | * |
28 | | - * This function provides a drop-in replacement for cudaMemcpyAsync that uses |
29 | | - * cudaMemcpyBatchAsync internally to work around known issues with cudaMemcpyAsync |
30 | | - * when available (CUDA 12.8+). For older CUDA versions, it falls back to the |
31 | | - * original cudaMemcpyAsync. The function automatically handles the different API |
32 | | - * signatures between CUDA runtime versions. |
| 30 | + * Uses cudaMemcpyBatchAsync for CUDA 12.8+ with proper edge case handling. |
| 31 | + * Falls back to cudaMemcpyAsync for older CUDA versions or edge cases. |
33 | 32 | * |
34 | 33 | * @param dst Destination memory address |
35 | 34 | * @param src Source memory address |
36 | 35 | * @param count Number of bytes to copy |
37 | | - * @param kind Type of memory copy (cudaMemcpyHostToDevice, cudaMemcpyDeviceToHost, etc.) |
38 | | - * @param stream CUDA stream for the asynchronous operation |
| 36 | + * @param kind Memory copy direction |
| 37 | + * @param stream CUDA stream for the operation |
39 | 38 | */ |
40 | 39 | inline void memcpy_async( |
41 | | - void* dst, const void* src, size_t count, cudaMemcpyKind kind, cuda::stream_ref stream) |
| 40 | + void* dst, void const* src, size_t count, cudaMemcpyKind kind, cuda::stream_ref stream) |
42 | 41 | { |
| 42 | + if (dst == nullptr || src == nullptr || count == 0) { return; } |
| 43 | + |
43 | 44 | #if CUDART_VERSION >= 12080 |
44 | | - // CUDA 12.8+ - Use cudaMemcpyBatchAsync as a workaround for cudaMemcpyAsync bugs |
45 | | - void* dsts[1] = {dst}; |
46 | | - void* srcs[1] = {const_cast<void*>(src)}; |
47 | | - size_t sizes[1] = {count}; |
48 | | - cudaMemcpyAttributes attrs[1] = {{.srcAccessOrder = cudaMemcpySrcAccessOrderStream}}; |
49 | | - size_t attrsIdxs[1] = {0}; |
| 45 | + if (stream.get() == 0) { |
| 46 | + CUCO_CUDA_TRY(cudaMemcpyAsync(dst, src, count, kind, stream.get())); |
| 47 | + return; |
| 48 | + } |
| 49 | + |
| 50 | + void* dsts[1] = {dst}; |
| 51 | + void* srcs[1] = {const_cast<void*>(src)}; |
| 52 | + std::size_t sizes[1] = {count}; |
| 53 | + std::size_t attrs_idxs[1] = {0}; |
| 54 | + |
| 55 | + cudaMemcpyAttributes attrs[1] = {}; |
| 56 | + attrs[0].srcAccessOrder = cudaMemcpySrcAccessOrderStream; |
| 57 | + attrs[0].flags = cudaMemcpyFlagPreferOverlapWithCompute; |
50 | 58 |
|
51 | 59 | #if CUDART_VERSION >= 13000 |
52 | | - // CUDA 13.0+ API - no failIdx parameter |
53 | | - CUCO_CUDA_TRY(cudaMemcpyBatchAsync(dsts, srcs, sizes, 1, attrs, attrsIdxs, 1, stream.get())); |
| 60 | + CUCO_CUDA_TRY(cudaMemcpyBatchAsync(dsts, srcs, sizes, 1, attrs, attrs_idxs, 1, stream.get())); |
54 | 61 | #else |
55 | | - // CUDA 12.8-12.x API - requires failIdx parameter |
56 | | - size_t failIdx; |
| 62 | + std::size_t fail_idx; |
57 | 63 | CUCO_CUDA_TRY( |
58 | | - cudaMemcpyBatchAsync(dsts, srcs, sizes, 1, attrs, attrsIdxs, 1, &failIdx, stream.get())); |
59 | | -#endif |
60 | | - |
| 64 | + cudaMemcpyBatchAsync(dsts, srcs, sizes, 1, attrs, attrs_idxs, 1, &fail_idx, stream.get())); |
| 65 | +#endif // CUDART_VERSION >= 13000 |
61 | 66 | #else |
62 | | - // CUDA 12.0-12.7 - Fall back to original cudaMemcpyAsync |
63 | | - // Note: This may still have the original bugs that cudaMemcpyBatchAsync was designed to fix |
| 67 | + // CUDA < 12.8 - use regular cudaMemcpyAsync |
64 | 68 | CUCO_CUDA_TRY(cudaMemcpyAsync(dst, src, count, kind, stream.get())); |
65 | | -#endif |
| 69 | +#endif // CUDART_VERSION >= 12080 |
66 | 70 | } |
67 | | - |
68 | 71 | } // namespace cuco::detail |
0 commit comments