Skip to content

Commit c3eb5f7

Browse files
authored
fix(runtime): store active task's Waker in Scheduler (#494)
1 parent 1bc48f0 commit c3eb5f7

File tree

7 files changed

+238
-16
lines changed

7 files changed

+238
-16
lines changed

Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@ nix = "0.30.1"
5252
once_cell = "1.18.0"
5353
os_pipe = "1.1.4"
5454
paste = "1.0.14"
55+
pin-project-lite = "0.2.16"
5556
rand = "0.9.0"
5657
rustls = { version = "0.23.1", default-features = false }
5758
rustls-native-certs = "0.8.0"

compio-io/Cargo.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ compio-buf = { workspace = true, features = ["arrayvec", "bytes"] }
1515
futures-util = { workspace = true, features = ["sink"] }
1616
paste = { workspace = true }
1717
thiserror = { workspace = true, optional = true }
18-
pin-project-lite = { version = "0.2.14", optional = true }
18+
pin-project-lite = { workspace = true, optional = true }
1919
serde = { version = "1.0.219", optional = true }
2020
serde_json = { version = "1.0.140", optional = true }
2121

compio-runtime/Cargo.toml

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -41,8 +41,9 @@ crossbeam-queue = { workspace = true }
4141
futures-util = { workspace = true }
4242
once_cell = { workspace = true }
4343
scoped-tls = "1.0.1"
44-
slab = { workspace = true, optional = true }
44+
slab = { workspace = true }
4545
socket2 = { workspace = true }
46+
pin-project-lite = { workspace = true }
4647

4748
# Windows specific dependencies
4849
[target.'cfg(windows)'.dependencies]
@@ -61,11 +62,19 @@ block2 = "0.6.0"
6162

6263
[features]
6364
event = ["dep:cfg-if", "compio-buf/arrayvec"]
64-
time = ["dep:slab"]
65+
time = []
6566

6667
# Enable it to always notify the driver when a task schedules.
6768
notify-always = []
6869

70+
[[test]]
71+
name = "custom_loop"
72+
required-features = ["event", "time"]
73+
6974
[[test]]
7075
name = "event"
7176
required-features = ["event"]
77+
78+
[[test]]
79+
name = "drop"
80+
required-features = ["time"]

compio-runtime/src/runtime/mod.rs

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -154,7 +154,7 @@ impl Runtime {
154154

155155
/// Block on the future till it completes.
156156
pub fn block_on<F: Future>(&self, future: F) -> F::Output {
157-
CURRENT_RUNTIME.set(self, || {
157+
self.enter(|| {
158158
let mut result = None;
159159
unsafe { self.spawn_unchecked(async { result = Some(future.await) }) }.detach();
160160
loop {
@@ -348,6 +348,14 @@ impl Runtime {
348348
}
349349
}
350350

351+
impl Drop for Runtime {
352+
fn drop(&mut self) {
353+
self.enter(|| {
354+
self.scheduler.clear();
355+
})
356+
}
357+
}
358+
351359
impl AsRawFd for Runtime {
352360
fn as_raw_fd(&self) -> RawFd {
353361
self.driver.borrow().as_raw_fd()
Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
use std::{
2+
future::Future,
3+
pin::Pin,
4+
task::{Context, Poll},
5+
};
6+
7+
use pin_project_lite::pin_project;
8+
9+
/// Calls a function when dropped.
10+
struct Defer<F: FnMut()>(F);
11+
12+
impl<F: FnMut()> Drop for Defer<F> {
13+
fn drop(&mut self) {
14+
(self.0)();
15+
}
16+
}
17+
18+
pin_project! {
19+
/// A future wrapper that runs a hook when dropped.
20+
pub(crate) struct DropHook<Fut, Hook: FnMut()> {
21+
#[pin]
22+
future: Fut,
23+
_hook: Defer<Hook>,
24+
}
25+
}
26+
27+
impl<Fut, Hook: FnMut()> DropHook<Fut, Hook> {
28+
/// Creates a new [`DropHook`].
29+
pub(crate) fn new(future: Fut, hook: Hook) -> Self {
30+
Self {
31+
future,
32+
_hook: Defer(hook),
33+
}
34+
}
35+
}
36+
37+
impl<Fut: Future, Hook: FnMut()> Future for DropHook<Fut, Hook> {
38+
type Output = Fut::Output;
39+
40+
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
41+
self.project().future.poll(cx)
42+
}
43+
}

compio-runtime/src/runtime/scheduler/mod.rs

Lines changed: 84 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,13 @@
1-
use crate::runtime::scheduler::{local_queue::LocalQueue, send_wrapper::SendWrapper};
1+
use crate::runtime::scheduler::{
2+
drop_hook::DropHook, local_queue::LocalQueue, send_wrapper::SendWrapper,
3+
};
24
use async_task::{Runnable, Task};
35
use compio_driver::NotifyHandle;
46
use crossbeam_queue::SegQueue;
5-
use std::{future::Future, marker::PhantomData, sync::Arc};
7+
use slab::Slab;
8+
use std::{cell::RefCell, future::Future, marker::PhantomData, rc::Rc, sync::Arc, task::Waker};
69

10+
mod drop_hook;
711
mod local_queue;
812
mod send_wrapper;
913

@@ -68,13 +72,38 @@ impl TaskQueue {
6872
let local_queue = unsafe { self.local_queue.get_unchecked() };
6973
local_queue.is_empty() && self.sync_queue.is_empty()
7074
}
75+
76+
/// Clears both queues.
77+
///
78+
/// # Safety
79+
///
80+
/// Call this method in the same thread as the creator.
81+
unsafe fn clear(&self) {
82+
// SAFETY: See the safety comment of this method.
83+
let local_queue = unsafe { self.local_queue.get_unchecked() };
84+
85+
while let Some(item) = local_queue.pop() {
86+
drop(item);
87+
}
88+
89+
while let Some(item) = self.sync_queue.pop() {
90+
drop(item);
91+
}
92+
}
7193
}
7294

7395
/// A scheduler for managing and executing tasks.
7496
pub(crate) struct Scheduler {
97+
/// Queue for scheduled tasks.
7598
task_queue: Arc<TaskQueue>,
99+
100+
/// `Waker` of active tasks.
101+
active_tasks: Rc<RefCell<Slab<Waker>>>,
102+
103+
/// Number of scheduler ticks for each `run` invocation.
76104
event_interval: usize,
77-
// `Scheduler` is `!Send` and `!Sync`.
105+
106+
/// Makes this type `!Send` and `!Sync`.
78107
_local_marker: PhantomData<*const ()>,
79108
}
80109

@@ -83,6 +112,7 @@ impl Scheduler {
83112
pub(crate) fn new(event_interval: usize) -> Self {
84113
Self {
85114
task_queue: Arc::new(TaskQueue::new()),
115+
active_tasks: Rc::new(RefCell::new(Slab::new())),
86116
event_interval,
87117
_local_marker: PhantomData,
88118
}
@@ -101,20 +131,39 @@ impl Scheduler {
101131
where
102132
F: Future,
103133
{
134+
let mut active_tasks = self.active_tasks.borrow_mut();
135+
let task_entry = active_tasks.vacant_entry();
136+
137+
let future = {
138+
let active_tasks = self.active_tasks.clone();
139+
let index = task_entry.key();
140+
141+
// Wrap the future with a drop hook to remove the waker upon completion.
142+
DropHook::new(future, move || {
143+
active_tasks.borrow_mut().try_remove(index);
144+
})
145+
};
146+
104147
let schedule = {
105-
// Use `Weak` to break reference cycle.
106-
// `TaskQueue` -> `Runnable` -> `TaskQueue`
148+
// The schedule closure is managed by the `Waker` and may be dropped on another thread,
149+
// so use `Weak` to ensure the `TaskQueue` is always dropped on the creator thread.
107150
let task_queue = Arc::downgrade(&self.task_queue);
108151

109152
move |runnable| {
110-
if let Some(task_queue) = task_queue.upgrade() {
111-
task_queue.push(runnable, &notify);
112-
}
153+
// The `upgrade()` never fails because all tasks are dropped when the `Scheduler` is dropped,
154+
// if a `Waker` is used after that, the schedule closure will never be called.
155+
task_queue.upgrade().unwrap().push(runnable, &notify);
113156
}
114157
};
115158

116159
let (runnable, task) = async_task::spawn_unchecked(future, schedule);
160+
161+
// Store the waker.
162+
task_entry.insert(runnable.waker());
163+
164+
// Schedule the task for execution.
117165
runnable.schedule();
166+
118167
task
119168
}
120169

@@ -124,11 +173,13 @@ impl Scheduler {
124173
pub(crate) fn run(&self) -> bool {
125174
for _ in 0..self.event_interval {
126175
// SAFETY:
127-
// `Scheduler` is `!Send` and `!Sync`, so this method is only called
128-
// on `TaskQueue`'s creator thread.
176+
// This method is only called on `TaskQueue`'s creator thread
177+
// because `Scheduler` is `!Send` and `!Sync`.
129178
let tasks = unsafe { self.task_queue.pop() };
130179

131180
// Run the tasks, which will poll the futures.
181+
//
182+
// SAFETY:
132183
// Since spawned tasks are not required to be `Send`, they must always be polled
133184
// on the same thread. Because `Scheduler` is `!Send` and `!Sync`, this is safe.
134185
match tasks {
@@ -147,8 +198,29 @@ impl Scheduler {
147198
}
148199

149200
// SAFETY:
150-
// `Scheduler` is `!Send` and `!Sync`, so this method is only called
151-
// on `TaskQueue`'s creator thread.
201+
// This method is only called on `TaskQueue`'s creator thread
202+
// because `Scheduler` is `!Send` and `!Sync`.
152203
!unsafe { self.task_queue.is_empty() }
153204
}
205+
206+
/// Clears all active tasks.
207+
///
208+
/// This method **must** be called before the scheduler is dropped.
209+
pub(crate) fn clear(&self) {
210+
// Drain and wake all wakers, which will schedule all active tasks.
211+
self.active_tasks
212+
.borrow_mut()
213+
.drain()
214+
.for_each(|waker| waker.wake());
215+
216+
// Then drop all scheduled tasks, which will drop all futures.
217+
//
218+
// SAFETY:
219+
// Since spawned tasks are not required to be `Send`, they must always be dropped
220+
// on the same thread. Because `Scheduler` is `!Send` and `!Sync`, this is safe.
221+
//
222+
// This method is only called on `TaskQueue`'s creator thread
223+
// because `Scheduler` is `!Send` and `!Sync`.
224+
unsafe { self.task_queue.clear() };
225+
}
154226
}

compio-runtime/tests/drop.rs

Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,89 @@
1+
use futures_util::task::AtomicWaker;
2+
use std::{
3+
future::Future,
4+
pin::Pin,
5+
sync::Arc,
6+
task::{Context, Poll},
7+
thread::{self, ThreadId},
8+
};
9+
10+
struct DropWatcher {
11+
waker: Arc<AtomicWaker>,
12+
thread_id: ThreadId,
13+
}
14+
15+
impl DropWatcher {
16+
fn new(waker: Arc<AtomicWaker>) -> Self {
17+
Self {
18+
waker,
19+
thread_id: thread::current().id(),
20+
}
21+
}
22+
}
23+
24+
impl Future for DropWatcher {
25+
type Output = ();
26+
27+
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<()> {
28+
self.waker.register(cx.waker());
29+
Poll::Pending
30+
}
31+
}
32+
33+
impl Drop for DropWatcher {
34+
fn drop(&mut self) {
35+
if self.thread_id != thread::current().id() {
36+
panic!("DropWatcher dropped on a different thread!!!");
37+
}
38+
}
39+
}
40+
41+
#[test]
42+
fn test_drop_with_timer() {
43+
compio_runtime::Runtime::new().unwrap().block_on(async {
44+
compio_runtime::spawn(async {
45+
loop {
46+
compio_runtime::time::sleep(std::time::Duration::from_secs(1)).await;
47+
}
48+
})
49+
.detach();
50+
})
51+
}
52+
53+
#[test]
54+
fn test_wake_after_runtime_drop() {
55+
let waker = Arc::new(AtomicWaker::new());
56+
let waker_clone = waker.clone();
57+
58+
let rt = compio_runtime::Runtime::new().unwrap();
59+
60+
rt.block_on(async move {
61+
compio_runtime::spawn(DropWatcher::new(waker_clone)).detach();
62+
});
63+
64+
drop(rt);
65+
66+
// Use `unwrap()` to ensure there is a waker stored.
67+
waker.take().unwrap().wake();
68+
}
69+
70+
#[test]
71+
fn test_wake_from_another_thread_after_runtime_drop() {
72+
let waker = Arc::new(AtomicWaker::new());
73+
let waker_clone = waker.clone();
74+
75+
let rt = compio_runtime::Runtime::new().unwrap();
76+
77+
rt.block_on(async move {
78+
compio_runtime::spawn(DropWatcher::new(waker_clone)).detach();
79+
});
80+
81+
drop(rt);
82+
83+
thread::spawn(move || {
84+
// Use `unwrap()` to ensure there is a waker stored.
85+
waker.take().unwrap().wake();
86+
})
87+
.join()
88+
.unwrap();
89+
}

0 commit comments

Comments
 (0)