Skip to content

Commit 3dd86d9

Browse files
committed
Feat: finalize module changes
1 parent 0c700db commit 3dd86d9

File tree

6 files changed

+105
-85
lines changed

6 files changed

+105
-85
lines changed

crates/cust/CHANGELOG.md

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -74,9 +74,9 @@ Instead you can now use `DeviceSlice::index` which behaves the same.
7474
- Added `ModuleJitOption`, `JitFallback`, `JitTarget`, and `OptLevel` for specifying options when loading a module. Note that
7575
`ModuleJitOption::MaxRegisters` does not seem to work currently, but NVIDIA is looking into it.
7676
You can achieve the same goal by compiling the ptx to cubin using nvcc then loading that: `nvcc --cubin foo.ptx -maxrregcount=REGS`
77-
- Added `Module::from_fatbin` and `Module::from_fatbin_unchecked`.
78-
- Added `Module::from_cubin` and `Module::from_cubin_unchecked`.
79-
- Added `Module::from_ptr` and `Module::from_ptx_cstr`.
77+
- Added `Module::from_fatbin`.
78+
- Added `Module::from_cubin`.
79+
- Added `Module::from_ptx` and `Module::from_ptx_cstr`.
8080
- `Stream`, `Module`, `Linker`, `Function`, `Event`, `UnifiedBox`, `ArrayObject`, `LockedBuffer`, `LockedBox`, `DeviceSlice`, `DeviceBuffer`, and `DeviceBox` all now impl `Send` and `Sync`, this makes
8181
it much easier to write multigpu code. The CUDA API is fully thread-safe except for graph objects.
8282

@@ -98,6 +98,7 @@ it much easier to write multigpu code. The CUDA API is fully thread-safe except
9898
- `DeviceSlice::as_ptr` and `DeviceSlice::as_ptr_mut` now both return a `DevicePointer<T>`.
9999
- `DeviceSlice` is now `Clone` and `Copy`.
100100
- `DevicePointer::as_raw` now returns a `CUdeviceptr`, not a `*const T` (use `DevicePointer::as_ptr`).
101+
- Fixed typo in `CudaError`, `InvalidSouce` is now `InvalidSource`, no more invalid sauce 🍅🥣
101102

102103
## 0.2.2 - 12/5/21
103104

crates/cust/resources/add.cubin

704 Bytes
Binary file not shown.

crates/cust/resources/add.fatbin

704 Bytes
Binary file not shown.

crates/cust/src/error.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ pub enum CudaError {
5252
InvalidPtx = 218,
5353
InvalidGraphicsContext = 219,
5454
NvlinkUncorrectable = 220,
55-
InvalidSouce = 300,
55+
InvalidSource = 300,
5656
FileNotFound = 301,
5757
SharedObjectSymbolNotFound = 302,
5858
SharedObjectInitFailed = 303,
@@ -165,7 +165,7 @@ impl ToResult for cudaError_enum {
165165
Err(CudaError::InvalidGraphicsContext)
166166
}
167167
cudaError_enum::CUDA_ERROR_NVLINK_UNCORRECTABLE => Err(CudaError::NvlinkUncorrectable),
168-
cudaError_enum::CUDA_ERROR_INVALID_SOURCE => Err(CudaError::InvalidSouce),
168+
cudaError_enum::CUDA_ERROR_INVALID_SOURCE => Err(CudaError::InvalidSource),
169169
cudaError_enum::CUDA_ERROR_FILE_NOT_FOUND => Err(CudaError::FileNotFound),
170170
cudaError_enum::CUDA_ERROR_SHARED_OBJECT_SYMBOL_NOT_FOUND => {
171171
Err(CudaError::SharedObjectSymbolNotFound)

crates/cust/src/lib.rs

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,6 @@ mod surface;
7575
mod texture;
7676
pub mod util;
7777

78-
pub use cust_core;
7978
pub use cust_derive::DeviceCopy;
8079
pub use cust_raw as sys;
8180

crates/cust/src/module.rs

Lines changed: 99 additions & 79 deletions
Original file line numberDiff line numberDiff line change
@@ -178,91 +178,111 @@ impl Module {
178178
}
179179
}
180180

181-
// TODO(RDambrosio016): figure out why the heck cuda rejects cubins literally made by nvcc and loaded by fs::read
182-
183-
// /// Creates a new module by loading a fatbin (fat binary) file.
184-
// ///
185-
// /// Fatbinary files are files that contain multiple ptx or cubin files. The driver will choose already-built
186-
// /// cubin if it is present, and otherwise JIT compile any PTX in the file to cubin.
187-
// ///
188-
// /// # Example
189-
// ///
190-
// /// ```
191-
// /// # use cust::*;
192-
// /// # use std::error::Error;
193-
// /// # fn main() -> Result<(), Box<dyn Error>> {
194-
// /// # let _ctx = quick_init()?;
195-
// /// use cust::module::Module;
196-
// /// let fatbin_bytes = std::fs::read("./resources/add.cubin")?;
197-
// /// assert!(fatbin_bytes.contains(&0));
198-
// /// let module = Module::from_cubin(&fatbin_bytes, &[])?;
199-
// /// # Ok(())
200-
// /// # }
201-
// /// ```
202-
// pub fn from_fatbin<T: AsRef<[u8]>>(
203-
// bytes: T,
204-
// options: &[ModuleJitOption],
205-
// ) -> CudaResult<Module> {
206-
// let mut bytes = bytes.as_ref().to_vec();
207-
// bytes.push(0);
208-
// // fatbins are just ELF files like cubins, and cuModuleLoadDataEx accepts ptx, cubin, and fatbin.
209-
// // We just make the distinction in case we want to do anything extra in the future. As well
210-
// // as keep things explicit to anyone reading the code.
211-
// Self::from_cubin(bytes, options)
212-
// }
213-
214-
// pub unsafe fn from_fatbin_unchecked<T: AsRef<[u8]>>(
215-
// bytes: T,
216-
// options: &[ModuleJitOption],
217-
// ) -> CudaResult<Module> {
218-
// Self::from_cubin_unchecked(bytes, options)
219-
// }
181+
/// Creates a new module by loading a fatbin (fat binary) file.
182+
///
183+
/// Fatbinary files are files that contain multiple ptx or cubin files. The driver will choose already-built
184+
/// cubin if it is present, and otherwise JIT compile any PTX in the file to cubin.
185+
///
186+
/// # Example
187+
///
188+
/// ```
189+
/// # use cust::*;
190+
/// # use std::error::Error;
191+
/// # fn main() -> Result<(), Box<dyn Error>> {
192+
/// # let _ctx = quick_init()?;
193+
/// use cust::module::Module;
194+
/// let fatbin_bytes = std::fs::read("./resources/add.fatbin")?;
195+
/// // will return InvalidSource if the fatbin does not contain any compatible code, meaning, either
196+
/// // cubin compiled for the same device architecture OR PTX that can be JITted into valid code.
197+
/// let module = Module::from_fatbin(&fatbin_bytes, &[])?;
198+
/// # Ok(())
199+
/// # }
200+
/// ```
201+
pub fn from_fatbin<T: AsRef<[u8]>>(
202+
bytes: T,
203+
options: &[ModuleJitOption],
204+
) -> CudaResult<Module> {
205+
// fatbins can be loaded just like cubins, we just use different methods so it's explicit.
206+
// please don't use from_cubin for fatbins, that is pure chaos and ferris will come to your house
207+
Self::from_cubin(bytes, options)
208+
}
220209

221-
// pub fn from_cubin<T: AsRef<[u8]>>(bytes: T, options: &[ModuleJitOption]) -> CudaResult<Module> {
222-
// let bytes = bytes.as_ref();
223-
// goblin::elf::Elf::parse(bytes).expect("Cubin/Fatbin was not valid ELF!");
224-
// // SAFETY: we verified the bytes were valid ELF
225-
// unsafe { Self::from_cubin_unchecked(bytes, options) }
226-
// }
210+
/// Creates a new module by loading a cubin (CUDA Binary) file.
211+
///
212+
/// Cubins are architecture/compute-capability specific files generated as the final step of the CUDA compilation
213+
/// process. They cannot be interchanged across compute capabilities unlike PTX (to some degree). You can create one
214+
/// using the PTX compiler APIs, the cust [`Linker`](crate::link::Linker), or nvcc (`nvcc a.ptx --cubin -arch=sm_XX`).
215+
///
216+
/// # Example
217+
///
218+
/// ```
219+
/// # use cust::*;
220+
/// # use std::error::Error;
221+
/// # fn main() -> Result<(), Box<dyn Error>> {
222+
/// # let _ctx = quick_init()?;
223+
/// use cust::module::Module;
224+
/// let cubin_bytes = std::fs::read("./resources/add.cubin")?;
225+
/// // will return InvalidSource if the cubin arch doesn't match the context's device arch!
226+
/// let module = Module::from_cubin(&cubin_bytes, &[])?;
227+
/// # Ok(())
228+
/// # }
229+
/// ```
230+
pub fn from_cubin<T: AsRef<[u8]>>(bytes: T, options: &[ModuleJitOption]) -> CudaResult<Module> {
231+
// it is very unclear whether cuda wants or doesn't want a null terminator. The method works
232+
// whether you have one or not. So for safety we just add one. In theory you can figure out the
233+
// length of an ELF image without a null terminator. But the docs are confusing, so we add one just
234+
// to be sure.
235+
let mut bytes = bytes.as_ref().to_vec();
236+
bytes.push(0);
237+
// SAFETY: the image is known to be dereferenceable
238+
unsafe { Self::load_module(bytes.as_ptr() as *const c_void, options) }
239+
}
227240

228-
// pub unsafe fn from_cubin_unchecked<T: AsRef<[u8]>>(
229-
// bytes: T,
230-
// options: &[ModuleJitOption],
231-
// ) -> CudaResult<Module> {
232-
// let bytes = bytes.as_ref();
233-
// let mut module = Module {
234-
// inner: ptr::null_mut(),
235-
// };
236-
// let (mut options, mut option_values) = ModuleJitOption::into_raw(options);
237-
// cuda::cuModuleLoadDataEx(
238-
// &mut module.inner as *mut cuda::CUmodule,
239-
// bytes.as_ptr() as *const c_void,
240-
// options.len() as c_uint,
241-
// options.as_mut_ptr(),
242-
// option_values.as_mut_ptr(),
243-
// )
244-
// .to_result()?;
245-
// Ok(module)
246-
// }
241+
unsafe fn load_module(image: *const c_void, options: &[ModuleJitOption]) -> CudaResult<Module> {
242+
let mut module = Module {
243+
inner: ptr::null_mut(),
244+
};
245+
let (mut options, mut option_values) = ModuleJitOption::into_raw(options);
246+
cuda::cuModuleLoadDataEx(
247+
&mut module.inner as *mut cuda::CUmodule,
248+
image,
249+
options.len() as c_uint,
250+
options.as_mut_ptr(),
251+
option_values.as_mut_ptr(),
252+
)
253+
.to_result()?;
254+
Ok(module)
255+
}
247256

257+
/// Creates a new module from a [`CStr`] pointing to PTX code.
258+
///
259+
/// The driver will JIT the PTX into arch-specific cubin or pick already-cached cubin if available.
248260
pub fn from_ptx_cstr(cstr: &CStr, options: &[ModuleJitOption]) -> CudaResult<Module> {
249-
unsafe {
250-
let mut module = Module {
251-
inner: ptr::null_mut(),
252-
};
253-
let (mut options, mut option_values) = ModuleJitOption::into_raw(options);
254-
cuda::cuModuleLoadDataEx(
255-
&mut module.inner as *mut cuda::CUmodule,
256-
cstr.as_ptr() as *const c_void,
257-
options.len() as c_uint,
258-
options.as_mut_ptr(),
259-
option_values.as_mut_ptr(),
260-
)
261-
.to_result()?;
262-
Ok(module)
263-
}
261+
// SAFETY: the image is known to be dereferenceable
262+
unsafe { Self::load_module(cstr.as_ptr() as *const c_void, options) }
264263
}
265264

265+
/// Creates a new module from a PTX string, allocating an intermediate buffer for the [`CString`].
266+
///
267+
/// The driver will JIT the PTX into arch-specific cubin or pick already-cached cubin if available.
268+
///
269+
/// # Panics
270+
///
271+
/// Panics if `string` contains a nul.
272+
///
273+
/// # Example
274+
///
275+
/// ```
276+
/// # use cust::*;
277+
/// # use std::error::Error;
278+
/// # fn main() -> Result<(), Box<dyn Error>> {
279+
/// # let _ctx = quick_init()?;
280+
/// use cust::module::Module;
281+
/// let ptx = std::fs::read("./resources/add.ptx")?;
282+
/// let module = Module::from_ptx(&ptx, &[])?;
283+
/// # Ok(())
284+
/// # }
285+
/// ```
266286
pub fn from_ptx<T: AsRef<str>>(string: T, options: &[ModuleJitOption]) -> CudaResult<Module> {
267287
let cstr = CString::new(string.as_ref())
268288
.expect("string given to Module::from_str contained nul bytes");

0 commit comments

Comments
 (0)