Skip to content

Commit 01a9cf1

Browse files
committed
feat: add custom error for safe provider
1 parent 0546feb commit 01a9cf1

File tree

4 files changed

+76
-30
lines changed

4 files changed

+76
-30
lines changed

src/error.rs

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ use alloy::{
77
};
88
use thiserror::Error;
99

10-
use crate::block_range_scanner::Message;
10+
use crate::{block_range_scanner::Message, safe_provider::SafeProviderError};
1111

1212
#[derive(Error, Debug, Clone)]
1313
pub enum ScannerError {
@@ -42,6 +42,22 @@ pub enum ScannerError {
4242

4343
#[error("Block not found, block number: {0}")]
4444
BlockNotFound(BlockNumberOrTag),
45+
46+
#[error("Operation timed out")]
47+
Timeout,
48+
49+
#[error("Retry failed after {0} tries")]
50+
RetryFail(usize),
51+
}
52+
53+
impl From<SafeProviderError> for ScannerError {
54+
fn from(error: SafeProviderError) -> ScannerError {
55+
match error {
56+
SafeProviderError::RpcError(err) => ScannerError::RpcError(err),
57+
SafeProviderError::Timeout => ScannerError::Timeout,
58+
SafeProviderError::RetryFail(num) => ScannerError::RetryFail(num),
59+
}
60+
}
4561
}
4662

4763
impl From<Result<RangeInclusive<BlockNumber>, ScannerError>> for Message {

src/event_scanner/message.rs

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
use alloy::{rpc::types::Log, sol_types::SolEvent};
22

3-
use crate::{ScannerError, ScannerMessage};
3+
use crate::{ScannerError, ScannerMessage, safe_provider::SafeProviderError};
44

55
pub type Message = ScannerMessage<Vec<Log>, ScannerError>;
66

@@ -10,6 +10,13 @@ impl From<Vec<Log>> for Message {
1010
}
1111
}
1212

13+
impl From<SafeProviderError> for Message {
14+
fn from(error: SafeProviderError) -> Message {
15+
let scanner_error: ScannerError = error.into();
16+
scanner_error.into()
17+
}
18+
}
19+
1320
impl<E: SolEvent> PartialEq<Vec<E>> for Message {
1421
fn eq(&self, other: &Vec<E>) -> bool {
1522
self.eq(&other.as_slice())

src/event_scanner/modes/common.rs

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,11 @@ use std::ops::RangeInclusive;
33
use crate::{
44
block_range_scanner::{MAX_BUFFERED_MESSAGES, Message as BlockRangeMessage},
55
event_scanner::{filter::EventFilter, listener::EventListener, message::Message},
6-
safe_provider::SafeProvider,
6+
safe_provider::{SafeProvider, SafeProviderError},
77
};
88
use alloy::{
99
network::Network,
1010
rpc::types::{Filter, Log},
11-
transports::{RpcError, TransportErrorKind},
1211
};
1312
use tokio::sync::{
1413
broadcast::{self, Sender, error::RecvError},
@@ -130,7 +129,7 @@ async fn get_logs<N: Network>(
130129
event_filter: &EventFilter,
131130
log_filter: &Filter,
132131
provider: &SafeProvider<N>,
133-
) -> Result<Vec<Log>, RpcError<TransportErrorKind>> {
132+
) -> Result<Vec<Log>, SafeProviderError> {
134133
let log_filter = log_filter.clone().from_block(*range.start()).to_block(*range.end());
135134

136135
match provider.get_logs(&log_filter).await {

src/safe_provider.rs

Lines changed: 49 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
use std::{future::Future, time::Duration};
1+
use std::{future::Future, sync::Arc, time::Duration};
22

33
use alloy::{
44
eips::BlockNumberOrTag,
@@ -9,8 +9,25 @@ use alloy::{
99
transports::{RpcError, TransportErrorKind},
1010
};
1111
use backon::{ExponentialBuilder, Retryable};
12+
use thiserror::Error;
1213
use tracing::{error, info};
1314

15+
#[derive(Error, Debug, Clone)]
16+
pub enum SafeProviderError {
17+
#[error("RPC error: {0}")]
18+
RpcError(Arc<RpcError<TransportErrorKind>>),
19+
#[error("Operation timed out")]
20+
Timeout,
21+
#[error("Retry failed after {0} tries")]
22+
RetryFail(usize),
23+
}
24+
25+
impl From<RpcError<TransportErrorKind>> for SafeProviderError {
26+
fn from(err: RpcError<TransportErrorKind>) -> Self {
27+
SafeProviderError::RpcError(Arc::new(err))
28+
}
29+
}
30+
1431
/// Safe provider wrapper with built-in retry and timeout mechanisms.
1532
///
1633
/// This wrapper around Alloy providers automatically handles retries,
@@ -70,9 +87,11 @@ impl<N: Network> SafeProvider<N> {
7087
pub async fn get_block_by_number(
7188
&self,
7289
number: BlockNumberOrTag,
73-
) -> Result<Option<N::BlockResponse>, RpcError<TransportErrorKind>> {
90+
) -> Result<Option<N::BlockResponse>, SafeProviderError> {
7491
info!("eth_getBlockByNumber called");
75-
let operation = async || self.provider.get_block_by_number(number).await;
92+
let operation = async || {
93+
self.provider.get_block_by_number(number).await.map_err(SafeProviderError::from)
94+
};
7695
let result = self.retry_with_total_timeout(operation).await;
7796
if let Err(e) = &result {
7897
error!(error = %e, "eth_getByBlockNumber failed");
@@ -86,9 +105,10 @@ impl<N: Network> SafeProvider<N> {
86105
///
87106
/// Returns an error if RPC call fails repeatedly even
88107
/// after exhausting retries or if the call times out.
89-
pub async fn get_block_number(&self) -> Result<u64, RpcError<TransportErrorKind>> {
108+
pub async fn get_block_number(&self) -> Result<u64, SafeProviderError> {
90109
info!("eth_getBlockNumber called");
91-
let operation = || self.provider.get_block_number();
110+
let operation =
111+
async || self.provider.get_block_number().await.map_err(SafeProviderError::from);
92112
let result = self.retry_with_total_timeout(operation).await;
93113
if let Err(e) = &result {
94114
error!(error = %e, "eth_getBlockNumber failed");
@@ -105,9 +125,10 @@ impl<N: Network> SafeProvider<N> {
105125
pub async fn get_block_by_hash(
106126
&self,
107127
hash: alloy::primitives::BlockHash,
108-
) -> Result<Option<N::BlockResponse>, RpcError<TransportErrorKind>> {
128+
) -> Result<Option<N::BlockResponse>, SafeProviderError> {
109129
info!("eth_getBlockByHash called");
110-
let operation = async || self.provider.get_block_by_hash(hash).await;
130+
let operation =
131+
async || self.provider.get_block_by_hash(hash).await.map_err(SafeProviderError::from);
111132
let result = self.retry_with_total_timeout(operation).await;
112133
if let Err(e) = &result {
113134
error!(error = %e, "eth_getBlockByHash failed");
@@ -121,12 +142,10 @@ impl<N: Network> SafeProvider<N> {
121142
///
122143
/// Returns an error if RPC call fails repeatedly even
123144
/// after exhausting retries or if the call times out.
124-
pub async fn get_logs(
125-
&self,
126-
filter: &Filter,
127-
) -> Result<Vec<Log>, RpcError<TransportErrorKind>> {
145+
pub async fn get_logs(&self, filter: &Filter) -> Result<Vec<Log>, SafeProviderError> {
128146
info!("eth_getLogs called");
129-
let operation = || self.provider.get_logs(filter);
147+
let operation =
148+
async || self.provider.get_logs(filter).await.map_err(SafeProviderError::from);
130149
let result = self.retry_with_total_timeout(operation).await;
131150
if let Err(e) = &result {
132151
error!(error = %e, "eth_getLogs failed");
@@ -142,11 +161,14 @@ impl<N: Network> SafeProvider<N> {
142161
/// after exhausting retries or if the call times out.
143162
pub async fn subscribe_blocks(
144163
&self,
145-
) -> Result<Subscription<N::HeaderResponse>, RpcError<TransportErrorKind>> {
164+
) -> Result<Subscription<N::HeaderResponse>, SafeProviderError> {
146165
info!("eth_subscribe called");
147166
let provider = self.provider.clone();
148-
let result =
149-
self.retry_with_total_timeout(|| async { provider.subscribe_blocks().await }).await;
167+
let result = self
168+
.retry_with_total_timeout(|| async {
169+
provider.subscribe_blocks().await.map_err(SafeProviderError::from)
170+
})
171+
.await;
150172
if let Err(e) = &result {
151173
error!(error = %e, "eth_subscribe failed");
152174
}
@@ -167,10 +189,10 @@ impl<N: Network> SafeProvider<N> {
167189
async fn retry_with_total_timeout<T, F, Fut>(
168190
&self,
169191
operation: F,
170-
) -> Result<T, RpcError<TransportErrorKind>>
192+
) -> Result<T, SafeProviderError>
171193
where
172194
F: Fn() -> Fut,
173-
Fut: Future<Output = Result<T, RpcError<TransportErrorKind>>>,
195+
Fut: Future<Output = Result<T, SafeProviderError>>,
174196
{
175197
let retry_strategy = ExponentialBuilder::default()
176198
.with_max_times(self.max_retries)
@@ -182,8 +204,9 @@ impl<N: Network> SafeProvider<N> {
182204
)
183205
.await
184206
{
185-
Ok(res) => res,
186-
Err(_) => Err(TransportErrorKind::custom_str("total operation timeout exceeded")),
207+
Ok(Ok(res)) => Ok(res),
208+
Ok(Err(_)) => Err(SafeProviderError::RetryFail(self.max_retries + 1)),
209+
Err(_) => Err(SafeProviderError::Timeout),
187210
}
188211
}
189212
}
@@ -234,7 +257,9 @@ mod tests {
234257
.retry_with_total_timeout(|| async {
235258
call_count.fetch_add(1, Ordering::SeqCst);
236259
if call_count.load(Ordering::SeqCst) < 3 {
237-
Err(TransportErrorKind::custom_str("temporary error"))
260+
Err(SafeProviderError::RpcError(Arc::new(TransportErrorKind::custom_str(
261+
"temp error",
262+
))))
238263
} else {
239264
Ok(call_count.load(Ordering::SeqCst))
240265
}
@@ -253,14 +278,13 @@ mod tests {
253278
let result = provider
254279
.retry_with_total_timeout(|| async {
255280
call_count.fetch_add(1, Ordering::SeqCst);
256-
Err::<i32, RpcError<TransportErrorKind>>(TransportErrorKind::custom_str(
257-
"permanent error",
258-
))
281+
// permanent error
282+
Err::<i32, SafeProviderError>(SafeProviderError::Timeout)
259283
})
260284
.await;
261285

262286
let err = result.unwrap_err();
263-
assert!(err.to_string().contains("permanent error"),);
287+
assert!(matches!(err, SafeProviderError::RetryFail(3)));
264288
assert_eq!(call_count.load(Ordering::SeqCst), 3);
265289
}
266290

@@ -277,6 +301,6 @@ mod tests {
277301
.await;
278302

279303
let err = result.unwrap_err();
280-
assert!(err.to_string().contains("total operation timeout exceeded"),);
304+
assert!(matches!(err, SafeProviderError::Timeout));
281305
}
282306
}

0 commit comments

Comments
 (0)