Skip to content

Commit 5e969cb

Browse files
authored
refactor(runtime): introduce Scheduler for task scheduling (#491)
1 parent 4a5c750 commit 5e969cb

File tree

4 files changed

+214
-93
lines changed

4 files changed

+214
-93
lines changed

compio-runtime/src/runtime/mod.rs

Lines changed: 15 additions & 93 deletions
Original file line numberDiff line numberDiff line change
@@ -1,25 +1,20 @@
11
use std::{
22
any::Any,
33
cell::{Cell, RefCell},
4-
collections::{HashSet, VecDeque},
4+
collections::HashSet,
55
future::{Future, ready},
66
io,
7-
marker::PhantomData,
87
panic::AssertUnwindSafe,
9-
rc::Rc,
10-
sync::Arc,
118
task::{Context, Poll},
129
time::Duration,
1310
};
1411

15-
use async_task::{Runnable, Task};
12+
use async_task::Task;
1613
use compio_buf::IntoInner;
1714
use compio_driver::{
18-
AsRawFd, DriverType, Key, NotifyHandle, OpCode, Proactor, ProactorBuilder, PushEntry, RawFd,
19-
op::Asyncify,
15+
AsRawFd, DriverType, Key, OpCode, Proactor, ProactorBuilder, PushEntry, RawFd, op::Asyncify,
2016
};
2117
use compio_log::{debug, instrument};
22-
use crossbeam_queue::SegQueue;
2318
use futures_util::{FutureExt, future::Either};
2419

2520
pub(crate) mod op;
@@ -29,76 +24,22 @@ pub(crate) mod time;
2924
mod buffer_pool;
3025
pub use buffer_pool::*;
3126

32-
mod send_wrapper;
33-
use send_wrapper::SendWrapper;
27+
mod scheduler;
3428

3529
#[cfg(feature = "time")]
3630
use crate::runtime::time::{TimerFuture, TimerKey, TimerRuntime};
37-
use crate::{BufResult, affinity::bind_to_cpu_set, runtime::op::OpFuture};
31+
use crate::{
32+
BufResult,
33+
affinity::bind_to_cpu_set,
34+
runtime::{op::OpFuture, scheduler::Scheduler},
35+
};
3836

3937
scoped_tls::scoped_thread_local!(static CURRENT_RUNTIME: Runtime);
4038

4139
/// Type alias for `Task<Result<T, Box<dyn Any + Send>>>`, which resolves to an
4240
/// `Err` when the spawned future panicked.
4341
pub type JoinHandle<T> = Task<Result<T, Box<dyn Any + Send>>>;
4442

45-
struct RunnableQueue {
46-
local_runnables: SendWrapper<RefCell<VecDeque<Runnable>>>,
47-
sync_runnables: SegQueue<Runnable>,
48-
}
49-
50-
impl RunnableQueue {
51-
pub fn new() -> Self {
52-
Self {
53-
local_runnables: SendWrapper::new(RefCell::new(VecDeque::new())),
54-
sync_runnables: SegQueue::new(),
55-
}
56-
}
57-
58-
pub fn schedule(&self, runnable: Runnable, handle: &NotifyHandle) {
59-
if let Some(runnables) = self.local_runnables.get() {
60-
runnables.borrow_mut().push_back(runnable);
61-
#[cfg(feature = "notify-always")]
62-
handle.notify().ok();
63-
} else {
64-
self.sync_runnables.push(runnable);
65-
handle.notify().ok();
66-
}
67-
}
68-
69-
/// SAFETY: call in the main thread
70-
pub unsafe fn run(&self, event_interval: usize) -> bool {
71-
let local_runnables = self.local_runnables.get_unchecked();
72-
73-
for _ in 0..event_interval {
74-
let local_task = local_runnables.borrow_mut().pop_front();
75-
76-
// Perform an empty check as a fast path, since `pop()` is more expensive.
77-
let sync_task = if self.sync_runnables.is_empty() {
78-
None
79-
} else {
80-
self.sync_runnables.pop()
81-
};
82-
83-
match (local_task, sync_task) {
84-
(Some(local), Some(sync)) => {
85-
local.run();
86-
sync.run();
87-
}
88-
(Some(local), None) => {
89-
local.run();
90-
}
91-
(None, Some(sync)) => {
92-
sync.run();
93-
}
94-
(None, None) => break,
95-
}
96-
}
97-
98-
!(local_runnables.borrow().is_empty() && self.sync_runnables.is_empty())
99-
}
100-
}
101-
10243
thread_local! {
10344
static RUNTIME_ID: Cell<u64> = const { Cell::new(0) };
10445
}
@@ -107,10 +48,9 @@ thread_local! {
10748
/// sent to other threads.
10849
pub struct Runtime {
10950
driver: RefCell<Proactor>,
110-
runnables: Arc<RunnableQueue>,
51+
scheduler: Scheduler,
11152
#[cfg(feature = "time")]
11253
timer_runtime: RefCell<TimerRuntime>,
113-
event_interval: usize,
11454
// Runtime id is used to check if the buffer pool is belonged to this runtime or not.
11555
// Without this, if user enable `io-uring-buf-ring` feature then:
11656
// 1. Create a buffer pool at runtime1
@@ -119,9 +59,6 @@ pub struct Runtime {
11959
// - buffer pool will return a wrong buffer which the buffer's data is uninit, that will cause
12060
// UB
12161
id: u64,
122-
// Other fields don't make it !Send, but actually `local_runnables` implies it should be !Send,
123-
// otherwise it won't be valid if the runtime is sent to other threads.
124-
_p: PhantomData<Rc<VecDeque<Runnable>>>,
12562
}
12663

12764
impl Runtime {
@@ -148,12 +85,10 @@ impl Runtime {
14885
}
14986
Ok(Self {
15087
driver: RefCell::new(proactor_builder.build()?),
151-
runnables: Arc::new(RunnableQueue::new()),
88+
scheduler: Scheduler::new(*event_interval),
15289
#[cfg(feature = "time")]
15390
timer_runtime: RefCell::new(TimerRuntime::new()),
154-
event_interval: *event_interval,
15591
id,
156-
_p: PhantomData,
15792
})
15893
}
15994

@@ -202,22 +137,10 @@ impl Runtime {
202137
///
203138
/// The caller should ensure the captured lifetime long enough.
204139
pub unsafe fn spawn_unchecked<F: Future>(&self, future: F) -> Task<F::Output> {
205-
let schedule = {
206-
// Use `Weak` to break reference cycle.
207-
// `RunnableQueue` -> `Runnable` -> `RunnableQueue`
208-
let runnables = Arc::downgrade(&self.runnables);
209-
let handle = self.driver.borrow().handle();
210-
211-
move |runnable| {
212-
if let Some(runnables) = runnables.upgrade() {
213-
runnables.schedule(runnable, &handle);
214-
}
215-
}
216-
};
140+
let notify = self.driver.borrow().handle();
217141

218-
let (runnable, task) = async_task::spawn_unchecked(future, schedule);
219-
runnable.schedule();
220-
task
142+
// SAFETY: See the safety comment of this method.
143+
unsafe { self.scheduler.spawn_unchecked(future, notify) }
221144
}
222145

223146
/// Low level API to control the runtime.
@@ -226,8 +149,7 @@ impl Runtime {
226149
///
227150
/// The return value indicates whether there are still tasks in the queue.
228151
pub fn run(&self) -> bool {
229-
// SAFETY: self is !Send + !Sync.
230-
unsafe { self.runnables.run(self.event_interval) }
152+
self.scheduler.run()
231153
}
232154

233155
/// Block on the future till it completes.
Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
use std::{cell::UnsafeCell, collections::VecDeque};
2+
3+
/// A queue that is `!Sync` with interior mutability.
4+
pub(crate) struct LocalQueue<T> {
5+
queue: UnsafeCell<VecDeque<T>>,
6+
}
7+
8+
impl<T> LocalQueue<T> {
9+
/// Creates an empty `LocalQueue`.
10+
pub(crate) const fn new() -> Self {
11+
Self {
12+
queue: UnsafeCell::new(VecDeque::new()),
13+
}
14+
}
15+
16+
/// Pushes an item to the back of the queue.
17+
pub(crate) fn push(&self, item: T) {
18+
// SAFETY:
19+
// Exclusive mutable access because:
20+
// - The mutable reference is created and used immediately within this scope.
21+
// - `LocalQueue` is `!Sync`, so no other threads can access it concurrently.
22+
let queue = unsafe { &mut *self.queue.get() };
23+
queue.push_back(item);
24+
}
25+
26+
/// Pops an item from the front of the queue, returning `None` if empty.
27+
pub(crate) fn pop(&self) -> Option<T> {
28+
// SAFETY:
29+
// Exclusive mutable access because:
30+
// - The mutable reference is created and used immediately within this scope.
31+
// - `LocalQueue` is `!Sync`, so no other threads can access it concurrently.
32+
let queue = unsafe { &mut *self.queue.get() };
33+
queue.pop_front()
34+
}
35+
36+
/// Returns `true` if the queue is empty.
37+
pub(crate) fn is_empty(&self) -> bool {
38+
// SAFETY:
39+
// Exclusive mutable access because:
40+
// - The mutable reference is created and used immediately within this scope.
41+
// - `LocalQueue` is `!Sync`, so no other threads can access it concurrently.
42+
let queue = unsafe { &mut *self.queue.get() };
43+
queue.is_empty()
44+
}
45+
}
Lines changed: 154 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,154 @@
1+
use crate::runtime::scheduler::{local_queue::LocalQueue, send_wrapper::SendWrapper};
2+
use async_task::{Runnable, Task};
3+
use compio_driver::NotifyHandle;
4+
use crossbeam_queue::SegQueue;
5+
use std::{future::Future, marker::PhantomData, sync::Arc};
6+
7+
mod local_queue;
8+
mod send_wrapper;
9+
10+
/// A task queue consisting of a local queue and a synchronized queue.
11+
struct TaskQueue {
12+
local_queue: SendWrapper<LocalQueue<Runnable>>,
13+
sync_queue: SegQueue<Runnable>,
14+
}
15+
16+
impl TaskQueue {
17+
/// Creates a new `TaskQueue`.
18+
fn new() -> Self {
19+
Self {
20+
local_queue: SendWrapper::new(LocalQueue::new()),
21+
sync_queue: SegQueue::new(),
22+
}
23+
}
24+
25+
/// Pushes a `Runnable` task to the appropriate queue.
26+
///
27+
/// If the current thread is the same as the creator thread, push to the local queue.
28+
/// Otherwise, push to the sync queue.
29+
fn push(&self, runnable: Runnable, notify: &NotifyHandle) {
30+
if let Some(local_queue) = self.local_queue.get() {
31+
local_queue.push(runnable);
32+
#[cfg(feature = "notify-always")]
33+
notify.notify().ok();
34+
} else {
35+
self.sync_queue.push(runnable);
36+
notify.notify().ok();
37+
}
38+
}
39+
40+
/// Pops at most one task from each queue and returns them as `(local_task, sync_task)`.
41+
///
42+
/// # Safety
43+
///
44+
/// Call this method in the same thread as the creator.
45+
unsafe fn pop(&self) -> (Option<Runnable>, Option<Runnable>) {
46+
// SAFETY: See the safety comment of this method.
47+
let local_queue = unsafe { self.local_queue.get_unchecked() };
48+
49+
let local_task = local_queue.pop();
50+
51+
// Perform an empty check as a fast path, since `SegQueue::pop()` is more expensive.
52+
let sync_task = if self.sync_queue.is_empty() {
53+
None
54+
} else {
55+
self.sync_queue.pop()
56+
};
57+
58+
(local_task, sync_task)
59+
}
60+
61+
/// Returns `true` if both queues are empty.
62+
///
63+
/// # Safety
64+
///
65+
/// Call this method in the same thread as the creator.
66+
unsafe fn is_empty(&self) -> bool {
67+
// SAFETY: See the safety comment of this method.
68+
let local_queue = unsafe { self.local_queue.get_unchecked() };
69+
local_queue.is_empty() && self.sync_queue.is_empty()
70+
}
71+
}
72+
73+
/// A scheduler for managing and executing tasks.
74+
pub(crate) struct Scheduler {
75+
task_queue: Arc<TaskQueue>,
76+
event_interval: usize,
77+
// `Scheduler` is `!Send` and `!Sync`.
78+
_local_marker: PhantomData<*const ()>,
79+
}
80+
81+
impl Scheduler {
82+
/// Creates a new `Scheduler`.
83+
pub(crate) fn new(event_interval: usize) -> Self {
84+
Self {
85+
task_queue: Arc::new(TaskQueue::new()),
86+
event_interval,
87+
_local_marker: PhantomData,
88+
}
89+
}
90+
91+
/// Spawns a new asynchronous task, returning a [`Task`] for it.
92+
///
93+
/// # Safety
94+
///
95+
/// The caller should ensure the captured lifetime long enough.
96+
pub(crate) unsafe fn spawn_unchecked<F>(
97+
&self,
98+
future: F,
99+
notify: NotifyHandle,
100+
) -> Task<F::Output>
101+
where
102+
F: Future,
103+
{
104+
let schedule = {
105+
// Use `Weak` to break reference cycle.
106+
// `TaskQueue` -> `Runnable` -> `TaskQueue`
107+
let task_queue = Arc::downgrade(&self.task_queue);
108+
109+
move |runnable| {
110+
if let Some(task_queue) = task_queue.upgrade() {
111+
task_queue.push(runnable, &notify);
112+
}
113+
}
114+
};
115+
116+
let (runnable, task) = async_task::spawn_unchecked(future, schedule);
117+
runnable.schedule();
118+
task
119+
}
120+
121+
/// Run the scheduled tasks.
122+
///
123+
/// The return value indicates whether there are still tasks in the queue.
124+
pub(crate) fn run(&self) -> bool {
125+
for _ in 0..self.event_interval {
126+
// SAFETY:
127+
// `Scheduler` is `!Send` and `!Sync`, so this method is only called
128+
// on `TaskQueue`'s creator thread.
129+
let tasks = unsafe { self.task_queue.pop() };
130+
131+
// Run the tasks, which will poll the futures.
132+
// Since spawned tasks are not required to be `Send`, they must always be polled
133+
// on the same thread. Because `Scheduler` is `!Send` and `!Sync`, this is safe.
134+
match tasks {
135+
(Some(local), Some(sync)) => {
136+
local.run();
137+
sync.run();
138+
}
139+
(Some(local), None) => {
140+
local.run();
141+
}
142+
(None, Some(sync)) => {
143+
sync.run();
144+
}
145+
(None, None) => break,
146+
}
147+
}
148+
149+
// SAFETY:
150+
// `Scheduler` is `!Send` and `!Sync`, so this method is only called
151+
// on `TaskQueue`'s creator thread.
152+
!unsafe { self.task_queue.is_empty() }
153+
}
154+
}

0 commit comments

Comments
 (0)