Skip to content

Commit 8f01e6f

Browse files
committed
fix tests
1 parent fd7fd7b commit 8f01e6f

File tree

2 files changed

+150
-64
lines changed

2 files changed

+150
-64
lines changed

crates/sync/stage/tests/block.rs

Lines changed: 95 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ use katana_primitives::state::StateUpdatesWithClasses;
1717
use katana_primitives::{felt, ContractAddress, Felt};
1818
use katana_provider::api::block::{BlockHashProvider, BlockNumberProvider, BlockWriter};
1919
use katana_provider::test_utils::test_provider;
20-
use katana_provider::{ProviderError, ProviderResult};
20+
use katana_provider::{ProviderError, ProviderFactory, ProviderResult};
2121
use katana_stage::blocks::{BatchBlockDownloader, BlockDownloader, Blocks};
2222
use katana_stage::{Stage, StageExecutionInput};
2323
use rstest::rstest;
@@ -119,11 +119,11 @@ impl BlockDownloader for MockBlockDownloader {
119119
}
120120
}
121121

122-
/// Mock BlockWriter implementation for testing.
122+
/// Mock provider implementation for testing.
123123
///
124124
/// Tracks all insert operations and can be configured to return errors.
125-
#[derive(Clone)]
126-
struct MockProvider {
125+
#[derive(Clone, Debug)]
126+
struct MockInnerProvider {
127127
/// Stored blocks with their receipts and state updates.
128128
blocks: Arc<Mutex<Vec<(SealedBlockWithStatus, StateUpdatesWithClasses, Vec<Receipt>)>>>,
129129
/// Whether to return an error on insert.
@@ -132,40 +132,17 @@ struct MockProvider {
132132
error_message: Arc<Mutex<String>>,
133133
}
134134

135-
impl MockProvider {
136-
fn new() -> Self {
137-
Self {
138-
blocks: Arc::new(Mutex::new(Vec::new())),
139-
should_fail: Arc::new(Mutex::new(false)),
140-
error_message: Arc::new(Mutex::new(String::new())),
141-
}
142-
}
143-
144-
/// Add a block directly to the provider's storage.
145-
fn with_block(self, block: SealedBlockWithStatus) -> Self {
146-
self.blocks.lock().unwrap().push((block, Default::default(), Default::default()));
147-
self
148-
}
149-
150-
/// Configure the mock to fail on insert operations.
151-
fn with_insert_error(self, error: String) -> Self {
152-
*self.should_fail.lock().unwrap() = true;
153-
*self.error_message.lock().unwrap() = error;
154-
self
155-
}
156-
157-
/// Get the number of blocks stored.
158-
fn stored_block_count(&self) -> usize {
159-
self.blocks.lock().unwrap().len()
160-
}
161-
162-
/// Get all stored block numbers.
163-
fn stored_block_numbers(&self) -> Vec<BlockNumber> {
164-
self.blocks.lock().unwrap().iter().map(|(block, _, _)| block.block.header.number).collect()
135+
impl MockInnerProvider {
136+
fn new(
137+
blocks: Arc<Mutex<Vec<(SealedBlockWithStatus, StateUpdatesWithClasses, Vec<Receipt>)>>>,
138+
should_fail: Arc<Mutex<bool>>,
139+
error_message: Arc<Mutex<String>>,
140+
) -> Self {
141+
Self { blocks, should_fail, error_message }
165142
}
166143
}
167144

168-
impl BlockWriter for MockProvider {
145+
impl BlockWriter for MockInnerProvider {
169146
fn insert_block_with_states_and_receipts(
170147
&self,
171148
block: SealedBlockWithStatus,
@@ -184,7 +161,7 @@ impl BlockWriter for MockProvider {
184161
}
185162
}
186163

187-
impl BlockHashProvider for MockProvider {
164+
impl BlockHashProvider for MockInnerProvider {
188165
fn latest_hash(&self) -> ProviderResult<BlockHash> {
189166
self.blocks
190167
.lock()
@@ -205,6 +182,85 @@ impl BlockHashProvider for MockProvider {
205182
}
206183
}
207184

185+
impl katana_provider::MutableProvider for MockInnerProvider {
186+
fn commit(self) -> ProviderResult<()> {
187+
Ok(())
188+
}
189+
}
190+
191+
/// Mock ProviderFactory implementation for testing.
192+
///
193+
/// Tracks all insert operations and can be configured to return errors.
194+
#[derive(Clone)]
195+
struct MockProvider {
196+
/// Stored blocks with their receipts and state updates.
197+
blocks: Arc<Mutex<Vec<(SealedBlockWithStatus, StateUpdatesWithClasses, Vec<Receipt>)>>>,
198+
/// Whether to return an error on insert.
199+
should_fail: Arc<Mutex<bool>>,
200+
/// Error message to return when should_fail is true.
201+
error_message: Arc<Mutex<String>>,
202+
}
203+
204+
impl std::fmt::Debug for MockProvider {
205+
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
206+
f.debug_struct("MockProvider").finish_non_exhaustive()
207+
}
208+
}
209+
210+
impl MockProvider {
211+
fn new() -> Self {
212+
Self {
213+
blocks: Arc::new(Mutex::new(Vec::new())),
214+
should_fail: Arc::new(Mutex::new(false)),
215+
error_message: Arc::new(Mutex::new(String::new())),
216+
}
217+
}
218+
219+
/// Add a block directly to the provider's storage.
220+
fn with_block(self, block: SealedBlockWithStatus) -> Self {
221+
self.blocks.lock().unwrap().push((block, Default::default(), Default::default()));
222+
self
223+
}
224+
225+
/// Configure the mock to fail on insert operations.
226+
fn with_insert_error(self, error: String) -> Self {
227+
*self.should_fail.lock().unwrap() = true;
228+
*self.error_message.lock().unwrap() = error;
229+
self
230+
}
231+
232+
/// Get the number of blocks stored.
233+
fn stored_block_count(&self) -> usize {
234+
self.blocks.lock().unwrap().len()
235+
}
236+
237+
/// Get all stored block numbers.
238+
fn stored_block_numbers(&self) -> Vec<BlockNumber> {
239+
self.blocks.lock().unwrap().iter().map(|(block, _, _)| block.block.header.number).collect()
240+
}
241+
}
242+
243+
impl katana_provider::ProviderFactory for MockProvider {
244+
type Provider = MockInnerProvider;
245+
type ProviderMut = MockInnerProvider;
246+
247+
fn provider(&self) -> Self::Provider {
248+
MockInnerProvider::new(
249+
Arc::clone(&self.blocks),
250+
Arc::clone(&self.should_fail),
251+
Arc::clone(&self.error_message),
252+
)
253+
}
254+
255+
fn provider_mut(&self) -> Self::ProviderMut {
256+
MockInnerProvider::new(
257+
Arc::clone(&self.blocks),
258+
Arc::clone(&self.should_fail),
259+
Arc::clone(&self.error_message),
260+
)
261+
}
262+
}
263+
208264
/// Helper function to create a minimal test `SealedBlockWithStatus`.
209265
///
210266
/// Creates a block with the given number and automatically sets the parent hash
@@ -403,13 +459,14 @@ async fn fetch_blocks_from_gateway() {
403459
let feeder_gateway = SequencerGateway::sepolia();
404460
let downloader = BatchBlockDownloader::new_gateway(feeder_gateway, 10);
405461

406-
let mut stage = Blocks::new(&provider, downloader);
462+
let mut stage = Blocks::new(provider.clone(), downloader);
407463

408464
let input = StageExecutionInput::new(from_block, to_block);
409465
stage.execute(&input).await.expect("failed to execute stage");
410466

411467
// check provider storage
412-
let block_number = provider.latest_number().expect("failed to get latest block number");
468+
let block_number =
469+
provider.provider().latest_number().expect("failed to get latest block number");
413470
assert_eq!(block_number, to_block);
414471
}
415472

crates/sync/stage/tests/trie.rs

Lines changed: 55 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -9,63 +9,86 @@ use katana_primitives::Felt;
99
use katana_provider::api::block::HeaderProvider;
1010
use katana_provider::api::state_update::StateUpdateProvider;
1111
use katana_provider::api::trie::TrieWriter;
12-
use katana_provider::ProviderResult;
12+
use katana_provider::{ProviderFactory, ProviderResult};
1313
use katana_stage::trie::StateTrie;
1414
use katana_stage::{Stage, StageExecutionInput};
1515
use rstest::rstest;
1616
use starknet::macros::short_string;
1717
use starknet_types_core::hash::{Poseidon, StarkHash};
1818

19-
/// Mock provider implementation for testing StateTrie stage.
19+
/// Mock ProviderFactory implementation for testing StateTrie stage.
2020
///
2121
/// Provides configurable responses for headers, state updates, and trie operations.
2222
#[derive(Clone)]
2323
struct MockProvider {
24-
/// Map of block number to header.
25-
headers: Arc<Mutex<HashMap<BlockNumber, Header>>>,
26-
/// Map of block number to state update.
27-
state_updates: Arc<Mutex<HashMap<BlockNumber, StateUpdates>>>,
28-
/// Track trie insert calls for verification.
29-
trie_insert_calls: Arc<Mutex<Vec<BlockNumber>>>,
30-
/// Whether to return an error on trie operations.
31-
should_fail: Arc<Mutex<bool>>,
24+
inner: MockInnerProvider,
25+
}
26+
27+
impl std::fmt::Debug for MockProvider {
28+
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
29+
f.debug_struct("MockProvider").finish_non_exhaustive()
30+
}
3231
}
3332

3433
impl MockProvider {
3534
fn new() -> Self {
3635
Self {
37-
headers: Arc::new(Mutex::new(HashMap::new())),
38-
state_updates: Arc::new(Mutex::new(HashMap::new())),
39-
trie_insert_calls: Arc::new(Mutex::new(Vec::new())),
40-
should_fail: Arc::new(Mutex::new(false)),
36+
inner: MockInnerProvider {
37+
headers: Arc::new(Mutex::new(HashMap::new())),
38+
state_updates: Arc::new(Mutex::new(HashMap::new())),
39+
trie_insert_calls: Arc::new(Mutex::new(Vec::new())),
40+
should_fail: Arc::new(Mutex::new(false)),
41+
},
4142
}
4243
}
4344

4445
/// Configure a header for a specific block.
4546
fn with_header(self, block_number: BlockNumber, header: Header) -> Self {
46-
self.headers.lock().unwrap().insert(block_number, header);
47+
self.inner.headers.lock().unwrap().insert(block_number, header);
4748
self
4849
}
4950

5051
/// Configure a state update for a specific block.
5152
fn with_state_update(self, block_number: BlockNumber, state_update: StateUpdates) -> Self {
52-
self.state_updates.lock().unwrap().insert(block_number, state_update);
53-
self
54-
}
55-
56-
/// Configure the mock to fail on trie operations.
57-
fn with_trie_error(self) -> Self {
58-
*self.should_fail.lock().unwrap() = true;
53+
self.inner.state_updates.lock().unwrap().insert(block_number, state_update);
5954
self
6055
}
6156

6257
/// Get all block numbers that had trie inserts called.
6358
fn trie_insert_blocks(&self) -> Vec<BlockNumber> {
64-
self.trie_insert_calls.lock().unwrap().clone()
59+
self.inner.trie_insert_calls.lock().unwrap().clone()
6560
}
6661
}
6762

68-
impl HeaderProvider for MockProvider {
63+
impl ProviderFactory for MockProvider {
64+
type Provider = MockInnerProvider;
65+
type ProviderMut = MockInnerProvider;
66+
67+
fn provider(&self) -> Self::Provider {
68+
self.inner.clone()
69+
}
70+
71+
fn provider_mut(&self) -> Self::ProviderMut {
72+
self.inner.clone()
73+
}
74+
}
75+
76+
/// Mock inner provider implementation for testing StateTrie stage.
77+
///
78+
/// Provides configurable responses for headers, state updates, and trie operations.
79+
#[derive(Clone, Debug)]
80+
struct MockInnerProvider {
81+
/// Map of block number to header.
82+
headers: Arc<Mutex<HashMap<BlockNumber, Header>>>,
83+
/// Map of block number to state update.
84+
state_updates: Arc<Mutex<HashMap<BlockNumber, StateUpdates>>>,
85+
/// Track trie insert calls for verification.
86+
trie_insert_calls: Arc<Mutex<Vec<BlockNumber>>>,
87+
/// Whether to return an error on trie operations.
88+
should_fail: Arc<Mutex<bool>>,
89+
}
90+
91+
impl HeaderProvider for MockInnerProvider {
6992
fn header(&self, id: BlockHashOrNumber) -> ProviderResult<Option<Header>> {
7093
let block_number = match id {
7194
BlockHashOrNumber::Num(num) => num,
@@ -80,7 +103,7 @@ impl HeaderProvider for MockProvider {
80103
}
81104
}
82105

83-
impl StateUpdateProvider for MockProvider {
106+
impl StateUpdateProvider for MockInnerProvider {
84107
fn state_update(&self, block_id: BlockHashOrNumber) -> ProviderResult<Option<StateUpdates>> {
85108
let block_number = match block_id {
86109
BlockHashOrNumber::Num(num) => num,
@@ -109,7 +132,7 @@ impl StateUpdateProvider for MockProvider {
109132
}
110133
}
111134

112-
impl TrieWriter for MockProvider {
135+
impl TrieWriter for MockInnerProvider {
113136
fn trie_insert_declared_classes(
114137
&self,
115138
block_number: BlockNumber,
@@ -139,6 +162,12 @@ impl TrieWriter for MockProvider {
139162
}
140163
}
141164

165+
impl katana_provider::MutableProvider for MockInnerProvider {
166+
fn commit(self) -> ProviderResult<()> {
167+
Ok(())
168+
}
169+
}
170+
142171
/// Helper function to compute the expected state root from mock trie roots.
143172
fn compute_mock_state_root() -> Felt {
144173
let class_trie_root = Felt::from(0x1234u64);

0 commit comments

Comments
 (0)