Skip to content

Commit 7f9520c

Browse files
committed
feat(postgres): keepalives for tcp
1 parent e723688 commit 7f9520c

File tree

8 files changed

+352
-14
lines changed

8 files changed

+352
-14
lines changed

Cargo.lock

Lines changed: 1 addition & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

sqlx-core/Cargo.toml

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,8 @@ any = []
1919
json = ["serde", "serde_json"]
2020

2121
# for conditional compilation
22-
_rt-async-std = ["async-std", "async-io"]
23-
_rt-tokio = ["tokio", "tokio-stream"]
22+
_rt-async-std = ["async-std", "async-io", "socket2"]
23+
_rt-tokio = ["tokio", "tokio-stream", "socket2"]
2424
_tls-native-tls = ["native-tls"]
2525
_tls-rustls-aws-lc-rs = ["_tls-rustls", "rustls/aws-lc-rs", "webpki-roots"]
2626
_tls-rustls-ring-webpki = ["_tls-rustls", "rustls/ring", "webpki-roots"]
@@ -83,6 +83,7 @@ hashlink = "0.10.0"
8383
indexmap = "2.0"
8484
event-listener = "5.2.0"
8585
hashbrown = "0.15.0"
86+
socket2 = { version = "0.5", features = ["all"], optional = true }
8687

8788
[dev-dependencies]
8889
sqlx = { workspace = true, features = ["postgres", "sqlite", "mysql", "migrate", "macros", "time", "uuid"] }

sqlx-core/src/net/mod.rs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,5 +2,6 @@ mod socket;
22
pub mod tls;
33

44
pub use socket::{
5-
connect_tcp, connect_uds, BufferedSocket, Socket, SocketIntoBox, WithSocket, WriteBuffer,
5+
connect_tcp, connect_uds, BufferedSocket, KeepaliveConfig, Socket, SocketIntoBox, WithSocket,
6+
WriteBuffer,
67
};

sqlx-core/src/net/socket/mod.rs

Lines changed: 67 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ use std::io;
33
use std::path::Path;
44
use std::pin::Pin;
55
use std::task::{ready, Context, Poll};
6+
use std::time::Duration;
67

78
use bytes::BufMut;
89

@@ -12,6 +13,25 @@ use crate::io::ReadBuf;
1213

1314
mod buffered;
1415

16+
/// Configuration for TCP keepalive probes on a connection.
17+
///
18+
/// All fields default to `None`, meaning the OS default is used.
19+
/// Constructing a `KeepaliveConfig::default()` and passing it enables keepalive
20+
/// with OS defaults for all parameters.
21+
#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)]
22+
pub struct KeepaliveConfig {
23+
/// Time the connection must be idle before keepalive probes begin.
24+
/// `None` means the OS default.
25+
pub idle: Option<Duration>,
26+
/// Interval between keepalive probes.
27+
/// `None` means the OS default.
28+
pub interval: Option<Duration>,
29+
/// Maximum number of failed probes before the connection is dropped.
30+
/// Only supported on Unix; ignored on other platforms.
31+
/// `None` means the OS default.
32+
pub retries: Option<u32>,
33+
}
34+
1535
pub trait Socket: Send + Sync + Unpin + 'static {
1636
fn try_read(&mut self, buf: &mut dyn ReadBuf) -> io::Result<usize>;
1737

@@ -184,9 +204,44 @@ impl<S: Socket + ?Sized> Socket for Box<S> {
184204
}
185205
}
186206

207+
#[cfg(any(feature = "_rt-tokio", feature = "_rt-async-std"))]
208+
fn build_tcp_keepalive(config: &KeepaliveConfig) -> socket2::TcpKeepalive {
209+
let mut ka = socket2::TcpKeepalive::new();
210+
211+
if let Some(idle) = config.idle {
212+
ka = ka.with_time(idle);
213+
}
214+
215+
// socket2's `with_interval` is unavailable on these platforms.
216+
#[cfg(not(any(
217+
target_os = "haiku",
218+
target_os = "openbsd",
219+
target_os = "redox",
220+
target_os = "solaris",
221+
)))]
222+
if let Some(interval) = config.interval {
223+
ka = ka.with_interval(interval);
224+
}
225+
226+
// socket2's `with_retries` is unavailable on these platforms.
227+
#[cfg(not(any(
228+
target_os = "haiku",
229+
target_os = "openbsd",
230+
target_os = "redox",
231+
target_os = "solaris",
232+
target_os = "windows",
233+
)))]
234+
if let Some(retries) = config.retries {
235+
ka = ka.with_retries(retries);
236+
}
237+
238+
ka
239+
}
240+
187241
pub async fn connect_tcp<Ws: WithSocket>(
188242
host: &str,
189243
port: u16,
244+
keepalive: Option<&KeepaliveConfig>,
190245
with_socket: Ws,
191246
) -> crate::Result<Ws::Output> {
192247
// IPv6 addresses in URLs will be wrapped in brackets and the `url` crate doesn't trim those.
@@ -199,6 +254,11 @@ pub async fn connect_tcp<Ws: WithSocket>(
199254
let stream = TcpStream::connect((host, port)).await?;
200255
stream.set_nodelay(true)?;
201256

257+
if let Some(ka) = keepalive {
258+
let sock = socket2::SockRef::from(&stream);
259+
sock.set_tcp_keepalive(&build_tcp_keepalive(ka))?;
260+
}
261+
202262
return Ok(with_socket.with_socket(stream).await);
203263
}
204264

@@ -216,6 +276,12 @@ pub async fn connect_tcp<Ws: WithSocket>(
216276
.await
217277
.and_then(|s| {
218278
s.get_ref().set_nodelay(true)?;
279+
280+
if let Some(ka) = keepalive {
281+
let sock = socket2::SockRef::from(s.get_ref());
282+
sock.set_tcp_keepalive(&build_tcp_keepalive(ka))?;
283+
}
284+
219285
Ok(s)
220286
});
221287
match stream {
@@ -238,7 +304,7 @@ pub async fn connect_tcp<Ws: WithSocket>(
238304

239305
#[cfg(not(feature = "_rt-async-std"))]
240306
{
241-
crate::rt::missing_rt((host, port, with_socket))
307+
crate::rt::missing_rt((host, port, keepalive, with_socket))
242308
}
243309
}
244310

sqlx-mysql/src/connection/establish.rs

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,9 @@ impl MySqlConnection {
1818

1919
let handshake = match &options.socket {
2020
Some(path) => crate::net::connect_uds(path, do_handshake).await?,
21-
None => crate::net::connect_tcp(&options.host, options.port, do_handshake).await?,
21+
None => {
22+
crate::net::connect_tcp(&options.host, options.port, None, do_handshake).await?
23+
}
2224
};
2325

2426
let stream = handshake?;

sqlx-postgres/src/connection/stream.rs

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,20 @@ impl PgStream {
4444
pub(super) async fn connect(options: &PgConnectOptions) -> Result<Self, Error> {
4545
let socket_result = match options.fetch_socket() {
4646
Some(ref path) => net::connect_uds(path, MaybeUpgradeTls(options)).await?,
47-
None => net::connect_tcp(&options.host, options.port, MaybeUpgradeTls(options)).await?,
47+
None => {
48+
let keepalive = if options.keepalives {
49+
Some(&options.keepalive_config)
50+
} else {
51+
None
52+
};
53+
net::connect_tcp(
54+
&options.host,
55+
options.port,
56+
keepalive,
57+
MaybeUpgradeTls(options),
58+
)
59+
.await?
60+
}
4861
};
4962

5063
let socket = socket_result?;

sqlx-postgres/src/options/mod.rs

Lines changed: 114 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,11 @@ use std::borrow::Cow;
22
use std::env::var;
33
use std::fmt::{Display, Write};
44
use std::path::{Path, PathBuf};
5+
use std::time::Duration;
56

67
pub use ssl_mode::PgSslMode;
78

9+
use crate::net::KeepaliveConfig;
810
use crate::{connection::LogSettings, net::tls::CertificateInput};
911

1012
mod connect;
@@ -30,6 +32,8 @@ pub struct PgConnectOptions {
3032
pub(crate) log_settings: LogSettings,
3133
pub(crate) extra_float_digits: Option<Cow<'static, str>>,
3234
pub(crate) options: Option<String>,
35+
pub(crate) keepalives: bool,
36+
pub(crate) keepalive_config: KeepaliveConfig,
3337
}
3438

3539
impl Default for PgConnectOptions {
@@ -90,6 +94,9 @@ impl PgConnectOptions {
9094
extra_float_digits: Some("2".into()),
9195
log_settings: Default::default(),
9296
options: var("PGOPTIONS").ok(),
97+
// Matches libpq default: keepalives=1 with OS defaults for timers.
98+
keepalives: true,
99+
keepalive_config: KeepaliveConfig::default(),
93100
}
94101
}
95102

@@ -441,6 +448,85 @@ impl PgConnectOptions {
441448
self
442449
}
443450

451+
/// Enables or disables TCP keepalive on the connection.
452+
///
453+
/// This option is ignored for Unix domain sockets.
454+
///
455+
/// Keepalive is enabled by default.
456+
///
457+
/// When enabled, OS defaults are used for all timer parameters unless
458+
/// overridden by [`keepalives_idle`][Self::keepalives_idle],
459+
/// [`keepalives_interval`][Self::keepalives_interval], or
460+
/// [`keepalives_retries`][Self::keepalives_retries].
461+
///
462+
/// # Example
463+
///
464+
/// ```rust
465+
/// # use sqlx_postgres::PgConnectOptions;
466+
/// let options = PgConnectOptions::new()
467+
/// .keepalives(false);
468+
/// ```
469+
pub fn keepalives(mut self, enable: bool) -> Self {
470+
self.keepalives = enable;
471+
self
472+
}
473+
474+
/// Sets the idle time before TCP keepalive probes begin.
475+
///
476+
/// This is ignored for Unix domain sockets, or if the `keepalives`
477+
/// option is disabled.
478+
///
479+
/// # Example
480+
///
481+
/// ```rust
482+
/// # use std::time::Duration;
483+
/// # use sqlx_postgres::PgConnectOptions;
484+
/// let options = PgConnectOptions::new()
485+
/// .keepalives_idle(Duration::from_secs(60));
486+
/// ```
487+
pub fn keepalives_idle(mut self, idle: Duration) -> Self {
488+
self.keepalive_config.idle = Some(idle);
489+
self
490+
}
491+
492+
/// Sets the interval between TCP keepalive probes.
493+
///
494+
/// This is ignored for Unix domain sockets, or if the `keepalives`
495+
/// option is disabled.
496+
///
497+
/// # Example
498+
///
499+
/// ```rust
500+
/// # use std::time::Duration;
501+
/// # use sqlx_postgres::PgConnectOptions;
502+
/// let options = PgConnectOptions::new()
503+
/// .keepalives_interval(Duration::from_secs(5));
504+
/// ```
505+
pub fn keepalives_interval(mut self, interval: Duration) -> Self {
506+
self.keepalive_config.interval = Some(interval);
507+
self
508+
}
509+
510+
/// Sets the maximum number of TCP keepalive probes before the connection is dropped.
511+
///
512+
/// This is ignored for Unix domain sockets, or if the `keepalives`
513+
/// option is disabled.
514+
///
515+
/// Only supported on Unix platforms; ignored on other platforms.
516+
///
517+
/// # Example
518+
///
519+
/// ```rust
520+
/// # use sqlx_postgres::PgConnectOptions;
521+
/// let options = PgConnectOptions::new()
522+
/// .keepalives_retries(3);
523+
/// ```
524+
#[cfg(unix)]
525+
pub fn keepalives_retries(mut self, retries: u32) -> Self {
526+
self.keepalive_config.retries = Some(retries);
527+
self
528+
}
529+
444530
/// We try using a socket if hostname starts with `/` or if socket parameter
445531
/// is specified.
446532
pub(crate) fn fetch_socket(&self) -> Option<String> {
@@ -569,6 +655,34 @@ impl PgConnectOptions {
569655
pub fn get_options(&self) -> Option<&str> {
570656
self.options.as_deref()
571657
}
658+
659+
/// Get whether TCP keepalives are enabled.
660+
///
661+
/// # Example
662+
///
663+
/// ```rust
664+
/// # use sqlx_postgres::PgConnectOptions;
665+
/// let options = PgConnectOptions::new();
666+
/// assert!(options.get_keepalives());
667+
/// ```
668+
pub fn get_keepalives(&self) -> bool {
669+
self.keepalives
670+
}
671+
672+
/// Get the idle time before TCP keepalive probes begin.
673+
pub fn get_keepalives_idle(&self) -> Option<Duration> {
674+
self.keepalive_config.idle
675+
}
676+
677+
/// Get the interval between TCP keepalive probes.
678+
pub fn get_keepalives_interval(&self) -> Option<Duration> {
679+
self.keepalive_config.interval
680+
}
681+
682+
/// Get the maximum number of TCP keepalive probes.
683+
pub fn get_keepalives_retries(&self) -> Option<u32> {
684+
self.keepalive_config.retries
685+
}
572686
}
573687

574688
fn default_host(port: u16) -> String {

0 commit comments

Comments
 (0)