Skip to content

Commit 83bbec2

Browse files
committed
move to different module
1 parent 95665f7 commit 83bbec2

File tree

9 files changed

+219
-185
lines changed

9 files changed

+219
-185
lines changed

Cargo.lock

Lines changed: 1 addition & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -148,6 +148,7 @@ num-bigint = "0.4.3"
148148
num-traits = { version = "0.2", default-features = false }
149149
once_cell = "1.0"
150150
parking_lot = "0.12.1"
151+
pin-project = "1.1"
151152
postcard = { version = "1.0.10", features = [ "use-std" ], default-features = false }
152153
rand = "0.8.5"
153154
rayon = "1.8.0"

crates/sync/stage/src/blocks/mod.rs

Lines changed: 3 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@ use katana_primitives::transaction::{Tx, TxWithHash};
1313
use katana_primitives::Felt;
1414
use katana_provider::api::block::{BlockHashProvider, BlockWriter};
1515
use katana_provider::{MutableProvider, ProviderError, ProviderFactory};
16-
use katana_tasks::TaskSpawner;
1716
use num_traits::ToPrimitive;
1817
use starknet::core::types::ResourcePrice;
1918
use tracing::{error, info_span, Instrument};
@@ -29,13 +28,12 @@ pub use downloader::{BatchBlockDownloader, BlockDownloader};
2928
pub struct Blocks<PF, B> {
3029
provider: PF,
3130
downloader: B,
32-
task_spawner: TaskSpawner,
3331
}
3432

3533
impl<PF, B> Blocks<PF, B> {
3634
/// Create a new [`Blocks`] stage.
37-
pub fn new(provider: PF, downloader: B, task_spawner: TaskSpawner) -> Self {
38-
Self { provider, downloader, task_spawner }
35+
pub fn new(provider: PF, downloader: B) -> Self {
36+
Self { provider, downloader }
3937
}
4038

4139
/// Validates that the downloaded blocks form a valid chain.
@@ -114,10 +112,7 @@ where
114112
.await
115113
.map_err(Error::Gateway)?;
116114

117-
self.task_spawner
118-
.scope(|s| s.spawn_blocking(|| self.validate_chain_invariant(&blocks)))
119-
.await
120-
.map_err(Error::BlockValidationTaskJoinError)??;
115+
self.validate_chain_invariant(&blocks)?;
121116

122117
let span = info_span!(target: "stage", "blocks.insert", from = %input.from(), to = %input.to());
123118
let _enter = span.enter();

crates/tasks/Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ version.workspace = true
77

88
[dependencies]
99
futures.workspace = true
10+
pin-project.workspace = true
1011
rayon.workspace = true
1112
thiserror.workspace = true
1213
tokio.workspace = true

crates/tasks/src/lib.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,12 @@
22

33
mod blocking;
44
mod manager;
5+
mod scope;
56
mod spawner;
67
mod task;
78

89
pub use blocking::*;
910
pub use manager::*;
11+
pub use scope::*;
1012
pub use spawner::*;
1113
pub use task::*;

crates/tasks/src/manager.rs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -35,11 +35,11 @@ use crate::{CpuBlockingTaskPool, TaskSpawner};
3535
/// [blog post]: https://ryhl.io/blog/async-what-is-blocking/
3636
#[derive(Debug, Clone)]
3737
pub struct TaskManager {
38-
inner: Arc<Inner>,
38+
inner: Arc<TaskManagerInner>,
3939
}
4040

4141
#[derive(Debug)]
42-
pub(crate) struct Inner {
42+
pub(crate) struct TaskManagerInner {
4343
/// A handle to the Tokio runtime.
4444
pub(crate) handle: Handle,
4545
/// Keep track of currently running tasks.
@@ -62,7 +62,7 @@ impl TaskManager {
6262
.expect("failed to build blocking task thread pool");
6363

6464
Self {
65-
inner: Arc::new(Inner {
65+
inner: Arc::new(TaskManagerInner {
6666
handle,
6767
blocking_pool,
6868
tracker: TaskTracker::new(),

crates/tasks/src/scope.rs

Lines changed: 167 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,167 @@
1+
use core::future::Future;
2+
use core::marker::PhantomData;
3+
use core::pin::Pin;
4+
use std::panic::{self, AssertUnwindSafe};
5+
use std::sync::Arc;
6+
use std::task::Poll;
7+
8+
use futures::stream::FuturesUnordered;
9+
use futures::StreamExt;
10+
use tokio::sync::mpsc::{UnboundedReceiver, UnboundedSender};
11+
use tokio::sync::oneshot;
12+
use tokio::task::block_in_place;
13+
14+
use crate::{Cancellable, JoinError, TaskManagerInner};
15+
16+
type BoxScopedFuture<'scope> = Pin<Box<dyn Future<Output = ()> + Send + 'scope>>;
17+
type BoxScopeFuture<'scope, T> = Pin<Box<dyn Future<Output = T> + Send + 'scope>>;
18+
19+
/// Spawner used inside [`TaskSpawner::scope`] to spawn scoped tasks.
20+
#[derive(Debug, Clone)]
21+
pub struct ScopedTaskSpawner<'scope> {
22+
pub(crate) inner: Arc<TaskManagerInner>,
23+
pub(crate) sender: UnboundedSender<BoxScopedFuture<'scope>>,
24+
pub(crate) _marker: PhantomData<&'scope ()>,
25+
}
26+
27+
impl<'scope> ScopedTaskSpawner<'scope> {
28+
/// Spawn an async task that can borrow data with lifetime `'scope`.
29+
pub fn spawn<F>(&self, fut: F) -> ScopedJoinHandle<'scope, F::Output>
30+
where
31+
F: Future + Send + 'scope,
32+
F::Output: Send + 'scope,
33+
{
34+
let (tx, rx) = oneshot::channel();
35+
let mut receiver = rx;
36+
let cancellation_token = self.inner.on_cancel.clone();
37+
38+
let task = Box::pin(async move {
39+
let result = Cancellable::new(cancellation_token, fut).await;
40+
let _ = tx.send(result);
41+
});
42+
43+
if self.sender.send(task).is_err() {
44+
receiver.close();
45+
}
46+
47+
ScopedJoinHandle { receiver, _marker: PhantomData }
48+
}
49+
50+
/// Spawn a blocking task that can borrow data with lifetime `'scope`.
51+
pub fn spawn_blocking<F, R>(&self, func: F) -> ScopedJoinHandle<'scope, R>
52+
where
53+
F: FnOnce() -> R + Send + 'scope,
54+
R: Send + 'scope,
55+
{
56+
let (tx, rx) = oneshot::channel();
57+
let mut receiver = rx;
58+
let cancellation_token = self.inner.on_cancel.clone();
59+
60+
let task = Box::pin(async move {
61+
let result = block_in_place(|| {
62+
if cancellation_token.is_cancelled() {
63+
return Err(JoinError::Cancelled);
64+
}
65+
66+
panic::catch_unwind(AssertUnwindSafe(func)).map_err(JoinError::Panic)
67+
});
68+
69+
let _ = tx.send(result);
70+
});
71+
72+
if self.sender.send(task).is_err() {
73+
receiver.close();
74+
}
75+
76+
ScopedJoinHandle { receiver, _marker: PhantomData }
77+
}
78+
}
79+
80+
pub(crate) struct TaskScope<'scope, R>
81+
where
82+
R: Send + 'scope,
83+
{
84+
tasks: FuturesUnordered<BoxScopedFuture<'scope>>,
85+
receiver: UnboundedReceiver<BoxScopedFuture<'scope>>,
86+
user_future: BoxScopeFuture<'scope, R>,
87+
spawner: Option<ScopedTaskSpawner<'scope>>,
88+
receiver_closed: bool,
89+
}
90+
91+
impl<'scope, R> TaskScope<'scope, R>
92+
where
93+
R: Send + 'scope,
94+
{
95+
pub(crate) fn new(
96+
spawner: ScopedTaskSpawner<'scope>,
97+
receiver: UnboundedReceiver<BoxScopedFuture<'scope>>,
98+
user_future: BoxScopeFuture<'scope, R>,
99+
) -> Self {
100+
Self {
101+
tasks: FuturesUnordered::new(),
102+
receiver,
103+
user_future,
104+
spawner: Some(spawner),
105+
receiver_closed: false,
106+
}
107+
}
108+
109+
pub(crate) async fn run(mut self) -> R {
110+
let mut user_output = None;
111+
let mut user_done = false;
112+
113+
loop {
114+
tokio::select! {
115+
result = self.receiver.recv(), if !self.receiver_closed => {
116+
match result {
117+
Some(task) => self.tasks.push(task),
118+
None => self.receiver_closed = true,
119+
}
120+
}
121+
122+
Some(_) = self.tasks.next(), if !self.tasks.is_empty() => {}
123+
124+
output = self.user_future.as_mut(), if !user_done => {
125+
user_output = Some(output);
126+
user_done = true;
127+
// Drop the spawner to close the sender side of the channel.
128+
self.spawner.take();
129+
}
130+
131+
else => {
132+
if user_output.is_some() && self.receiver_closed && self.tasks.is_empty() {
133+
break;
134+
}
135+
}
136+
}
137+
}
138+
139+
user_output.expect("user future completed")
140+
}
141+
}
142+
143+
/// Join handle for scoped tasks spawned via
144+
/// [`TaskSpawner::scope`](crate::spawner::TaskSpawner::scope).
145+
pub struct ScopedJoinHandle<'scope, T> {
146+
receiver: oneshot::Receiver<crate::Result<T>>,
147+
_marker: PhantomData<&'scope ()>,
148+
}
149+
150+
impl<T> core::fmt::Debug for ScopedJoinHandle<'_, T> {
151+
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
152+
f.debug_struct("ScopedJoinHandle").finish()
153+
}
154+
}
155+
156+
impl<T> Future for ScopedJoinHandle<'_, T> {
157+
type Output = crate::Result<T>;
158+
159+
fn poll(self: Pin<&mut Self>, cx: &mut core::task::Context<'_>) -> Poll<Self::Output> {
160+
let this = self.get_mut();
161+
match Pin::new(&mut this.receiver).poll(cx) {
162+
Poll::Pending => Poll::Pending,
163+
Poll::Ready(Ok(value)) => Poll::Ready(value),
164+
Poll::Ready(Err(..)) => Poll::Ready(Err(JoinError::Cancelled)),
165+
}
166+
}
167+
}

0 commit comments

Comments
 (0)