Skip to content

Commit e8f1020

Browse files
authored
Merge pull request #180 from cipherstash/optional-connection-timeout
fix: make connection timeout optional
2 parents 5c5a1ee + 3274f39 commit e8f1020

File tree

10 files changed

+101
-33
lines changed

10 files changed

+101
-33
lines changed

packages/cipherstash-proxy/src/config/database.rs

Lines changed: 3 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,7 @@ pub struct DatabaseConfig {
1919
#[serde(deserialize_with = "protected_string_deserializer")]
2020
password: Protected<String>,
2121

22-
#[serde(default = "DatabaseConfig::default_connection_timeout")]
23-
pub connection_timeout: u64,
22+
pub connection_timeout: Option<u64>,
2423

2524
#[serde(default)]
2625
pub with_tls_verification: bool,
@@ -41,11 +40,6 @@ impl DatabaseConfig {
4140
5432
4241
}
4342

44-
// 5 minutes
45-
pub const fn default_connection_timeout() -> u64 {
46-
1000 * 60 * 5
47-
}
48-
4943
pub const fn default_config_reload_interval() -> u64 {
5044
60
5145
}
@@ -73,8 +67,8 @@ impl DatabaseConfig {
7367
self.password.to_owned().risky_unwrap()
7468
}
7569

76-
pub fn connection_timeout(&self) -> Duration {
77-
Duration::from_millis(self.connection_timeout)
70+
pub fn connection_timeout(&self) -> Option<Duration> {
71+
self.connection_timeout.map(Duration::from_millis)
7872
}
7973

8074
pub fn server_name(&self) -> Result<ServerName, Error> {

packages/cipherstash-proxy/src/connect/mod.rs

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ use tokio::{
1414
use tokio_postgres::Client;
1515
use tracing::{debug, error, info, warn};
1616

17+
const TCP_USER_TIMEOUT: Duration = Duration::from_secs(10);
1718
const TCP_KEEPALIVE_INTERVAL: Duration = Duration::from_secs(5);
1819
const TCP_KEEPALIVE_TIME: Duration = Duration::from_secs(5);
1920
const TCP_KEEPALIVE_RETRIES: u32 = 5;
@@ -120,6 +121,17 @@ pub fn configure(stream: &TcpStream) {
120121
);
121122
});
122123

124+
#[cfg(target_os = "linux")]
125+
match sock_ref.set_tcp_user_timeout(Some(TCP_USER_TIMEOUT)) {
126+
Ok(_) => (),
127+
Err(err) => {
128+
warn!(
129+
msg = "Error configuring tcp_user_timeout for connection",
130+
error = err.to_string()
131+
);
132+
}
133+
}
134+
123135
match sock_ref.set_keepalive(true) {
124136
Ok(_) => {
125137
let params = &TcpKeepalive::new()

packages/cipherstash-proxy/src/error.rs

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,8 @@ use crate::{postgresql::Column, Identifier};
22
use bytes::BytesMut;
33
use cipherstash_client::encryption;
44
use metrics_exporter_prometheus::BuildError;
5-
use std::io;
5+
use std::{io, time::Duration};
66
use thiserror::Error;
7-
use tokio::time::error::Elapsed;
87

98
const ERROR_DOC_BASE_URL: &str = "https://github.com/cipherstash/proxy/blob/main/docs/errors.md";
109

@@ -25,8 +24,8 @@ pub enum Error {
2524
#[error("Connection closed by client")]
2625
ConnectionClosed,
2726

28-
#[error("Connection timed out")]
29-
ConnectionTimeout(#[from] Elapsed),
27+
#[error("Connection timed out after {} ms", duration.as_secs())]
28+
ConnectionTimeout { duration: Duration },
3029

3130
#[error("Error creating connection")]
3231
DatabaseConnection,

packages/cipherstash-proxy/src/main.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -111,8 +111,8 @@ fn main() -> Result<(), Box<dyn std::error::Error>> {
111111
Error::CancelRequest => {
112112
info!(msg = "Database connection closed after cancel request");
113113
}
114-
Error::ConnectionTimeout(_) => {
115-
warn!(msg = "Database connection timeout");
114+
Error::ConnectionTimeout{..} => {
115+
warn!(msg = "Database connection timeout", error = err.to_string());
116116
}
117117
_ => {
118118
error!(msg = "Database connection error", error = err.to_string());

packages/cipherstash-proxy/src/postgresql/backend.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,7 @@ where
8181
pub async fn rewrite(&mut self) -> Result<(), Error> {
8282
let connection_timeout = self.encrypt.config.database.connection_timeout();
8383

84-
let (code, mut bytes) = protocol::read_message_with_timeout(
84+
let (code, mut bytes) = protocol::read_message(
8585
&mut self.server_reader,
8686
self.context.client_id,
8787
connection_timeout,

packages/cipherstash-proxy/src/postgresql/frontend.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,7 @@ where
7676

7777
pub async fn rewrite(&mut self) -> Result<(), Error> {
7878
let connection_timeout = self.encrypt.config.database.connection_timeout();
79-
let (code, mut bytes) = protocol::read_message_with_timeout(
79+
let (code, mut bytes) = protocol::read_message(
8080
&mut self.client_reader,
8181
self.context.client_id,
8282
connection_timeout,

packages/cipherstash-proxy/src/postgresql/handler.rs

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,7 @@ pub async fn handler(
7070
);
7171

7272
loop {
73-
let startup_message = startup::read_message_with_timeout(
73+
let startup_message = startup::read_message(
7474
&mut client_stream,
7575
encrypt.config.database.connection_timeout(),
7676
)
@@ -125,8 +125,7 @@ pub async fn handler(
125125

126126
let connection_timeout = encrypt.config.database.connection_timeout();
127127
let (_code, bytes) =
128-
protocol::read_message_with_timeout(&mut client_stream, client_id, connection_timeout)
129-
.await?;
128+
protocol::read_message(&mut client_stream, client_id, connection_timeout).await?;
130129

131130
let password_message = PasswordMessage::try_from(&bytes)?;
132131

packages/cipherstash-proxy/src/postgresql/messages/error_response.rs

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ pub const CODE_INVALID_PASSWORD: &str = "28P01";
1616
pub const CODE_RAISE_EXCEPTION: &str = "P0001";
1717
pub const CODE_SYNTAX_ERROR: &str = "42601";
1818
pub const CODE_INVALID_TEXT_REPRESENTATION: &str = "22P02";
19+
pub const CODE_IDLE_SESSION_TIMEOUT: &str = "57P05";
1920

2021
///
2122
/// ErrorResponse (B)
@@ -58,6 +59,29 @@ pub enum ErrorResponseCode {
5859
}
5960

6061
impl ErrorResponse {
62+
pub fn connection_timeout() -> Self {
63+
Self {
64+
fields: vec![
65+
Field {
66+
code: ErrorResponseCode::Severity,
67+
value: "FATAL".to_string(),
68+
},
69+
Field {
70+
code: ErrorResponseCode::SeverityLegacy,
71+
value: "FATAL".to_string(),
72+
},
73+
Field {
74+
code: ErrorResponseCode::Code,
75+
value: CODE_IDLE_SESSION_TIMEOUT.to_string(),
76+
},
77+
Field {
78+
code: ErrorResponseCode::Message,
79+
value: "Connection timeout".to_string(),
80+
},
81+
],
82+
}
83+
}
84+
6185
pub fn invalid_password(message: &str) -> Self {
6286
Self {
6387
fields: vec![

packages/cipherstash-proxy/src/postgresql/protocol.rs

Lines changed: 23 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -81,18 +81,37 @@ pub async fn read_auth_message<S: AsyncRead + Unpin>(
8181
Authentication::try_from(&bytes)
8282
}
8383

84+
///
85+
/// Reads a Postgres message from client with an optional timeout
86+
///
87+
/// Timeout values are in config
88+
///
89+
///
90+
pub async fn read_message<S: AsyncRead + Unpin>(
91+
mut stream: S,
92+
client_id: i32,
93+
connection_timeout: Option<Duration>,
94+
) -> Result<(Code, BytesMut), Error> {
95+
match connection_timeout {
96+
Some(duration) => read_message_with_timeout(stream, client_id, duration).await,
97+
None => read(&mut stream, client_id).await,
98+
}
99+
}
100+
84101
///
85102
/// Reads a Postgres message from client with a timeout
86103
///
87104
/// Timeout values are in config
88105
///
89106
///
90-
pub async fn read_message_with_timeout<S: AsyncRead + Unpin>(
107+
async fn read_message_with_timeout<S: AsyncRead + Unpin>(
91108
mut stream: S,
92109
client_id: i32,
93-
connection_timeout: Duration,
110+
duration: Duration,
94111
) -> Result<(Code, BytesMut), Error> {
95-
timeout(connection_timeout, read_message(&mut stream, client_id)).await?
112+
timeout(duration, read(&mut stream, client_id))
113+
.await
114+
.map_err(|_| Error::ConnectionTimeout { duration })?
96115
}
97116

98117
///
@@ -102,7 +121,7 @@ pub async fn read_message_with_timeout<S: AsyncRead + Unpin>(
102121
/// Byte is then passed as `code` to this function to preserve the message structure
103122
///
104123
///
105-
pub async fn read_message<S: AsyncRead + Unpin>(
124+
async fn read<S: AsyncRead + Unpin>(
106125
mut stream: S,
107126
client_id: i32,
108127
) -> Result<(Code, BytesMut), Error> {

packages/cipherstash-proxy/src/postgresql/startup.rs

Lines changed: 30 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -46,14 +46,35 @@ pub async fn with_tls(stream: AsyncStream, config: &TandemConfig) -> Result<Asyn
4646
}
4747
}
4848

49-
pub async fn read_message_with_timeout<C>(
50-
client: &mut C,
51-
connection_timeout: Duration,
52-
) -> Result<StartupMessage, Error>
53-
where
54-
C: AsyncRead + Unpin,
55-
{
56-
timeout(connection_timeout, read_message(client)).await?
49+
///
50+
/// Reads a Postgres startup message from client with an optional timeout
51+
///
52+
/// Timeout values are in config
53+
///
54+
///
55+
pub async fn read_message<S: AsyncRead + Unpin>(
56+
mut stream: S,
57+
connection_timeout: Option<Duration>,
58+
) -> Result<StartupMessage, Error> {
59+
match connection_timeout {
60+
Some(duration) => read_message_with_timeout(stream, duration).await,
61+
None => read(&mut stream).await,
62+
}
63+
}
64+
65+
///
66+
/// Reads a Postgres message from client with a timeout
67+
///
68+
/// Timeout values are in config
69+
///
70+
///
71+
async fn read_message_with_timeout<S: AsyncRead + Unpin>(
72+
mut stream: S,
73+
duration: Duration,
74+
) -> Result<StartupMessage, Error> {
75+
timeout(duration, read(&mut stream))
76+
.await
77+
.map_err(|_| Error::ConnectionTimeout { duration })?
5778
}
5879

5980
///
@@ -62,7 +83,7 @@ where
6283
///
6384
///
6485
///
65-
async fn read_message<C>(client: &mut C) -> Result<StartupMessage, Error>
86+
async fn read<C>(client: &mut C) -> Result<StartupMessage, Error>
6687
where
6788
C: AsyncRead + Unpin,
6889
{

0 commit comments

Comments
 (0)