Skip to content

Commit f874255

Browse files
committed
feat(postgres): keepalives for tcp
1 parent f5cdf33 commit f874255

File tree

8 files changed

+371
-22
lines changed

8 files changed

+371
-22
lines changed

Cargo.lock

Lines changed: 1 addition & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

sqlx-core/Cargo.toml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,12 +21,11 @@ json = ["serde", "serde_json"]
2121

2222
# for conditional compilation
2323
_rt-async-global-executor = ["async-global-executor", "_rt-async-io", "_rt-async-task"]
24-
_rt-async-io = ["async-io", "async-fs"] # see note at async-fs declaration
24+
_rt-async-io = ["async-io", "async-fs", "socket2"] # see note at async-fs declaration
2525
_rt-async-std = ["async-std", "_rt-async-io"]
2626
_rt-async-task = ["async-task"]
2727
_rt-smol = ["smol", "_rt-async-io", "_rt-async-task"]
28-
_rt-tokio = ["tokio", "tokio-stream"]
29-
28+
_rt-tokio = ["tokio", "tokio-stream", "socket2"]
3029
_tls-native-tls = ["native-tls"]
3130
_tls-rustls-aws-lc-rs = ["_tls-rustls", "rustls/aws-lc-rs", "webpki-roots"]
3231
_tls-rustls-ring-webpki = ["_tls-rustls", "rustls/ring", "webpki-roots"]
@@ -102,6 +101,7 @@ hashlink = "0.11.0"
102101
indexmap = "2.0"
103102
event-listener = "5.2.0"
104103
hashbrown = "0.16.0"
104+
socket2 = { version = "0.5", features = ["all"], optional = true }
105105

106106
thiserror.workspace = true
107107

sqlx-core/src/net/mod.rs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,5 +2,6 @@ mod socket;
22
pub mod tls;
33

44
pub use socket::{
5-
connect_tcp, connect_uds, BufferedSocket, Socket, SocketIntoBox, WithSocket, WriteBuffer,
5+
connect_tcp, connect_uds, BufferedSocket, KeepaliveConfig, Socket, SocketIntoBox, WithSocket,
6+
WriteBuffer,
67
};

sqlx-core/src/net/socket/mod.rs

Lines changed: 86 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ use std::io;
33
use std::path::Path;
44
use std::pin::Pin;
55
use std::task::{ready, Context, Poll};
6+
use std::time::Duration;
67

78
pub use buffered::{BufferedSocket, WriteBuffer};
89
use bytes::BufMut;
@@ -12,6 +13,25 @@ use crate::io::ReadBuf;
1213

1314
mod buffered;
1415

16+
/// Configuration for TCP keepalive probes on a connection.
17+
///
18+
/// All fields default to `None`, meaning the OS default is used.
19+
/// Constructing a `KeepaliveConfig::default()` and passing it enables keepalive
20+
/// with OS defaults for all parameters.
21+
#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)]
22+
pub struct KeepaliveConfig {
23+
/// Time the connection must be idle before keepalive probes begin.
24+
/// `None` means the OS default.
25+
pub idle: Option<Duration>,
26+
/// Interval between keepalive probes.
27+
/// `None` means the OS default.
28+
pub interval: Option<Duration>,
29+
/// Maximum number of failed probes before the connection is dropped.
30+
/// Only supported on Unix; ignored on other platforms.
31+
/// `None` means the OS default.
32+
pub retries: Option<u32>,
33+
}
34+
1535
pub trait Socket: Send + Sync + Unpin + 'static {
1636
fn try_read(&mut self, buf: &mut dyn ReadBuf) -> io::Result<usize>;
1737

@@ -181,23 +201,63 @@ impl<S: Socket + ?Sized> Socket for Box<S> {
181201
}
182202
}
183203

204+
#[cfg(any(feature = "_rt-tokio", feature = "_rt-async-io"))]
205+
fn build_tcp_keepalive(config: &KeepaliveConfig) -> socket2::TcpKeepalive {
206+
let mut ka = socket2::TcpKeepalive::new();
207+
208+
if let Some(idle) = config.idle {
209+
ka = ka.with_time(idle);
210+
}
211+
212+
// socket2's `with_interval` is unavailable on these platforms.
213+
#[cfg(not(any(
214+
target_os = "haiku",
215+
target_os = "openbsd",
216+
target_os = "redox",
217+
target_os = "solaris",
218+
)))]
219+
if let Some(interval) = config.interval {
220+
ka = ka.with_interval(interval);
221+
}
222+
223+
// socket2's `with_retries` is unavailable on these platforms.
224+
#[cfg(not(any(
225+
target_os = "haiku",
226+
target_os = "openbsd",
227+
target_os = "redox",
228+
target_os = "solaris",
229+
target_os = "windows",
230+
)))]
231+
if let Some(retries) = config.retries {
232+
ka = ka.with_retries(retries);
233+
}
234+
235+
ka
236+
}
237+
184238
pub async fn connect_tcp<Ws: WithSocket>(
185239
host: &str,
186240
port: u16,
241+
keepalive: Option<&KeepaliveConfig>,
187242
with_socket: Ws,
188243
) -> crate::Result<Ws::Output> {
189244
#[cfg(feature = "_rt-tokio")]
190245
if crate::rt::rt_tokio::available() {
191-
return Ok(with_socket
192-
.with_socket(tokio::net::TcpStream::connect((host, port)).await?)
193-
.await);
246+
let stream = tokio::net::TcpStream::connect((host, port)).await?;
247+
248+
if let Some(ka) = keepalive {
249+
let sock = socket2::SockRef::from(&stream);
250+
sock.set_tcp_keepalive(&build_tcp_keepalive(ka))?;
251+
}
252+
253+
return Ok(with_socket.with_socket(stream).await);
194254
}
195255

196256
cfg_if! {
197257
if #[cfg(feature = "_rt-async-io")] {
198-
Ok(with_socket.with_socket(connect_tcp_async_io(host, port).await?).await)
258+
Ok(with_socket.with_socket(connect_tcp_async_io(host, port, keepalive).await?).await)
199259
} else {
200-
crate::rt::missing_rt((host, port, with_socket))
260+
crate::rt::missing_rt((host, port, keepalive, with_socket))
201261
}
202262
}
203263
}
@@ -208,15 +268,26 @@ pub async fn connect_tcp<Ws: WithSocket>(
208268
///
209269
/// This implements the same behavior as [`tokio::net::TcpStream::connect()`].
210270
#[cfg(feature = "_rt-async-io")]
211-
async fn connect_tcp_async_io(host: &str, port: u16) -> crate::Result<impl Socket> {
271+
async fn connect_tcp_async_io(
272+
host: &str,
273+
port: u16,
274+
keepalive: Option<&KeepaliveConfig>,
275+
) -> crate::Result<impl Socket> {
212276
use async_io::Async;
213277
use std::net::{IpAddr, TcpStream, ToSocketAddrs};
214278

215279
// IPv6 addresses in URLs will be wrapped in brackets and the `url` crate doesn't trim those.
216280
let host = host.trim_matches(&['[', ']'][..]);
217281

218282
if let Ok(addr) = host.parse::<IpAddr>() {
219-
return Ok(Async::<TcpStream>::connect((addr, port)).await?);
283+
let stream = Async::<TcpStream>::connect((addr, port)).await?;
284+
285+
if let Some(ka) = keepalive {
286+
let sock = socket2::SockRef::from(stream.get_ref());
287+
sock.set_tcp_keepalive(&build_tcp_keepalive(ka))?;
288+
}
289+
290+
return Ok(stream);
220291
}
221292

222293
let host = host.to_string();
@@ -232,7 +303,14 @@ async fn connect_tcp_async_io(host: &str, port: u16) -> crate::Result<impl Socke
232303
// Loop through all the Socket Addresses that the hostname resolves to
233304
for socket_addr in addresses {
234305
match Async::<TcpStream>::connect(socket_addr).await {
235-
Ok(stream) => return Ok(stream),
306+
Ok(stream) => {
307+
if let Some(ka) = keepalive {
308+
let sock = socket2::SockRef::from(stream.get_ref());
309+
sock.set_tcp_keepalive(&build_tcp_keepalive(ka))?;
310+
}
311+
312+
return Ok(stream);
313+
}
236314
Err(e) => last_err = Some(e),
237315
}
238316
}

sqlx-mysql/src/connection/establish.rs

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,9 @@ impl MySqlConnection {
1717

1818
let handshake = match &options.socket {
1919
Some(path) => crate::net::connect_uds(path, do_handshake).await?,
20-
None => crate::net::connect_tcp(&options.host, options.port, do_handshake).await?,
20+
None => {
21+
crate::net::connect_tcp(&options.host, options.port, None, do_handshake).await?
22+
}
2123
};
2224

2325
let stream = handshake?;

sqlx-postgres/src/connection/stream.rs

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,20 @@ impl PgStream {
4444
pub(super) async fn connect(options: &PgConnectOptions) -> Result<Self, Error> {
4545
let socket_result = match options.fetch_socket() {
4646
Some(ref path) => net::connect_uds(path, MaybeUpgradeTls(options)).await?,
47-
None => net::connect_tcp(&options.host, options.port, MaybeUpgradeTls(options)).await?,
47+
None => {
48+
let keepalive = if options.keepalives {
49+
Some(&options.keepalive_config)
50+
} else {
51+
None
52+
};
53+
net::connect_tcp(
54+
&options.host,
55+
options.port,
56+
keepalive,
57+
MaybeUpgradeTls(options),
58+
)
59+
.await?
60+
}
4861
};
4962

5063
let socket = socket_result?;

sqlx-postgres/src/options/mod.rs

Lines changed: 114 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,11 @@ use std::borrow::Cow;
22
use std::env::var;
33
use std::fmt::{self, Display, Write};
44
use std::path::{Path, PathBuf};
5+
use std::time::Duration;
56

67
pub use ssl_mode::PgSslMode;
78

9+
use crate::net::KeepaliveConfig;
810
use crate::{connection::LogSettings, net::tls::CertificateInput};
911

1012
mod connect;
@@ -30,6 +32,8 @@ pub struct PgConnectOptions {
3032
pub(crate) log_settings: LogSettings,
3133
pub(crate) extra_float_digits: Option<Cow<'static, str>>,
3234
pub(crate) options: Option<String>,
35+
pub(crate) keepalives: bool,
36+
pub(crate) keepalive_config: KeepaliveConfig,
3337
}
3438

3539
impl Default for PgConnectOptions {
@@ -97,6 +101,9 @@ impl PgConnectOptions {
97101
extra_float_digits: Some("2".into()),
98102
log_settings: Default::default(),
99103
options: var("PGOPTIONS").ok(),
104+
// Matches libpq default: keepalives=1 with OS defaults for timers.
105+
keepalives: true,
106+
keepalive_config: KeepaliveConfig::default(),
100107
}
101108
}
102109

@@ -452,6 +459,85 @@ impl PgConnectOptions {
452459
self
453460
}
454461

462+
/// Enables or disables TCP keepalive on the connection.
463+
///
464+
/// This option is ignored for Unix domain sockets.
465+
///
466+
/// Keepalive is enabled by default.
467+
///
468+
/// When enabled, OS defaults are used for all timer parameters unless
469+
/// overridden by [`keepalives_idle`][Self::keepalives_idle],
470+
/// [`keepalives_interval`][Self::keepalives_interval], or
471+
/// [`keepalives_retries`][Self::keepalives_retries].
472+
///
473+
/// # Example
474+
///
475+
/// ```rust
476+
/// # use sqlx_postgres::PgConnectOptions;
477+
/// let options = PgConnectOptions::new()
478+
/// .keepalives(false);
479+
/// ```
480+
pub fn keepalives(mut self, enable: bool) -> Self {
481+
self.keepalives = enable;
482+
self
483+
}
484+
485+
/// Sets the idle time before TCP keepalive probes begin.
486+
///
487+
/// This is ignored for Unix domain sockets, or if the `keepalives`
488+
/// option is disabled.
489+
///
490+
/// # Example
491+
///
492+
/// ```rust
493+
/// # use std::time::Duration;
494+
/// # use sqlx_postgres::PgConnectOptions;
495+
/// let options = PgConnectOptions::new()
496+
/// .keepalives_idle(Duration::from_secs(60));
497+
/// ```
498+
pub fn keepalives_idle(mut self, idle: Duration) -> Self {
499+
self.keepalive_config.idle = Some(idle);
500+
self
501+
}
502+
503+
/// Sets the interval between TCP keepalive probes.
504+
///
505+
/// This is ignored for Unix domain sockets, or if the `keepalives`
506+
/// option is disabled.
507+
///
508+
/// # Example
509+
///
510+
/// ```rust
511+
/// # use std::time::Duration;
512+
/// # use sqlx_postgres::PgConnectOptions;
513+
/// let options = PgConnectOptions::new()
514+
/// .keepalives_interval(Duration::from_secs(5));
515+
/// ```
516+
pub fn keepalives_interval(mut self, interval: Duration) -> Self {
517+
self.keepalive_config.interval = Some(interval);
518+
self
519+
}
520+
521+
/// Sets the maximum number of TCP keepalive probes before the connection is dropped.
522+
///
523+
/// This is ignored for Unix domain sockets, or if the `keepalives`
524+
/// option is disabled.
525+
///
526+
/// Only supported on Unix platforms; ignored on other platforms.
527+
///
528+
/// # Example
529+
///
530+
/// ```rust
531+
/// # use sqlx_postgres::PgConnectOptions;
532+
/// let options = PgConnectOptions::new()
533+
/// .keepalives_retries(3);
534+
/// ```
535+
#[cfg(unix)]
536+
pub fn keepalives_retries(mut self, retries: u32) -> Self {
537+
self.keepalive_config.retries = Some(retries);
538+
self
539+
}
540+
455541
/// We try using a socket if hostname starts with `/` or if socket parameter
456542
/// is specified.
457543
pub(crate) fn fetch_socket(&self) -> Option<String> {
@@ -580,6 +666,34 @@ impl PgConnectOptions {
580666
pub fn get_options(&self) -> Option<&str> {
581667
self.options.as_deref()
582668
}
669+
670+
/// Get whether TCP keepalives are enabled.
671+
///
672+
/// # Example
673+
///
674+
/// ```rust
675+
/// # use sqlx_postgres::PgConnectOptions;
676+
/// let options = PgConnectOptions::new();
677+
/// assert!(options.get_keepalives());
678+
/// ```
679+
pub fn get_keepalives(&self) -> bool {
680+
self.keepalives
681+
}
682+
683+
/// Get the idle time before TCP keepalive probes begin.
684+
pub fn get_keepalives_idle(&self) -> Option<Duration> {
685+
self.keepalive_config.idle
686+
}
687+
688+
/// Get the interval between TCP keepalive probes.
689+
pub fn get_keepalives_interval(&self) -> Option<Duration> {
690+
self.keepalive_config.interval
691+
}
692+
693+
/// Get the maximum number of TCP keepalive probes.
694+
pub fn get_keepalives_retries(&self) -> Option<u32> {
695+
self.keepalive_config.retries
696+
}
583697
}
584698

585699
fn default_host(port: u16) -> String {

0 commit comments

Comments
 (0)