Skip to content

Commit 655124a

Browse files
LeoPatOZ0xNeshi
andauthored
Pubsub checks (#153)
Co-authored-by: 0xNeshi <[email protected]>
1 parent 956c77f commit 655124a

File tree

2 files changed

+169
-42
lines changed

2 files changed

+169
-42
lines changed

src/robust_provider.rs

Lines changed: 167 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -85,8 +85,8 @@ impl<N: Network> RobustProvider<N> {
8585

8686
/// Set the base delay for exponential backoff retries.
8787
#[must_use]
88-
pub fn min_delay(mut self, retry_interval: Duration) -> Self {
89-
self.min_delay = retry_interval;
88+
pub fn min_delay(mut self, min_delay: Duration) -> Self {
89+
self.min_delay = min_delay;
9090
self
9191
}
9292

@@ -105,8 +105,8 @@ impl<N: Network> RobustProvider<N> {
105105
///
106106
/// Fallback providers are used when the primary provider times out or fails.
107107
#[must_use]
108-
pub fn fallback(mut self, provider: RootProvider<N>) -> Self {
109-
self.providers.push(provider);
108+
pub fn fallback(mut self, provider: impl Provider<N>) -> Self {
109+
self.providers.push(provider.root().to_owned());
110110
self
111111
}
112112

@@ -122,9 +122,10 @@ impl<N: Network> RobustProvider<N> {
122122
) -> Result<N::BlockResponse, Error> {
123123
info!("eth_getBlockByNumber called");
124124
let result = self
125-
.retry_with_total_timeout(move |provider| async move {
126-
provider.get_block_by_number(number).await
127-
})
125+
.retry_with_total_timeout(
126+
move |provider| async move { provider.get_block_by_number(number).await },
127+
false,
128+
)
128129
.await;
129130
if let Err(e) = &result {
130131
error!(error = %e, "eth_getByBlockNumber failed");
@@ -144,6 +145,7 @@ impl<N: Network> RobustProvider<N> {
144145
let result = self
145146
.retry_with_total_timeout(
146147
move |provider| async move { provider.get_block_number().await },
148+
false,
147149
)
148150
.await;
149151
if let Err(e) = &result {
@@ -164,9 +166,10 @@ impl<N: Network> RobustProvider<N> {
164166
) -> Result<N::BlockResponse, Error> {
165167
info!("eth_getBlockByHash called");
166168
let result = self
167-
.retry_with_total_timeout(move |provider| async move {
168-
provider.get_block_by_hash(hash).await
169-
})
169+
.retry_with_total_timeout(
170+
move |provider| async move { provider.get_block_by_hash(hash).await },
171+
false,
172+
)
170173
.await;
171174
if let Err(e) = &result {
172175
error!(error = %e, "eth_getBlockByHash failed");
@@ -186,6 +189,7 @@ impl<N: Network> RobustProvider<N> {
186189
let result = self
187190
.retry_with_total_timeout(
188191
move |provider| async move { provider.get_logs(filter).await },
192+
false,
189193
)
190194
.await;
191195
if let Err(e) = &result {
@@ -202,11 +206,12 @@ impl<N: Network> RobustProvider<N> {
202206
/// after exhausting retries or if the call times out.
203207
pub async fn subscribe_blocks(&self) -> Result<Subscription<N::HeaderResponse>, Error> {
204208
info!("eth_subscribe called");
205-
// We need this otherwise error is not clear
209+
// immediately fail if primary does not support pubsub
206210
self.root().client().expect_pubsub_frontend();
207211
let result = self
208212
.retry_with_total_timeout(
209213
move |provider| async move { provider.subscribe_blocks().await },
214+
true,
210215
)
211216
.await;
212217
if let Err(e) = &result {
@@ -224,17 +229,27 @@ impl<N: Network> RobustProvider<N> {
224229
/// If the timeout is exceeded and fallback providers are available, it will
225230
/// attempt to use each fallback provider in sequence.
226231
///
232+
/// If `require_pubsub` is true, providers that don't support pubsub will be skipped.
233+
///
227234
/// # Errors
228235
///
229236
/// - Returns [`RpcError<TransportErrorKind>`] with message "total operation timeout exceeded
230237
/// and all fallback providers failed" if the overall timeout elapses and no fallback
231238
/// providers succeed.
239+
/// - Returns [`RpcError::Transport(TransportErrorKind::PubsubUnavailable)`] if `require_pubsub`
240+
/// is true and all providers don't support pubsub.
232241
/// - Propagates any [`RpcError<TransportErrorKind>`] from the underlying retries.
233-
async fn retry_with_total_timeout<T: Debug, F, Fut>(&self, operation: F) -> Result<T, Error>
242+
async fn retry_with_total_timeout<T: Debug, F, Fut>(
243+
&self,
244+
operation: F,
245+
require_pubsub: bool,
246+
) -> Result<T, Error>
234247
where
235248
F: Fn(RootProvider<N>) -> Fut,
236249
Fut: Future<Output = Result<T, RpcError<TransportErrorKind>>>,
237250
{
251+
let mut skipped_count = 0;
252+
238253
let mut providers = self.providers.iter();
239254
let primary = providers.next().expect("should have primary provider");
240255

@@ -253,6 +268,11 @@ impl<N: Network> RobustProvider<N> {
253268
// This loop starts at index 1 automatically
254269
for (idx, provider) in providers.enumerate() {
255270
let fallback_num = idx + 1;
271+
if require_pubsub && !Self::supports_pubsub(provider) {
272+
info!("Fallback provider {} doesn't support pubsub, skipping", fallback_num);
273+
skipped_count += 1;
274+
continue;
275+
}
256276
info!("Attempting fallback provider {}/{}", fallback_num, self.providers.len() - 1);
257277

258278
match self.try_provider_with_timeout(provider, &operation).await {
@@ -267,6 +287,13 @@ impl<N: Network> RobustProvider<N> {
267287
}
268288
}
269289

290+
// If all providers were skipped due to pubsub requirement
291+
if skipped_count == self.providers.len() {
292+
error!("All providers skipped - none support pubsub");
293+
return Err(RpcError::Transport(TransportErrorKind::PubsubUnavailable).into());
294+
}
295+
296+
// Return the last error encountered
270297
error!("All providers failed or timed out");
271298
Err(last_error)
272299
}
@@ -298,25 +325,30 @@ impl<N: Network> RobustProvider<N> {
298325
.map_err(Error::from)?
299326
.map_err(Error::from)
300327
}
328+
329+
/// Check if a provider supports pubsub
330+
fn supports_pubsub(provider: &RootProvider<N>) -> bool {
331+
provider.client().pubsub_frontend().is_some()
332+
}
301333
}
302334

303335
#[cfg(test)]
304336
mod tests {
305337
use super::*;
306-
use alloy::network::Ethereum;
338+
use alloy::{
339+
network::Ethereum,
340+
providers::{ProviderBuilder, WsConnect},
341+
};
342+
use alloy_node_bindings::Anvil;
307343
use std::sync::atomic::{AtomicUsize, Ordering};
308344
use tokio::time::sleep;
309345

310-
fn test_provider(
311-
timeout: u64,
312-
max_retries: usize,
313-
retry_interval: u64,
314-
) -> RobustProvider<Ethereum> {
346+
fn test_provider(timeout: u64, max_retries: usize, min_delay: u64) -> RobustProvider<Ethereum> {
315347
RobustProvider {
316348
providers: vec![RootProvider::new_http("http://localhost:8545".parse().unwrap())],
317349
max_timeout: Duration::from_millis(timeout),
318350
max_retries,
319-
min_delay: Duration::from_millis(retry_interval),
351+
min_delay: Duration::from_millis(min_delay),
320352
}
321353
}
322354

@@ -327,11 +359,14 @@ mod tests {
327359
let call_count = AtomicUsize::new(0);
328360

329361
let result = provider
330-
.retry_with_total_timeout(|_| async {
331-
call_count.fetch_add(1, Ordering::SeqCst);
332-
let count = call_count.load(Ordering::SeqCst);
333-
Ok(count)
334-
})
362+
.retry_with_total_timeout(
363+
|_| async {
364+
call_count.fetch_add(1, Ordering::SeqCst);
365+
let count = call_count.load(Ordering::SeqCst);
366+
Ok(count)
367+
},
368+
false,
369+
)
335370
.await;
336371

337372
assert!(matches!(result, Ok(1)));
@@ -344,14 +379,17 @@ mod tests {
344379
let call_count = AtomicUsize::new(0);
345380

346381
let result = provider
347-
.retry_with_total_timeout(|_| async {
348-
call_count.fetch_add(1, Ordering::SeqCst);
349-
let count = call_count.load(Ordering::SeqCst);
350-
match count {
351-
3 => Ok(count),
352-
_ => Err(TransportErrorKind::BackendGone.into()),
353-
}
354-
})
382+
.retry_with_total_timeout(
383+
|_| async {
384+
call_count.fetch_add(1, Ordering::SeqCst);
385+
let count = call_count.load(Ordering::SeqCst);
386+
match count {
387+
3 => Ok(count),
388+
_ => Err(TransportErrorKind::BackendGone.into()),
389+
}
390+
},
391+
false,
392+
)
355393
.await;
356394

357395
assert!(matches!(result, Ok(3)));
@@ -364,10 +402,13 @@ mod tests {
364402
let call_count = AtomicUsize::new(0);
365403

366404
let result: Result<(), Error> = provider
367-
.retry_with_total_timeout(|_| async {
368-
call_count.fetch_add(1, Ordering::SeqCst);
369-
Err(TransportErrorKind::BackendGone.into())
370-
})
405+
.retry_with_total_timeout(
406+
|_| async {
407+
call_count.fetch_add(1, Ordering::SeqCst);
408+
Err(TransportErrorKind::BackendGone.into())
409+
},
410+
false,
411+
)
371412
.await;
372413

373414
assert!(matches!(result, Err(Error::RpcError(_))));
@@ -380,12 +421,98 @@ mod tests {
380421
let provider = test_provider(max_timeout, 10, 1);
381422

382423
let result = provider
383-
.retry_with_total_timeout(move |_provider| async move {
384-
sleep(Duration::from_millis(max_timeout + 10)).await;
385-
Ok(42)
386-
})
424+
.retry_with_total_timeout(
425+
move |_provider| async move {
426+
sleep(Duration::from_millis(max_timeout + 10)).await;
427+
Ok(42)
428+
},
429+
false,
430+
)
387431
.await;
388432

389433
assert!(matches!(result, Err(Error::Timeout)));
390434
}
435+
436+
#[tokio::test]
437+
async fn test_subscribe_fails_causes_backup_to_be_used() {
438+
let anvil_1 = Anvil::new().port(2222_u16).try_spawn().expect("Failed to start anvil");
439+
440+
let ws_provider_1 = ProviderBuilder::new()
441+
.connect_ws(WsConnect::new(anvil_1.ws_endpoint_url().as_str()))
442+
.await
443+
.expect("Failed to connect to WS");
444+
445+
let anvil_2 = Anvil::new().port(1111_u16).try_spawn().expect("Failed to start anvil");
446+
447+
let ws_provider_2 = ProviderBuilder::new()
448+
.connect_ws(WsConnect::new(anvil_2.ws_endpoint_url().as_str()))
449+
.await
450+
.expect("Failed to connect to WS");
451+
452+
let robust = RobustProvider::new(ws_provider_1)
453+
.fallback(ws_provider_2)
454+
.max_timeout(Duration::from_secs(5))
455+
.max_retries(10)
456+
.min_delay(Duration::from_millis(100));
457+
458+
drop(anvil_1);
459+
460+
let result = robust.subscribe_blocks().await;
461+
462+
assert!(result.is_ok(), "Expected subscribe blocks to work");
463+
}
464+
465+
#[tokio::test]
466+
#[should_panic(expected = "called pubsub_frontend on a non-pubsub transport")]
467+
async fn test_subscribe_fails_if_primary_provider_lacks_pubsub() {
468+
let anvil = Anvil::new().try_spawn().expect("Failed to start anvil");
469+
470+
let http_provider = ProviderBuilder::new().connect_http(anvil.endpoint_url());
471+
let ws_provider = ProviderBuilder::new()
472+
.connect_ws(WsConnect::new(anvil.ws_endpoint_url().as_str()))
473+
.await
474+
.expect("Failed to connect to WS");
475+
476+
let robust = RobustProvider::new(http_provider)
477+
.fallback(ws_provider)
478+
.max_timeout(Duration::from_secs(5))
479+
.max_retries(10)
480+
.min_delay(Duration::from_millis(100));
481+
482+
let _ = robust.subscribe_blocks().await;
483+
}
484+
485+
#[tokio::test]
486+
async fn test_ws_fails_http_fallback_returns_primary_error() {
487+
let anvil_1 = Anvil::new().try_spawn().expect("Failed to start anvil");
488+
489+
let ws_provider = ProviderBuilder::new()
490+
.connect_ws(WsConnect::new(anvil_1.ws_endpoint_url().as_str()))
491+
.await
492+
.expect("Failed to connect to WS");
493+
494+
let anvil_2 = Anvil::new().port(8222_u16).try_spawn().expect("Failed to start anvil");
495+
let http_provider = ProviderBuilder::new().connect_http(anvil_2.endpoint_url());
496+
497+
let robust = RobustProvider::new(ws_provider.clone())
498+
.fallback(http_provider)
499+
.max_timeout(Duration::from_millis(500))
500+
.max_retries(0)
501+
.min_delay(Duration::from_millis(10));
502+
503+
// force ws_provider to fail and return BackendGone
504+
drop(anvil_1);
505+
506+
let err = robust.subscribe_blocks().await.unwrap_err();
507+
508+
// The error should be either a Timeout or BackendGone from the primary WS provider,
509+
// NOT a PubsubUnavailable error (which would indicate HTTP fallback was attempted)
510+
match err {
511+
Error::Timeout => {}
512+
Error::RpcError(e) => {
513+
assert!(matches!(e.as_ref(), RpcError::Transport(TransportErrorKind::BackendGone)));
514+
}
515+
Error::BlockNotFound(id) => panic!("Unexpected error type: BlockNotFound({id})"),
516+
}
517+
}
391518
}

tests/block_range_scanner.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ async fn live_mode_processes_all_blocks_respecting_block_confirmations() -> anyh
3232

3333
robust_provider.root().anvil_mine(Some(1), None).await?;
3434

35-
assert_next!(stream, 6..=6);
35+
assert_next!(stream, 6..=6, timeout = 10);
3636
assert_empty!(stream);
3737

3838
// --- 1 block confirmation ---
@@ -50,7 +50,7 @@ async fn live_mode_processes_all_blocks_respecting_block_confirmations() -> anyh
5050

5151
robust_provider.root().anvil_mine(Some(1), None).await?;
5252

53-
assert_next!(stream, 11..=11);
53+
assert_next!(stream, 11..=11, timeout = 10);
5454
assert_empty!(stream);
5555

5656
Ok(())

0 commit comments

Comments
 (0)