Skip to content

Commit d399559

Browse files
committed
Add Lua::yield_with to allow yielding Rust async functions and exchange values between Lua coroutine and Rust.
This functionality is similar to `coroutine.yield` and `coroutine.resume` without C restrictions.
1 parent 30735d5 commit d399559

File tree

4 files changed

+158
-12
lines changed

4 files changed

+158
-12
lines changed

src/state.rs

Lines changed: 96 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ use crate::{buffer::Buffer, chunk::Compiler};
3737
use {
3838
crate::types::LightUserData,
3939
std::future::{self, Future},
40+
std::task::Poll,
4041
};
4142

4243
#[cfg(feature = "serde")]
@@ -2079,6 +2080,101 @@ impl Lua {
20792080
LightUserData(&ASYNC_POLL_TERMINATE as *const u8 as *mut std::os::raw::c_void)
20802081
}
20812082

2083+
#[cfg(feature = "async")]
2084+
#[inline(always)]
2085+
pub(crate) fn poll_yield() -> LightUserData {
2086+
static ASYNC_POLL_YIELD: u8 = 0;
2087+
LightUserData(&ASYNC_POLL_YIELD as *const u8 as *mut std::os::raw::c_void)
2088+
}
2089+
2090+
/// Suspends the current async function, returning the provided arguments to caller.
2091+
///
2092+
/// This function is similar to [`coroutine.yield`] but allow yeilding Rust functions
2093+
/// and passing values to the caller.
2094+
/// Please note that you cannot cross [`Thread`] boundaries (e.g. calling `yield_with` on one
2095+
/// thread and resuming on another).
2096+
///
2097+
/// # Examples
2098+
///
2099+
/// Async iterator:
2100+
///
2101+
/// ```
2102+
/// # use mlua::{Lua, Result};
2103+
///
2104+
/// async fn generator(lua: Lua, _: ()) -> Result<()> {
2105+
/// for i in 0..10 {
2106+
/// lua.yield_with::<()>(i).await?;
2107+
/// }
2108+
/// Ok(())
2109+
/// }
2110+
///
2111+
/// fn main() -> Result<()> {
2112+
/// let lua = Lua::new();
2113+
/// lua.globals().set("generator", lua.create_async_function(generator)?)?;
2114+
///
2115+
/// lua.load(r#"
2116+
/// local n = 0
2117+
/// for i in coroutine.wrap(generator) do
2118+
/// n = n + i
2119+
/// end
2120+
/// assert(n == 45)
2121+
/// "#)
2122+
/// .exec()
2123+
/// }
2124+
/// ```
2125+
///
2126+
/// Exchange values on yield:
2127+
///
2128+
/// ```
2129+
/// # use mlua::{Lua, Result, Value};
2130+
///
2131+
/// async fn pingpong(lua: Lua, mut val: i32) -> Result<()> {
2132+
/// loop {
2133+
/// val = lua.yield_with::<i32>(val).await? + 1;
2134+
/// }
2135+
/// Ok(())
2136+
/// }
2137+
///
2138+
/// # fn main() -> Result<()> {
2139+
/// let lua = Lua::new();
2140+
///
2141+
/// let co = lua.create_thread(lua.create_async_function(pingpong)?)?;
2142+
/// assert_eq!(co.resume::<i32>(1)?, 1);
2143+
/// assert_eq!(co.resume::<i32>(2)?, 3);
2144+
/// assert_eq!(co.resume::<i32>(3)?, 4);
2145+
///
2146+
/// # Ok(())
2147+
/// # }
2148+
/// ```
2149+
///
2150+
/// [`coroutine.yield`]: https://www.lua.org/manual/5.4/manual.html#pdf-coroutine.yield
2151+
#[cfg(feature = "async")]
2152+
#[cfg_attr(docsrs, doc(cfg(feature = "async")))]
2153+
pub async fn yield_with<R: FromLuaMulti>(&self, args: impl IntoLuaMulti) -> Result<R> {
2154+
let mut args = Some(args.into_lua_multi(self)?);
2155+
future::poll_fn(move |_cx| match args.take() {
2156+
Some(args) => unsafe {
2157+
let lua = self.lock();
2158+
lua.push(Self::poll_yield())?; // yield marker
2159+
if args.len() <= 1 {
2160+
lua.push(args.front())?;
2161+
} else {
2162+
lua.push(lua.create_sequence_from(&args)?)?;
2163+
}
2164+
lua.push(args.len())?;
2165+
Poll::Pending
2166+
},
2167+
None => unsafe {
2168+
let lua = self.lock();
2169+
let state = lua.state();
2170+
let _sg = StackGuard::with_top(state, 0);
2171+
let nvals = ffi::lua_gettop(state);
2172+
Poll::Ready(R::from_stack_multi(nvals, &lua))
2173+
},
2174+
})
2175+
.await
2176+
}
2177+
20822178
/// Returns a weak reference to the Lua instance.
20832179
///
20842180
/// This is useful for creating a reference to the Lua instance that does not prevent it from

src/state/raw.rs

Lines changed: 26 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1278,6 +1278,13 @@ impl RawLua {
12781278
let mut ctx = Context::from_waker(rawlua.waker());
12791279
match fut.as_mut().map(|fut| fut.as_mut().poll(&mut ctx)) {
12801280
Some(Poll::Pending) => {
1281+
let fut_nvals = ffi::lua_gettop(state);
1282+
if fut_nvals >= 3 && ffi::lua_tolightuserdata(state, -3) == Lua::poll_yield().0 {
1283+
// We have some values to yield
1284+
ffi::lua_pushnil(state);
1285+
ffi::lua_replace(state, -4);
1286+
return Ok(3);
1287+
}
12811288
ffi::lua_pushnil(state);
12821289
ffi::lua_pushlightuserdata(state, Lua::poll_pending().0);
12831290
Ok(2)
@@ -1348,6 +1355,7 @@ impl RawLua {
13481355
local poll = get_poll(...)
13491356
local nres, res, res2 = poll()
13501357
while true do
1358+
-- Poll::Ready branch, `nres` is the number of results
13511359
if nres ~= nil then
13521360
if nres == 0 then
13531361
return
@@ -1363,10 +1371,20 @@ impl RawLua {
13631371
return unpack(res, nres)
13641372
end
13651373
end
1366-
-- `res` is a "pending" value
1367-
-- `yield` can return a signal to drop the future that we should propagate
1368-
-- to the poller
1369-
nres, res, res2 = poll(yield(res))
1374+
1375+
-- Poll::Pending branch
1376+
if res2 == nil then
1377+
-- `res` is a "pending" value
1378+
-- `yield` can return a signal to drop the future that we should propagate
1379+
-- to the poller
1380+
nres, res, res2 = poll(yield(res))
1381+
elseif res2 == 0 then
1382+
nres, res, res2 = poll(yield())
1383+
elseif res2 == 1 then
1384+
nres, res, res2 = poll(yield(res))
1385+
else
1386+
nres, res, res2 = poll(yield(unpack(res, res2)))
1387+
end
13701388
end
13711389
"#,
13721390
)
@@ -1378,14 +1396,14 @@ impl RawLua {
13781396

13791397
#[cfg(feature = "async")]
13801398
#[inline]
1381-
pub(crate) unsafe fn waker(&self) -> &Waker {
1382-
(*self.extra.get()).waker.as_ref()
1399+
pub(crate) fn waker(&self) -> &Waker {
1400+
unsafe { (*self.extra.get()).waker.as_ref() }
13831401
}
13841402

13851403
#[cfg(feature = "async")]
13861404
#[inline]
1387-
pub(crate) unsafe fn set_waker(&self, waker: NonNull<Waker>) -> NonNull<Waker> {
1388-
mem::replace(&mut (*self.extra.get()).waker, waker)
1405+
pub(crate) fn set_waker(&self, waker: NonNull<Waker>) -> NonNull<Waker> {
1406+
unsafe { mem::replace(&mut (*self.extra.get()).waker, waker) }
13891407
}
13901408
}
13911409

src/thread.rs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -604,7 +604,7 @@ impl<R: FromLuaMulti> Future for AsyncThread<R> {
604604

605605
if status.is_yielded() {
606606
if !(nresults == 1 && is_poll_pending(thread_state)) {
607-
// Ignore value returned via yield()
607+
// Ignore values returned via yield()
608608
cx.waker().wake_by_ref();
609609
}
610610
return Poll::Pending;
@@ -635,7 +635,7 @@ struct WakerGuard<'lua, 'a> {
635635
impl<'lua, 'a> WakerGuard<'lua, 'a> {
636636
#[inline]
637637
pub fn new(lua: &'lua RawLua, waker: &'a Waker) -> Result<WakerGuard<'lua, 'a>> {
638-
let prev = unsafe { lua.set_waker(NonNull::from(waker)) };
638+
let prev = lua.set_waker(NonNull::from(waker));
639639
Ok(WakerGuard {
640640
lua,
641641
prev,
@@ -647,7 +647,7 @@ impl<'lua, 'a> WakerGuard<'lua, 'a> {
647647
#[cfg(feature = "async")]
648648
impl Drop for WakerGuard<'_, '_> {
649649
fn drop(&mut self) {
650-
unsafe { self.lua.set_waker(self.prev) };
650+
self.lua.set_waker(self.prev);
651651
}
652652
}
653653

tests/async.rs

Lines changed: 33 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ use futures_util::stream::TryStreamExt;
88
use tokio::sync::Mutex;
99

1010
use mlua::{
11-
Error, Function, Lua, LuaOptions, MultiValue, ObjectLike, Result, StdLib, Table, UserData,
11+
Error, Function, Lua, LuaOptions, MultiValue, ObjectLike, Result, StdLib, Table, ThreadStatus, UserData,
1212
UserDataMethods, UserDataRef, Value,
1313
};
1414

@@ -667,3 +667,35 @@ async fn test_async_hook() -> Result<()> {
667667

668668
Ok(())
669669
}
670+
671+
#[test]
672+
fn test_async_yield_with() -> Result<()> {
673+
let lua = Lua::new();
674+
675+
let func = lua.create_async_function(|lua, (mut a, mut b): (i32, i32)| async move {
676+
let zero = lua.yield_with::<MultiValue>(()).await?;
677+
assert!(zero.is_empty());
678+
let one = lua.yield_with::<MultiValue>(a + b).await?;
679+
assert_eq!(one.len(), 1);
680+
681+
for _ in 0..3 {
682+
(a, b) = lua.yield_with((a + b, a * b)).await?;
683+
}
684+
Ok((0, 0))
685+
})?;
686+
687+
let thread = lua.create_thread(func)?;
688+
689+
let zero = thread.resume::<MultiValue>((2, 3))?; // function arguments
690+
assert!(zero.is_empty());
691+
let one = thread.resume::<i32>(())?; // value of "zero" is passed here
692+
assert_eq!(one, 5);
693+
694+
assert_eq!(thread.resume::<(i32, i32)>(1)?, (5, 6)); // value of "one" is passed here
695+
assert_eq!(thread.resume::<(i32, i32)>((10, 11))?, (21, 110));
696+
assert_eq!(thread.resume::<(i32, i32)>((11, 12))?, (23, 132));
697+
assert_eq!(thread.resume::<(i32, i32)>((12, 13))?, (0, 0));
698+
assert_eq!(thread.status(), ThreadStatus::Finished);
699+
700+
Ok(())
701+
}

0 commit comments

Comments
 (0)