1
1
/* *
2
- * Copyright 2024, XGBoost Contributors
2
+ * Copyright 2024-2025 , XGBoost Contributors
3
3
*/
4
4
#include " device_helpers.cuh" // for CurrentDevice
5
5
#include " resource.cuh"
@@ -18,14 +18,26 @@ CudaMmapResource::CudaMmapResource(StringView path, std::size_t offset, std::siz
18
18
}},
19
19
n_{length} {
20
20
auto device = dh::CurrentDevice ();
21
+ #if (CUDA_VERSION / 1000) >= 13
22
+ cudaMemLocation loc;
23
+ loc.type = cudaMemLocationTypeDevice;
24
+ loc.id = device;
25
+ #else
26
+ auto loc = device;
27
+ #endif // (CUDA_VERSION / 1000) >= 13
21
28
dh::safe_cuda (
22
- cudaMemAdvise (handle_->base_ptr , handle_->base_size , cudaMemAdviseSetReadMostly, device));
23
- dh::safe_cuda (cudaMemAdvise (handle_->base_ptr , handle_->base_size ,
24
- cudaMemAdviseSetPreferredLocation, device));
29
+ cudaMemAdvise (handle_->base_ptr , handle_->base_size , cudaMemAdviseSetReadMostly, loc));
25
30
dh::safe_cuda (
26
- cudaMemAdvise (handle_->base_ptr , handle_->base_size , cudaMemAdviseSetAccessedBy, device));
31
+ cudaMemAdvise (handle_->base_ptr , handle_->base_size , cudaMemAdviseSetPreferredLocation, loc));
32
+ dh::safe_cuda (
33
+ cudaMemAdvise (handle_->base_ptr , handle_->base_size , cudaMemAdviseSetAccessedBy, loc));
34
+ #if (CUDA_VERSION / 1000) >= 13
35
+ dh::safe_cuda (
36
+ cudaMemPrefetchAsync (handle_->base_ptr , handle_->base_size , loc, 0 , dh::DefaultStream ()));
37
+ #else
27
38
dh::safe_cuda (
28
39
cudaMemPrefetchAsync (handle_->base_ptr , handle_->base_size , device, dh::DefaultStream ()));
40
+ #endif // (CUDA_VERSION / 1000) >= 13
29
41
}
30
42
31
43
[[nodiscard]] void * CudaMmapResource::Data () {
0 commit comments