@@ -17,7 +17,7 @@ use katana_primitives::state::StateUpdatesWithClasses;
1717use katana_primitives:: { felt, ContractAddress , Felt } ;
1818use katana_provider:: api:: block:: { BlockHashProvider , BlockNumberProvider , BlockWriter } ;
1919use katana_provider:: test_utils:: test_provider;
20- use katana_provider:: { ProviderError , ProviderResult } ;
20+ use katana_provider:: { ProviderError , ProviderFactory , ProviderResult } ;
2121use katana_stage:: blocks:: { BatchBlockDownloader , BlockDownloader , Blocks } ;
2222use katana_stage:: { Stage , StageExecutionInput } ;
2323use rstest:: rstest;
@@ -119,11 +119,11 @@ impl BlockDownloader for MockBlockDownloader {
119119 }
120120}
121121
122- /// Mock BlockWriter implementation for testing.
122+ /// Mock provider implementation for testing.
123123///
124124/// Tracks all insert operations and can be configured to return errors.
125- #[ derive( Clone ) ]
126- struct MockProvider {
125+ #[ derive( Clone , Debug ) ]
126+ struct MockInnerProvider {
127127 /// Stored blocks with their receipts and state updates.
128128 blocks : Arc < Mutex < Vec < ( SealedBlockWithStatus , StateUpdatesWithClasses , Vec < Receipt > ) > > > ,
129129 /// Whether to return an error on insert.
@@ -132,40 +132,17 @@ struct MockProvider {
132132 error_message : Arc < Mutex < String > > ,
133133}
134134
135- impl MockProvider {
136- fn new ( ) -> Self {
137- Self {
138- blocks : Arc :: new ( Mutex :: new ( Vec :: new ( ) ) ) ,
139- should_fail : Arc :: new ( Mutex :: new ( false ) ) ,
140- error_message : Arc :: new ( Mutex :: new ( String :: new ( ) ) ) ,
141- }
142- }
143-
144- /// Add a block directly to the provider's storage.
145- fn with_block ( self , block : SealedBlockWithStatus ) -> Self {
146- self . blocks . lock ( ) . unwrap ( ) . push ( ( block, Default :: default ( ) , Default :: default ( ) ) ) ;
147- self
148- }
149-
150- /// Configure the mock to fail on insert operations.
151- fn with_insert_error ( self , error : String ) -> Self {
152- * self . should_fail . lock ( ) . unwrap ( ) = true ;
153- * self . error_message . lock ( ) . unwrap ( ) = error;
154- self
155- }
156-
157- /// Get the number of blocks stored.
158- fn stored_block_count ( & self ) -> usize {
159- self . blocks . lock ( ) . unwrap ( ) . len ( )
160- }
161-
162- /// Get all stored block numbers.
163- fn stored_block_numbers ( & self ) -> Vec < BlockNumber > {
164- self . blocks . lock ( ) . unwrap ( ) . iter ( ) . map ( |( block, _, _) | block. block . header . number ) . collect ( )
135+ impl MockInnerProvider {
136+ fn new (
137+ blocks : Arc < Mutex < Vec < ( SealedBlockWithStatus , StateUpdatesWithClasses , Vec < Receipt > ) > > > ,
138+ should_fail : Arc < Mutex < bool > > ,
139+ error_message : Arc < Mutex < String > > ,
140+ ) -> Self {
141+ Self { blocks, should_fail, error_message }
165142 }
166143}
167144
168- impl BlockWriter for MockProvider {
145+ impl BlockWriter for MockInnerProvider {
169146 fn insert_block_with_states_and_receipts (
170147 & self ,
171148 block : SealedBlockWithStatus ,
@@ -184,7 +161,7 @@ impl BlockWriter for MockProvider {
184161 }
185162}
186163
187- impl BlockHashProvider for MockProvider {
164+ impl BlockHashProvider for MockInnerProvider {
188165 fn latest_hash ( & self ) -> ProviderResult < BlockHash > {
189166 self . blocks
190167 . lock ( )
@@ -205,6 +182,85 @@ impl BlockHashProvider for MockProvider {
205182 }
206183}
207184
185+ impl katana_provider:: MutableProvider for MockInnerProvider {
186+ fn commit ( self ) -> ProviderResult < ( ) > {
187+ Ok ( ( ) )
188+ }
189+ }
190+
191+ /// Mock ProviderFactory implementation for testing.
192+ ///
193+ /// Tracks all insert operations and can be configured to return errors.
194+ #[ derive( Clone ) ]
195+ struct MockProvider {
196+ /// Stored blocks with their receipts and state updates.
197+ blocks : Arc < Mutex < Vec < ( SealedBlockWithStatus , StateUpdatesWithClasses , Vec < Receipt > ) > > > ,
198+ /// Whether to return an error on insert.
199+ should_fail : Arc < Mutex < bool > > ,
200+ /// Error message to return when should_fail is true.
201+ error_message : Arc < Mutex < String > > ,
202+ }
203+
204+ impl std:: fmt:: Debug for MockProvider {
205+ fn fmt ( & self , f : & mut std:: fmt:: Formatter < ' _ > ) -> std:: fmt:: Result {
206+ f. debug_struct ( "MockProvider" ) . finish_non_exhaustive ( )
207+ }
208+ }
209+
210+ impl MockProvider {
211+ fn new ( ) -> Self {
212+ Self {
213+ blocks : Arc :: new ( Mutex :: new ( Vec :: new ( ) ) ) ,
214+ should_fail : Arc :: new ( Mutex :: new ( false ) ) ,
215+ error_message : Arc :: new ( Mutex :: new ( String :: new ( ) ) ) ,
216+ }
217+ }
218+
219+ /// Add a block directly to the provider's storage.
220+ fn with_block ( self , block : SealedBlockWithStatus ) -> Self {
221+ self . blocks . lock ( ) . unwrap ( ) . push ( ( block, Default :: default ( ) , Default :: default ( ) ) ) ;
222+ self
223+ }
224+
225+ /// Configure the mock to fail on insert operations.
226+ fn with_insert_error ( self , error : String ) -> Self {
227+ * self . should_fail . lock ( ) . unwrap ( ) = true ;
228+ * self . error_message . lock ( ) . unwrap ( ) = error;
229+ self
230+ }
231+
232+ /// Get the number of blocks stored.
233+ fn stored_block_count ( & self ) -> usize {
234+ self . blocks . lock ( ) . unwrap ( ) . len ( )
235+ }
236+
237+ /// Get all stored block numbers.
238+ fn stored_block_numbers ( & self ) -> Vec < BlockNumber > {
239+ self . blocks . lock ( ) . unwrap ( ) . iter ( ) . map ( |( block, _, _) | block. block . header . number ) . collect ( )
240+ }
241+ }
242+
243+ impl katana_provider:: ProviderFactory for MockProvider {
244+ type Provider = MockInnerProvider ;
245+ type ProviderMut = MockInnerProvider ;
246+
247+ fn provider ( & self ) -> Self :: Provider {
248+ MockInnerProvider :: new (
249+ Arc :: clone ( & self . blocks ) ,
250+ Arc :: clone ( & self . should_fail ) ,
251+ Arc :: clone ( & self . error_message ) ,
252+ )
253+ }
254+
255+ fn provider_mut ( & self ) -> Self :: ProviderMut {
256+ MockInnerProvider :: new (
257+ Arc :: clone ( & self . blocks ) ,
258+ Arc :: clone ( & self . should_fail ) ,
259+ Arc :: clone ( & self . error_message ) ,
260+ )
261+ }
262+ }
263+
208264/// Helper function to create a minimal test `SealedBlockWithStatus`.
209265///
210266/// Creates a block with the given number and automatically sets the parent hash
@@ -403,13 +459,14 @@ async fn fetch_blocks_from_gateway() {
403459 let feeder_gateway = SequencerGateway :: sepolia ( ) ;
404460 let downloader = BatchBlockDownloader :: new_gateway ( feeder_gateway, 10 ) ;
405461
406- let mut stage = Blocks :: new ( & provider, downloader) ;
462+ let mut stage = Blocks :: new ( provider. clone ( ) , downloader) ;
407463
408464 let input = StageExecutionInput :: new ( from_block, to_block) ;
409465 stage. execute ( & input) . await . expect ( "failed to execute stage" ) ;
410466
411467 // check provider storage
412- let block_number = provider. latest_number ( ) . expect ( "failed to get latest block number" ) ;
468+ let block_number =
469+ provider. provider ( ) . latest_number ( ) . expect ( "failed to get latest block number" ) ;
413470 assert_eq ! ( block_number, to_block) ;
414471}
415472
0 commit comments