Skip to content

Commit 6ab550c

Browse files
committed
Feat: use cuMemHostLaunch instead of cuStreamAddCallback internally
1 parent 605dcb0 commit 6ab550c

File tree

2 files changed

+10
-16
lines changed

2 files changed

+10
-16
lines changed

crates/cust/CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ overall simplifying the context handling APIs. This does mean that the API chang
1717
The old context handling is fully present in `cust::context::legacy` for anyone who needs it for specific reasons. If you use `quick_init` you don't need to worry about
1818
any breaking changes, the API is the same.
1919

20+
- `Stream::add_callback` now internally uses `cuLaunchHostFunc` anticipating the deprecation and removal of `cuStreamAddCallback` per the driver docs. This does however mean that the function no longer takes a device status as a parameter and does not execute on context error.
2021
- Added `cust::memory::LockedBox`, same as `LockedBuffer` except for single elements.
2122
- Added `cust::memory::cuda_malloc_async`.
2223
- Added `cust::memory::cuda_free_async`.

crates/cust/src/stream.rs

Lines changed: 9 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
use crate::error::{CudaResult, DropResult, ToResult};
1414
use crate::event::Event;
1515
use crate::function::{BlockSize, Function, GridSize};
16-
use crate::sys::{self as cuda, cudaError_enum, CUstream};
16+
use crate::sys::{self as cuda, CUstream};
1717
use std::ffi::c_void;
1818
use std::mem;
1919
use std::panic;
@@ -147,9 +147,6 @@ impl Stream {
147147
///
148148
/// Callbacks must not make any CUDA API calls.
149149
///
150-
/// The callback will be passed a `CudaResult<()>` indicating the
151-
/// current state of the device with `Ok(())` denoting normal operation.
152-
///
153150
/// # Examples
154151
///
155152
/// ```
@@ -163,23 +160,22 @@ impl Stream {
163160
///
164161
/// // ... queue up some work on the stream
165162
///
166-
/// stream.add_callback(Box::new(|status| {
167-
/// println!("Device status is {:?}", status);
163+
/// stream.add_callback(Box::new(|| {
164+
/// println!("Work is done!");
168165
/// }));
169166
///
170167
/// // ... queue up some more work on the stream
171168
/// # Ok(())
172169
/// # }
173170
pub fn add_callback<T>(&self, callback: Box<T>) -> CudaResult<()>
174171
where
175-
T: FnOnce(CudaResult<()>) + Send,
172+
T: FnOnce() + Send,
176173
{
177174
unsafe {
178-
cuda::cuStreamAddCallback(
175+
cuda::cuLaunchHostFunc(
179176
self.inner,
180177
Some(callback_wrapper::<T>),
181178
Box::into_raw(callback) as *mut c_void,
182-
0,
183179
)
184180
.to_result()
185181
}
@@ -339,16 +335,13 @@ impl Drop for Stream {
339335
}
340336
}
341337
}
342-
unsafe extern "C" fn callback_wrapper<T>(
343-
_stream: CUstream,
344-
status: cudaError_enum,
345-
callback: *mut c_void,
346-
) where
347-
T: FnOnce(CudaResult<()>) + Send,
338+
unsafe extern "C" fn callback_wrapper<T>(callback: *mut c_void)
339+
where
340+
T: FnOnce() + Send,
348341
{
349342
// Stop panics from unwinding across the FFI
350343
let _ = panic::catch_unwind(|| {
351344
let callback: Box<T> = Box::from_raw(callback as *mut T);
352-
callback(status.to_result());
345+
callback();
353346
});
354347
}

0 commit comments

Comments
 (0)