Skip to content

Commit b5a3489

Browse files
committed
feat(task): scope tasks
1 parent 0506350 commit b5a3489

File tree

2 files changed

+239
-6
lines changed

2 files changed

+239
-6
lines changed

crates/sync/stage/src/blocks/mod.rs

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ use katana_primitives::transaction::{Tx, TxWithHash};
1313
use katana_primitives::Felt;
1414
use katana_provider::api::block::{BlockHashProvider, BlockWriter};
1515
use katana_provider::{MutableProvider, ProviderError, ProviderFactory};
16+
use katana_tasks::TaskSpawner;
1617
use num_traits::ToPrimitive;
1718
use starknet::core::types::ResourcePrice;
1819
use tracing::{error, info_span, Instrument};
@@ -28,12 +29,13 @@ pub use downloader::{BatchBlockDownloader, BlockDownloader};
2829
pub struct Blocks<PF, B> {
2930
provider: PF,
3031
downloader: B,
32+
task_spawner: TaskSpawner,
3133
}
3234

3335
impl<PF, B> Blocks<PF, B> {
3436
/// Create a new [`Blocks`] stage.
35-
pub fn new(provider: PF, downloader: B) -> Self {
36-
Self { provider, downloader }
37+
pub fn new(provider: PF, downloader: B, task_spawner: TaskSpawner) -> Self {
38+
Self { provider, downloader, task_spawner }
3739
}
3840

3941
/// Validates that the downloaded blocks form a valid chain.
@@ -112,12 +114,14 @@ where
112114
.await
113115
.map_err(Error::Gateway)?;
114116

117+
self.task_spawner
118+
.scope(|s| Box::pin(s.spawn_blocking(|| self.validate_chain_invariant(&blocks))))
119+
.await
120+
.map_err(Error::BlockValidationTaskJoinError)??;
121+
115122
let span = info_span!(target: "stage", "blocks.insert", from = %input.from(), to = %input.to());
116123
let _enter = span.enter();
117124

118-
// TODO: spawn onto a blocking thread pool
119-
self.validate_chain_invariant(&blocks)?;
120-
121125
let provider_mut = self.provider.provider_mut();
122126

123127
for block in blocks {
@@ -152,6 +156,9 @@ pub enum Error {
152156
#[error(transparent)]
153157
Provider(#[from] ProviderError),
154158

159+
#[error("block validation task error: {0}")]
160+
BlockValidationTaskJoinError(katana_tasks::JoinError),
161+
155162
#[error(
156163
"chain invariant violation: block {block_num} parent hash {parent_hash:#x} does not match \
157164
previous block hash {expected_hash:#x}"

crates/tasks/src/spawner.rs

Lines changed: 227 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,22 @@
11
use core::future::Future;
2+
use core::marker::PhantomData;
3+
use core::pin::Pin;
24
use std::panic::{self, AssertUnwindSafe};
35
use std::sync::Arc;
46

5-
use futures::FutureExt;
7+
use futures::stream::FuturesUnordered;
8+
use futures::{FutureExt, StreamExt};
9+
use tokio::sync::mpsc::{UnboundedReceiver, UnboundedSender};
10+
use tokio::sync::oneshot;
11+
use tokio::task::block_in_place;
612
use tokio_util::sync::CancellationToken;
713
use tracing::{debug, error};
814

915
use crate::{CpuBlockingJoinHandle, Inner, JoinError, JoinHandle};
1016

17+
type BoxScopedFuture<'scope> = Pin<Box<dyn Future<Output = ()> + Send + 'scope>>;
18+
type BoxScopeFuture<'scope, T> = Pin<Box<dyn Future<Output = T> + Send + 'scope>>;
19+
1120
/// A spawner for spawning tasks on the [`TaskManager`] that it was derived from.
1221
///
1322
/// This is the main way to spawn tasks on a [`TaskManager`]. It can only be created
@@ -53,11 +62,181 @@ impl TaskSpawner {
5362
JoinHandle(handle)
5463
}
5564

65+
/// Runs a scoped block in which tasks may borrow non-`'static` data but are guaranteed to be
66+
/// completed before this method returns.
67+
pub async fn scope<'scope, R>(
68+
&'scope self,
69+
f: impl FnOnce(ScopedTaskSpawner<'scope>) -> BoxScopeFuture<'scope, R>,
70+
) -> R
71+
where
72+
R: Send + 'scope,
73+
{
74+
let (sender, receiver) = tokio::sync::mpsc::unbounded_channel();
75+
let scoped_spawner: ScopedTaskSpawner<'scope> =
76+
ScopedTaskSpawner { inner: self.inner.clone(), sender, _marker: PhantomData };
77+
78+
let user_future = f(scoped_spawner.clone());
79+
80+
TaskScope::new(scoped_spawner, receiver, user_future).run().await
81+
}
82+
5683
pub(crate) fn cancellation_token(&self) -> &CancellationToken {
5784
&self.inner.on_cancel
5885
}
5986
}
6087

88+
/// Spawner used inside [`TaskSpawner::scope`] to spawn scoped tasks.
89+
#[derive(Debug, Clone)]
90+
pub struct ScopedTaskSpawner<'scope> {
91+
inner: Arc<Inner>,
92+
sender: UnboundedSender<BoxScopedFuture<'scope>>,
93+
_marker: PhantomData<&'scope ()>,
94+
}
95+
96+
impl<'scope> ScopedTaskSpawner<'scope> {
97+
/// Spawn an async task that can borrow data with lifetime `'scope`.
98+
pub fn spawn<F>(&self, fut: F) -> ScopedJoinHandle<'scope, F::Output>
99+
where
100+
F: Future + Send + 'scope,
101+
F::Output: Send + 'scope,
102+
{
103+
let (tx, rx) = oneshot::channel();
104+
let mut receiver = rx;
105+
let cancellation_token = self.inner.on_cancel.clone();
106+
107+
let task = Box::pin(async move {
108+
let result = cancellable(cancellation_token, fut).await;
109+
let _ = tx.send(result);
110+
});
111+
112+
if self.sender.send(task).is_err() {
113+
receiver.close();
114+
}
115+
116+
ScopedJoinHandle { receiver, _marker: PhantomData }
117+
}
118+
119+
/// Spawn a blocking task that can borrow data with lifetime `'scope`.
120+
pub fn spawn_blocking<F, R>(&self, func: F) -> ScopedJoinHandle<'scope, R>
121+
where
122+
F: FnOnce() -> R + Send + 'scope,
123+
R: Send + 'scope,
124+
{
125+
let (tx, rx) = oneshot::channel();
126+
let mut receiver = rx;
127+
let cancellation_token = self.inner.on_cancel.clone();
128+
129+
let task = Box::pin(async move {
130+
let result = block_in_place(|| {
131+
if cancellation_token.is_cancelled() {
132+
return Err(JoinError::Cancelled);
133+
}
134+
135+
panic::catch_unwind(AssertUnwindSafe(func)).map_err(JoinError::Panic)
136+
});
137+
138+
let _ = tx.send(result);
139+
});
140+
141+
if self.sender.send(task).is_err() {
142+
receiver.close();
143+
}
144+
145+
ScopedJoinHandle { receiver, _marker: PhantomData }
146+
}
147+
}
148+
149+
struct TaskScope<'scope, R>
150+
where
151+
R: Send + 'scope,
152+
{
153+
tasks: FuturesUnordered<BoxScopedFuture<'scope>>,
154+
receiver: UnboundedReceiver<BoxScopedFuture<'scope>>,
155+
user_future: BoxScopeFuture<'scope, R>,
156+
spawner: Option<ScopedTaskSpawner<'scope>>,
157+
receiver_closed: bool,
158+
}
159+
160+
impl<'scope, R> TaskScope<'scope, R>
161+
where
162+
R: Send + 'scope,
163+
{
164+
fn new(
165+
spawner: ScopedTaskSpawner<'scope>,
166+
receiver: UnboundedReceiver<BoxScopedFuture<'scope>>,
167+
user_future: BoxScopeFuture<'scope, R>,
168+
) -> Self {
169+
Self {
170+
tasks: FuturesUnordered::new(),
171+
receiver,
172+
user_future,
173+
spawner: Some(spawner),
174+
receiver_closed: false,
175+
}
176+
}
177+
178+
async fn run(mut self) -> R {
179+
let mut user_output = None;
180+
let mut user_done = false;
181+
182+
loop {
183+
tokio::select! {
184+
result = self.receiver.recv(), if !self.receiver_closed => {
185+
match result {
186+
Some(task) => self.tasks.push(task),
187+
None => self.receiver_closed = true,
188+
}
189+
}
190+
191+
Some(_) = self.tasks.next(), if !self.tasks.is_empty() => {}
192+
193+
output = self.user_future.as_mut(), if !user_done => {
194+
user_output = Some(output);
195+
user_done = true;
196+
// Drop the spawner to close the sender side of the channel.
197+
self.spawner.take();
198+
}
199+
200+
else => {
201+
if user_output.is_some() && self.receiver_closed && self.tasks.is_empty() {
202+
break;
203+
}
204+
}
205+
}
206+
}
207+
208+
user_output.expect("user future completed")
209+
}
210+
}
211+
212+
/// Join handle for scoped tasks spawned via [`TaskSpawner::scope`].
213+
pub struct ScopedJoinHandle<'scope, T> {
214+
receiver: oneshot::Receiver<crate::Result<T>>,
215+
_marker: PhantomData<&'scope ()>,
216+
}
217+
218+
impl<T> core::fmt::Debug for ScopedJoinHandle<'_, T> {
219+
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
220+
f.debug_struct("ScopedJoinHandle").finish()
221+
}
222+
}
223+
224+
impl<T> Future for ScopedJoinHandle<'_, T> {
225+
type Output = crate::Result<T>;
226+
227+
fn poll(
228+
self: Pin<&mut Self>,
229+
cx: &mut core::task::Context<'_>,
230+
) -> core::task::Poll<Self::Output> {
231+
let this = self.get_mut();
232+
match Pin::new(&mut this.receiver).poll(cx) {
233+
core::task::Poll::Ready(Ok(value)) => core::task::Poll::Ready(value),
234+
core::task::Poll::Ready(Err(..)) => core::task::Poll::Ready(Err(JoinError::Cancelled)),
235+
core::task::Poll::Pending => core::task::Poll::Pending,
236+
}
237+
}
238+
}
239+
61240
/// A task spawner dedicated for spawning CPU-bound blocking tasks.
62241
///
63242
/// Tasks spawned by this spawner will be executed on a thread pool dedicated to CPU-bound tasks.
@@ -215,6 +394,7 @@ async fn cancellable<F: Future>(
215394
mod tests {
216395

217396
use std::future::pending;
397+
use std::sync::atomic::{AtomicUsize, Ordering};
218398

219399
use crate::TaskManager;
220400

@@ -275,4 +455,50 @@ mod tests {
275455
let error = result.expect_err("should be panic error");
276456
assert!(error.is_panic());
277457
}
458+
459+
#[tokio::test]
460+
async fn scoped_spawn_allows_borrowed_data() {
461+
let manager = TaskManager::current();
462+
let spawner = manager.task_spawner();
463+
464+
let text = String::from("scoped");
465+
let expected_len = text.len();
466+
467+
spawner
468+
.scope(|scope| {
469+
let text_ref: &String = &text;
470+
Box::pin(async move {
471+
let handle = scope.spawn(async move { text_ref.len() });
472+
assert_eq!(handle.await.unwrap(), expected_len);
473+
})
474+
})
475+
.await;
476+
477+
// original value is still valid after scope returns
478+
assert_eq!(text.len(), expected_len);
479+
}
480+
481+
#[tokio::test]
482+
async fn scoped_spawn_blocking_allows_borrowed_data() {
483+
let manager = TaskManager::current();
484+
let spawner = manager.task_spawner();
485+
486+
let counter = AtomicUsize::new(0);
487+
488+
spawner
489+
.scope(|scope| {
490+
let counter_ref = &counter;
491+
Box::pin(async move {
492+
let handle = scope.spawn_blocking(move || {
493+
counter_ref.fetch_add(1, Ordering::SeqCst);
494+
counter_ref.load(Ordering::SeqCst)
495+
});
496+
497+
assert_eq!(handle.await.unwrap(), 1);
498+
})
499+
})
500+
.await;
501+
502+
assert_eq!(counter.load(Ordering::SeqCst), 1);
503+
}
278504
}

0 commit comments

Comments
 (0)