Skip to content

Commit 5ccf306

Browse files
authored
Add spawn_local (#176)
1 parent 2aa3107 commit 5ccf306

File tree

9 files changed

+90
-19
lines changed

9 files changed

+90
-19
lines changed

Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ regex = "1.5.5"
3636
tempfile = "3.2.0"
3737
test-log = { version = "0.2.8", default-features = false, features = ["trace"] }
3838
tracing-subscriber = { version = "0.3.9", features = ["env-filter"] }
39+
trybuild = "1.0"
3940
pin-project = "1.1.3"
4041

4142
[lib]

src/future/mod.rs

Lines changed: 31 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -18,11 +18,10 @@ use std::task::{Context, Poll, Waker};
1818

1919
pub mod batch_semaphore;
2020

21-
/// Spawn a new async task that the executor will run to completion.
22-
pub fn spawn<T, F>(fut: F) -> JoinHandle<T>
21+
fn spawn_inner<F>(fut: F) -> JoinHandle<F::Output>
2322
where
24-
F: Future<Output = T> + Send + 'static,
25-
T: Send + 'static,
23+
F: Future + 'static,
24+
F::Output: 'static,
2625
{
2726
let stack_size = ExecutionState::with(|s| s.config.stack_size);
2827
let inner = Arc::new(std::sync::Mutex::new(JoinHandleInner::default()));
@@ -33,6 +32,25 @@ where
3332
JoinHandle { task_id, inner }
3433
}
3534

35+
/// Spawn a new async task that the executor will run to completion.
36+
pub fn spawn<F>(fut: F) -> JoinHandle<F::Output>
37+
where
38+
F: Future + Send + 'static,
39+
F::Output: Send + 'static,
40+
{
41+
spawn_inner(fut)
42+
}
43+
44+
/// Spawn a new async task that the executor will run to completion.
45+
/// This is just `spawn` without the `Send` bound, and it mirrors `spawn_local` from Tokio.
46+
pub fn spawn_local<F>(fut: F) -> JoinHandle<F::Output>
47+
where
48+
F: Future + 'static,
49+
F::Output: 'static,
50+
{
51+
spawn_inner(fut)
52+
}
53+
3654
/// An owned permission to join on an async task (await its termination).
3755
#[derive(Debug)]
3856
pub struct JoinHandle<T> {
@@ -120,27 +138,28 @@ impl<T> Future for JoinHandle<T> {
120138
// contains a mutex-wrapped field that stores the value and the waker for the task
121139
// waiting on the join handle. When `poll` returns `Poll::Ready`, the `Wrapper` stores
122140
// the result in the `result` field and wakes the `waker`.
123-
struct Wrapper<T, F> {
141+
struct Wrapper<F: Future> {
124142
future: Pin<Box<F>>,
125-
inner: std::sync::Arc<std::sync::Mutex<JoinHandleInner<T>>>,
143+
inner: std::sync::Arc<std::sync::Mutex<JoinHandleInner<F::Output>>>,
126144
}
127145

128-
impl<T, F> Wrapper<T, F>
146+
impl<F> Wrapper<F>
129147
where
130-
F: Future<Output = T> + Send + 'static,
148+
F: Future + 'static,
149+
F::Output: 'static,
131150
{
132-
fn new(future: F, inner: std::sync::Arc<std::sync::Mutex<JoinHandleInner<T>>>) -> Self {
151+
fn new(future: F, inner: std::sync::Arc<std::sync::Mutex<JoinHandleInner<F::Output>>>) -> Self {
133152
Self {
134153
future: Box::pin(future),
135154
inner,
136155
}
137156
}
138157
}
139158

140-
impl<T, F> Future for Wrapper<T, F>
159+
impl<F> Future for Wrapper<F>
141160
where
142-
F: Future<Output = T> + Send + 'static,
143-
T: Send + 'static,
161+
F: Future + 'static,
162+
F::Output: 'static,
144163
{
145164
type Output = ();
146165

src/runtime/execution.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -372,7 +372,7 @@ impl ExecutionState {
372372
/// if it wants to give the new task a chance to run immediately.
373373
pub(crate) fn spawn_future<F>(future: F, stack_size: usize, name: Option<String>) -> TaskId
374374
where
375-
F: Future<Output = ()> + Send + 'static,
375+
F: Future<Output = ()> + 'static,
376376
{
377377
let task_id = Self::with(|state| {
378378
let schedule_len = state.current_schedule.len();

src/runtime/task/mod.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -206,7 +206,7 @@ impl Task {
206206
parent_task_id: Option<TaskId>,
207207
) -> Self
208208
where
209-
F: FnOnce() + Send + 'static,
209+
F: FnOnce() + 'static,
210210
{
211211
assert!(id.0 < clock.time.len());
212212
let mut continuation = ContinuationPool::acquire(stack_size);
@@ -288,7 +288,7 @@ impl Task {
288288
parent_task_id: Option<TaskId>,
289289
) -> Self
290290
where
291-
F: Future<Output = ()> + Send + 'static,
291+
F: Future<Output = ()> + 'static,
292292
{
293293
let mut future = Box::pin(future);
294294

src/runtime/thread/continuation.rs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -32,11 +32,11 @@ pub(crate) struct Continuation {
3232
/// A cell to pass functions into continuations
3333
#[allow(clippy::type_complexity)]
3434
#[derive(Clone)]
35-
struct ContinuationFunction(Rc<Cell<Option<Box<dyn FnOnce() + Send>>>>);
35+
struct ContinuationFunction(Rc<Cell<Option<Box<dyn FnOnce()>>>>);
3636

3737
// Safety: we arrange for the `function` field of `Continuation` to only be accessed by one thread
3838
// at a time: Shuttle tests are single threaded, and continuations are never shared across threads
39-
// by the ContinuationPool, which is thread-local. The function itself already implements `Send`.
39+
// by the ContinuationPool, which is thread-local.
4040
unsafe impl Send for ContinuationFunction {}
4141

4242
/// Inputs that we can pass to a continuation.
@@ -104,7 +104,7 @@ impl Continuation {
104104

105105
/// Provide a new function for the continuation to execute. The continuation must
106106
/// be in reusable state.
107-
pub fn initialize(&mut self, fun: Box<dyn FnOnce() + Send>) {
107+
pub fn initialize(&mut self, fun: Box<dyn FnOnce()>) {
108108
debug_assert_eq!(
109109
self.state,
110110
ContinuationState::NotReady,

tests/basic/thread.rs

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11
use shuttle::current::{get_name_for_task, me};
22
use shuttle::sync::{Barrier, Condvar, Mutex};
3-
use shuttle::{check_dfs, check_random, thread};
3+
use shuttle::{check_dfs, check_random, future, thread};
44
use std::collections::HashSet;
5+
use std::rc::Rc;
56
use std::sync::atomic::{AtomicBool, AtomicU8, Ordering};
67
use std::sync::Arc;
78
use test_log::test;
@@ -642,3 +643,14 @@ fn thread_unpark_after_spurious_wakeup() {
642643
None,
643644
)
644645
}
646+
647+
#[test]
648+
fn spawn_local_sanity() {
649+
check_dfs(
650+
|| {
651+
let rc = Rc::new(0);
652+
shuttle::future::block_on(future::spawn_local(async { drop(rc) })).unwrap()
653+
},
654+
None,
655+
);
656+
}

tests/mod.rs

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,12 @@ mod data;
55
mod demo;
66
mod future;
77

8+
#[test]
9+
fn ui() {
10+
let t = trybuild::TestCases::new();
11+
t.compile_fail("tests/ui/*.rs");
12+
}
13+
814
use shuttle::scheduler::{ReplayScheduler, Scheduler};
915
use shuttle::{check_random_with_seed, replay_from_file, Config, FailurePersistence, Runner};
1016
use std::panic::{self, RefUnwindSafe, UnwindSafe};

tests/ui/spawn_not_send.rs

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
use shuttle::future;
2+
use shuttle::check_dfs;
3+
use std::rc::Rc;
4+
5+
fn main() {
6+
check_dfs(
7+
|| {
8+
let rc = Rc::new(0);
9+
shuttle::future::block_on(future::spawn(async { drop(rc) })).unwrap()
10+
},
11+
None,
12+
);
13+
}

tests/ui/spawn_not_send.stderr

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
error: future cannot be sent between threads safely
2+
--> tests/ui/spawn_not_send.rs:9:53
3+
|
4+
9 | shuttle::future::block_on(future::spawn(async { drop(rc) })).unwrap()
5+
| ^^^^^^^^^^^^^^^^^^ future created by async block is not `Send`
6+
|
7+
= help: within `{async block@$DIR/tests/ui/spawn_not_send.rs:9:53: 9:58}`, the trait `Send` is not implemented for `Rc<i32>`
8+
note: captured value is not `Send`
9+
--> tests/ui/spawn_not_send.rs:9:66
10+
|
11+
9 | shuttle::future::block_on(future::spawn(async { drop(rc) })).unwrap()
12+
| ^^ has type `Rc<i32>` which is not `Send`
13+
note: required by a bound in `shuttle::future::spawn`
14+
--> src/future/mod.rs
15+
|
16+
| pub fn spawn<F>(fut: F) -> JoinHandle<F::Output>
17+
| ----- required by a bound in this function
18+
| where
19+
| F: Future + Send + 'static,
20+
| ^^^^ required by this bound in `spawn`

0 commit comments

Comments
 (0)