Skip to content

Commit d1a587f

Browse files
committed
Wrap hooks in refcounter instead of box, same as previously.
This allows to obtain independent clone of closure. Ensure that stack is always clean when running thread hook.
1 parent 5bbd23e commit d1a587f

File tree

4 files changed

+32
-32
lines changed

4 files changed

+32
-32
lines changed

src/state.rs

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -516,7 +516,7 @@ impl Lua {
516516
let lua = self.lock();
517517
unsafe {
518518
(*lua.extra.get()).hook_triggers = triggers;
519-
(*lua.extra.get()).hook_callback = Some(Box::new(callback));
519+
(*lua.extra.get()).hook_callback = Some(XRc::new(callback));
520520
lua.set_thread_hook(lua.state(), HookKind::Global)
521521
}
522522
}
@@ -564,7 +564,7 @@ impl Lua {
564564
F: Fn(&Lua, Debug) -> Result<VmState> + MaybeSend + 'static,
565565
{
566566
let lua = self.lock();
567-
unsafe { lua.set_thread_hook(lua.state(), HookKind::Thread(triggers, Box::new(callback))) }
567+
unsafe { lua.set_thread_hook(lua.state(), HookKind::Thread(triggers, XRc::new(callback))) }
568568
}
569569

570570
/// Removes a global hook previously set by [`Lua::set_global_hook`].
@@ -644,8 +644,6 @@ impl Lua {
644644
where
645645
F: Fn(&Lua) -> Result<VmState> + MaybeSend + 'static,
646646
{
647-
use std::rc::Rc;
648-
649647
unsafe extern "C-unwind" fn interrupt_proc(state: *mut ffi::lua_State, gc: c_int) {
650648
if gc >= 0 {
651649
// We don't support GC interrupts since they cannot survive Lua exceptions
@@ -654,7 +652,7 @@ impl Lua {
654652
let result = callback_error_ext(state, ptr::null_mut(), move |extra, _| {
655653
let interrupt_cb = (*extra).interrupt_callback.clone();
656654
let interrupt_cb = mlua_expect!(interrupt_cb, "no interrupt callback set in interrupt_proc");
657-
if Rc::strong_count(&interrupt_cb) > 2 {
655+
if XRc::strong_count(&interrupt_cb) > 2 {
658656
return Ok(VmState::Continue); // Don't allow recursion
659657
}
660658
let _guard = StateGuard::new((*extra).raw_lua(), state);
@@ -671,7 +669,7 @@ impl Lua {
671669
// Set interrupt callback
672670
let lua = self.lock();
673671
unsafe {
674-
(*lua.extra.get()).interrupt_callback = Some(Rc::new(callback));
672+
(*lua.extra.get()).interrupt_callback = Some(XRc::new(callback));
675673
(*ffi::lua_callbacks(lua.main_state())).interrupt = Some(interrupt_proc);
676674
}
677675
}

src/state/raw.rs

Lines changed: 18 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -410,15 +410,12 @@ impl RawLua {
410410

411411
unsafe extern "C-unwind" fn global_hook_proc(state: *mut ffi::lua_State, ar: *mut ffi::lua_Debug) {
412412
let status = callback_error_ext(state, ptr::null_mut(), move |extra, _| {
413-
let rawlua = (*extra).raw_lua();
414-
let _guard = StateGuard::new(rawlua, state);
415-
let debug = Debug::new(rawlua, ar);
416-
match (*extra).hook_callback.take() {
417-
Some(hook_cb) => {
418-
// Temporary obtain ownership of the hook callback
419-
let result = hook_cb((*extra).lua(), debug);
420-
(*extra).hook_callback = Some(hook_cb);
421-
result
413+
match (*extra).hook_callback.clone() {
414+
Some(hook_callback) => {
415+
let rawlua = (*extra).raw_lua();
416+
let _guard = StateGuard::new(rawlua, state);
417+
let debug = Debug::new(rawlua, ar);
418+
hook_callback((*extra).lua(), debug)
422419
}
423420
None => {
424421
ffi::lua_sethook(state, None, 0, 0);
@@ -430,11 +427,17 @@ impl RawLua {
430427
}
431428

432429
unsafe extern "C-unwind" fn hook_proc(state: *mut ffi::lua_State, ar: *mut ffi::lua_Debug) {
430+
let top = ffi::lua_gettop(state);
431+
let mut hook_callback_ptr = ptr::null();
433432
ffi::luaL_checkstack(state, 3, ptr::null());
434-
ffi::lua_getfield(state, ffi::LUA_REGISTRYINDEX, HOOKS_KEY);
435-
ffi::lua_pushthread(state);
436-
if ffi::lua_rawget(state, -2) != ffi::LUA_TUSERDATA {
437-
ffi::lua_pop(state, 2);
433+
if ffi::lua_getfield(state, ffi::LUA_REGISTRYINDEX, HOOKS_KEY) == ffi::LUA_TTABLE {
434+
ffi::lua_pushthread(state);
435+
if ffi::lua_rawget(state, -2) == ffi::LUA_TUSERDATA {
436+
hook_callback_ptr = get_internal_userdata::<HookCallback>(state, -1, ptr::null());
437+
}
438+
}
439+
ffi::lua_settop(state, top);
440+
if hook_callback_ptr.is_null() {
438441
ffi::lua_sethook(state, None, 0, 0);
439442
return;
440443
}
@@ -443,13 +446,8 @@ impl RawLua {
443446
let rawlua = (*extra).raw_lua();
444447
let _guard = StateGuard::new(rawlua, state);
445448
let debug = Debug::new(rawlua, ar);
446-
match get_internal_userdata::<HookCallback>(state, -1, ptr::null()).as_ref() {
447-
Some(hook_cb) => hook_cb((*extra).lua(), debug),
448-
None => {
449-
ffi::lua_sethook(state, None, 0, 0);
450-
Ok(VmState::Continue)
451-
}
452-
}
449+
let hook_callback = (*hook_callback_ptr).clone();
450+
hook_callback((*extra).lua(), debug)
453451
});
454452
process_status(state, (*ar).event, status)
455453
}
@@ -482,7 +480,6 @@ impl RawLua {
482480

483481
ffi::lua_pushthread(thread_state);
484482
ffi::lua_xmove(thread_state, state, 1); // key (thread)
485-
let callback: HookCallback = Box::new(callback);
486483
let _ = push_internal_userdata(state, callback, false); // value (hook callback)
487484
ffi::lua_rawset(state, -3); // hooktable[thread] = hook callback
488485
})?;

src/thread.rs

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -272,7 +272,12 @@ impl Thread {
272272
F: Fn(&crate::Lua, Debug) -> Result<crate::VmState> + crate::MaybeSend + 'static,
273273
{
274274
let lua = self.0.lua.lock();
275-
unsafe { lua.set_thread_hook(self.state(), HookKind::Thread(triggers, Box::new(callback))) }
275+
unsafe {
276+
lua.set_thread_hook(
277+
self.state(),
278+
HookKind::Thread(triggers, crate::types::XRc::new(callback)),
279+
)
280+
}
276281
}
277282

278283
/// Removes any hook function from this thread.

src/types.rs

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -79,16 +79,16 @@ pub(crate) enum HookKind {
7979
}
8080

8181
#[cfg(all(feature = "send", not(feature = "luau")))]
82-
pub(crate) type HookCallback = Box<dyn Fn(&Lua, Debug) -> Result<VmState> + Send>;
82+
pub(crate) type HookCallback = XRc<dyn Fn(&Lua, Debug) -> Result<VmState> + Send>;
8383

8484
#[cfg(all(not(feature = "send"), not(feature = "luau")))]
85-
pub(crate) type HookCallback = Box<dyn Fn(&Lua, Debug) -> Result<VmState>>;
85+
pub(crate) type HookCallback = XRc<dyn Fn(&Lua, Debug) -> Result<VmState>>;
8686

8787
#[cfg(all(feature = "send", feature = "luau"))]
88-
pub(crate) type InterruptCallback = std::rc::Rc<dyn Fn(&Lua) -> Result<VmState> + Send>;
88+
pub(crate) type InterruptCallback = XRc<dyn Fn(&Lua) -> Result<VmState> + Send>;
8989

9090
#[cfg(all(not(feature = "send"), feature = "luau"))]
91-
pub(crate) type InterruptCallback = std::rc::Rc<dyn Fn(&Lua) -> Result<VmState>>;
91+
pub(crate) type InterruptCallback = XRc<dyn Fn(&Lua) -> Result<VmState>>;
9292

9393
#[cfg(all(feature = "send", feature = "lua54"))]
9494
pub(crate) type WarnCallback = Box<dyn Fn(&Lua, &str, bool) -> Result<()> + Send>;

0 commit comments

Comments
 (0)