Skip to content

Commit 76741c5

Browse files
author
Preslav Le
committed
Add optional ping timeout
1 parent 452da1a commit 76741c5

File tree

3 files changed

+146
-2
lines changed

3 files changed

+146
-2
lines changed

sqlx-core/src/pool/connection.rs

Lines changed: 27 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
use std::fmt::{self, Debug, Formatter};
22
use std::future::{self, Future};
3+
use std::io;
34
use std::ops::{Deref, DerefMut};
45
use std::sync::Arc;
56
use std::time::{Duration, Instant};
@@ -15,6 +16,29 @@ use crate::pool::options::PoolConnectionMetadata;
1516

1617
const CLOSE_ON_DROP_TIMEOUT: Duration = Duration::from_secs(5);
1718

19+
/// Helper function to execute a ping with an optional timeout.
20+
///
21+
/// If `timeout` is `Some(Duration::ZERO)`, immediately returns a timeout error
22+
/// for deterministic testing behavior.
23+
async fn ping_with_timeout<F>(timeout: Option<Duration>, ping: F) -> Result<(), Error>
24+
where
25+
F: Future<Output = Result<(), Error>>,
26+
{
27+
match timeout {
28+
Some(timeout) if timeout.is_zero() => {
29+
// Duration::ZERO means "always timeout immediately"
30+
// This provides deterministic behavior for testing
31+
Err(Error::Io(io::Error::new(io::ErrorKind::TimedOut, "ping timed out")))
32+
}
33+
Some(timeout) => {
34+
crate::rt::timeout(timeout, ping)
35+
.await
36+
.unwrap_or_else(|_| Err(Error::Io(io::Error::new(io::ErrorKind::TimedOut, "ping timed out"))))
37+
}
38+
None => ping.await,
39+
}
40+
}
41+
1842
/// A connection managed by a [`Pool`][crate::pool::Pool].
1943
///
2044
/// Will be returned to the pool on-drop.
@@ -311,7 +335,8 @@ impl<DB: Database> Floating<DB, Live<DB>> {
311335
// returned to the pool; also of course, if it was dropped due to an error
312336
// this is simply a band-aid as SQLx-next connections should be able
313337
// to recover from cancellations
314-
if let Err(error) = self.raw.ping().await {
338+
let ping_result = ping_with_timeout(self.guard.pool.options.ping_timeout, self.raw.ping()).await;
339+
if let Err(error) = ping_result {
315340
tracing::warn!(
316341
%error,
317342
"error occurred while testing the connection on-release",
@@ -370,7 +395,7 @@ impl<DB: Database> Floating<DB, Idle<DB>> {
370395
}
371396

372397
pub async fn ping(&mut self) -> Result<(), Error> {
373-
self.live.raw.ping().await
398+
ping_with_timeout(self.guard.pool.options.ping_timeout, self.live.raw.ping()).await
374399
}
375400

376401
pub fn into_live(self) -> Floating<DB, Live<DB>> {

sqlx-core/src/pool/options.rs

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,7 @@ pub struct PoolOptions<DB: Database> {
8282
pub(crate) min_connections: u32,
8383
pub(crate) max_lifetime: Option<Duration>,
8484
pub(crate) idle_timeout: Option<Duration>,
85+
pub(crate) ping_timeout: Option<Duration>,
8586
pub(crate) fair: bool,
8687

8788
pub(crate) parent_pool: Option<Pool<DB>>,
@@ -105,6 +106,7 @@ impl<DB: Database> Clone for PoolOptions<DB> {
105106
min_connections: self.min_connections,
106107
max_lifetime: self.max_lifetime,
107108
idle_timeout: self.idle_timeout,
109+
ping_timeout: self.ping_timeout,
108110
fair: self.fair,
109111
parent_pool: self.parent_pool.clone(),
110112
}
@@ -160,6 +162,7 @@ impl<DB: Database> PoolOptions<DB> {
160162
acquire_timeout: Duration::from_secs(30),
161163
idle_timeout: Some(Duration::from_secs(10 * 60)),
162164
max_lifetime: Some(Duration::from_secs(30 * 60)),
165+
ping_timeout: None,
163166
fair: true,
164167
parent_pool: None,
165168
}
@@ -307,6 +310,21 @@ impl<DB: Database> PoolOptions<DB> {
307310
self.idle_timeout
308311
}
309312

313+
/// Set the timeout for pinging connections when they are returned to the pool.
314+
///
315+
/// If the ping takes longer than this, the connection is closed and a warning is logged.
316+
///
317+
/// When set to `None` (the default), there is no timeout.
318+
pub fn ping_timeout(mut self, timeout: impl Into<Option<Duration>>) -> Self {
319+
self.ping_timeout = timeout.into();
320+
self
321+
}
322+
323+
/// Get the timeout for pinging connections when they are returned to the pool.
324+
pub fn get_ping_timeout(&self) -> Option<Duration> {
325+
self.ping_timeout
326+
}
327+
310328
/// If true, the health of a connection will be verified by a call to [`Connection::ping`]
311329
/// before returning the connection.
312330
///
@@ -590,6 +608,7 @@ impl<DB: Database> Debug for PoolOptions<DB> {
590608
.field("connect_timeout", &self.acquire_timeout)
591609
.field("max_lifetime", &self.max_lifetime)
592610
.field("idle_timeout", &self.idle_timeout)
611+
.field("ping_timeout", &self.ping_timeout)
593612
.field("test_before_acquire", &self.test_before_acquire)
594613
.finish()
595614
}

tests/any/pool.rs

Lines changed: 100 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -268,3 +268,103 @@ async fn test_connection_maintenance() -> anyhow::Result<()> {
268268

269269
Ok(())
270270
}
271+
272+
#[sqlx_macros::test]
273+
async fn pool_ping_timeout_on_return() -> anyhow::Result<()> {
274+
sqlx::any::install_default_drivers();
275+
276+
// With a reasonable timeout, connections should be returned to the pool
277+
let pool = AnyPoolOptions::new()
278+
.ping_timeout(Duration::from_secs(10))
279+
.max_connections(1)
280+
.connect(&dotenvy::var("DATABASE_URL")?)
281+
.await?;
282+
283+
let mut conn = pool.acquire().await?;
284+
sqlx::query("SELECT 1").fetch_one(&mut *conn).await?;
285+
conn.return_to_pool().await;
286+
287+
assert_eq!(pool.num_idle(), 1);
288+
drop(pool);
289+
290+
// With a zero timeout, connections should be discarded on return
291+
let pool = AnyPoolOptions::new()
292+
.ping_timeout(Duration::ZERO)
293+
.max_connections(1)
294+
.connect(&dotenvy::var("DATABASE_URL")?)
295+
.await?;
296+
297+
let mut conn = pool.acquire().await?;
298+
sqlx::query("SELECT 1").fetch_one(&mut *conn).await?;
299+
conn.return_to_pool().await;
300+
301+
assert_eq!(pool.num_idle(), 0);
302+
303+
Ok(())
304+
}
305+
306+
#[sqlx_macros::test]
307+
async fn pool_ping_timeout_on_acquire() -> anyhow::Result<()> {
308+
sqlx::any::install_default_drivers();
309+
310+
// Helper to wait for idle connections
311+
async fn wait_for_idle(pool: &sqlx::AnyPool, expected: usize) {
312+
for _ in 0..100 {
313+
if pool.num_idle() == expected {
314+
return;
315+
}
316+
sqlx_core::rt::sleep(Duration::from_millis(50)).await;
317+
}
318+
panic!("timed out waiting for {} idle connections, got {}", expected, pool.num_idle());
319+
}
320+
321+
// With a reasonable timeout, idle connections should be used
322+
let connect_count = Arc::new(AtomicUsize::new(0));
323+
let connect_count_ = connect_count.clone();
324+
let pool = AnyPoolOptions::new()
325+
.ping_timeout(Duration::from_secs(10))
326+
.test_before_acquire(true)
327+
.min_connections(1)
328+
.max_connections(1)
329+
.after_connect(move |_conn, _meta| {
330+
connect_count_.fetch_add(1, Ordering::SeqCst);
331+
Box::pin(async { Ok(()) })
332+
})
333+
.connect(&dotenvy::var("DATABASE_URL")?)
334+
.await?;
335+
336+
wait_for_idle(&pool, 1).await;
337+
assert_eq!(connect_count.load(Ordering::SeqCst), 1);
338+
339+
// Acquire should reuse the same connection
340+
let _conn = pool.acquire().await?;
341+
assert_eq!(connect_count.load(Ordering::SeqCst), 1);
342+
drop(pool);
343+
344+
// With a zero timeout, idle connections should fail ping and be replaced
345+
let connect_count = Arc::new(AtomicUsize::new(0));
346+
let connect_count_ = connect_count.clone();
347+
let pool = AnyPoolOptions::new()
348+
.ping_timeout(Duration::ZERO)
349+
.test_before_acquire(true)
350+
.min_connections(1)
351+
.max_connections(1)
352+
// Disable timeouts to prevent the reaper from interfering
353+
.idle_timeout(None)
354+
.max_lifetime(None)
355+
.after_connect(move |_conn, _meta| {
356+
connect_count_.fetch_add(1, Ordering::SeqCst);
357+
Box::pin(async { Ok(()) })
358+
})
359+
.connect_lazy(&dotenvy::var("DATABASE_URL")?)?;
360+
361+
wait_for_idle(&pool, 1).await;
362+
assert_eq!(connect_count.load(Ordering::SeqCst), 1);
363+
364+
// Acquire - ping will fail and the caller will go ahead and open a new
365+
// connection. Importantly, the caller won't observe any error.
366+
let _conn = pool.acquire().await?;
367+
assert_eq!(connect_count.load(Ordering::SeqCst), 2);
368+
369+
Ok(())
370+
}

0 commit comments

Comments
 (0)