Skip to content

Commit b80a1c8

Browse files
committed
Add task library (based on tokio::task)
1 parent 3ba3314 commit b80a1c8

File tree

8 files changed

+562
-20
lines changed

8 files changed

+562
-20
lines changed

Cargo.toml

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ json = ["mlua/serde", "dep:ouroboros", "dep:serde", "dep:serde_json"]
2020
regex = ["dep:regex", "dep:ouroboros", "dep:quick_cache"]
2121
yaml = ["mlua/serde", "dep:ouroboros", "dep:serde", "dep:serde_yaml"]
2222
http = ["dep:http"]
23+
task = ["dep:tokio", "dep:tokio-util", "mlua/async"]
2324

2425
[dependencies]
2526
mlua = { version = "0.11" }
@@ -33,3 +34,10 @@ quick_cache = { version = "0.6", optional = true }
3334

3435
# http
3536
http = { version = "1.3", optional = true }
37+
38+
# tokio
39+
tokio = { version = "1", features = ["full"], optional = true }
40+
tokio-util = { version = "0.7", features = ["time"], optional = true }
41+
42+
[dev-dependencies]
43+
tokio = { version = "1", features = ["full"] }

src/lib.rs

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,14 +5,15 @@ pub(crate) const METAMETHOD_ITER: &str = if cfg!(feature = "luau") { "__iter" }
55
#[macro_use]
66
mod macros;
77
mod types;
8+
mod util;
89

910
pub(crate) mod terminal;
10-
pub(crate) mod time;
1111

1212
pub mod assertions;
1313
pub mod bytes;
1414
pub mod env;
1515
pub mod testing;
16+
pub mod time;
1617

1718
#[cfg(feature = "json")]
1819
pub mod json;
@@ -21,4 +22,8 @@ pub mod regex;
2122
#[cfg(feature = "yaml")]
2223
pub mod yaml;
2324

25+
#[cfg(feature = "http")]
2426
pub mod http;
27+
28+
#[cfg(feature = "task")]
29+
pub mod task;

src/macros.rs

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,3 +6,25 @@ macro_rules! lua_try {
66
}
77
};
88
}
9+
10+
macro_rules! defer {
11+
($($item:tt)*) => {
12+
let _guard = crate::util::defer(|| { $($item)* });
13+
};
14+
}
15+
16+
macro_rules! opt_param {
17+
($table:expr, $name:expr) => {
18+
match ($table.as_ref())
19+
.map(|t| t.raw_get::<Option<_>>($name))
20+
.transpose()
21+
{
22+
Ok(Some(v)) => Ok(v),
23+
Ok(None) => Ok(None),
24+
Err(err) => {
25+
use mlua::ErrorContext as _;
26+
Err(err.with_context(|_| format!("invalid `{}`", $name)))
27+
}
28+
}
29+
};
30+
}

src/task.rs

Lines changed: 274 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,274 @@
1+
use std::cell::RefCell;
2+
use std::panic;
3+
use std::rc::Rc;
4+
use std::time::Instant;
5+
6+
use mlua::{
7+
Either, ExternalError, Function, Lua, MetaMethod, MultiValue, Result, Table, UserData, UserDataFields,
8+
UserDataMethods, UserDataRef, UserDataRegistry, Value,
9+
};
10+
use tokio::task::{AbortHandle, JoinHandle, JoinSet};
11+
use tokio::time::{Instant as TokioInstant, MissedTickBehavior};
12+
use tokio_util::time::FutureExt as _;
13+
14+
use crate::time::Duration;
15+
16+
#[derive(Clone, Default)]
17+
struct Params {
18+
name: Option<String>,
19+
timeout: Option<Duration>,
20+
}
21+
22+
pub struct TaskHandle {
23+
name: Option<String>,
24+
started: Rc<RefCell<Option<Instant>>>,
25+
elapsed: Rc<RefCell<Option<Duration>>>,
26+
handle: Either<Option<JoinHandle<Result<Value>>>, AbortHandle>,
27+
}
28+
29+
impl UserData for TaskHandle {
30+
fn register(registry: &mut UserDataRegistry<Self>) {
31+
registry.add_field_method_get("id", |_, this| match this.handle.as_ref() {
32+
Either::Left(jh) => Ok(jh.as_ref().map(|h| h.id().to_string())),
33+
Either::Right(ah) => Ok(Some(ah.id().to_string())),
34+
});
35+
36+
registry.add_field_method_get("name", |lua, this| lua.pack(this.name.as_deref()));
37+
38+
registry.add_async_method_mut("join", |_, mut this, ()| async move {
39+
if this.handle.is_right() {
40+
return Ok(Err("cannot join grouped task".into_lua_err()));
41+
}
42+
match this.handle.as_mut().left().and_then(|h| h.take()) {
43+
Some(jh) => match jh.await {
44+
Ok(res) => Ok(res),
45+
Err(err) if err.is_panic() => panic::resume_unwind(err.into_panic()),
46+
Err(err) => Ok(Err(err.into_lua_err())),
47+
},
48+
None => Ok(Err("task already joined".into_lua_err())),
49+
}
50+
});
51+
52+
registry.add_async_method("abort", |_, this, ()| async move {
53+
match this.handle.as_ref() {
54+
Either::Left(Some(jh)) => {
55+
jh.abort();
56+
}
57+
Either::Left(None) => {}
58+
Either::Right(ah) => {
59+
ah.abort();
60+
}
61+
}
62+
Ok(())
63+
});
64+
65+
registry.add_method("elapsed", |_, this, ()| match *this.elapsed.borrow() {
66+
Some(dur) => Ok(Some(dur)),
67+
None => Ok(this.started.borrow().map(|s| Duration(s.elapsed()))),
68+
});
69+
70+
registry.add_method("is_finished", |_, this, ()| match this.handle.as_ref() {
71+
Either::Left(Some(jh)) => Ok(jh.is_finished()),
72+
Either::Left(None) => Ok(true),
73+
Either::Right(ah) => Ok(ah.is_finished()),
74+
});
75+
}
76+
}
77+
78+
pub struct Task {
79+
func: Function,
80+
params: Params,
81+
}
82+
83+
impl Task {
84+
fn new(func: Function, params: Option<Table>) -> Result<Self> {
85+
let name: Option<String> = opt_param!(params, "name")?;
86+
let timeout: Option<Duration> = opt_param!(params, "timeout")?;
87+
Ok(Self {
88+
func,
89+
params: Params { name, timeout },
90+
})
91+
}
92+
}
93+
94+
impl UserData for Task {}
95+
96+
pub struct Group(JoinSet<Result<Value>>);
97+
98+
impl Group {
99+
fn new() -> Self {
100+
Group(JoinSet::new())
101+
}
102+
}
103+
104+
impl UserData for Group {
105+
fn register(registry: &mut UserDataRegistry<Self>) {
106+
registry.add_method_mut(
107+
"spawn",
108+
|_, this, (func, args): (Either<Function, UserDataRef<Task>>, MultiValue)| {
109+
let Params { name, timeout } = (func.as_ref())
110+
.right()
111+
.map_or(Params::default(), |ud| ud.params.clone());
112+
113+
let started = Rc::new(RefCell::new(None));
114+
let elapsed = Rc::new(RefCell::new(None));
115+
let (started2, elapsed2) = (started.clone(), elapsed.clone());
116+
117+
let fut = match func {
118+
Either::Left(f) => f.call_async(args),
119+
Either::Right(ud) => ud.func.call_async(args),
120+
};
121+
122+
let abort_handle = this.0.spawn_local(async move {
123+
*started2.borrow_mut() = Some(Instant::now());
124+
defer! {
125+
*elapsed2.borrow_mut() = Some(Duration(started2.borrow().unwrap().elapsed()));
126+
}
127+
128+
let result = match timeout {
129+
Some(dur) => fut.timeout(dur.0).await,
130+
None => Ok(fut.await),
131+
};
132+
result
133+
.map_err(|_| "task exceeded timeout".into_lua_err())
134+
.flatten()
135+
});
136+
137+
Ok(TaskHandle {
138+
name,
139+
started,
140+
elapsed,
141+
handle: Either::Right(abort_handle),
142+
})
143+
},
144+
);
145+
146+
registry.add_method("len", |_, this, ()| Ok(this.0.len()));
147+
148+
registry.add_async_method_mut("join_next", |_, mut this, ()| async move {
149+
match this.0.join_next().await {
150+
Some(Ok(res)) => Ok(Ok(Some(lua_try!(res)))),
151+
Some(Err(err)) if err.is_panic() => panic::resume_unwind(err.into_panic()),
152+
Some(Err(err)) => Ok(Err(err.to_string())),
153+
None => Ok(Ok(None)),
154+
}
155+
});
156+
157+
registry.add_async_method_mut("join_all", |_, mut this, ()| async move {
158+
let mut results = Vec::with_capacity(this.0.len());
159+
while let Some(res) = this.0.join_next().await {
160+
match res {
161+
Ok(Ok(val)) => results.push(val),
162+
Ok(Err(err)) => results.push(Value::Error(Box::new(err))),
163+
Err(err) if err.is_panic() => panic::resume_unwind(err.into_panic()),
164+
Err(err) => results.push(Value::Error(Box::new(err.into_lua_err()))),
165+
}
166+
}
167+
Ok(results)
168+
});
169+
170+
registry.add_method_mut("abort_all", |_, this, ()| {
171+
this.0.abort_all();
172+
Ok(())
173+
});
174+
175+
registry.add_method_mut("detach_all", |_, this, ()| {
176+
this.0.detach_all();
177+
Ok(())
178+
});
179+
180+
registry.add_meta_method(MetaMethod::Len, |_, this, ()| Ok(this.0.len()));
181+
}
182+
}
183+
184+
fn spawn_inner(params: Params, fut: impl Future<Output = Result<Value>> + 'static) -> Result<TaskHandle> {
185+
let Params { name, timeout } = params;
186+
187+
let started = Rc::new(RefCell::new(None));
188+
let elapsed = Rc::new(RefCell::new(None));
189+
let (started2, elapsed2) = (started.clone(), elapsed.clone());
190+
191+
let handle = tokio::task::spawn_local(async move {
192+
*started2.borrow_mut() = Some(Instant::now());
193+
defer! {
194+
*elapsed2.borrow_mut() = Some(Duration(started2.borrow().unwrap().elapsed()));
195+
}
196+
197+
let result = match timeout {
198+
Some(dur) => fut.timeout(dur.0).await,
199+
None => Ok(fut.await),
200+
};
201+
result
202+
.map_err(|_| "task exceeded timeout".into_lua_err())
203+
.flatten()
204+
});
205+
206+
Ok(TaskHandle {
207+
name,
208+
started,
209+
elapsed,
210+
handle: Either::Left(Some(handle)),
211+
})
212+
}
213+
214+
pub fn spawn(_: &Lua, (func, args): (Either<Function, UserDataRef<Task>>, MultiValue)) -> Result<TaskHandle> {
215+
let params = (func.as_ref())
216+
.right()
217+
.map_or(Params::default(), |ud| ud.params.clone());
218+
219+
spawn_inner(params, async move {
220+
match func {
221+
Either::Left(f) => f.call_async(args).await,
222+
Either::Right(ud) => ud.func.call_async(args).await,
223+
}
224+
})
225+
}
226+
227+
pub fn spawn_every(
228+
_: &Lua,
229+
(dur, func, args): (Duration, Either<Function, UserDataRef<Task>>, MultiValue),
230+
) -> Result<TaskHandle> {
231+
let (func, params) = match func {
232+
Either::Left(f) => (f, Params::default()),
233+
Either::Right(ud) => (ud.func.clone(), ud.params.clone()),
234+
};
235+
236+
spawn_inner(params, async move {
237+
let mut interval = tokio::time::interval_at(TokioInstant::now() + dur.0, dur.0);
238+
interval.set_missed_tick_behavior(MissedTickBehavior::Delay);
239+
loop {
240+
interval.tick().await;
241+
func.call_async::<()>(args.clone()).await?;
242+
}
243+
})
244+
}
245+
246+
pub async fn sleep(_: Lua, dur: Duration) -> Result<()> {
247+
tokio::time::sleep(dur.0).await;
248+
Ok(())
249+
}
250+
251+
pub async fn yield_now(_: Lua, _: ()) -> Result<()> {
252+
tokio::task::yield_now().await;
253+
Ok(())
254+
}
255+
256+
/// A loader for the `task` module.
257+
fn loader(lua: &Lua) -> Result<Table> {
258+
let t = lua.create_table()?;
259+
t.set("create", Function::wrap(Task::new))?;
260+
t.set("group", Function::wrap_raw(Group::new))?;
261+
t.set("spawn", lua.create_function(spawn)?)?;
262+
t.set("spawn_every", lua.create_function(spawn_every)?)?;
263+
t.set("sleep", lua.create_async_function(sleep)?)?;
264+
t.set("yield", lua.create_async_function(yield_now)?)?;
265+
Ok(t)
266+
}
267+
268+
/// Registers the `task` module in the given Lua state.
269+
pub fn register(lua: &Lua, name: Option<&str>) -> Result<Table> {
270+
let name = name.unwrap_or("@task");
271+
let value = loader(lua)?;
272+
lua.register_module(name, &value)?;
273+
Ok(value)
274+
}

0 commit comments

Comments
 (0)