Skip to content

Commit 6d6073f

Browse files
authored
refactor(stage): perform trie computation as separate task (#373)
Spawn a separate task on the cpu-bound blocking task for performing the actual state trie computation to avoid blocking the async executor.
1 parent f3c9ab5 commit 6d6073f

File tree

3 files changed

+47
-21
lines changed

3 files changed

+47
-21
lines changed

crates/node/src/full/mod.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -127,7 +127,7 @@ impl Node {
127127
let block_downloader = BatchBlockDownloader::new_gateway(gateway_client.clone(), 20);
128128
pipeline.add_stage(Blocks::new(storage_provider.clone(), block_downloader));
129129
pipeline.add_stage(Classes::new(storage_provider.clone(), gateway_client.clone(), 20));
130-
pipeline.add_stage(StateTrie::new(storage_provider.clone()));
130+
pipeline.add_stage(StateTrie::new(storage_provider.clone(), task_spawner.clone()));
131131

132132
// -- build chain tip watcher using gateway client
133133

crates/sync/stage/src/trie.rs

Lines changed: 39 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ use katana_provider::api::block::HeaderProvider;
55
use katana_provider::api::state_update::StateUpdateProvider;
66
use katana_provider::api::trie::TrieWriter;
77
use katana_provider::{MutableProvider, ProviderFactory};
8+
use katana_tasks::TaskSpawner;
89
use starknet::macros::short_string;
910
use starknet_types_core::hash::{Poseidon, StarkHash};
1011
use tracing::{debug, debug_span, error};
@@ -22,19 +23,20 @@ use crate::{Stage, StageExecutionInput, StageExecutionOutput, StageResult};
2223
#[derive(Debug)]
2324
pub struct StateTrie<P> {
2425
storage_provider: P,
26+
task_spawner: TaskSpawner,
2527
}
2628

2729
impl<P> StateTrie<P> {
2830
/// Create a new [`StateTrie`] stage.
29-
pub fn new(storage_provider: P) -> Self {
30-
Self { storage_provider }
31+
pub fn new(storage_provider: P, task_spawner: TaskSpawner) -> Self {
32+
Self { storage_provider, task_spawner }
3133
}
3234
}
3335

3436
impl<P> Stage for StateTrie<P>
3537
where
3638
P: ProviderFactory,
37-
<P as ProviderFactory>::ProviderMut: StateUpdateProvider + HeaderProvider + TrieWriter,
39+
<P as ProviderFactory>::ProviderMut: StateUpdateProvider + HeaderProvider + TrieWriter + Clone,
3840
{
3941
fn id(&self) -> &'static str {
4042
"StateTrie"
@@ -58,21 +60,37 @@ where
5860
.state_update(block_number.into())?
5961
.ok_or(Error::MissingStateUpdate(block_number))?;
6062

61-
let computed_contract_trie_root =
62-
provider_mut.trie_insert_contract_updates(block_number, &state_update)?;
63-
64-
debug!(
65-
contract_trie_root = format!("{computed_contract_trie_root:#x}"),
66-
"Computed contract trie root."
67-
);
68-
69-
let computed_class_trie_root = provider_mut
70-
.trie_insert_declared_classes(block_number, &state_update.declared_classes)?;
71-
72-
debug!(
73-
classes_tri_root = format!("{computed_class_trie_root:#x}"),
74-
"Computed classes trie root."
75-
);
63+
let provider_mut_clone = provider_mut.clone();
64+
let (computed_contract_trie_root, computed_class_trie_root) = self
65+
.task_spawner
66+
.cpu_bound()
67+
.spawn(move || {
68+
let computed_contract_trie_root = provider_mut_clone
69+
.trie_insert_contract_updates(block_number, &state_update)?;
70+
71+
debug!(
72+
contract_trie_root = format!("{computed_contract_trie_root:#x}"),
73+
"Computed contract trie root."
74+
);
75+
76+
let computed_class_trie_root = provider_mut_clone
77+
.trie_insert_declared_classes(
78+
block_number,
79+
&state_update.declared_classes,
80+
)?;
81+
82+
debug!(
83+
classes_tri_root = format!("{computed_class_trie_root:#x}"),
84+
"Computed classes trie root."
85+
);
86+
87+
Result::<(Felt, Felt), crate::Error>::Ok((
88+
computed_contract_trie_root,
89+
computed_class_trie_root,
90+
))
91+
})
92+
.await
93+
.map_err(Error::StateComputationTaskJoinError)??;
7694

7795
let computed_state_root = if computed_class_trie_root == Felt::ZERO {
7896
computed_contract_trie_root
@@ -121,6 +139,9 @@ pub enum Error {
121139
#[error("Missing state update for block {0}")]
122140
MissingStateUpdate(BlockNumber),
123141

142+
#[error("State computation task join error: {0}")]
143+
StateComputationTaskJoinError(katana_tasks::JoinError),
144+
124145
#[error(
125146
"State root mismatch at block {block_number}: expected (from header) {expected:#x}, \
126147
computed {computed:#x}"

crates/sync/stage/tests/trie.rs

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ use katana_provider::api::trie::TrieWriter;
1212
use katana_provider::{ProviderFactory, ProviderResult};
1313
use katana_stage::trie::StateTrie;
1414
use katana_stage::{Stage, StageExecutionInput};
15+
use katana_tasks::TaskManager;
1516
use rstest::rstest;
1617
use starknet::macros::short_string;
1718
use starknet_types_core::hash::{Poseidon, StarkHash};
@@ -220,6 +221,7 @@ async fn verify_state_roots_success(
220221
#[case] to_block: BlockNumber,
221222
#[case] expected_blocks: Vec<BlockNumber>,
222223
) {
224+
let task_manager = TaskManager::current();
223225
let mut provider = MockProvider::new();
224226

225227
// Configure blocks with correct state roots
@@ -230,8 +232,10 @@ async fn verify_state_roots_success(
230232
provider = provider.with_header(num, header).with_state_update(num, state_update);
231233
}
232234

235+
let mut stage = StateTrie::new(provider.clone(), task_manager.task_spawner());
236+
233237
let input = StageExecutionInput::new(from_block, to_block);
234-
let result = StateTrie::new(provider.clone()).execute(&input).await;
238+
let result = stage.execute(&input).await;
235239
assert!(result.is_ok(), "Stage execution should succeed");
236240

237241
// Verify that trie inserts were called for each block (twice per block: classes + contracts)
@@ -251,7 +255,8 @@ async fn state_root_mismatch_returns_error() {
251255
.with_header(block_number, header)
252256
.with_state_update(block_number, state_update);
253257

254-
let mut stage = StateTrie::new(provider);
258+
let task_manager = TaskManager::current();
259+
let mut stage = StateTrie::new(provider, task_manager.task_spawner());
255260
let input = StageExecutionInput::new(block_number, block_number);
256261

257262
let result = stage.execute(&input).await;

0 commit comments

Comments
 (0)