Skip to content

Commit c46f1dd

Browse files
authored
Add statement mode (#612)
1 parent d06e265 commit c46f1dd

File tree

10 files changed

+162
-48
lines changed

10 files changed

+162
-48
lines changed

pgdog-config/src/core.rs

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -302,6 +302,21 @@ impl Config {
302302
}
303303
}
304304

305+
// Check pooler mode.
306+
let mut pooler_mode = HashMap::<String, Option<PoolerMode>>::new();
307+
for database in &self.databases {
308+
if let Some(mode) = pooler_mode.get(&database.name) {
309+
if mode != &database.pooler_mode {
310+
warn!(
311+
"database \"{}\" (shard={}, role={}) has a different \"pooler_mode\" setting, ignoring",
312+
database.name, database.shard, database.role,
313+
);
314+
}
315+
} else {
316+
pooler_mode.insert(database.name.clone(), database.pooler_mode.clone());
317+
}
318+
}
319+
305320
// Check that idle_healthcheck_interval is shorter than ban_timeout.
306321
if self.general.ban_timeout > 0
307322
&& self.general.idle_healthcheck_interval >= self.general.ban_timeout
@@ -316,12 +331,12 @@ impl Config {
316331
match self.general.passthrough_auth {
317332
PassthoughAuth::Enabled if !self.general.tls_client_required => {
318333
warn!(
319-
"consider setting tls_client_required while passthrough_auth is enabled to prevent clients from exposing plaintext passwords"
334+
"consider setting \"tls_client_required\" while \"passthrough_auth\" is enabled to prevent clients from exposing plaintext passwords"
320335
);
321336
}
322337
PassthoughAuth::EnabledPlain => {
323338
warn!(
324-
"passthrough_auth plain is enabled - network traffic may expose plaintext passwords"
339+
"\"passthrough_auth\" is set to \"plain\", network traffic may expose plaintext passwords"
325340
)
326341
}
327342
_ => (),

pgdog-config/src/pooling.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,13 +49,15 @@ pub enum PoolerMode {
4949
#[default]
5050
Transaction,
5151
Session,
52+
Statement,
5253
}
5354

5455
impl std::fmt::Display for PoolerMode {
5556
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
5657
match self {
5758
Self::Transaction => write!(f, "transaction"),
5859
Self::Session => write!(f, "session"),
60+
Self::Statement => write!(f, "statement"),
5961
}
6062
}
6163
}

pgdog/src/backend/pool/cluster.rs

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,21 @@ pub struct ClusterShardConfig {
8585
pub replicas: Vec<PoolConfig>,
8686
}
8787

88+
impl ClusterShardConfig {
89+
pub fn pooler_mode(&self) -> PoolerMode {
90+
// One of these will exist.
91+
92+
if let Some(ref primary) = self.primary {
93+
return primary.config.pooler_mode;
94+
}
95+
96+
self.replicas
97+
.first()
98+
.map(|replica| replica.config.pooler_mode)
99+
.unwrap_or_default()
100+
}
101+
}
102+
88103
/// Cluster creation config.
89104
pub struct ClusterConfig<'a> {
90105
pub name: &'a str,
@@ -122,12 +137,17 @@ impl<'a> ClusterConfig<'a> {
122137
sharded_schemas: ShardedSchemas,
123138
rewrite: &'a Rewrite,
124139
) -> Self {
140+
let pooler_mode = shards
141+
.first()
142+
.map(|shard| shard.pooler_mode())
143+
.unwrap_or(user.pooler_mode.unwrap_or(general.pooler_mode));
144+
125145
Self {
126146
name: &user.database,
127147
password: user.password(),
128148
user: &user.name,
129149
replication_sharding: user.replication_sharding.clone(),
130-
pooler_mode: user.pooler_mode.unwrap_or(general.pooler_mode),
150+
pooler_mode,
131151
lb_strategy: general.load_balancing_strategy,
132152
shards,
133153
sharded_tables,

pgdog/src/backend/pool/connection/mod.rs

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -402,18 +402,17 @@ impl Connection {
402402
self.cluster.as_ref().ok_or(Error::NotConnected)
403403
}
404404

405-
/// Transaction mode pooling.
405+
/// Pooler is in session mode.
406406
#[inline]
407-
pub(crate) fn transaction_mode(&self) -> bool {
407+
pub(crate) fn session_mode(&self) -> bool {
408408
self.cluster()
409-
.map(|c| c.pooler_mode() == PoolerMode::Transaction)
409+
.map(|c| c.pooler_mode() == PoolerMode::Session)
410410
.unwrap_or(true)
411411
}
412412

413-
/// Pooler is in session mod
414413
#[inline]
415-
pub(crate) fn session_mode(&self) -> bool {
416-
!self.transaction_mode()
414+
pub(crate) fn pooler_mode(&self) -> PoolerMode {
415+
self.cluster().map(|c| c.pooler_mode()).unwrap_or_default()
417416
}
418417

419418
/// This is an admin DB connection.

pgdog/src/config/mod.rs

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -144,19 +144,26 @@ pub fn overrides(overrides: Overrides) {
144144
// Test helper functions
145145
#[cfg(test)]
146146
pub fn load_test() {
147+
load_test_with_pooler_mode(PoolerMode::Transaction)
148+
}
149+
150+
#[cfg(test)]
151+
pub fn load_test_with_pooler_mode(pooler_mode: PoolerMode) {
147152
use crate::backend::databases::init;
148153

149154
let mut config = ConfigAndUsers::default();
150155
config.config.databases = vec![Database {
151156
name: "pgdog".into(),
152157
host: "127.0.0.1".into(),
153158
port: 5432,
159+
pooler_mode: Some(pooler_mode),
154160
..Default::default()
155161
}];
156162
config.users.users = vec![User {
157163
name: "pgdog".into(),
158164
database: "pgdog".into(),
159165
password: Some("pgdog".into()),
166+
pooler_mode: Some(pooler_mode),
160167
..Default::default()
161168
}];
162169

pgdog/src/frontend/client/query_engine/query.rs

Lines changed: 25 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ use crate::{
99
state::State,
1010
};
1111

12-
use tracing::debug;
12+
use tracing::{debug, error};
1313

1414
use super::*;
1515

@@ -264,7 +264,7 @@ impl QueryEngine {
264264

265265
// Release the connection back into the pool before flushing data to client.
266266
// Flushing can take a minute and we don't want to block the connection from being reused.
267-
if self.backend.transaction_mode() && context.requests_left == 0 {
267+
if !self.backend.session_mode() && context.requests_left == 0 {
268268
self.backend.disconnect();
269269
}
270270

@@ -312,18 +312,14 @@ impl QueryEngine {
312312
&& !context.admin
313313
&& context.client_request.executable()
314314
{
315-
let bytes_sent = context
316-
.stream
317-
.error(
318-
ErrorResponse::cross_shard_disabled(),
319-
context.in_transaction(),
320-
)
321-
.await?;
322-
self.stats.sent(bytes_sent);
315+
let error = ErrorResponse::cross_shard_disabled();
316+
317+
self.error_response(context, error).await?;
323318

324319
if self.backend.connected() && self.backend.done() {
325320
self.backend.disconnect();
326321
}
322+
327323
Ok(false)
328324
} else {
329325
Ok(true)
@@ -367,19 +363,31 @@ impl QueryEngine {
367363
{
368364
let error = ErrorResponse::in_failed_transaction();
369365

370-
self.hooks.on_engine_error(context, &error)?;
371-
372-
let bytes_sent = context
373-
.stream
374-
.error(error, context.in_transaction())
375-
.await?;
376-
self.stats.sent(bytes_sent);
366+
self.error_response(context, error).await?;
377367

378368
Ok(false)
379369
} else {
380370
Ok(true)
381371
}
382372
}
373+
374+
pub(super) async fn error_response(
375+
&mut self,
376+
context: &mut QueryEngineContext<'_>,
377+
error: ErrorResponse,
378+
) -> Result<(), Error> {
379+
error!("{:?} [{:?}]", error.message, context.stream.peer_addr());
380+
381+
self.hooks.on_engine_error(context, &error)?;
382+
383+
let bytes_sent = context
384+
.stream
385+
.error(error, context.in_transaction())
386+
.await?;
387+
self.stats.sent(bytes_sent);
388+
389+
return Ok(());
390+
}
383391
}
384392

385393
#[derive(Debug, Default, Clone)]

pgdog/src/frontend/client/query_engine/route_query.rs

Lines changed: 18 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
1-
use tracing::{error, trace};
1+
use pgdog_config::PoolerMode;
2+
use tracing::trace;
23

34
use super::*;
45

@@ -7,6 +8,14 @@ impl QueryEngine {
78
&mut self,
89
context: &mut QueryEngineContext<'_>,
910
) -> Result<bool, Error> {
11+
// Check that we can route this transaction at all.
12+
if self.backend.pooler_mode() == PoolerMode::Statement && context.client_request.is_begin()
13+
{
14+
self.error_response(context, ErrorResponse::transaction_statement_mode())
15+
.await?;
16+
return Ok(false);
17+
}
18+
1019
// Admin doesn't have a cluster.
1120
let cluster = if let Ok(cluster) = self.backend.cluster() {
1221
if !context.in_transaction() && !cluster.online() {
@@ -17,18 +26,14 @@ impl QueryEngine {
1726

1827
match self.backend.cluster() {
1928
Ok(cluster) => cluster,
20-
Err(err) => {
29+
Err(_) => {
2130
// Cluster is gone.
22-
error!("{:?} [{:?}]", err, context.stream.peer_addr());
23-
let error =
24-
ErrorResponse::connection(&identifier.user, &identifier.database);
25-
self.hooks.on_engine_error(context, &error)?;
31+
self.error_response(
32+
context,
33+
ErrorResponse::connection(&identifier.user, &identifier.database),
34+
)
35+
.await?;
2636

27-
let bytes_sent = context
28-
.stream
29-
.error(error, context.in_transaction())
30-
.await?;
31-
self.stats.sent(bytes_sent);
3237
return Ok(false);
3338
}
3439
}
@@ -55,17 +60,9 @@ impl QueryEngine {
5560
);
5661
}
5762
Err(err) => {
58-
error!("{:?} [{:?}]", err, context.stream.peer_addr());
59-
60-
let error = ErrorResponse::syntax(err.to_string().as_str());
61-
62-
self.hooks.on_engine_error(context, &error)?;
63-
64-
let bytes_sent = context
65-
.stream
66-
.error(error, context.in_transaction())
63+
self.error_response(context, ErrorResponse::syntax(err.to_string().as_str()))
6764
.await?;
68-
self.stats.sent(bytes_sent);
65+
6966
return Ok(false);
7067
}
7168
}

pgdog/src/frontend/client/test/mod.rs

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
use std::time::{Duration, Instant};
22

3+
use pgdog_config::PoolerMode;
34
use tokio::{
45
io::{AsyncReadExt, AsyncWriteExt},
56
net::{TcpListener, TcpStream},
@@ -10,7 +11,10 @@ use bytes::{Buf, BufMut, BytesMut};
1011

1112
use crate::{
1213
backend::databases::databases,
13-
config::{config, load_test, load_test_replicas, set, PreparedStatements, Role},
14+
config::{
15+
config, load_test, load_test_replicas, load_test_with_pooler_mode, set, PreparedStatements,
16+
Role,
17+
},
1418
frontend::{
1519
client::{BufferEvent, QueryEngine},
1620
prepared_statements, Client,
@@ -739,6 +743,25 @@ async fn test_client_login_timeout() {
739743
handle.await.unwrap().unwrap();
740744
}
741745

746+
#[tokio::test]
747+
async fn test_statement_mode() {
748+
crate::logger();
749+
750+
load_test_with_pooler_mode(PoolerMode::Statement);
751+
let (mut conn, mut client) = parallel_test_client().await;
752+
753+
let _ = tokio::spawn(async move {
754+
client.run().await.unwrap();
755+
});
756+
757+
let req = buffer!({ Query::new("BEGIN") });
758+
conn.write_all(&req).await.unwrap();
759+
760+
let msgs = read!(conn, ['E', 'Z']);
761+
let error = ErrorResponse::from_bytes(msgs[0].clone().freeze()).unwrap();
762+
assert_eq!(error.code, "58000");
763+
}
764+
742765
#[tokio::test]
743766
async fn test_client_login_timeout_does_not_affect_queries() {
744767
crate::logger();

pgdog/src/frontend/client_request.rs

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
use std::ops::{Deref, DerefMut};
33

44
use lazy_static::lazy_static;
5+
use regex::Regex;
56

67
use crate::{
78
frontend::router::parser::RewritePlan,
@@ -108,6 +109,26 @@ impl ClientRequest {
108109
Ok(None)
109110
}
110111

112+
pub fn is_begin(&self) -> bool {
113+
lazy_static! {
114+
static ref BEGIN: Regex = Regex::new("(?i)^BEGIN").unwrap();
115+
}
116+
117+
for message in &self.messages {
118+
let query = match message {
119+
ProtocolMessage::Parse(parse) => parse.query(),
120+
ProtocolMessage::Query(query) => query.query(),
121+
_ => continue,
122+
};
123+
124+
if BEGIN.is_match(query) {
125+
return true;
126+
}
127+
}
128+
129+
false
130+
}
131+
111132
/// If this buffer contains bound parameters, retrieve them.
112133
pub fn parameters(&self) -> Result<Option<&Bind>, Error> {
113134
for message in &self.messages {
@@ -462,4 +483,16 @@ mod test {
462483
assert_eq!(second_slice[2].code(), 'H'); // Flush
463484
assert_eq!(second_slice[3].code(), 'S'); // Sync
464485
}
486+
487+
#[test]
488+
fn test_detect_begin() {
489+
for query in [
490+
ProtocolMessage::Query(Query::new("begin")),
491+
ProtocolMessage::Query(Query::new("BEGIN WORK REPEATABLE READ")),
492+
ProtocolMessage::Parse(Parse::new_anonymous("BEGIN")),
493+
] {
494+
let req = ClientRequest::from(vec![query]);
495+
assert!(req.is_begin());
496+
}
497+
}
465498
}

0 commit comments

Comments
 (0)