1- use std:: time :: Duration ;
1+ use std:: sync :: Arc ;
22
33use anyhow:: Result ;
4+ use backon:: { ExponentialBuilder , Retryable } ;
45use katana_primitives:: block:: BlockNumber ;
5- use katana_primitives:: class:: {
6- CasmContractClass , ClassHash , CompiledClass , ContractClass , SierraContractClass ,
7- } ;
8- use katana_primitives:: conversion:: rpc:: { legacy_rpc_to_class, StarknetRsLegacyContractClass } ;
6+ use katana_primitives:: class:: { ClassHash , ContractClass , SierraContractClass } ;
7+ use katana_primitives:: conversion:: rpc:: StarknetRsLegacyContractClass ;
98use katana_provider:: traits:: contract:: { ContractClassWriter , ContractClassWriterExt } ;
109use katana_provider:: traits:: state_update:: StateUpdateProvider ;
1110use katana_rpc_types:: class:: RpcSierraContractClass ;
1211use starknet:: providers:: sequencer:: models:: { BlockId , DeployedClass } ;
1312use starknet:: providers:: { ProviderError , SequencerGatewayProvider } ;
14- use tracing:: info ;
13+ use tracing:: warn ;
1514
1615use super :: { Stage , StageExecutionInput , StageResult } ;
1716
@@ -24,22 +23,97 @@ pub enum Error {
2423#[ derive( Debug ) ]
2524pub struct Classes < P > {
2625 provider : P ,
27- feeder_gateway : SequencerGatewayProvider ,
26+ downloader : Downloader ,
2827}
2928
3029impl < P > Classes < P > {
3130 pub fn new ( provider : P , feeder_gateway : SequencerGatewayProvider ) -> Self {
32- Self { provider, feeder_gateway }
31+ Self { provider, downloader : Downloader :: new ( feeder_gateway) }
3332 }
3433}
3534
36- impl < P > Classes < P >
35+ #[ async_trait:: async_trait]
36+ impl < P > Stage for Classes < P >
3737where
38- P : StateUpdateProvider + ContractClassWriter ,
38+ P : StateUpdateProvider + ContractClassWriter + ContractClassWriterExt ,
3939{
40- #[ allow( deprecated) ]
41- async fn get_class ( & self , hash : ClassHash , block : BlockNumber ) -> Result < ContractClass , Error > {
42- let class = self . feeder_gateway . get_class_by_hash ( hash, BlockId :: Number ( block) ) . await ?;
40+ fn id ( & self ) -> & ' static str {
41+ "Classes"
42+ }
43+
44+ async fn execute ( & mut self , input : & StageExecutionInput ) -> StageResult {
45+ for i in input. from ..=input. to {
46+ let class_hashes = self . provider . declared_classes ( i. into ( ) ) ?. unwrap ( ) ;
47+ let class_hashes = class_hashes. keys ( ) . map ( |hash| * hash) . collect :: < Vec < _ > > ( ) ;
48+
49+ let classes = self . downloader . fetch_classes ( & class_hashes, i) . await ?;
50+ for ( hash, class) in classes {
51+ self . provider . set_class ( hash, class) ?;
52+ }
53+ }
54+
55+ Ok ( ( ) )
56+ }
57+ }
58+
59+ #[ derive( Debug , Clone ) ]
60+ struct Downloader {
61+ client : Arc < SequencerGatewayProvider > ,
62+ }
63+
64+ impl Downloader {
65+ fn new ( client : SequencerGatewayProvider ) -> Self {
66+ Self { client : Arc :: new ( client) }
67+ }
68+
69+ async fn fetch_classes (
70+ & self ,
71+ classes : & [ ClassHash ] ,
72+ block : BlockNumber ,
73+ ) -> Result < Vec < ( ClassHash , ContractClass ) > , Error > {
74+ let mut all_results = Vec :: with_capacity ( classes. len ( ) ) ;
75+
76+ for hash in classes {
77+ let mut futures = Vec :: new ( ) ;
78+
79+ futures. push ( self . fetch_class_with_retry ( * hash, block) ) ;
80+ let batch_results = futures:: future:: join_all ( futures) . await ;
81+
82+ all_results. extend ( batch_results) ;
83+ }
84+
85+ all_results. into_iter ( ) . collect ( )
86+ }
87+
88+ async fn fetch_class_with_retry (
89+ & self ,
90+ hash : ClassHash ,
91+ block : BlockNumber ,
92+ ) -> Result < ( ClassHash , ContractClass ) , Error > {
93+ let request = || async move {
94+ #[ allow( deprecated) ]
95+ self . clone ( ) . fetch_class ( hash, block) . await
96+ } ;
97+
98+ // Retry only when being rate limited
99+ let result = request
100+ . retry ( ExponentialBuilder :: default ( ) )
101+ . when ( |e| matches ! ( e, Error :: Gateway ( ProviderError :: RateLimited ) ) )
102+ . notify ( |error, _| {
103+ warn ! ( target: "pipeline" , hash = format!( "{hash:#x}" ) , %block, %error, "Retrying class download." ) ;
104+ } )
105+ . await ?;
106+
107+ Ok ( ( hash, result) )
108+ }
109+
110+ async fn fetch_class (
111+ & self ,
112+ hash : ClassHash ,
113+ block : BlockNumber ,
114+ ) -> Result < ContractClass , Error > {
115+ #[ allow( deprecated) ]
116+ let class = self . client . get_class_by_hash ( hash, BlockId :: Number ( block) ) . await ?;
43117
44118 let class = match class {
45119 DeployedClass :: LegacyClass ( legacy) => {
@@ -59,36 +133,6 @@ where
59133 }
60134}
61135
62- #[ async_trait:: async_trait]
63- impl < P > Stage for Classes < P >
64- where
65- P : StateUpdateProvider + ContractClassWriter + ContractClassWriterExt ,
66- {
67- fn id ( & self ) -> & ' static str {
68- "Classes"
69- }
70-
71- async fn execute ( & mut self , input : & StageExecutionInput ) -> StageResult {
72- for i in input. from ..=input. to {
73- // loop thru all the class hashes in the current block
74- let class_hashes = self . provider . declared_classes ( i. into ( ) ) ?. unwrap ( ) ;
75-
76- // TODO: do this in parallel
77- for hash in class_hashes. keys ( ) {
78- info ! ( target: "pipeline" , class_hash = format!( "{hash:#x}" ) , "Fetching class artifacts." ) ;
79-
80- // 1. fetch sierra and casm class from fgw
81- let class = self . get_class ( * hash, i) . await ?;
82- self . provider . set_class ( * hash, class) ?;
83-
84- tokio:: time:: sleep ( Duration :: from_secs ( 1 ) ) . await ;
85- }
86- }
87-
88- Ok ( ( ) )
89- }
90- }
91-
92136fn to_inner_legacy_class ( class : StarknetRsLegacyContractClass ) -> Result < ContractClass > {
93137 let value = serde_json:: to_value ( class) ?;
94138 let class = serde_json:: from_value :: < katana_primitives:: class:: LegacyContractClass > ( value) ?;
0 commit comments