diff --git a/crates/cust/build.rs b/crates/cust/build.rs index 77fe3b1a..1b1f8674 100644 --- a/crates/cust/build.rs +++ b/crates/cust/build.rs @@ -10,4 +10,34 @@ fn main() { if driver_version >= 12030 { println!("cargo::rustc-cfg=conditional_node"); } + // In CUDA 13.0 several pairs/trios of functions were merged: + // ``` + // CUresult cuMemAdvise(CUdeviceptr devPtr, size_t count, CUmem_advise advice, CUdevice device); + // CUresult cuMemAdvise_v2(CUdeviceptr devPtr, size_t count, CUmem_advise advice, CUmemLocation location); + // + // CUresult cuMemPrefetchAsync(CUdeviceptr devPtr, size_t count, CUdevice dstDevice, CUstream hStream); + // CUresult cuMemPrefetchAsync_v2(CUdeviceptr devPtr, size_t count, CUmemLocation location, unsigned int flags, CUstream hStream); + // + // CUresult cuGraphGetEdges(CUgraph hGraph, CUgraphNode* from, CUgraphNode* to, size_t* numEdges); + // CUresult cuGraphGetEdges_v2(CUgraph hGraph, CUgraphNode* from, CUgraphNode* to, CUgraphEdgeData* edgeData, size_t* numEdges); + // + // CUresult cuCtxCreate(CUcontext* pctx, unsigned int flags, CUdevice dev); + // CUresult cuCtxCreate_v3(CUcontext* pctx, CUexecAffinityParam* paramsArray, int numParams, unsigned int flags, CUdevice dev); + // CUresult cuCtxCreate_v4(CUcontext* pctx, CUctxCreateParams* ctxCreateParams, unsigned int flags, CUdevice dev); + // ``` + // In each case, the resulting single function has the name of the first function and the type + // signature of the last. + // + // These cfgs let you call these functions and make it work for both pre CUDA-13.0 and CUDA + // 13.0. When support for CUDA 12.x is dropped, these cfgs can be removed. + println!("cargo::rustc-check-cfg=cfg(cuMemAdvise_v2)"); + println!("cargo::rustc-check-cfg=cfg(cuMemPrefetchAsync_v2)"); + println!("cargo::rustc-check-cfg=cfg(cuGraphGetEdges_v2)"); + println!("cargo::rustc-check-cfg=cfg(cuCtxCreate_v4)"); + if driver_version >= 13000 { + println!("cargo::rustc-cfg=cuMemAdvise_v2"); + println!("cargo::rustc-cfg=cuMemPrefetchAsync_v2"); + println!("cargo::rustc-cfg=cuGraphGetEdges_v2"); + println!("cargo::rustc-cfg=cuCtxCreate_v4"); + } } diff --git a/crates/cust/src/context/legacy.rs b/crates/cust/src/context/legacy.rs index 07583741..3e39ce47 100644 --- a/crates/cust/src/context/legacy.rs +++ b/crates/cust/src/context/legacy.rs @@ -262,8 +262,14 @@ impl Context { // lifetime guarantees so we create-and-push, then pop, then the programmer has to // push again. let mut ctx: CUcontext = ptr::null_mut(); - driver_sys::cuCtxCreate(&mut ctx as *mut CUcontext, flags.bits(), device.as_raw()) - .to_result()?; + driver_sys::cuCtxCreate( + &mut ctx as *mut CUcontext, + #[cfg(cuCtxCreate_v4)] + &mut driver_sys::CUctxCreateParams::default(), + flags.bits(), + device.as_raw(), + ) + .to_result()?; Ok(Context { inner: ctx }) } } diff --git a/crates/cust/src/graph.rs b/crates/cust/src/graph.rs index b24e0963..3db77dd5 100644 --- a/crates/cust/src/graph.rs +++ b/crates/cust/src/graph.rs @@ -415,6 +415,8 @@ impl Graph { self.raw, ptr::null_mut(), ptr::null_mut(), + #[cfg(cuGraphGetEdges_v2)] + ptr::null_mut(), size.as_mut_ptr(), ) .to_result()?; @@ -439,6 +441,8 @@ impl Graph { self.raw, from.as_mut_ptr(), to.as_mut_ptr(), + #[cfg(cuGraphGetEdges_v2)] + ptr::null_mut(), &num_edges as *const _ as *mut usize, ) .to_result()?; diff --git a/crates/cust/src/memory/unified.rs b/crates/cust/src/memory/unified.rs index d67693e5..d833a62d 100644 --- a/crates/cust/src/memory/unified.rs +++ b/crates/cust/src/memory/unified.rs @@ -640,10 +640,19 @@ pub trait MemoryAdvise: private::Sealed { let mem_size = std::mem::size_of_val(slice); unsafe { + let id = -1; // -1 is CU_DEVICE_CPU driver_sys::cuMemPrefetchAsync( slice.as_ptr() as driver_sys::CUdeviceptr, mem_size, - -1, // CU_DEVICE_CPU #define + #[cfg(cuMemPrefetchAsync_v2)] + driver_sys::CUmemLocation { + type_: driver_sys::CUmemLocationType::CU_MEM_LOCATION_TYPE_DEVICE, + id, + }, + #[cfg(not(cuMemPrefetchAsync_v2))] + id, + #[cfg(cuMemPrefetchAsync_v2)] + 0, // flags for future use, must be 0 as of CUDA 13.0 stream.as_inner(), ) .to_result()?; @@ -677,10 +686,19 @@ pub trait MemoryAdvise: private::Sealed { let mem_size = std::mem::size_of_val(slice); unsafe { + let id = device.as_raw(); driver_sys::cuMemPrefetchAsync( slice.as_ptr() as driver_sys::CUdeviceptr, mem_size, - device.as_raw(), + #[cfg(cuMemPrefetchAsync_v2)] + driver_sys::CUmemLocation { + type_: driver_sys::CUmemLocationType::CU_MEM_LOCATION_TYPE_DEVICE, + id, + }, + #[cfg(not(cuMemPrefetchAsync_v2))] + id, + #[cfg(cuMemPrefetchAsync_v2)] + 0, // flags for future use, must be 0 as of CUDA 13.0 stream.as_inner(), ) .to_result()?; @@ -709,11 +727,18 @@ pub trait MemoryAdvise: private::Sealed { }; unsafe { + let id = 0; driver_sys::cuMemAdvise( slice.as_ptr() as driver_sys::CUdeviceptr, mem_size, advice, - 0, + #[cfg(cuMemAdvise_v2)] + driver_sys::CUmemLocation { + type_: driver_sys::CUmemLocationType::CU_MEM_LOCATION_TYPE_DEVICE, + id, + }, + #[cfg(not(cuMemAdvise_v2))] + id, ) .to_result()?; } @@ -744,11 +769,18 @@ pub trait MemoryAdvise: private::Sealed { let mem_size = std::mem::size_of_val(slice); unsafe { + let id = preferred_location.map(|d| d.as_raw()).unwrap_or(-1); // -1 is CU_DEVICE_CPU driver_sys::cuMemAdvise( slice.as_ptr() as driver_sys::CUdeviceptr, mem_size, driver_sys::CUmem_advise::CU_MEM_ADVISE_SET_PREFERRED_LOCATION, - preferred_location.map(|d| d.as_raw()).unwrap_or(-1), + #[cfg(cuMemAdvise_v2)] + driver_sys::CUmemLocation { + type_: driver_sys::CUmemLocationType::CU_MEM_LOCATION_TYPE_DEVICE, + id, + }, + #[cfg(not(cuMemAdvise_v2))] + id, ) .to_result()?; } @@ -761,11 +793,18 @@ pub trait MemoryAdvise: private::Sealed { let mem_size = std::mem::size_of_val(slice); unsafe { + let id = 0; driver_sys::cuMemAdvise( slice.as_ptr() as driver_sys::CUdeviceptr, mem_size, driver_sys::CUmem_advise::CU_MEM_ADVISE_UNSET_PREFERRED_LOCATION, - 0, + #[cfg(cuMemAdvise_v2)] + driver_sys::CUmemLocation { + type_: driver_sys::CUmemLocationType::CU_MEM_LOCATION_TYPE_DEVICE, + id, + }, + #[cfg(not(cuMemAdvise_v2))] + id, ) .to_result()?; } diff --git a/crates/cust_raw/build/main.rs b/crates/cust_raw/build/main.rs index d2adf6d3..c62c272a 100644 --- a/crates/cust_raw/build/main.rs +++ b/crates/cust_raw/build/main.rs @@ -2,7 +2,7 @@ //! The build script for the cust_raw generates bindings for libraries in the //! CUDA SDK. The build scripts searches for the CUDA SDK by reading the //! `CUDA_PATH`, `CUDA_ROOT`, or `CUDA_TOOLKIT_ROOT_DIR` environment variables -//! in that order. If none of these variables are set to a vaild CUDA Toolkit +//! in that order. If none of these variables are set to a valid CUDA Toolkit //! SDK path, the build script will attempt to search for any SDK in the //! default installation locations for the current platform. //!