@@ -85,8 +85,8 @@ impl<N: Network> RobustProvider<N> {
8585
8686 /// Set the base delay for exponential backoff retries.
8787 #[ must_use]
88- pub fn min_delay ( mut self , retry_interval : Duration ) -> Self {
89- self . min_delay = retry_interval ;
88+ pub fn min_delay ( mut self , min_delay : Duration ) -> Self {
89+ self . min_delay = min_delay ;
9090 self
9191 }
9292
@@ -105,8 +105,8 @@ impl<N: Network> RobustProvider<N> {
105105 ///
106106 /// Fallback providers are used when the primary provider times out or fails.
107107 #[ must_use]
108- pub fn fallback ( mut self , provider : RootProvider < N > ) -> Self {
109- self . providers . push ( provider) ;
108+ pub fn fallback ( mut self , provider : impl Provider < N > ) -> Self {
109+ self . providers . push ( provider. root ( ) . to_owned ( ) ) ;
110110 self
111111 }
112112
@@ -122,9 +122,10 @@ impl<N: Network> RobustProvider<N> {
122122 ) -> Result < N :: BlockResponse , Error > {
123123 info ! ( "eth_getBlockByNumber called" ) ;
124124 let result = self
125- . retry_with_total_timeout ( move |provider| async move {
126- provider. get_block_by_number ( number) . await
127- } )
125+ . retry_with_total_timeout (
126+ move |provider| async move { provider. get_block_by_number ( number) . await } ,
127+ false ,
128+ )
128129 . await ;
129130 if let Err ( e) = & result {
130131 error ! ( error = %e, "eth_getByBlockNumber failed" ) ;
@@ -144,6 +145,7 @@ impl<N: Network> RobustProvider<N> {
144145 let result = self
145146 . retry_with_total_timeout (
146147 move |provider| async move { provider. get_block_number ( ) . await } ,
148+ false ,
147149 )
148150 . await ;
149151 if let Err ( e) = & result {
@@ -164,9 +166,10 @@ impl<N: Network> RobustProvider<N> {
164166 ) -> Result < N :: BlockResponse , Error > {
165167 info ! ( "eth_getBlockByHash called" ) ;
166168 let result = self
167- . retry_with_total_timeout ( move |provider| async move {
168- provider. get_block_by_hash ( hash) . await
169- } )
169+ . retry_with_total_timeout (
170+ move |provider| async move { provider. get_block_by_hash ( hash) . await } ,
171+ false ,
172+ )
170173 . await ;
171174 if let Err ( e) = & result {
172175 error ! ( error = %e, "eth_getBlockByHash failed" ) ;
@@ -186,6 +189,7 @@ impl<N: Network> RobustProvider<N> {
186189 let result = self
187190 . retry_with_total_timeout (
188191 move |provider| async move { provider. get_logs ( filter) . await } ,
192+ false ,
189193 )
190194 . await ;
191195 if let Err ( e) = & result {
@@ -202,11 +206,12 @@ impl<N: Network> RobustProvider<N> {
202206 /// after exhausting retries or if the call times out.
203207 pub async fn subscribe_blocks ( & self ) -> Result < Subscription < N :: HeaderResponse > , Error > {
204208 info ! ( "eth_subscribe called" ) ;
205- // We need this otherwise error is not clear
209+ // immediately fail if primary does not support pubsub
206210 self . root ( ) . client ( ) . expect_pubsub_frontend ( ) ;
207211 let result = self
208212 . retry_with_total_timeout (
209213 move |provider| async move { provider. subscribe_blocks ( ) . await } ,
214+ true ,
210215 )
211216 . await ;
212217 if let Err ( e) = & result {
@@ -224,17 +229,27 @@ impl<N: Network> RobustProvider<N> {
224229 /// If the timeout is exceeded and fallback providers are available, it will
225230 /// attempt to use each fallback provider in sequence.
226231 ///
232+ /// If `require_pubsub` is true, providers that don't support pubsub will be skipped.
233+ ///
227234 /// # Errors
228235 ///
229236 /// - Returns [`RpcError<TransportErrorKind>`] with message "total operation timeout exceeded
230237 /// and all fallback providers failed" if the overall timeout elapses and no fallback
231238 /// providers succeed.
239+ /// - Returns [`RpcError::Transport(TransportErrorKind::PubsubUnavailable)`] if `require_pubsub`
240+ /// is true and all providers don't support pubsub.
232241 /// - Propagates any [`RpcError<TransportErrorKind>`] from the underlying retries.
233- async fn retry_with_total_timeout < T : Debug , F , Fut > ( & self , operation : F ) -> Result < T , Error >
242+ async fn retry_with_total_timeout < T : Debug , F , Fut > (
243+ & self ,
244+ operation : F ,
245+ require_pubsub : bool ,
246+ ) -> Result < T , Error >
234247 where
235248 F : Fn ( RootProvider < N > ) -> Fut ,
236249 Fut : Future < Output = Result < T , RpcError < TransportErrorKind > > > ,
237250 {
251+ let mut skipped_count = 0 ;
252+
238253 let mut providers = self . providers . iter ( ) ;
239254 let primary = providers. next ( ) . expect ( "should have primary provider" ) ;
240255
@@ -253,6 +268,11 @@ impl<N: Network> RobustProvider<N> {
253268 // This loop starts at index 1 automatically
254269 for ( idx, provider) in providers. enumerate ( ) {
255270 let fallback_num = idx + 1 ;
271+ if require_pubsub && !Self :: supports_pubsub ( provider) {
272+ info ! ( "Fallback provider {} doesn't support pubsub, skipping" , fallback_num) ;
273+ skipped_count += 1 ;
274+ continue ;
275+ }
256276 info ! ( "Attempting fallback provider {}/{}" , fallback_num, self . providers. len( ) - 1 ) ;
257277
258278 match self . try_provider_with_timeout ( provider, & operation) . await {
@@ -267,6 +287,13 @@ impl<N: Network> RobustProvider<N> {
267287 }
268288 }
269289
290+ // If all providers were skipped due to pubsub requirement
291+ if skipped_count == self . providers . len ( ) {
292+ error ! ( "All providers skipped - none support pubsub" ) ;
293+ return Err ( RpcError :: Transport ( TransportErrorKind :: PubsubUnavailable ) . into ( ) ) ;
294+ }
295+
296+ // Return the last error encountered
270297 error ! ( "All providers failed or timed out" ) ;
271298 Err ( last_error)
272299 }
@@ -298,25 +325,30 @@ impl<N: Network> RobustProvider<N> {
298325 . map_err ( Error :: from) ?
299326 . map_err ( Error :: from)
300327 }
328+
329+ /// Check if a provider supports pubsub
330+ fn supports_pubsub ( provider : & RootProvider < N > ) -> bool {
331+ provider. client ( ) . pubsub_frontend ( ) . is_some ( )
332+ }
301333}
302334
303335#[ cfg( test) ]
304336mod tests {
305337 use super :: * ;
306- use alloy:: network:: Ethereum ;
338+ use alloy:: {
339+ network:: Ethereum ,
340+ providers:: { ProviderBuilder , WsConnect } ,
341+ } ;
342+ use alloy_node_bindings:: Anvil ;
307343 use std:: sync:: atomic:: { AtomicUsize , Ordering } ;
308344 use tokio:: time:: sleep;
309345
310- fn test_provider (
311- timeout : u64 ,
312- max_retries : usize ,
313- retry_interval : u64 ,
314- ) -> RobustProvider < Ethereum > {
346+ fn test_provider ( timeout : u64 , max_retries : usize , min_delay : u64 ) -> RobustProvider < Ethereum > {
315347 RobustProvider {
316348 providers : vec ! [ RootProvider :: new_http( "http://localhost:8545" . parse( ) . unwrap( ) ) ] ,
317349 max_timeout : Duration :: from_millis ( timeout) ,
318350 max_retries,
319- min_delay : Duration :: from_millis ( retry_interval ) ,
351+ min_delay : Duration :: from_millis ( min_delay ) ,
320352 }
321353 }
322354
@@ -327,11 +359,14 @@ mod tests {
327359 let call_count = AtomicUsize :: new ( 0 ) ;
328360
329361 let result = provider
330- . retry_with_total_timeout ( |_| async {
331- call_count. fetch_add ( 1 , Ordering :: SeqCst ) ;
332- let count = call_count. load ( Ordering :: SeqCst ) ;
333- Ok ( count)
334- } )
362+ . retry_with_total_timeout (
363+ |_| async {
364+ call_count. fetch_add ( 1 , Ordering :: SeqCst ) ;
365+ let count = call_count. load ( Ordering :: SeqCst ) ;
366+ Ok ( count)
367+ } ,
368+ false ,
369+ )
335370 . await ;
336371
337372 assert ! ( matches!( result, Ok ( 1 ) ) ) ;
@@ -344,14 +379,17 @@ mod tests {
344379 let call_count = AtomicUsize :: new ( 0 ) ;
345380
346381 let result = provider
347- . retry_with_total_timeout ( |_| async {
348- call_count. fetch_add ( 1 , Ordering :: SeqCst ) ;
349- let count = call_count. load ( Ordering :: SeqCst ) ;
350- match count {
351- 3 => Ok ( count) ,
352- _ => Err ( TransportErrorKind :: BackendGone . into ( ) ) ,
353- }
354- } )
382+ . retry_with_total_timeout (
383+ |_| async {
384+ call_count. fetch_add ( 1 , Ordering :: SeqCst ) ;
385+ let count = call_count. load ( Ordering :: SeqCst ) ;
386+ match count {
387+ 3 => Ok ( count) ,
388+ _ => Err ( TransportErrorKind :: BackendGone . into ( ) ) ,
389+ }
390+ } ,
391+ false ,
392+ )
355393 . await ;
356394
357395 assert ! ( matches!( result, Ok ( 3 ) ) ) ;
@@ -364,10 +402,13 @@ mod tests {
364402 let call_count = AtomicUsize :: new ( 0 ) ;
365403
366404 let result: Result < ( ) , Error > = provider
367- . retry_with_total_timeout ( |_| async {
368- call_count. fetch_add ( 1 , Ordering :: SeqCst ) ;
369- Err ( TransportErrorKind :: BackendGone . into ( ) )
370- } )
405+ . retry_with_total_timeout (
406+ |_| async {
407+ call_count. fetch_add ( 1 , Ordering :: SeqCst ) ;
408+ Err ( TransportErrorKind :: BackendGone . into ( ) )
409+ } ,
410+ false ,
411+ )
371412 . await ;
372413
373414 assert ! ( matches!( result, Err ( Error :: RpcError ( _) ) ) ) ;
@@ -380,12 +421,98 @@ mod tests {
380421 let provider = test_provider ( max_timeout, 10 , 1 ) ;
381422
382423 let result = provider
383- . retry_with_total_timeout ( move |_provider| async move {
384- sleep ( Duration :: from_millis ( max_timeout + 10 ) ) . await ;
385- Ok ( 42 )
386- } )
424+ . retry_with_total_timeout (
425+ move |_provider| async move {
426+ sleep ( Duration :: from_millis ( max_timeout + 10 ) ) . await ;
427+ Ok ( 42 )
428+ } ,
429+ false ,
430+ )
387431 . await ;
388432
389433 assert ! ( matches!( result, Err ( Error :: Timeout ) ) ) ;
390434 }
435+
436+ #[ tokio:: test]
437+ async fn test_subscribe_fails_causes_backup_to_be_used ( ) {
438+ let anvil_1 = Anvil :: new ( ) . port ( 2222_u16 ) . try_spawn ( ) . expect ( "Failed to start anvil" ) ;
439+
440+ let ws_provider_1 = ProviderBuilder :: new ( )
441+ . connect_ws ( WsConnect :: new ( anvil_1. ws_endpoint_url ( ) . as_str ( ) ) )
442+ . await
443+ . expect ( "Failed to connect to WS" ) ;
444+
445+ let anvil_2 = Anvil :: new ( ) . port ( 1111_u16 ) . try_spawn ( ) . expect ( "Failed to start anvil" ) ;
446+
447+ let ws_provider_2 = ProviderBuilder :: new ( )
448+ . connect_ws ( WsConnect :: new ( anvil_2. ws_endpoint_url ( ) . as_str ( ) ) )
449+ . await
450+ . expect ( "Failed to connect to WS" ) ;
451+
452+ let robust = RobustProvider :: new ( ws_provider_1)
453+ . fallback ( ws_provider_2)
454+ . max_timeout ( Duration :: from_secs ( 5 ) )
455+ . max_retries ( 10 )
456+ . min_delay ( Duration :: from_millis ( 100 ) ) ;
457+
458+ drop ( anvil_1) ;
459+
460+ let result = robust. subscribe_blocks ( ) . await ;
461+
462+ assert ! ( result. is_ok( ) , "Expected subscribe blocks to work" ) ;
463+ }
464+
465+ #[ tokio:: test]
466+ #[ should_panic( expected = "called pubsub_frontend on a non-pubsub transport" ) ]
467+ async fn test_subscribe_fails_if_primary_provider_lacks_pubsub ( ) {
468+ let anvil = Anvil :: new ( ) . try_spawn ( ) . expect ( "Failed to start anvil" ) ;
469+
470+ let http_provider = ProviderBuilder :: new ( ) . connect_http ( anvil. endpoint_url ( ) ) ;
471+ let ws_provider = ProviderBuilder :: new ( )
472+ . connect_ws ( WsConnect :: new ( anvil. ws_endpoint_url ( ) . as_str ( ) ) )
473+ . await
474+ . expect ( "Failed to connect to WS" ) ;
475+
476+ let robust = RobustProvider :: new ( http_provider)
477+ . fallback ( ws_provider)
478+ . max_timeout ( Duration :: from_secs ( 5 ) )
479+ . max_retries ( 10 )
480+ . min_delay ( Duration :: from_millis ( 100 ) ) ;
481+
482+ let _ = robust. subscribe_blocks ( ) . await ;
483+ }
484+
485+ #[ tokio:: test]
486+ async fn test_ws_fails_http_fallback_returns_primary_error ( ) {
487+ let anvil_1 = Anvil :: new ( ) . try_spawn ( ) . expect ( "Failed to start anvil" ) ;
488+
489+ let ws_provider = ProviderBuilder :: new ( )
490+ . connect_ws ( WsConnect :: new ( anvil_1. ws_endpoint_url ( ) . as_str ( ) ) )
491+ . await
492+ . expect ( "Failed to connect to WS" ) ;
493+
494+ let anvil_2 = Anvil :: new ( ) . port ( 8222_u16 ) . try_spawn ( ) . expect ( "Failed to start anvil" ) ;
495+ let http_provider = ProviderBuilder :: new ( ) . connect_http ( anvil_2. endpoint_url ( ) ) ;
496+
497+ let robust = RobustProvider :: new ( ws_provider. clone ( ) )
498+ . fallback ( http_provider)
499+ . max_timeout ( Duration :: from_millis ( 500 ) )
500+ . max_retries ( 0 )
501+ . min_delay ( Duration :: from_millis ( 10 ) ) ;
502+
503+ // force ws_provider to fail and return BackendGone
504+ drop ( anvil_1) ;
505+
506+ let err = robust. subscribe_blocks ( ) . await . unwrap_err ( ) ;
507+
508+ // The error should be either a Timeout or BackendGone from the primary WS provider,
509+ // NOT a PubsubUnavailable error (which would indicate HTTP fallback was attempted)
510+ match err {
511+ Error :: Timeout => { }
512+ Error :: RpcError ( e) => {
513+ assert ! ( matches!( e. as_ref( ) , RpcError :: Transport ( TransportErrorKind :: BackendGone ) ) ) ;
514+ }
515+ Error :: BlockNotFound ( id) => panic ! ( "Unexpected error type: BlockNotFound({id})" ) ,
516+ }
517+ }
391518}
0 commit comments