Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 30 additions & 0 deletions crates/cust/build.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,4 +10,34 @@ fn main() {
if driver_version >= 12030 {
println!("cargo::rustc-cfg=conditional_node");
}
// In CUDA 13.0 several pairs/trios of functions were merged:
// ```
// CUresult cuMemAdvise(CUdeviceptr devPtr, size_t count, CUmem_advise advice, CUdevice device);
// CUresult cuMemAdvise_v2(CUdeviceptr devPtr, size_t count, CUmem_advise advice, CUmemLocation location);
//
// CUresult cuMemPrefetchAsync(CUdeviceptr devPtr, size_t count, CUdevice dstDevice, CUstream hStream);
// CUresult cuMemPrefetchAsync_v2(CUdeviceptr devPtr, size_t count, CUmemLocation location, unsigned int flags, CUstream hStream);
//
// CUresult cuGraphGetEdges(CUgraph hGraph, CUgraphNode* from, CUgraphNode* to, size_t* numEdges);
// CUresult cuGraphGetEdges_v2(CUgraph hGraph, CUgraphNode* from, CUgraphNode* to, CUgraphEdgeData* edgeData, size_t* numEdges);
//
// CUresult cuCtxCreate(CUcontext* pctx, unsigned int flags, CUdevice dev);
// CUresult cuCtxCreate_v3(CUcontext* pctx, CUexecAffinityParam* paramsArray, int numParams, unsigned int flags, CUdevice dev);
// CUresult cuCtxCreate_v4(CUcontext* pctx, CUctxCreateParams* ctxCreateParams, unsigned int flags, CUdevice dev);
// ```
// In each case, the resulting single function has the name of the first function and the type
// signature of the last.
//
// These cfgs let you call these functions and make it work for both pre CUDA-13.0 and CUDA
// 13.0. When support for CUDA 12.x is dropped, these cfgs can be removed.
println!("cargo::rustc-check-cfg=cfg(cuMemAdvise_v2)");
println!("cargo::rustc-check-cfg=cfg(cuMemPrefetchAsync_v2)");
println!("cargo::rustc-check-cfg=cfg(cuGraphGetEdges_v2)");
println!("cargo::rustc-check-cfg=cfg(cuCtxCreate_v4)");
if driver_version >= 13000 {
println!("cargo::rustc-cfg=cuMemAdvise_v2");
println!("cargo::rustc-cfg=cuMemPrefetchAsync_v2");
println!("cargo::rustc-cfg=cuGraphGetEdges_v2");
println!("cargo::rustc-cfg=cuCtxCreate_v4");
}
}
10 changes: 8 additions & 2 deletions crates/cust/src/context/legacy.rs
Original file line number Diff line number Diff line change
Expand Up @@ -262,8 +262,14 @@ impl Context {
// lifetime guarantees so we create-and-push, then pop, then the programmer has to
// push again.
let mut ctx: CUcontext = ptr::null_mut();
driver_sys::cuCtxCreate(&mut ctx as *mut CUcontext, flags.bits(), device.as_raw())
.to_result()?;
driver_sys::cuCtxCreate(
&mut ctx as *mut CUcontext,
#[cfg(cuCtxCreate_v4)]
&mut driver_sys::CUctxCreateParams::default(),
flags.bits(),
device.as_raw(),
)
.to_result()?;
Ok(Context { inner: ctx })
}
}
Expand Down
4 changes: 4 additions & 0 deletions crates/cust/src/graph.rs
Original file line number Diff line number Diff line change
Expand Up @@ -415,6 +415,8 @@ impl Graph {
self.raw,
ptr::null_mut(),
ptr::null_mut(),
#[cfg(cuGraphGetEdges_v2)]
ptr::null_mut(),
size.as_mut_ptr(),
)
.to_result()?;
Expand All @@ -439,6 +441,8 @@ impl Graph {
self.raw,
from.as_mut_ptr(),
to.as_mut_ptr(),
#[cfg(cuGraphGetEdges_v2)]
ptr::null_mut(),
&num_edges as *const _ as *mut usize,
)
.to_result()?;
Expand Down
49 changes: 44 additions & 5 deletions crates/cust/src/memory/unified.rs
Original file line number Diff line number Diff line change
Expand Up @@ -640,10 +640,19 @@ pub trait MemoryAdvise<T: DeviceCopy>: private::Sealed {
let mem_size = std::mem::size_of_val(slice);

unsafe {
let id = -1; // -1 is CU_DEVICE_CPU
driver_sys::cuMemPrefetchAsync(
slice.as_ptr() as driver_sys::CUdeviceptr,
mem_size,
-1, // CU_DEVICE_CPU #define
#[cfg(cuMemPrefetchAsync_v2)]
driver_sys::CUmemLocation {
type_: driver_sys::CUmemLocationType::CU_MEM_LOCATION_TYPE_DEVICE,
id,
},
#[cfg(not(cuMemPrefetchAsync_v2))]
id,
#[cfg(cuMemPrefetchAsync_v2)]
0, // flags for future use, must be 0 as of CUDA 13.0
stream.as_inner(),
)
.to_result()?;
Expand Down Expand Up @@ -677,10 +686,19 @@ pub trait MemoryAdvise<T: DeviceCopy>: private::Sealed {
let mem_size = std::mem::size_of_val(slice);

unsafe {
let id = device.as_raw();
driver_sys::cuMemPrefetchAsync(
slice.as_ptr() as driver_sys::CUdeviceptr,
mem_size,
device.as_raw(),
#[cfg(cuMemPrefetchAsync_v2)]
driver_sys::CUmemLocation {
type_: driver_sys::CUmemLocationType::CU_MEM_LOCATION_TYPE_DEVICE,
id,
},
#[cfg(not(cuMemPrefetchAsync_v2))]
id,
#[cfg(cuMemPrefetchAsync_v2)]
0, // flags for future use, must be 0 as of CUDA 13.0
stream.as_inner(),
)
.to_result()?;
Expand Down Expand Up @@ -709,11 +727,18 @@ pub trait MemoryAdvise<T: DeviceCopy>: private::Sealed {
};

unsafe {
let id = 0;
driver_sys::cuMemAdvise(
slice.as_ptr() as driver_sys::CUdeviceptr,
mem_size,
advice,
0,
#[cfg(cuMemAdvise_v2)]
driver_sys::CUmemLocation {
type_: driver_sys::CUmemLocationType::CU_MEM_LOCATION_TYPE_DEVICE,
id,
},
#[cfg(not(cuMemAdvise_v2))]
id,
)
.to_result()?;
}
Expand Down Expand Up @@ -744,11 +769,18 @@ pub trait MemoryAdvise<T: DeviceCopy>: private::Sealed {
let mem_size = std::mem::size_of_val(slice);

unsafe {
let id = preferred_location.map(|d| d.as_raw()).unwrap_or(-1); // -1 is CU_DEVICE_CPU
driver_sys::cuMemAdvise(
slice.as_ptr() as driver_sys::CUdeviceptr,
mem_size,
driver_sys::CUmem_advise::CU_MEM_ADVISE_SET_PREFERRED_LOCATION,
preferred_location.map(|d| d.as_raw()).unwrap_or(-1),
#[cfg(cuMemAdvise_v2)]
driver_sys::CUmemLocation {
type_: driver_sys::CUmemLocationType::CU_MEM_LOCATION_TYPE_DEVICE,
id,
},
#[cfg(not(cuMemAdvise_v2))]
id,
)
.to_result()?;
}
Expand All @@ -761,11 +793,18 @@ pub trait MemoryAdvise<T: DeviceCopy>: private::Sealed {
let mem_size = std::mem::size_of_val(slice);

unsafe {
let id = 0;
driver_sys::cuMemAdvise(
slice.as_ptr() as driver_sys::CUdeviceptr,
mem_size,
driver_sys::CUmem_advise::CU_MEM_ADVISE_UNSET_PREFERRED_LOCATION,
0,
#[cfg(cuMemAdvise_v2)]
driver_sys::CUmemLocation {
type_: driver_sys::CUmemLocationType::CU_MEM_LOCATION_TYPE_DEVICE,
id,
},
#[cfg(not(cuMemAdvise_v2))]
id,
)
.to_result()?;
}
Expand Down
2 changes: 1 addition & 1 deletion crates/cust_raw/build/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
//! The build script for the cust_raw generates bindings for libraries in the
//! CUDA SDK. The build scripts searches for the CUDA SDK by reading the
//! `CUDA_PATH`, `CUDA_ROOT`, or `CUDA_TOOLKIT_ROOT_DIR` environment variables
//! in that order. If none of these variables are set to a vaild CUDA Toolkit
//! in that order. If none of these variables are set to a valid CUDA Toolkit
//! SDK path, the build script will attempt to search for any SDK in the
//! default installation locations for the current platform.
//!
Expand Down
Loading