Skip to content

Commit 22d6f9b

Browse files
fix: cactus soundness — enforce Sync invariant, guard callback aliasing, lock in Drop (#4103)
* fix: enforce Sync invariant at compile time, guard callback aliasing, lock in Drop - Introduce InferenceGuard: the model's raw FFI handle is now only accessible through the guard returned by lock_inference(), making it a compile error to touch the handle without holding the lock. stop() remains the sole documented exception (atomic-only). - Rewrite token_trampoline to use &CallbackState (shared ref) with Cell/UnsafeCell interior mutability and an in_callback re-entrancy guard, eliminating the previous &mut aliasing risk. - Acquire inference_lock in Model::drop so cactus_destroy waits for any in-flight FFI operation to complete. - Add SAFETY comment for the stack-pinned CallbackState pointer passed to cactus_complete. Co-Authored-By: yujonglee <yujonglee.dev@gmail.com> * style: fix dprint formatting in lock_inference Co-Authored-By: yujonglee <yujonglee.dev@gmail.com> --------- Co-authored-by: Devin AI <158243242+devin-ai-integration[bot]@users.noreply.github.com> Co-authored-by: yujonglee <yujonglee.dev@gmail.com>
1 parent 428619d commit 22d6f9b

File tree

5 files changed

+65
-34
lines changed

5 files changed

+65
-34
lines changed

crates/cactus/src/llm/complete.rs

Lines changed: 33 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,19 @@
1+
use std::cell::{Cell, UnsafeCell};
12
use std::ffi::{CStr, CString};
23

34
use crate::error::{Error, Result};
45
use crate::ffi_utils::{RESPONSE_BUF_SIZE, parse_buf};
5-
use crate::model::Model;
6+
use crate::model::{InferenceGuard, Model};
67

78
use super::{CompleteOptions, CompletionResult, Message};
89

910
type TokenCallback = unsafe extern "C" fn(*const std::ffi::c_char, u32, *mut std::ffi::c_void);
1011

1112
struct CallbackState<'a, F: FnMut(&str) -> bool> {
12-
on_token: &'a mut F,
13+
on_token: UnsafeCell<&'a mut F>,
1314
model: &'a Model,
14-
stopped: bool,
15+
stopped: Cell<bool>,
16+
in_callback: Cell<bool>,
1517
}
1618

1719
unsafe extern "C" fn token_trampoline<F: FnMut(&str) -> bool>(
@@ -23,21 +25,28 @@ unsafe extern "C" fn token_trampoline<F: FnMut(&str) -> bool>(
2325
return;
2426
}
2527

26-
let state = unsafe { &mut *(user_data as *mut CallbackState<F>) };
27-
if state.stopped {
28+
// SAFETY: We only create a shared reference to CallbackState. Interior
29+
// mutability (Cell/UnsafeCell) handles mutation. The `in_callback` guard
30+
// prevents re-entrant access to the UnsafeCell contents.
31+
let state = unsafe { &*(user_data as *const CallbackState<F>) };
32+
if state.stopped.get() || state.in_callback.get() {
2833
return;
2934
}
35+
state.in_callback.set(true);
3036

3137
let result = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
3238
let chunk = unsafe { CStr::from_ptr(token) }.to_string_lossy();
33-
if !(state.on_token)(&chunk) {
34-
state.stopped = true;
39+
// SAFETY: The `in_callback` flag ensures exclusive access to the closure.
40+
let on_token = unsafe { &mut *state.on_token.get() };
41+
if !on_token(&chunk) {
42+
state.stopped.set(true);
3543
state.model.stop();
3644
}
3745
}));
3846

47+
state.in_callback.set(false);
3948
if result.is_err() {
40-
state.stopped = true;
49+
state.stopped.set(true);
4150
state.model.stop();
4251
}
4352
}
@@ -58,6 +67,7 @@ pub(super) fn complete_error(rc: i32) -> Error {
5867
impl Model {
5968
fn call_complete(
6069
&self,
70+
guard: &InferenceGuard<'_>,
6171
messages_c: &CString,
6272
options_c: &CString,
6373
callback: Option<TokenCallback>,
@@ -67,7 +77,7 @@ impl Model {
6777

6878
let rc = unsafe {
6979
cactus_sys::cactus_complete(
70-
self.raw_handle(),
80+
guard.raw_handle(),
7181
messages_c.as_ptr(),
7282
buf.as_mut_ptr().cast::<std::ffi::c_char>(),
7383
buf.len(),
@@ -86,9 +96,10 @@ impl Model {
8696
messages: &[Message],
8797
options: &CompleteOptions,
8898
) -> Result<CompletionResult> {
89-
let _guard = self.lock_inference();
99+
let guard = self.lock_inference();
90100
let (messages_c, options_c) = serialize_complete_request(messages, options)?;
91-
let (rc, buf) = self.call_complete(&messages_c, &options_c, None, std::ptr::null_mut());
101+
let (rc, buf) =
102+
self.call_complete(&guard, &messages_c, &options_c, None, std::ptr::null_mut());
92103

93104
if rc < 0 {
94105
return Err(complete_error(rc));
@@ -106,23 +117,28 @@ impl Model {
106117
where
107118
F: FnMut(&str) -> bool,
108119
{
109-
let _guard = self.lock_inference();
120+
let guard = self.lock_inference();
110121
let (messages_c, options_c) = serialize_complete_request(messages, options)?;
111122

112-
let mut state = CallbackState {
113-
on_token: &mut on_token,
123+
let state = CallbackState {
124+
on_token: UnsafeCell::new(&mut on_token),
114125
model: self,
115-
stopped: false,
126+
stopped: Cell::new(false),
127+
in_callback: Cell::new(false),
116128
};
117129

130+
// SAFETY: `state` is stack-allocated and lives for the duration of the
131+
// FFI call. The C++ side must not retain this pointer beyond the return
132+
// of `cactus_complete`.
118133
let (rc, buf) = self.call_complete(
134+
&guard,
119135
&messages_c,
120136
&options_c,
121137
Some(token_trampoline::<F>),
122-
(&mut state as *mut CallbackState<F>).cast::<std::ffi::c_void>(),
138+
(&state as *const CallbackState<F> as *mut std::ffi::c_void),
123139
);
124140

125-
if rc < 0 && !state.stopped {
141+
if rc < 0 && !state.stopped.get() {
126142
return Err(complete_error(rc));
127143
}
128144

crates/cactus/src/model.rs

Lines changed: 26 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -11,10 +11,23 @@ pub struct Model {
1111
}
1212

1313
unsafe impl Send for Model {}
14-
// SAFETY: All FFI methods that touch model state are serialized by `inference_lock`.
14+
// SAFETY: All FFI methods that touch model state are serialized by `inference_lock`,
15+
// which is enforced at compile time via `InferenceGuard` — the model's raw handle is
16+
// only accessible through the guard returned by `lock_inference()`.
1517
// The sole exception is `stop()`, which only sets a `std::atomic<bool>` on the C++ side.
1618
unsafe impl Sync for Model {}
1719

20+
pub(crate) struct InferenceGuard<'a> {
21+
handle: NonNull<std::ffi::c_void>,
22+
_guard: MutexGuard<'a, ()>,
23+
}
24+
25+
impl InferenceGuard<'_> {
26+
pub(crate) fn raw_handle(&self) -> *mut std::ffi::c_void {
27+
self.handle.as_ptr()
28+
}
29+
}
30+
1831
pub struct ModelBuilder {
1932
model_path: PathBuf,
2033
}
@@ -53,27 +66,29 @@ impl Model {
5366
}
5467

5568
pub fn reset(&mut self) {
56-
let _guard = self.lock_inference();
69+
let guard = self.lock_inference();
5770
unsafe {
58-
cactus_sys::cactus_reset(self.handle.as_ptr());
71+
cactus_sys::cactus_reset(guard.raw_handle());
5972
}
6073
}
6174

62-
pub(crate) fn lock_inference(&self) -> MutexGuard<'_, ()> {
63-
self.inference_lock
75+
pub(crate) fn lock_inference(&self) -> InferenceGuard<'_> {
76+
let guard = self
77+
.inference_lock
6478
.lock()
65-
.unwrap_or_else(|e| e.into_inner())
66-
}
67-
68-
pub(crate) fn raw_handle(&self) -> *mut std::ffi::c_void {
69-
self.handle.as_ptr()
79+
.unwrap_or_else(|e| e.into_inner());
80+
InferenceGuard {
81+
handle: self.handle,
82+
_guard: guard,
83+
}
7084
}
7185
}
7286

7387
impl Drop for Model {
7488
fn drop(&mut self) {
89+
let guard = self.lock_inference();
7590
unsafe {
76-
cactus_sys::cactus_destroy(self.handle.as_ptr());
91+
cactus_sys::cactus_destroy(guard.raw_handle());
7792
}
7893
}
7994
}

crates/cactus/src/stt/batch.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ impl Model {
3232
input: TranscribeInput<'_>,
3333
options: &TranscribeOptions,
3434
) -> Result<TranscriptionResult> {
35-
let _guard = self.lock_inference();
35+
let guard = self.lock_inference();
3636
let prompt_c = CString::new(build_whisper_prompt(options))?;
3737
let options_c = CString::new(serde_json::to_string(options)?)?;
3838
let mut buf = vec![0u8; RESPONSE_BUF_SIZE];
@@ -44,7 +44,7 @@ impl Model {
4444

4545
let rc = unsafe {
4646
cactus_sys::cactus_transcribe(
47-
self.raw_handle(),
47+
guard.raw_handle(),
4848
path_ptr,
4949
prompt_c.as_ptr(),
5050
buf.as_mut_ptr() as *mut std::ffi::c_char,

crates/cactus/src/stt/transcriber.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -96,11 +96,11 @@ impl std::str::FromStr for StreamResult {
9696

9797
impl<'a> Transcriber<'a> {
9898
pub fn new(model: &'a Model, options: &TranscribeOptions, cloud: CloudConfig) -> Result<Self> {
99-
let _guard = model.lock_inference();
99+
let guard = model.lock_inference();
100100
let options_c = serialize_stream_options(options, &cloud)?;
101101

102102
let raw = unsafe {
103-
cactus_sys::cactus_stream_transcribe_start(model.raw_handle(), options_c.as_ptr())
103+
cactus_sys::cactus_stream_transcribe_start(guard.raw_handle(), options_c.as_ptr())
104104
};
105105

106106
let handle = NonNull::new(raw).ok_or_else(|| {

crates/cactus/src/vad.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ impl Model {
5454
pcm: Option<&[u8]>,
5555
options: &VadOptions,
5656
) -> Result<VadResult> {
57-
let _guard = self.lock_inference();
57+
let guard = self.lock_inference();
5858
let options_c = CString::new(serde_json::to_string(options)?)?;
5959
let mut buf = vec![0u8; RESPONSE_BUF_SIZE];
6060

@@ -64,7 +64,7 @@ impl Model {
6464

6565
let rc = unsafe {
6666
cactus_sys::cactus_vad(
67-
self.raw_handle(),
67+
guard.raw_handle(),
6868
path.map_or(std::ptr::null(), |p| p.as_ptr()),
6969
buf.as_mut_ptr() as *mut std::ffi::c_char,
7070
buf.len(),

0 commit comments

Comments
 (0)