Skip to content

Commit 1935bbf

Browse files
committed
Feat: add back set_current, don't return a duration in Linker::complete
1 parent e97db1e commit 1935bbf

File tree

4 files changed

+93
-83
lines changed

4 files changed

+93
-83
lines changed

crates/cust/CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ any breaking changes, the API is the same.
2727
- Added `zeroed_async` to `DeviceBox`.
2828
- Added `drop_async` to `DeviceBox`.
2929
- Added `new_async` to `DeviceBox`.
30+
- `Linker::complete` now only returns the built cubin, and not the cubin and a duration.
3031

3132
## 0.2.2 - 12/5/21
3233

crates/cust/src/context/mod.rs

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -576,6 +576,30 @@ impl CurrentContext {
576576
unsafe { cuda::cuCtxSetSharedMemConfig(transmute(cfg)).to_result() }
577577
}
578578

579+
/// Set the given context as the current context for this thread.
580+
///
581+
/// # Example
582+
///
583+
/// ```
584+
/// # use cust::device::Device;
585+
/// # use cust::context::{ Context, ContextFlags, CurrentContext };
586+
/// # use std::error::Error;
587+
/// #
588+
/// # fn main () -> Result<(), Box<dyn Error>> {
589+
/// # cust::init(cust::CudaFlags::empty())?;
590+
/// # let device = Device::get_device(0)?;
591+
/// let context = Context::new(device)?;
592+
/// CurrentContext::set_current(&context)?;
593+
/// # Ok(())
594+
/// # }
595+
/// ```
596+
pub fn set_current<C: ContextHandle>(c: &C) -> CudaResult<()> {
597+
unsafe {
598+
cuda::cuCtxSetCurrent(c.get_inner()).to_result()?;
599+
Ok(())
600+
}
601+
}
602+
579603
/// Block to wait for a context's tasks to complete.
580604
pub fn synchronize() -> CudaResult<()> {
581605
unsafe {

crates/cust/src/link.rs

Lines changed: 5 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
//! Functions for linking together multiple PTX files into a module.
22
3-
use std::{mem::MaybeUninit, time::Duration};
3+
use std::mem::MaybeUninit;
4+
use std::ptr::null_mut;
45

56
use crate::sys as cuda;
67

@@ -12,7 +13,6 @@ static UNNAMED: &str = "\0";
1213
#[derive(Debug)]
1314
pub struct Linker {
1415
raw: cuda::CUlinkState,
15-
duration_ptr: *mut *mut f32,
1616
}
1717

1818
impl Linker {
@@ -22,19 +22,11 @@ impl Linker {
2222
// Therefore we use box to alloc the memory for us, then into_raw it so we now have ownership
2323
// of the memory (and dont have any aliasing requirements attached either).
2424

25-
// technically it should be fine to alloc as Box<*mut f32> then dealloc with Box<Box<f32>> but
26-
// in the future rust may make this guarantee untrue so just alloc as Box<Box<f32>> then cast.
27-
let ptr = Box::into_raw(Box::new(Box::new(0.0f32))) as *mut *mut f32;
28-
29-
// cuda shouldnt be modifying this but trust the bindings in that it wants a *mut ptr.
30-
let options = &mut [cuda::CUjit_option_enum::CU_JIT_WALL_TIME];
3125
unsafe {
3226
let mut raw = MaybeUninit::uninit();
33-
cuda::cuLinkCreate_v2(1, options.as_mut_ptr(), ptr.cast(), raw.as_mut_ptr())
34-
.to_result()?;
27+
cuda::cuLinkCreate_v2(0, null_mut(), null_mut(), raw.as_mut_ptr()).to_result()?;
3528
Ok(Self {
3629
raw: raw.assume_init(),
37-
duration_ptr: ptr,
3830
})
3931
}
4032
}
@@ -121,7 +113,7 @@ impl Linker {
121113

122114
/// Runs the linker to generate the final cubin bytes. Also returns a duration
123115
/// for how long it took to run the linker.
124-
pub fn complete(self) -> CudaResult<(Vec<u8>, Duration)> {
116+
pub fn complete(self) -> CudaResult<Vec<u8>> {
125117
let mut cubin = MaybeUninit::uninit();
126118
let mut size = MaybeUninit::uninit();
127119

@@ -134,16 +126,7 @@ impl Linker {
134126
let mut vec = Vec::with_capacity(size);
135127
vec.extend_from_slice(slice);
136128

137-
// now that duration has been written to, retrieve it and deallocate it.
138-
let duration = **self.duration_ptr;
139-
// recreate a box from our pointer, which will take ownership
140-
// of it and then drop it immediately. This is sound because
141-
// complete consumes self.
142-
Box::from_raw(self.duration_ptr as *mut Box<f32>);
143-
144-
// convert to nanos so we dont lose the decimal millisecs.
145-
let duration = Duration::from_nanos((duration * 1e6) as u64);
146-
Ok((vec, duration))
129+
Ok(vec)
147130
}
148131
}
149132
}

crates/cust/src/module.rs

Lines changed: 63 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -175,70 +175,72 @@ impl Module {
175175
}
176176
}
177177

178-
/// Creates a new module by loading a fatbin (fat binary) file.
179-
///
180-
/// Fatbinary files are files that contain multiple ptx or cubin files. The driver will choose already-built
181-
/// cubin if it is present, and otherwise JIT compile any PTX in the file to cubin.
182-
///
183-
/// # Example
184-
///
185-
/// ```
186-
/// # use cust::*;
187-
/// # use std::error::Error;
188-
/// # fn main() -> Result<(), Box<dyn Error>> {
189-
/// # let _ctx = quick_init()?;
190-
/// use cust::module::Module;
191-
/// let fatbin_bytes = std::fs::read("./resources/add.cubin")?;
192-
/// assert!(fatbin_bytes.contains(&0));
193-
/// let module = Module::from_cubin(&fatbin_bytes, &[])?;
194-
/// # Ok(())
195-
/// # }
196-
/// ```
197-
pub fn from_fatbin<T: AsRef<[u8]>>(
198-
bytes: T,
199-
options: &[ModuleJitOption],
200-
) -> CudaResult<Module> {
201-
let mut bytes = bytes.as_ref().to_vec();
202-
bytes.push(0);
203-
// fatbins are just ELF files like cubins, and cuModuleLoadDataEx accepts ptx, cubin, and fatbin.
204-
// We just make the distinction in case we want to do anything extra in the future. As well
205-
// as keep things explicit to anyone reading the code.
206-
Self::from_cubin(bytes, options)
207-
}
178+
// TODO(RDambrosio016): figure out why the heck cuda rejects cubins literally made by nvcc and loaded by fs::read
208179

209-
pub unsafe fn from_fatbin_unchecked<T: AsRef<[u8]>>(
210-
bytes: T,
211-
options: &[ModuleJitOption],
212-
) -> CudaResult<Module> {
213-
Self::from_cubin_unchecked(bytes, options)
214-
}
180+
// /// Creates a new module by loading a fatbin (fat binary) file.
181+
// ///
182+
// /// Fatbinary files are files that contain multiple ptx or cubin files. The driver will choose already-built
183+
// /// cubin if it is present, and otherwise JIT compile any PTX in the file to cubin.
184+
// ///
185+
// /// # Example
186+
// ///
187+
// /// ```
188+
// /// # use cust::*;
189+
// /// # use std::error::Error;
190+
// /// # fn main() -> Result<(), Box<dyn Error>> {
191+
// /// # let _ctx = quick_init()?;
192+
// /// use cust::module::Module;
193+
// /// let fatbin_bytes = std::fs::read("./resources/add.cubin")?;
194+
// /// assert!(fatbin_bytes.contains(&0));
195+
// /// let module = Module::from_cubin(&fatbin_bytes, &[])?;
196+
// /// # Ok(())
197+
// /// # }
198+
// /// ```
199+
// pub fn from_fatbin<T: AsRef<[u8]>>(
200+
// bytes: T,
201+
// options: &[ModuleJitOption],
202+
// ) -> CudaResult<Module> {
203+
// let mut bytes = bytes.as_ref().to_vec();
204+
// bytes.push(0);
205+
// // fatbins are just ELF files like cubins, and cuModuleLoadDataEx accepts ptx, cubin, and fatbin.
206+
// // We just make the distinction in case we want to do anything extra in the future. As well
207+
// // as keep things explicit to anyone reading the code.
208+
// Self::from_cubin(bytes, options)
209+
// }
215210

216-
pub fn from_cubin<T: AsRef<[u8]>>(bytes: T, options: &[ModuleJitOption]) -> CudaResult<Module> {
217-
let bytes = bytes.as_ref();
218-
goblin::elf::Elf::parse(bytes).expect("Cubin/Fatbin was not valid ELF!");
219-
// SAFETY: we verified the bytes were valid ELF
220-
unsafe { Self::from_cubin_unchecked(bytes, options) }
221-
}
211+
// pub unsafe fn from_fatbin_unchecked<T: AsRef<[u8]>>(
212+
// bytes: T,
213+
// options: &[ModuleJitOption],
214+
// ) -> CudaResult<Module> {
215+
// Self::from_cubin_unchecked(bytes, options)
216+
// }
222217

223-
pub unsafe fn from_cubin_unchecked<T: AsRef<[u8]>>(
224-
bytes: T,
225-
options: &[ModuleJitOption],
226-
) -> CudaResult<Module> {
227-
let bytes = bytes.as_ref();
228-
let mut module = Module {
229-
inner: ptr::null_mut(),
230-
};
231-
let (mut options, mut option_values) = ModuleJitOption::into_raw(options);
232-
cuda::cuModuleLoadDataEx(
233-
&mut module.inner as *mut cuda::CUmodule,
234-
bytes.as_ptr() as *const c_void,
235-
options.len() as c_uint,
236-
options.as_mut_ptr(),
237-
option_values.as_mut_ptr(),
238-
)
239-
.to_result()?;
240-
Ok(module)
241-
}
218+
// pub fn from_cubin<T: AsRef<[u8]>>(bytes: T, options: &[ModuleJitOption]) -> CudaResult<Module> {
219+
// let bytes = bytes.as_ref();
220+
// goblin::elf::Elf::parse(bytes).expect("Cubin/Fatbin was not valid ELF!");
221+
// // SAFETY: we verified the bytes were valid ELF
222+
// unsafe { Self::from_cubin_unchecked(bytes, options) }
223+
// }
224+
225+
// pub unsafe fn from_cubin_unchecked<T: AsRef<[u8]>>(
226+
// bytes: T,
227+
// options: &[ModuleJitOption],
228+
// ) -> CudaResult<Module> {
229+
// let bytes = bytes.as_ref();
230+
// let mut module = Module {
231+
// inner: ptr::null_mut(),
232+
// };
233+
// let (mut options, mut option_values) = ModuleJitOption::into_raw(options);
234+
// cuda::cuModuleLoadDataEx(
235+
// &mut module.inner as *mut cuda::CUmodule,
236+
// bytes.as_ptr() as *const c_void,
237+
// options.len() as c_uint,
238+
// options.as_mut_ptr(),
239+
// option_values.as_mut_ptr(),
240+
// )
241+
// .to_result()?;
242+
// Ok(module)
243+
// }
242244

243245
pub fn from_ptx_cstr(cstr: &CStr, options: &[ModuleJitOption]) -> CudaResult<Module> {
244246
unsafe {

0 commit comments

Comments
 (0)