Skip to content

Commit af7ce01

Browse files
committed
feat: add total timeout
1 parent 8bd13b1 commit af7ce01

File tree

1 file changed

+55
-11
lines changed

1 file changed

+55
-11
lines changed

src/safe_provider.rs

Lines changed: 55 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -103,8 +103,9 @@ impl<N: Network> SafeProvider<N> {
103103
) -> Result<Option<N::BlockResponse>, RpcError<TransportErrorKind>> {
104104
debug!("SafeProvider eth_getBlockByNumber called with number: {:?}", number);
105105
let provider = self.provider.clone();
106-
let result =
107-
self.retry_with_timeout(|| async { provider.get_block_by_number(number).await }).await;
106+
let result = self
107+
.retry_with_total_timeout(|| async { provider.get_block_by_number(number).await })
108+
.await;
108109
if let Err(e) = &result {
109110
error!("SafeProvider eth_getByBlockNumber failed: {}", e);
110111
}
@@ -119,7 +120,8 @@ impl<N: Network> SafeProvider<N> {
119120
pub async fn get_block_number(&self) -> Result<u64, RpcError<TransportErrorKind>> {
120121
debug!("SafeProvider eth_getBlockNumber called");
121122
let provider = self.provider.clone();
122-
let result = self.retry_with_timeout(|| async { provider.get_block_number().await }).await;
123+
let result =
124+
self.retry_with_total_timeout(|| async { provider.get_block_number().await }).await;
123125
if let Err(e) = &result {
124126
error!("SafeProvider eth_getBlockNumber failed: {}", e);
125127
}
@@ -137,8 +139,9 @@ impl<N: Network> SafeProvider<N> {
137139
) -> Result<Option<N::BlockResponse>, RpcError<TransportErrorKind>> {
138140
debug!("SafeProvider eth_getBlockByHash called with hash: {:?}", hash);
139141
let provider = self.provider.clone();
140-
let result =
141-
self.retry_with_timeout(|| async { provider.get_block_by_hash(hash).await }).await;
142+
let result = self
143+
.retry_with_total_timeout(|| async { provider.get_block_by_hash(hash).await })
144+
.await;
142145
if let Err(e) = &result {
143146
error!("SafeProvider eth_getBlockByHash failed: {}", e);
144147
}
@@ -156,7 +159,8 @@ impl<N: Network> SafeProvider<N> {
156159
) -> Result<Vec<Log>, RpcError<TransportErrorKind>> {
157160
debug!("eth_getLogs called with filter: {:?}", filter);
158161
let provider = self.provider.clone();
159-
let result = self.retry_with_timeout(|| async { provider.get_logs(filter).await }).await;
162+
let result =
163+
self.retry_with_total_timeout(|| async { provider.get_logs(filter).await }).await;
160164
if let Err(e) = &result {
161165
error!("SafeProvider eth_getLogs failed: {}", e);
162166
}
@@ -173,19 +177,20 @@ impl<N: Network> SafeProvider<N> {
173177
) -> Result<Subscription<N::HeaderResponse>, RpcError<TransportErrorKind>> {
174178
debug!("eth_subscribe called");
175179
let provider = self.provider.clone();
176-
let result = self.retry_with_timeout(|| async { provider.subscribe_blocks().await }).await;
180+
let result =
181+
self.retry_with_total_timeout(|| async { provider.subscribe_blocks().await }).await;
177182
if let Err(e) = &result {
178183
error!("SafeProvider eth_subscribe failed: {}", e);
179184
}
180185
result
181186
}
182187

183-
/// Execute `operation` with exponential backoff and a total timeout.
188+
/// Execute `operation` with exponential backoff respecting only the backoff budget.
184189
///
185190
/// # Errors
186191
/// Returns `RpcError<TransportErrorKind>` if all attempts fail or the
187-
/// total delay exceeds the configured timeout.
188-
pub(crate) async fn retry_with_timeout<T, F, Fut>(
192+
/// cumulative backoff delay exceeds the configured budget.
193+
async fn retry_with_timeout<T, F, Fut>(
189194
&self,
190195
operation: F,
191196
) -> Result<T, RpcError<TransportErrorKind>>
@@ -195,18 +200,42 @@ impl<N: Network> SafeProvider<N> {
195200
{
196201
let retry_strategy = ExponentialBuilder::default()
197202
.with_max_times(self.max_retries)
198-
.with_total_delay(Some(self.max_timeout))
199203
.with_min_delay(self.retry_interval);
200204

201205
operation.retry(retry_strategy).sleep(tokio::time::sleep).await
202206
}
207+
208+
/// Execute `operation` with exponential backoff and a true total timeout.
209+
///
210+
/// Wraps the retry logic with `tokio::time::timeout(self.max_timeout, ...)` so
211+
/// the entire operation (including time spent inside the RPC call) cannot exceed
212+
/// `max_timeout`.
213+
///
214+
/// # Errors
215+
/// - Returns `RpcError<TransportErrorKind>` with message "total operation timeout exceeded" if
216+
/// the overall timeout elapses.
217+
/// - Propagates any `RpcError<TransportErrorKind>` from the underlying retries.
218+
async fn retry_with_total_timeout<T, F, Fut>(
219+
&self,
220+
operation: F,
221+
) -> Result<T, RpcError<TransportErrorKind>>
222+
where
223+
F: Fn() -> Fut,
224+
Fut: Future<Output = Result<T, RpcError<TransportErrorKind>>>,
225+
{
226+
match tokio::time::timeout(self.max_timeout, self.retry_with_timeout(operation)).await {
227+
Ok(res) => res,
228+
Err(_) => Err(TransportErrorKind::custom_str("total operation timeout exceeded")),
229+
}
230+
}
203231
}
204232

205233
#[cfg(test)]
206234
mod tests {
207235
use super::*;
208236
use alloy::network::Ethereum;
209237
use std::sync::{Arc, Mutex};
238+
use tokio::time::sleep;
210239

211240
fn create_test_provider(
212241
timeout: Duration,
@@ -297,4 +326,19 @@ mod tests {
297326
assert!(result.is_err());
298327
assert_eq!(*call_count.lock().unwrap(), 3);
299328
}
329+
330+
#[tokio::test]
331+
async fn test_retry_with_timeout_respects_total_delay() {
332+
let max_timeout = Duration::from_millis(50);
333+
let provider = create_test_provider(max_timeout, 10, Duration::from_millis(1));
334+
335+
let result = provider
336+
.retry_with_total_timeout(move || async move {
337+
sleep(max_timeout + Duration::from_millis(10)).await;
338+
Ok(42)
339+
})
340+
.await;
341+
342+
assert!(result.is_err());
343+
}
300344
}

0 commit comments

Comments
 (0)