Skip to content

Commit 528299d

Browse files
committed
Replace cub::Max() with cuda::maximum<> in kernel reductions
1 parent 1c7f0e8 commit 528299d

File tree

1 file changed

+22
-1
lines changed

1 file changed

+22
-1
lines changed

csrc/pythonInterface.cpp

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,27 @@
55

66
#if BUILD_CUDA
77
#include <ops.cuh>
8+
#include <cuda_runtime_api.h>
9+
10+
#if CUDART_VERSION >= 13000
11+
static inline cudaError_t bnb_cudaMemPrefetchAsync(const void* ptr,
12+
size_t bytes,
13+
int device,
14+
cudaStream_t stream) {
15+
cudaMemLocation loc{};
16+
loc.type = cudaMemLocationTypeDevice;
17+
loc.id = device;
18+
// flags = 0
19+
return cudaMemPrefetchAsync(ptr, bytes, loc, 0u, stream);
20+
}
21+
#else
22+
static inline cudaError_t bnb_cudaMemPrefetchAsync(const void* ptr,
23+
size_t bytes,
24+
int device,
25+
cudaStream_t stream) {
26+
return cudaMemPrefetchAsync(ptr, bytes, device, stream);
27+
}
28+
#endif
829
#endif
930
#if BUILD_HIP
1031
#include <ops_hip.cuh>
@@ -623,7 +644,7 @@ void cprefetch(void* ptr, size_t bytes, int device) {
623644
if (hasPrefetch == 0)
624645
return;
625646

626-
CUDA_CHECK_RETURN(cudaMemPrefetchAsync(ptr, bytes, device, 0));
647+
CUDA_CHECK_RETURN(bnb_cudaMemPrefetchAsync(ptr, bytes, device, 0));
627648
CUDA_CHECK_RETURN(cudaPeekAtLastError());
628649
}
629650

0 commit comments

Comments
 (0)