Skip to content

Commit c8b42c2

Browse files
committed
scoped blockign
1 parent cab4de7 commit c8b42c2

File tree

4 files changed

+175
-20
lines changed

4 files changed

+175
-20
lines changed

crates/sync/stage/src/blocks/mod.rs

Lines changed: 30 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ use katana_primitives::transaction::{Tx, TxWithHash};
1313
use katana_primitives::Felt;
1414
use katana_provider::api::block::{BlockHashProvider, BlockWriter};
1515
use katana_provider::{MutableProvider, ProviderError, ProviderFactory};
16+
use katana_tasks::TaskSpawner;
1617
use num_traits::ToPrimitive;
1718
use starknet::core::types::ResourcePrice;
1819
use tracing::{error, info_span, Instrument};
@@ -28,19 +29,20 @@ pub use downloader::{BatchBlockDownloader, BlockDownloader};
2829
pub struct Blocks<PF, B> {
2930
provider: PF,
3031
downloader: B,
32+
task_spawner: TaskSpawner,
3133
}
3234

3335
impl<PF, B> Blocks<PF, B> {
3436
/// Create a new [`Blocks`] stage.
35-
pub fn new(provider: PF, downloader: B) -> Self {
36-
Self { provider, downloader }
37+
pub fn new(provider: PF, downloader: B, task_spawner: TaskSpawner) -> Self {
38+
Self { provider, downloader, task_spawner }
3739
}
3840

3941
/// Validates that the downloaded blocks form a valid chain.
4042
///
4143
/// This method checks the chain invariant: block N's parent hash must be block N-1's hash.
4244
/// For the first block in the list (if not block 0), it fetches the parent hash from storage.
43-
fn validate_chain_invariant(&self, blocks: &[StateUpdateWithBlock]) -> Result<(), Error>
45+
async fn validate_chain_invariant(&self, blocks: &[StateUpdateWithBlock]) -> Result<(), Error>
4446
where
4547
PF: ProviderFactory,
4648
<PF as ProviderFactory>::Provider: BlockHashProvider,
@@ -71,7 +73,27 @@ impl<PF, B> Blocks<PF, B> {
7173
}
7274
}
7375

74-
// Validate the rest of the blocks in the list
76+
// self.task_spawner.cpu_bound().spawn(|| {
77+
// // Validate the rest of the blocks in the list
78+
// for window in blocks.windows(2) {
79+
// let prev_block = &window[0].block;
80+
// let curr_block = &window[1].block;
81+
82+
// let prev_hash = prev_block.block_hash.unwrap_or_default();
83+
// let curr_block_num = curr_block.block_number.unwrap_or_default();
84+
85+
// if curr_block.parent_block_hash != prev_hash {
86+
// return Err(Error::ChainInvariantViolation {
87+
// block_num: curr_block_num,
88+
// parent_hash: curr_block.parent_block_hash,
89+
// expected_hash: prev_hash,
90+
// });
91+
// }
92+
// }
93+
94+
// Ok(())
95+
// });
96+
7597
for window in blocks.windows(2) {
7698
let prev_block = &window[0].block;
7799
let curr_block = &window[1].block;
@@ -112,7 +134,10 @@ where
112134
.await
113135
.map_err(Error::Gateway)?;
114136

115-
self.validate_chain_invariant(&blocks)?;
137+
self.task_spawner
138+
.scope(|s| s.spawn(self.validate_chain_invariant(&blocks)))
139+
.await
140+
.unwrap()?;
116141

117142
let span = info_span!(target: "stage", "blocks.insert", from = %input.from(), to = %input.to());
118143
let _enter = span.enter();

crates/tasks/src/blocking.rs

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,14 @@ impl CpuBlockingTaskPool {
8686
});
8787
CpuBlockingJoinHandle { inner: rx }
8888
}
89+
90+
pub(crate) fn scope<'scope, F, R>(&self, func: F) -> R
91+
where
92+
F: for<'s> FnOnce(&'s rayon::Scope<'scope>) -> R + Send + 'scope,
93+
R: Send + 'scope,
94+
{
95+
self.pool.scope(func)
96+
}
8997
}
9098

9199
#[derive(Debug)]

crates/tasks/src/scope.rs

Lines changed: 84 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ use core::future::Future;
22
use core::marker::PhantomData;
33
use core::pin::Pin;
44
use std::panic::{self, AssertUnwindSafe};
5-
use std::sync::Arc;
5+
use std::sync::{mpsc, Arc};
66
use std::task::Poll;
77

88
use futures::stream::FuturesUnordered;
@@ -50,7 +50,29 @@ impl<'scope> ScopedTaskSpawner<'scope> {
5050
ScopedJoinHandle { receiver, _marker: PhantomData }
5151
}
5252

53+
/// Returns a scoped spawner for CPU-bound work.
54+
pub fn cpu_bound(&self) -> ScopedCpuTaskSpawner<'scope> {
55+
ScopedCpuTaskSpawner {
56+
inner: self.inner.clone(),
57+
sender: self.sender.clone(),
58+
_marker: PhantomData,
59+
}
60+
}
61+
5362
/// Spawn a blocking task that can borrow data with lifetime `'scope`.
63+
///
64+
/// # Runtime requirement
65+
///
66+
/// This uses [`tokio::task::block_in_place`] under the hood, which only works on
67+
/// Tokio's multi-threaded runtime. On a current-thread runtime this will panic. If
68+
/// you need to support a current-thread runtime, prefer:
69+
///
70+
/// 1. Moving/cloning the data so the closure can be `'static` and using the non-scoped
71+
/// `TaskSpawner::spawn_blocking`, or
72+
/// 2. Running the work inline (accepting that it will block the current thread).
73+
///
74+
/// The scoped variant exists to let you borrow non-`'static` data; that capability
75+
/// currently depends on `block_in_place`.
5476
pub fn spawn_blocking<F, R>(&self, func: F) -> ScopedJoinHandle<'scope, R>
5577
where
5678
F: FnOnce() -> R + Send + 'scope,
@@ -83,6 +105,67 @@ impl<'scope> ScopedTaskSpawner<'scope> {
83105
}
84106
}
85107

108+
/// Spawner for CPU-bound scoped tasks executed on the dedicated blocking pool.
109+
#[derive(Debug, Clone)]
110+
pub struct ScopedCpuTaskSpawner<'scope> {
111+
pub(crate) inner: Arc<TaskManagerInner>,
112+
pub(crate) sender: UnboundedSender<BoxScopedFuture<'scope>>,
113+
pub(crate) _marker: PhantomData<&'scope ()>,
114+
}
115+
116+
impl<'scope> ScopedCpuTaskSpawner<'scope> {
117+
/// Spawn a CPU-bound task that can borrow data with lifetime `'scope`.
118+
///
119+
/// # Runtime requirement
120+
///
121+
/// Uses [`tokio::task::block_in_place`] to enter the blocking pool. This only works on
122+
/// Tokio's multithreaded runtime. On a current-thread runtime this will panic. If you
123+
/// must support a current-thread runtime, either move/clone data into a `'static` closure
124+
/// and use the non-scoped `CPUBoundTaskSpawner`, or run the work inline knowing it will
125+
/// block the thread.
126+
pub fn spawn<F, R>(&self, func: F) -> ScopedJoinHandle<'scope, R>
127+
where
128+
F: FnOnce() -> R + Send + 'scope,
129+
R: Send + 'scope,
130+
{
131+
let (tx, rx) = oneshot::channel();
132+
let mut receiver = rx;
133+
let inner = self.inner.clone();
134+
135+
let task = async move {
136+
let result = block_in_place(|| {
137+
if inner.on_cancel.is_cancelled() {
138+
return Err(JoinError::Cancelled);
139+
}
140+
141+
let (res_tx, res_rx) = mpsc::channel();
142+
inner.blocking_pool.scope(|scope| {
143+
scope.spawn(|_| {
144+
let _ = res_tx.send(panic::catch_unwind(AssertUnwindSafe(func)));
145+
});
146+
});
147+
148+
match res_rx.recv() {
149+
Ok(Ok(value)) => Ok(value),
150+
Ok(Err(panic)) => Err(JoinError::Panic(panic)),
151+
Err(..) => Err(JoinError::Cancelled),
152+
}
153+
});
154+
155+
let _ = tx.send(result);
156+
};
157+
158+
let task = self.inner.tracker.track_future(task);
159+
let task = Box::pin(task);
160+
161+
if self.sender.send(task).is_err() {
162+
receiver.close();
163+
}
164+
165+
ScopedJoinHandle { receiver, _marker: PhantomData }
166+
}
167+
}
168+
86169
pub(crate) struct TaskScope<'scope, R>
87170
where
88171
R: Send + 'scope,

crates/tasks/src/spawner.rs

Lines changed: 53 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -58,21 +58,26 @@ impl TaskSpawner {
5858
/// Runs a scoped block in which tasks may borrow non-`'static` data but are guaranteed to be
5959
/// completed before this method returns.
6060
///
61-
/// # Examples
61+
/// Unlike [`TaskSpawner::spawn`], scoped tasks don't need to be `'static`. Without a scope, the
62+
/// compiler prevents you from capturing stack data because the task could outlive the
63+
/// borrow:
6264
///
63-
/// What happens if you try to capture non-`'static` data without a scope?
6465
/// ```compile_fail
6566
/// # use katana_tasks::TaskManager;
6667
/// # #[tokio::main] async fn main() {
6768
/// let manager = TaskManager::current();
6869
/// let spawner = manager.task_spawner();
69-
/// let mut prefix = String::from("hi ");
70+
/// let prefix = String::from("Hello world!");
7071
///
7172
/// // ERROR: `spawn` requires futures to be `'static`.
72-
/// spawner.spawn(async { prefix.push_str("there") });
73+
/// spawner.spawn(async { println!("{prefix}") });
7374
/// # }
7475
/// ```
7576
///
77+
/// A scope keeps everything within the lifetime of the call, so you can borrow safely.
78+
///
79+
/// # Examples
80+
///
7681
/// Borrow stack data in an async task:
7782
/// ```
7883
/// # use katana_tasks::TaskManager;
@@ -119,7 +124,7 @@ impl TaskSpawner {
119124
/// # }
120125
/// ```
121126
///
122-
/// Use scoped blocking work with borrowed data:
127+
/// Use scoped blocking work with borrowed data (requires a multithreaded Tokio runtime):
123128
/// ```
124129
/// # use katana_tasks::TaskManager;
125130
/// # use std::sync::{Arc, Mutex};
@@ -422,36 +427,70 @@ mod tests {
422427
let spawner = manager.task_spawner();
423428

424429
let text = String::from("scoped");
425-
let expected_len = text.len();
426430

427-
spawner
431+
let returned_ref = spawner
428432
.scope(|scope| {
429433
let text_ref: &String = &text;
434+
scope.spawn(async move { text_ref })
435+
})
436+
.await
437+
.unwrap();
438+
439+
// original value is still valid after scope returns
440+
assert_eq!(text, *returned_ref);
441+
}
442+
443+
#[tokio::test(flavor = "multi_thread")]
444+
async fn scoped_spawn_blocking_allows_borrowed_data() {
445+
let manager = TaskManager::current();
446+
let spawner = manager.task_spawner();
447+
448+
let counter = AtomicUsize::new(0);
449+
450+
spawner
451+
.scope(|scope| {
452+
let counter_ref = &counter;
430453

431454
async move {
432-
let handle = scope.spawn(async move { text_ref.len() });
433-
assert_eq!(handle.await.unwrap(), expected_len);
455+
let handle = scope.spawn_blocking(move || {
456+
counter_ref.fetch_add(1, Ordering::SeqCst);
457+
counter_ref.load(Ordering::SeqCst)
458+
});
459+
460+
assert_eq!(handle.await.unwrap(), 1);
434461
}
435462
})
436463
.await;
437464

438-
// original value is still valid after scope returns
439-
assert_eq!(text.len(), expected_len);
465+
assert_eq!(counter.load(Ordering::SeqCst), 1);
466+
467+
let mut prefix = String::from("hello ");
468+
spawner
469+
.scope(|scope| {
470+
scope.spawn(async {
471+
prefix.push_str("world!");
472+
})
473+
})
474+
.await
475+
.unwrap();
476+
477+
assert_eq!(prefix.as_str(), "hello world!")
440478
}
441479

442-
#[tokio::test]
443-
async fn scoped_spawn_blocking_allows_borrowed_data() {
480+
#[tokio::test(flavor = "multi_thread")]
481+
async fn scoped_cpu_bound_allows_borrowed_data() {
444482
let manager = TaskManager::current();
445483
let spawner = manager.task_spawner();
446484

447485
let counter = AtomicUsize::new(0);
448486

449487
spawner
450488
.scope(|scope| {
489+
let cpu = scope.cpu_bound();
451490
let counter_ref = &counter;
452491

453492
async move {
454-
let handle = scope.spawn_blocking(move || {
493+
let handle = cpu.spawn(move || {
455494
counter_ref.fetch_add(1, Ordering::SeqCst);
456495
counter_ref.load(Ordering::SeqCst)
457496
});

0 commit comments

Comments
 (0)