Skip to content

Commit f034a26

Browse files
committed
Fix pre-12.8 compatibility
1 parent ee5addf commit f034a26

File tree

1 file changed

+13
-5
lines changed

1 file changed

+13
-5
lines changed

include/cuco/detail/utility/memcpy_async.cuh

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -26,9 +26,10 @@ namespace cuco::detail {
2626
* @brief Asynchronous memory copy utility that works around cudaMemcpyAsync bugs
2727
*
2828
* This function provides a drop-in replacement for cudaMemcpyAsync that uses
29-
* cudaMemcpyBatchAsync internally to work around known issues with cudaMemcpyAsync.
30-
* The function automatically handles the different API signatures between CUDA
31-
* runtime versions.
29+
* cudaMemcpyBatchAsync internally to work around known issues with cudaMemcpyAsync
30+
* when available (CUDA 12.8+). For older CUDA versions, it falls back to the
31+
* original cudaMemcpyAsync. The function automatically handles the different API
32+
* signatures between CUDA runtime versions.
3233
*
3334
* @param dst Destination memory address
3435
* @param src Source memory address
@@ -39,7 +40,8 @@ namespace cuco::detail {
3940
inline void memcpy_async(
4041
void* dst, const void* src, size_t count, cudaMemcpyKind kind, cuda::stream_ref stream)
4142
{
42-
// Use cudaMemcpyBatchAsync as a workaround for cudaMemcpyAsync bugs
43+
#if CUDART_VERSION >= 12080
44+
// CUDA 12.8+ - Use cudaMemcpyBatchAsync as a workaround for cudaMemcpyAsync bugs
4345
void* dsts[1] = {dst};
4446
void* srcs[1] = {const_cast<void*>(src)};
4547
size_t sizes[1] = {count};
@@ -50,11 +52,17 @@ inline void memcpy_async(
5052
// CUDA 13.0+ API - no failIdx parameter
5153
CUCO_CUDA_TRY(cudaMemcpyBatchAsync(dsts, srcs, sizes, 1, attrs, attrsIdxs, 1, stream.get()));
5254
#else
53-
// CUDA 12.x API - requires failIdx parameter
55+
// CUDA 12.8-12.x API - requires failIdx parameter
5456
size_t failIdx;
5557
CUCO_CUDA_TRY(
5658
cudaMemcpyBatchAsync(dsts, srcs, sizes, 1, attrs, attrsIdxs, 1, &failIdx, stream.get()));
5759
#endif
60+
61+
#else
62+
// CUDA 12.0-12.7 - Fall back to original cudaMemcpyAsync
63+
// Note: This may still have the original bugs that cudaMemcpyBatchAsync was designed to fix
64+
CUCO_CUDA_TRY(cudaMemcpyAsync(dst, src, count, kind, stream.get()));
65+
#endif
5866
}
5967

6068
} // namespace cuco::detail

0 commit comments

Comments
 (0)