diff --git a/Cargo.lock b/Cargo.lock index 6f0c81f9..8375ea2f 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2289,6 +2289,7 @@ dependencies = [ "serde", "serde_json", "sha1", + "smallvec", "socket2", "thiserror 2.0.12", "tikv-jemallocator", @@ -3322,9 +3323,9 @@ dependencies = [ [[package]] name = "smallvec" -version = "1.15.0" +version = "1.15.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8917285742e9f3e1683f0a9c4e6b57960b7314d0b08d30d1ecd426713ee2eee9" +checksum = "67b1b7a3b5fe4f1376887184045fcf45c69e92af734b7aaddc05fb777b6fbd03" dependencies = [ "serde", ] diff --git a/integration/go/go_pgx/load_balancer_test.go b/integration/go/go_pgx/load_balancer_test.go index aa337b95..c23032de 100644 --- a/integration/go/go_pgx/load_balancer_test.go +++ b/integration/go/go_pgx/load_balancer_test.go @@ -231,10 +231,18 @@ outer: } func prewarm(t *testing.T, pool *pgxpool.Pool) { + ctx := context.Background() for range 25 { - for _, q := range []string{"BEGIN", "SELECT 1", "COMMIT", "SELECT 1"} { - _, err := pool.Exec(context.Background(), q) - assert.NoError(t, err) - } + // transaction `BEGIN; SELECT 1; COMMIT;` + tx, err := pool.Begin(ctx) + assert.NoError(t, err) + _, err = tx.Exec(ctx, "SELECT 1") + assert.NoError(t, err) + err = tx.Commit(ctx) + assert.NoError(t, err) + + // no-transaction `SELECT 1;` + _, err = pool.Exec(ctx, "SELECT 1") + assert.NoError(t, err) } } diff --git a/integration/load_balancer/run.sh b/integration/load_balancer/run.sh index 05b7f394..701041cb 100644 --- a/integration/load_balancer/run.sh +++ b/integration/load_balancer/run.sh @@ -10,7 +10,7 @@ export PGHOST=127.0.0.1 export PGDATABASE=postgres export PGPASSWORD=postgres -docker-compose up -d +docker compose up -d echo "Waiting for Postgres to be ready" @@ -45,4 +45,4 @@ popd killall pgdog -docker-compose down +docker compose down diff --git a/pgdog/Cargo.toml b/pgdog/Cargo.toml index 238ee373..e00c1e5c 100644 --- a/pgdog/Cargo.toml +++ b/pgdog/Cargo.toml @@ -59,6 +59,7 @@ indexmap = "2.9" lru = "0.16" hickory-resolver = "0.25.2" lazy_static = "1" +smallvec = "1.15.1" [target.'cfg(not(target_env = "msvc"))'.dependencies] tikv-jemallocator = "0.6" diff --git a/pgdog/src/backend/pool/connection/binding.rs b/pgdog/src/backend/pool/connection/binding.rs index 9bbea4fe..f63be9bc 100644 --- a/pgdog/src/backend/pool/connection/binding.rs +++ b/pgdog/src/backend/pool/connection/binding.rs @@ -1,7 +1,7 @@ //! Binding between frontend client and a connection on the backend. use crate::{ - net::{parameter::Parameters, ProtocolMessage}, + net::{parameter::Parameters, ProtocolMessage, Query}, state::State, }; @@ -71,6 +71,7 @@ impl Binding { Binding::Admin(backend) => Ok(backend.read().await?), Binding::MultiShard(shards, state) => { + println!("2.1"); if shards.is_empty() { loop { debug!("multi-shard binding suspended"); @@ -237,6 +238,13 @@ impl Binding { Ok(()) } + /// Execute a BEGIN on all servers + /// TODO: Block mutli-shard BEGINs as transaction should not occur on multiple shards + pub async fn begin(&mut self) -> Result<(), Error> { + let query = Query::new("BEGIN"); + self.execute(query.query()).await + } + pub async fn link_client(&mut self, params: &Parameters) -> Result { match self { Binding::Server(Some(ref mut server)) => server.link_client(params).await, diff --git a/pgdog/src/backend/pool/connection/mirror.rs b/pgdog/src/backend/pool/connection/mirror.rs index d183ace8..8aa94669 100644 --- a/pgdog/src/backend/pool/connection/mirror.rs +++ b/pgdog/src/backend/pool/connection/mirror.rs @@ -9,6 +9,7 @@ use tracing::{debug, error}; use crate::backend::Cluster; use crate::config::config; use crate::frontend::client::timeouts::Timeouts; +use crate::frontend::logical_transaction::LogicalTransaction; use crate::frontend::{Command, PreparedStatements, Router, RouterContext}; use crate::net::Parameters; use crate::state::State; @@ -47,6 +48,8 @@ pub(crate) struct Mirror { params: Parameters, /// Mirror state. state: State, + /// Logical transaction state. + logical_transaction: LogicalTransaction, } impl Mirror { @@ -71,6 +74,7 @@ impl Mirror { cluster: cluster.clone(), state: State::Idle, params: Parameters::default(), + logical_transaction: LogicalTransaction::new(), }; let query_timeout = Timeouts::from_config(&config.config.general); @@ -135,7 +139,7 @@ impl Mirror { &self.cluster, &mut self.prepared_statements, &self.params, - false, + &self.logical_transaction, ) { match self.router.query(context) { Ok(command) => { diff --git a/pgdog/src/frontend/client/engine/context.rs b/pgdog/src/frontend/client/engine/context.rs index 909b0f09..98cf572e 100644 --- a/pgdog/src/frontend/client/engine/context.rs +++ b/pgdog/src/frontend/client/engine/context.rs @@ -1,5 +1,7 @@ use crate::{ - frontend::{client::Inner, Buffer, Client, PreparedStatements}, + frontend::{ + client::Inner, logical_transaction::LogicalTransaction, Buffer, Client, PreparedStatements, + }, net::Parameters, }; @@ -13,7 +15,7 @@ pub struct EngineContext<'a> { /// Client parameters. pub(super) params: &'a Parameters, /// Is the client inside a transaction? - pub(super) in_transaction: bool, + pub(super) logical_transaction: &'a LogicalTransaction, /// Messages currently in client's buffer. pub(super) buffer: &'a Buffer, } @@ -23,9 +25,13 @@ impl<'a> EngineContext<'a> { Self { prepared_statements: &mut client.prepared_statements, params: &client.params, - in_transaction: client.in_transaction, + logical_transaction: &client.logical_transaction, connected: inner.connected(), buffer: &client.request_buffer, } } + + pub fn in_transaction(&self) -> bool { + self.logical_transaction.in_transaction() + } } diff --git a/pgdog/src/frontend/client/engine/mod.rs b/pgdog/src/frontend/client/engine/mod.rs index fa5aae16..8bf7d3e8 100644 --- a/pgdog/src/frontend/client/engine/mod.rs +++ b/pgdog/src/frontend/client/engine/mod.rs @@ -54,7 +54,8 @@ impl<'a> Engine<'a> { 'S' => { if only_close || only_sync && !self.context.connected { messages.push( - ReadyForQuery::in_transaction(self.context.in_transaction).message()?, + ReadyForQuery::in_transaction(self.context.in_transaction()) + .message()?, ) } } @@ -73,7 +74,7 @@ impl<'a> Engine<'a> { #[cfg(test)] mod test { use crate::{ - frontend::{Buffer, PreparedStatements}, + frontend::{logical_transaction::LogicalTransaction, Buffer, PreparedStatements}, net::{Parameters, Parse, Sync}, }; @@ -93,11 +94,13 @@ mod test { Sync.into(), ]); + let logical_transaction = LogicalTransaction::new(); + let context = EngineContext { connected: false, prepared_statements: &mut prepared, params: ¶ms, - in_transaction: false, + logical_transaction: &logical_transaction, buffer: &buf, }; diff --git a/pgdog/src/frontend/client/inner.rs b/pgdog/src/frontend/client/inner.rs index 21051386..340fee54 100644 --- a/pgdog/src/frontend/client/inner.rs +++ b/pgdog/src/frontend/client/inner.rs @@ -6,8 +6,8 @@ use crate::{ Error as BackendError, }, frontend::{ - buffer::BufferedQuery, router::Error as RouterError, Buffer, Command, Comms, - PreparedStatements, Router, RouterContext, Stats, + logical_transaction::LogicalTransaction, router::Error as RouterError, Buffer, Command, + Comms, PreparedStatements, Router, RouterContext, Stats, }, net::Parameters, state::State, @@ -29,8 +29,6 @@ pub struct Inner { pub(super) router: Router, /// Client stats. pub(super) stats: Stats, - /// Start transaction statement, intercepted by the router. - pub(super) start_transaction: Option, /// Client-wide comms. pub(super) comms: Comms, } @@ -47,7 +45,6 @@ impl Inner { backend, router, stats: Stats::new(), - start_transaction: None, comms: client.comms.clone(), }) } @@ -58,7 +55,7 @@ impl Inner { buffer: &mut Buffer, prepared_statements: &mut PreparedStatements, params: &Parameters, - in_transaction: bool, + logical_transaction: &LogicalTransaction, ) -> Result, RouterError> { let command = self .backend @@ -71,7 +68,7 @@ impl Inner { cluster, // Cluster configuration. prepared_statements, // Prepared statements. params, // Client connection parameters. - in_transaction, // Client in explcitely started transaction. + logical_transaction, // Client in explcitely started transaction. )?; self.router.query(context) }) diff --git a/pgdog/src/frontend/client/mod.rs b/pgdog/src/frontend/client/mod.rs index ca5c2032..0576fd88 100644 --- a/pgdog/src/frontend/client/mod.rs +++ b/pgdog/src/frontend/client/mod.rs @@ -5,12 +5,16 @@ use std::time::Instant; use bytes::BytesMut; use engine::EngineContext; +use pg_query::protobuf::PartitionElem; +use smallvec::SmallVec; use timeouts::Timeouts; use tokio::time::timeout; use tokio::{select, spawn}; use tracing::{debug, enabled, error, info, trace, Level as LogLevel}; +use super::logical_transaction::{LogicalTransaction, TransactionError, TransactionStatus}; use super::{Buffer, Command, Comms, Error, PreparedStatements}; + use crate::auth::{md5, scram::Server}; use crate::backend::{ databases, @@ -51,7 +55,7 @@ pub struct Client { streaming: bool, shutdown: bool, prepared_statements: PreparedStatements, - in_transaction: bool, + logical_transaction: LogicalTransaction, timeouts: Timeouts, request_buffer: Buffer, stream_buffer: BytesMut, @@ -236,7 +240,7 @@ impl Client { replication_mode, connect_params: params, prepared_statements: PreparedStatements::new(), - in_transaction: false, + logical_transaction: LogicalTransaction::new(), timeouts: Timeouts::from_config(&config.config.general), request_buffer: Buffer::new(), stream_buffer: BytesMut::new(), @@ -277,7 +281,7 @@ impl Client { connect_params: connect_params.clone(), params: connect_params, admin: false, - in_transaction: false, + logical_transaction: LogicalTransaction::new(), timeouts: Timeouts::from_config(&config().config.general), request_buffer: Buffer::new(), stream_buffer: BytesMut::new(), @@ -309,15 +313,17 @@ impl Client { /// Run the client. async fn run(&mut self) -> Result<(), Error> { + println!("1"); let mut inner = Inner::new(self)?; let shutdown = self.comms.shutting_down(); + println!("2"); loop { let query_timeout = self.timeouts.query_timeout(&inner.stats.state); select! { _ = shutdown.notified() => { - if !inner.backend.connected() && inner.start_transaction.is_none() { + if !inner.backend.connected() && !self.in_transaction() { break; } } @@ -368,6 +374,20 @@ impl Client { /// Handle client messages. async fn client_messages(&mut self, mut inner: InnerBorrow<'_>) -> Result { + // We don't start a transaction on the servers until a client is actually executing something. + // This prevents us holding open connections to multiple servers + if self.should_trigger_buffered_begin() { + println!(""); + println!("****************************************"); + println!("****************************************"); + println!("****************************************"); + println!("****************************************"); + println!("****************************************"); + println!(""); + inner.backend.begin().await?; + self.logical_transaction.record_begin(); + } + inner .stats .received(self.request_buffer.total_message_len()); @@ -378,7 +398,7 @@ impl Client { "{} [{}] (in transaction: {})", query.query(), self.addr, - self.in_transaction + self.in_transaction() ); QueryLogger::new(&self.request_buffer).log().await?; } @@ -393,7 +413,7 @@ impl Client { match engine.execute().await? { Action::Intercept(msgs) => { self.stream.send_many(&msgs).await?; - inner.done(self.in_transaction); + inner.done(self.in_transaction()); self.update_stats(&mut inner); return Ok(false); } @@ -407,29 +427,37 @@ impl Client { &mut self.request_buffer, &mut self.prepared_statements, &self.params, - self.in_transaction, + &self.logical_transaction, ) { Ok(command) => command, Err(err) => { if err.empty_query() { self.stream.send(&EmptyQueryResponse).await?; self.stream - .send_flush(&ReadyForQuery::in_transaction(self.in_transaction)) + .send_flush(&ReadyForQuery::in_transaction(self.in_transaction())) .await?; } else { error!("{:?} [{}]", err, self.addr); self.stream .error( ErrorResponse::syntax(err.to_string().as_str()), - self.in_transaction, + self.in_transaction(), ) .await?; } - inner.done(self.in_transaction); + inner.done(self.in_transaction()); return Ok(false); } }; + println!(""); + println!("COMMAND: {:?}", command); + println!("-- connected? {}", connected); + + // AAAAAAA + // I decided that transactions keep the connection open. which makes total sense. + // how can we release a transaction to the connection pool and have it be reused by another client? + if !connected { // Simulate transaction starting // until client sends an actual query. @@ -441,27 +469,25 @@ impl Client { // to a shard. // match command { + Some(Command::CommitTransaction) => { + println!("HELLOOOOOOOO?"); + self.end_transaction(false).await?; + + inner.done(self.in_transaction()); + return Ok(false); + } Some(Command::StartTransaction(query)) => { if let BufferedQuery::Query(_) = query { self.start_transaction().await?; - inner.start_transaction = Some(query.clone()); - self.in_transaction = true; - inner.done(self.in_transaction); + + inner.done(self.in_transaction()); return Ok(false); } } Some(Command::RollbackTransaction) => { - inner.start_transaction = None; self.end_transaction(true).await?; - self.in_transaction = false; - inner.done(self.in_transaction); - return Ok(false); - } - Some(Command::CommitTransaction) => { - inner.start_transaction = None; - self.end_transaction(false).await?; - self.in_transaction = false; - inner.done(self.in_transaction); + + inner.done(self.in_transaction()); return Ok(false); } // How many shards are configured. @@ -470,11 +496,13 @@ impl Client { let mut dr = DataRow::new(); dr.add(*shards as i64); let cc = CommandComplete::from_str("SHOW"); - let rfq = ReadyForQuery::in_transaction(self.in_transaction); + let rfq = ReadyForQuery::in_transaction(self.in_transaction()); + self.stream .send_many(&[rd.message()?, dr.message()?, cc.message()?, rfq.message()?]) .await?; - inner.done(self.in_transaction); + + inner.done(self.in_transaction()); return Ok(false); } Some(Command::Deallocate) => { @@ -489,12 +517,20 @@ impl Client { return Ok(false); } - Some(Command::Query(query)) => { - if query.is_cross_shard() && self.cross_shard_disabled { + Some(Command::Query(route)) => { + if self.in_transaction() { + let shard = route.shard().clone(); + self.logical_transaction.execute_query(shard)?; + } + + if route.is_cross_shard() && self.cross_shard_disabled { self.stream - .error(ErrorResponse::cross_shard_disabled(), self.in_transaction) + .error( + ErrorResponse::cross_shard_disabled(), + self.logical_transaction.in_transaction(), + ) .await?; - inner.done(self.in_transaction); + inner.done(self.in_transaction()); inner.reset_router(); return Ok(false); } @@ -546,12 +582,12 @@ impl Client { if err.no_server() { error!("{} [{}]", err, self.addr); self.stream - .error(ErrorResponse::from_err(&err), self.in_transaction) + .error(ErrorResponse::from_err(&err), self.in_transaction()) .await?; // TODO: should this be wrapped in a method? inner.disconnect(); inner.reset_router(); - inner.done(self.in_transaction); + inner.done(self.in_transaction()); return Ok(false); } else { return Err(err.into()); @@ -560,15 +596,7 @@ impl Client { }; } - // We don't start a transaction on the servers until - // a client is actually executing something. - // - // This prevents us holding open connections to multiple servers - if self.request_buffer.executable() { - if let Some(query) = inner.start_transaction.take() { - inner.backend.execute(&query).await?; - } - } + println!("Inner.router.route: {}", inner.router.route()); for msg in self.request_buffer.iter() { if let ProtocolMessage::Bind(bind) = msg { @@ -617,6 +645,11 @@ impl Client { let message = message.backend(); let has_more_messages = inner.backend.has_more_messages(); + println!(""); + println!(""); + println!(""); + println!("BACKKEND: \n{:?}", message); + // Messages that we need to send to the client immediately. // ReadyForQuery (B) | CopyInResponse (B) | ErrorResponse(B) | NoticeResponse(B) | NotificationResponse (B) let flush = matches!(code, 'Z' | 'G' | 'E' | 'N' | 'A') @@ -627,13 +660,13 @@ impl Client { // ReadyForQuery (B) if code == 'Z' { inner.stats.query(); - // In transaction if buffered BEGIN from client - // or server is telling us we are. - self.in_transaction = message.in_transaction() || inner.start_transaction.is_some(); - inner.stats.idle(self.in_transaction); + + // In transaction if buffered BEGIN from client or server is telling us we are. + let in_transaction = message.in_transaction() || self.in_transaction(); + inner.stats.idle(in_transaction); // Flush mirrors. - if !self.in_transaction { + if !in_transaction { inner.backend.mirror_flush(); } } @@ -645,12 +678,15 @@ impl Client { // Flushing can take a minute and we don't want to block // the connection from being reused. if inner.backend.done() { - let changed_params = inner.backend.changed_params(); - if inner.transaction_mode() && !self.replication_mode { + if inner.transaction_mode() && !self.replication_mode && !self.in_transaction() { inner.disconnect(); } + + let changed_params = inner.backend.changed_params(); + inner.stats.transaction(); inner.reset_router(); + debug!( "transaction finished [{:.3}ms]", inner.stats.last_transaction_time.as_secs_f64() * 1000.0 @@ -758,13 +794,29 @@ impl Client { /// Tell the client we started a transaction. async fn start_transaction(&mut self) -> Result<(), Error> { - self.stream - .send_many(&[ - CommandComplete::new_begin().message()?.backend(), - ReadyForQuery::in_transaction(true).message()?, - ]) - .await?; + // stack‐allocate up to 3 messages: optional NOTICE + BEGIN + Ready + let mut messages: SmallVec<[Message; 3]> = SmallVec::new(); + + match self.logical_transaction.soft_begin() { + Err(TransactionError::ExpectedActive) => { + let notice = NoticeResponse::from(ErrorResponse::already_in_transaction()) + .message()? + .backend(); + + messages.push(notice); + } + Err(e) => return Err(e.into()), // any other error is fatal + Ok(()) => {} + } + + // push the BEGIN + in-transaction ready + messages.push(CommandComplete::new_begin().message()?.backend()); + messages.push(ReadyForQuery::in_transaction(true).message()?); + + // send all messages + self.stream.send_many(&messages).await?; debug!("transaction started"); + Ok(()) } @@ -773,18 +825,38 @@ impl Client { /// This avoids connecting to servers when clients start and commit transactions /// with no queries. async fn end_transaction(&mut self, rollback: bool) -> Result<(), Error> { + println!("ENDING??? ---"); + // stack‐allocate up to 3 messages: NOTICE + COMMIT/ROLLBACK + READY + let mut messages: SmallVec<[Message; 3]> = SmallVec::new(); + + let logical_result = if rollback { + self.logical_transaction.rollback() + } else { + self.logical_transaction.commit() + }; + + println!("ENDING --- {:?}", logical_result); + + match logical_result { + Err(TransactionError::ExpectedActive) => { + messages.push( + NoticeResponse::from(ErrorResponse::no_transaction()) + .message()? + .backend(), + ); + } + Err(e) => return Err(e.into()), + Ok(()) => {} + } + let cmd = if rollback { CommandComplete::new_rollback() } else { CommandComplete::new_commit() }; - let mut messages = if !self.in_transaction { - vec![NoticeResponse::from(ErrorResponse::no_transaction()).message()?] - } else { - vec![] - }; messages.push(cmd.message()?.backend()); messages.push(ReadyForQuery::idle().message()?); + self.stream.send_many(&messages).await?; debug!("transaction ended"); Ok(()) @@ -805,10 +877,10 @@ impl Client { self.stream .send_many(&[ CommandComplete::from_str(command).message()?.backend(), - ReadyForQuery::in_transaction(self.in_transaction).message()?, + ReadyForQuery::in_transaction(self.in_transaction()).message()?, ]) .await?; - inner.done(self.in_transaction); + inner.done(self.in_transaction()); Ok(()) } @@ -819,6 +891,31 @@ impl Client { .prepared_statements(self.prepared_statements.len_local()); inner.stats.memory_used(self.memory_usage()); } + + fn in_transaction(&self) -> bool { + self.logical_transaction.status == TransactionStatus::BeginPending + || self.logical_transaction.status == TransactionStatus::InProgress + } + + fn should_trigger_buffered_begin(&self) -> bool { + let executable = self.request_buffer.executable(); + let should_trigger_begin = self.logical_transaction.should_trigger_begin(); + + println!(""); + println!(""); + println!(""); + println!(""); + println!(""); + println!(""); + println!("buffer: {:?}", self.request_buffer); + println!(""); + println!( + "Executable: {}, should_trigger_begin: {}", + executable, should_trigger_begin + ); + + executable && should_trigger_begin + } } impl Drop for Client { diff --git a/pgdog/src/frontend/client/test/mod.rs b/pgdog/src/frontend/client/test/mod.rs index aaa089f9..508daa9c 100644 --- a/pgdog/src/frontend/client/test/mod.rs +++ b/pgdog/src/frontend/client/test/mod.rs @@ -128,18 +128,19 @@ async fn test_test_client() { let disconnect = client.client_messages(inner.get()).await.unwrap(); assert!(!disconnect); - assert!(!client.in_transaction); + assert!(!client.in_transaction()); assert_eq!(inner.stats.state, State::Active); // Buffer not cleared yet. assert_eq!(client.request_buffer.total_message_len(), query.len()); assert!(inner.backend.connected()); + let command = inner .command( &mut client.request_buffer, &mut client.prepared_statements, &client.params, - client.in_transaction, + &client.logical_transaction, ) .unwrap(); assert!(matches!(command, Some(Command::Query(_)))); @@ -272,6 +273,7 @@ async fn test_client_extended() { let _ = read!(conn, ['1', '2', 't', 'T', 'D', 'C', 'Z']); + println!("this does not print"); handle.await.unwrap(); } @@ -443,6 +445,29 @@ async fn test_lock_session() { async fn test_transaction_state() { let (mut conn, mut client, mut inner) = new_client!(true); + println!(""); + println!(""); + println!(""); + println!(""); + println!(""); + println!(""); + println!(""); + println!(""); + println!(""); + println!(""); + println!(""); + println!(""); + println!(""); + println!(""); + println!(""); + println!(""); + println!(""); + println!(""); + println!(""); + println!(""); + println!(""); + println!(""); + conn.write_all(&buffer!({ Query::new("BEGIN") })) .await .unwrap(); @@ -451,9 +476,8 @@ async fn test_transaction_state() { client.client_messages(inner.get()).await.unwrap(); read!(conn, ['C', 'Z']); - assert!(client.in_transaction); + assert!(client.in_transaction()); assert!(inner.router.route().is_write()); - assert!(inner.router.in_transaction()); conn.write_all(&buffer!( { Parse::named("test", "SELECT $1") }, @@ -467,9 +491,8 @@ async fn test_transaction_state() { client.client_messages(inner.get()).await.unwrap(); assert!(inner.router.routed()); - assert!(client.in_transaction); + assert!(client.in_transaction()); assert!(inner.router.route().is_write()); - assert!(inner.router.in_transaction()); for c in ['1', 't', 'T', 'Z'] { let msg = inner.backend.read().await.unwrap(); @@ -496,7 +519,6 @@ async fn test_transaction_state() { .await .unwrap(); - assert!(!inner.router.routed()); client.buffer(&State::Idle).await.unwrap(); client.client_messages(inner.get()).await.unwrap(); assert!(inner.router.routed()); @@ -504,34 +526,43 @@ async fn test_transaction_state() { for c in ['2', 'D', 'C', 'Z'] { let msg = inner.backend.read().await.unwrap(); assert_eq!(msg.code(), c); - client.server_message(&mut inner.get(), msg).await.unwrap(); } read!(conn, ['2', 'D', 'C', 'Z']); - assert!(inner.router.routed()); - assert!(client.in_transaction); + assert!(client.in_transaction()); assert!(inner.router.route().is_write()); - assert!(inner.router.in_transaction()); conn.write_all(&buffer!({ Query::new("COMMIT") })) .await .unwrap(); + assert!(client.in_transaction()); + client.buffer(&State::Idle).await.unwrap(); + + println!("2.5"); + client.client_messages(inner.get()).await.unwrap(); + println!("3."); + for c in ['C', 'Z'] { + println!("3.1"); let msg = inner.backend.read().await.unwrap(); + println!("mssage: {:?}", &msg); assert_eq!(msg.code(), c); + println!("3.2"); client.server_message(&mut inner.get(), msg).await.unwrap(); } + println!("4."); + read!(conn, ['C', 'Z']); - assert!(!client.in_transaction); + assert!(!client.in_transaction()); assert!(!inner.router.routed()); } diff --git a/pgdog/src/frontend/error.rs b/pgdog/src/frontend/error.rs index fbbc4b8e..9268ae9d 100644 --- a/pgdog/src/frontend/error.rs +++ b/pgdog/src/frontend/error.rs @@ -45,6 +45,9 @@ pub enum Error { #[error("join error")] Join(#[from] tokio::task::JoinError), + + #[error("transaction error: {0}")] + Transaction(#[from] super::logical_transaction::TransactionError), } impl Error { diff --git a/pgdog/src/frontend/logical_transaction.rs b/pgdog/src/frontend/logical_transaction.rs index 232b8034..37120053 100644 --- a/pgdog/src/frontend/logical_transaction.rs +++ b/pgdog/src/frontend/logical_transaction.rs @@ -37,6 +37,7 @@ use super::router::parser::Shard; #[derive(Debug)] pub struct LogicalTransaction { pub status: TransactionStatus, + begin_dispatched: bool, manual_shard: Option, dirty_shard: Option, } @@ -45,6 +46,7 @@ impl LogicalTransaction { pub fn new() -> Self { Self { status: TransactionStatus::Idle, + begin_dispatched: false, manual_shard: None, dirty_shard: None, } @@ -64,6 +66,15 @@ impl LogicalTransaction { .or_else(|| self.manual_shard.clone()) } + /// Return whether a transaction is currently open or pending. + /// This is because we don't actually trigger the begin until the first statement is executed. + pub fn in_transaction(&self) -> bool { + matches!( + self.status, + TransactionStatus::BeginPending | TransactionStatus::InProgress + ) + } + /// Mark that a `BEGIN` is pending. /// /// Transitions `Idle -> BeginPending`. @@ -77,12 +88,8 @@ impl LogicalTransaction { self.status = TransactionStatus::BeginPending; Ok(()) } - TransactionStatus::BeginPending | TransactionStatus::InProgress => { - Err(TransactionError::AlreadyInTransaction) - } - TransactionStatus::Committed | TransactionStatus::RolledBack => { - Err(TransactionError::AlreadyFinalized) - } + TransactionStatus::BeginPending => Err(TransactionError::ExpectedIdle), + TransactionStatus::InProgress => Err(TransactionError::ExpectedIdle), } } @@ -101,15 +108,12 @@ impl LogicalTransaction { self.touch_shard(shard)?; match self.status { + TransactionStatus::Idle => Err(TransactionError::ExpectedActive), TransactionStatus::BeginPending => { self.status = TransactionStatus::InProgress; Ok(()) } - - TransactionStatus::Idle => Err(TransactionError::NoPendingBegins), TransactionStatus::InProgress => Ok(()), - TransactionStatus::Committed => Err(TransactionError::AlreadyFinalized), - TransactionStatus::RolledBack => Err(TransactionError::AlreadyFinalized), } } @@ -123,15 +127,15 @@ impl LogicalTransaction { /// - `AlreadyFinalized` if already `Committed` or `RolledBack`. pub fn commit(&mut self) -> Result<(), TransactionError> { match self.status { + TransactionStatus::Idle => Err(TransactionError::ExpectedActive), + TransactionStatus::BeginPending => { + self.reset(); + Ok(()) + } TransactionStatus::InProgress => { - self.status = TransactionStatus::Committed; + self.reset(); Ok(()) } - - TransactionStatus::Idle => Err(TransactionError::NoPendingBegins), - TransactionStatus::BeginPending => Err(TransactionError::NoActiveTransaction), - TransactionStatus::Committed => Err(TransactionError::AlreadyFinalized), - TransactionStatus::RolledBack => Err(TransactionError::AlreadyFinalized), } } @@ -145,15 +149,15 @@ impl LogicalTransaction { /// - `AlreadyFinalized` if already `Committed` or `RolledBack`. pub fn rollback(&mut self) -> Result<(), TransactionError> { match self.status { + TransactionStatus::Idle => Err(TransactionError::ExpectedActive), + TransactionStatus::BeginPending => { + self.reset(); + Ok(()) + } TransactionStatus::InProgress => { - self.status = TransactionStatus::RolledBack; + self.reset(); Ok(()) } - - TransactionStatus::Idle => Err(TransactionError::NoPendingBegins), - TransactionStatus::BeginPending => Err(TransactionError::NoActiveTransaction), - TransactionStatus::Committed => Err(TransactionError::AlreadyFinalized), - TransactionStatus::RolledBack => Err(TransactionError::AlreadyFinalized), } } @@ -165,6 +169,16 @@ impl LogicalTransaction { self.status = TransactionStatus::Idle; self.manual_shard = None; self.dirty_shard = None; + self.begin_dispatched = false; + } + + /// TODO + pub fn record_begin(&mut self) { + self.begin_dispatched = true; + } + + pub fn should_trigger_begin(&self) -> bool { + self.in_transaction() && !self.begin_dispatched } /// Pin the transaction to a specific shard. @@ -237,10 +251,8 @@ impl LogicalTransaction { #[derive(Debug)] pub enum TransactionError { // Transaction lifecycle - AlreadyInTransaction, - NoActiveTransaction, - AlreadyFinalized, - NoPendingBegins, + ExpectedIdle, + ExpectedActive, // Sharding policy InvalidShardType, @@ -251,10 +263,8 @@ impl fmt::Display for TransactionError { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { use TransactionError::*; match self { - AlreadyInTransaction => write!(f, "transaction already started"), - NoActiveTransaction => write!(f, "no active transaction"), - AlreadyFinalized => write!(f, "transaction already finalized"), - NoPendingBegins => write!(f, "transaction not pending"), + ExpectedIdle => write!(f, "there is already a transaction in progress"), + ExpectedActive => write!(f, "there is no transaction in progress"), InvalidShardType => write!(f, "sharding hints must be ::Direct(n)"), ShardConflict => { write!(f, "can't run a transaction on multiple shards") @@ -276,10 +286,6 @@ pub enum TransactionStatus { BeginPending, /// Transaction active. InProgress, - /// ROLLBACK issued. - RolledBack, - /// COMMIT issued. - Committed, } // ----------------------------------------------------------------------------- @@ -309,43 +315,49 @@ mod tests { let mut tx = LogicalTransaction::new(); tx.soft_begin().unwrap(); let err = tx.soft_begin().unwrap_err(); - assert!(matches!(err, TransactionError::AlreadyInTransaction)); + assert!(matches!(err, TransactionError::ExpectedIdle)); } #[test] fn test_soft_begin_in_progress_errors() { let mut tx = LogicalTransaction::new(); + tx.soft_begin().unwrap(); tx.execute_query(Shard::Direct(0)).unwrap(); + let err = tx.soft_begin().unwrap_err(); - assert!(matches!(err, TransactionError::AlreadyInTransaction)); + assert!(matches!(err, TransactionError::ExpectedIdle)); } #[test] fn test_soft_begin_after_commit_errors() { let mut tx = LogicalTransaction::new(); + tx.soft_begin().unwrap(); tx.execute_query(Shard::Direct(0)).unwrap(); tx.commit().unwrap(); - let err = tx.soft_begin().unwrap_err(); - assert!(matches!(err, TransactionError::AlreadyFinalized)); + + tx.soft_begin().unwrap(); // no panic } #[test] fn test_soft_begin_after_rollback_errors() { let mut tx = LogicalTransaction::new(); + tx.soft_begin().unwrap(); tx.execute_query(Shard::Direct(0)).unwrap(); tx.rollback().unwrap(); - let err = tx.soft_begin().unwrap_err(); - assert!(matches!(err, TransactionError::AlreadyFinalized)); + + tx.soft_begin().unwrap(); // no panic } #[test] fn test_execute_query_from_begin_pending() { let mut tx = LogicalTransaction::new(); + tx.soft_begin().unwrap(); tx.execute_query(Shard::Direct(0)).unwrap(); + assert_eq!(tx.status, TransactionStatus::InProgress); assert_eq!(tx.dirty_shard, Some(Shard::Direct(0))); } @@ -354,25 +366,31 @@ mod tests { fn test_execute_query_from_idle_errors() { let mut tx = LogicalTransaction::new(); let err = tx.execute_query(Shard::Direct(0)).unwrap_err(); - assert!(matches!(err, TransactionError::NoPendingBegins)); + assert!(matches!(err, TransactionError::ExpectedActive)); } #[test] fn test_execute_query_after_commit_errors() { - let mut tx = LogicalTransaction::new(); - tx.soft_begin().unwrap(); - tx.execute_query(Shard::Direct(0)).unwrap(); - tx.commit().unwrap(); - let err = tx.execute_query(Shard::Direct(0)).unwrap_err(); - assert!(matches!(err, TransactionError::AlreadyFinalized)); + let mut ltx = LogicalTransaction::new(); + + ltx.soft_begin().unwrap(); + ltx.execute_query(Shard::Direct(0)).unwrap(); + ltx.execute_query(Shard::Direct(0)).unwrap(); + ltx.execute_query(Shard::Direct(0)).unwrap(); + ltx.commit().unwrap(); + + let err = ltx.execute_query(Shard::Direct(0)).unwrap_err(); + assert!(matches!(err, TransactionError::ExpectedActive)); } #[test] fn test_execute_query_multiple_on_same_shard() { let mut tx = LogicalTransaction::new(); + tx.soft_begin().unwrap(); tx.execute_query(Shard::Direct(0)).unwrap(); tx.execute_query(Shard::Direct(0)).unwrap(); + assert_eq!(tx.dirty_shard, Some(Shard::Direct(0))); assert_eq!(tx.status, TransactionStatus::InProgress); } @@ -380,8 +398,10 @@ mod tests { #[test] fn test_execute_query_cross_shard_errors() { let mut tx = LogicalTransaction::new(); + tx.soft_begin().unwrap(); tx.execute_query(Shard::Direct(0)).unwrap(); + let err = tx.execute_query(Shard::Direct(1)).unwrap_err(); assert!(matches!(err, TransactionError::ShardConflict)); } @@ -389,7 +409,9 @@ mod tests { #[test] fn test_execute_query_invalid_shard_type_errors() { let mut tx = LogicalTransaction::new(); + tx.soft_begin().unwrap(); + let err = tx.execute_query(Shard::All).unwrap_err(); assert!(matches!(err, TransactionError::InvalidShardType)); } @@ -397,90 +419,112 @@ mod tests { #[test] fn test_commit_from_in_progress() { let mut tx = LogicalTransaction::new(); + tx.soft_begin().unwrap(); tx.execute_query(Shard::Direct(0)).unwrap(); tx.commit().unwrap(); - assert_eq!(tx.status, TransactionStatus::Committed); + + assert_eq!(tx.status, TransactionStatus::Idle); } #[test] fn test_commit_from_idle_errors() { let mut tx = LogicalTransaction::new(); let err = tx.commit().unwrap_err(); - assert!(matches!(err, TransactionError::NoPendingBegins)); + assert!(matches!(err, TransactionError::ExpectedActive)); } #[test] - fn test_commit_from_begin_pending_errors() { + fn test_commit_from_begin_pending() { let mut tx = LogicalTransaction::new(); + tx.soft_begin().unwrap(); - let err = tx.commit().unwrap_err(); - assert!(matches!(err, TransactionError::NoActiveTransaction)); + tx.commit().unwrap(); // no-panic } #[test] fn test_commit_already_committed_errors() { let mut tx = LogicalTransaction::new(); + tx.soft_begin().unwrap(); tx.execute_query(Shard::Direct(0)).unwrap(); + tx.execute_query(Shard::Direct(0)).unwrap(); + tx.execute_query(Shard::Direct(0)).unwrap(); + tx.execute_query(Shard::Direct(0)).unwrap(); tx.commit().unwrap(); + let err = tx.commit().unwrap_err(); - assert!(matches!(err, TransactionError::AlreadyFinalized)); + assert!(matches!(err, TransactionError::ExpectedActive)); } #[test] fn test_rollback_from_in_progress() { let mut tx = LogicalTransaction::new(); + tx.soft_begin().unwrap(); tx.execute_query(Shard::Direct(0)).unwrap(); + tx.execute_query(Shard::Direct(0)).unwrap(); + tx.execute_query(Shard::Direct(0)).unwrap(); tx.rollback().unwrap(); - assert_eq!(tx.status, TransactionStatus::RolledBack); + + assert_eq!(tx.status, TransactionStatus::Idle); } #[test] fn test_rollback_from_begin_pending_errors() { let mut tx = LogicalTransaction::new(); + tx.soft_begin().unwrap(); - let err = tx.rollback().unwrap_err(); - assert!(matches!(err, TransactionError::NoActiveTransaction)); + tx.rollback().unwrap(); // no-panic } #[test] fn test_reset_clears_state() { let mut tx = LogicalTransaction::new(); + tx.soft_begin().unwrap(); tx.execute_query(Shard::Direct(0)).unwrap(); tx.set_manual_shard(Shard::Direct(0)).unwrap(); tx.reset(); + assert_eq!(tx.status, TransactionStatus::Idle); assert_eq!(tx.manual_shard, None); assert_eq!(tx.dirty_shard, None); } #[test] - fn test_set_manual_shard_before_touch() { + fn test_set_matching_manual_shard_before_touch() { let mut tx = LogicalTransaction::new(); + tx.soft_begin().unwrap(); tx.set_manual_shard(Shard::Direct(0)).unwrap(); + assert_eq!(tx.manual_shard, Some(Shard::Direct(0))); + + tx.execute_query(Shard::Direct(0)).unwrap(); // should succeed + tx.execute_query(Shard::Direct(0)).unwrap(); // should succeed tx.execute_query(Shard::Direct(0)).unwrap(); // should succeed } #[test] fn test_set_manual_shard_after_touch_same_ok() { let mut tx = LogicalTransaction::new(); + tx.soft_begin().unwrap(); tx.execute_query(Shard::Direct(0)).unwrap(); tx.set_manual_shard(Shard::Direct(0)).unwrap(); + assert_eq!(tx.manual_shard, Some(Shard::Direct(0))); } #[test] fn test_set_manual_shard_after_touch_different_errors() { let mut tx = LogicalTransaction::new(); + // touch shard 0 tx.soft_begin().unwrap(); tx.execute_query(Shard::Direct(0)).unwrap(); + // manually set shard 1 let err = tx.set_manual_shard(Shard::Direct(1)).unwrap_err(); assert!(matches!(err, TransactionError::ShardConflict)); @@ -490,8 +534,10 @@ mod tests { fn test_manual_then_dirty_conflict() { let mut tx = LogicalTransaction::new(); tx.soft_begin().unwrap(); + // pin to shard 0 tx.set_manual_shard(Shard::Direct(0)).unwrap(); + // touching another shard must fail let err = tx.execute_query(Shard::Direct(1)).unwrap_err(); assert!(matches!(err, TransactionError::ShardConflict)); @@ -507,8 +553,12 @@ mod tests { #[test] fn test_active_shard_dirty() { let mut tx = LogicalTransaction::new(); + tx.soft_begin().unwrap(); tx.execute_query(Shard::Direct(69)).unwrap(); + tx.execute_query(Shard::Direct(69)).unwrap(); + tx.execute_query(Shard::Direct(69)).unwrap(); + assert_eq!(tx.active_shard(), Some(Shard::Direct(69))); } @@ -523,57 +573,81 @@ mod tests { fn test_rollback_from_idle_errors() { let mut tx = LogicalTransaction::new(); let err = tx.rollback().unwrap_err(); - assert!(matches!(err, TransactionError::NoPendingBegins)); + assert!(matches!(err, TransactionError::ExpectedActive)); } #[test] fn test_commit_after_rollback_errors() { let mut tx = LogicalTransaction::new(); + tx.soft_begin().unwrap(); tx.execute_query(Shard::Direct(0)).unwrap(); + tx.execute_query(Shard::Direct(0)).unwrap(); + tx.execute_query(Shard::Direct(0)).unwrap(); tx.rollback().unwrap(); + let err = tx.commit().unwrap_err(); - assert!(matches!(err, TransactionError::AlreadyFinalized)); + assert!(matches!(err, TransactionError::ExpectedActive)); } #[test] fn test_rollback_after_commit_errors() { let mut tx = LogicalTransaction::new(); + tx.soft_begin().unwrap(); tx.execute_query(Shard::Direct(0)).unwrap(); + tx.execute_query(Shard::Direct(0)).unwrap(); + tx.execute_query(Shard::Direct(0)).unwrap(); tx.commit().unwrap(); + let err = tx.rollback().unwrap_err(); - assert!(matches!(err, TransactionError::AlreadyFinalized)); + assert!(matches!(err, TransactionError::ExpectedActive)); } #[test] fn test_rollback_already_rolledback_errors() { let mut tx = LogicalTransaction::new(); + tx.soft_begin().unwrap(); tx.execute_query(Shard::Direct(0)).unwrap(); + tx.execute_query(Shard::Direct(0)).unwrap(); + tx.execute_query(Shard::Direct(0)).unwrap(); tx.rollback().unwrap(); + let err = tx.rollback().unwrap_err(); - assert!(matches!(err, TransactionError::AlreadyFinalized)); + println!("Error: {:?}", err); + assert!(matches!(err, TransactionError::ExpectedActive)); } #[test] fn test_execute_query_after_rollback_errors() { let mut tx = LogicalTransaction::new(); + tx.soft_begin().unwrap(); tx.execute_query(Shard::Direct(0)).unwrap(); + tx.execute_query(Shard::Direct(0)).unwrap(); + tx.execute_query(Shard::Direct(0)).unwrap(); tx.rollback().unwrap(); + let err = tx.execute_query(Shard::Direct(0)).unwrap_err(); - assert!(matches!(err, TransactionError::AlreadyFinalized)); + assert!(matches!(err, TransactionError::ExpectedActive)); } #[test] fn test_set_manual_shard_multiple_changes_before_execute() { let mut tx = LogicalTransaction::new(); + tx.soft_begin().unwrap(); tx.set_manual_shard(Shard::Direct(1)).unwrap(); - tx.set_manual_shard(Shard::Direct(2)).unwrap(); + tx.set_manual_shard(Shard::Direct(2)).unwrap(); // change, no error. + assert_eq!(tx.manual_shard, Some(Shard::Direct(2))); + + tx.execute_query(Shard::Direct(2)).unwrap(); + tx.execute_query(Shard::Direct(2)).unwrap(); + tx.execute_query(Shard::Direct(2)).unwrap(); tx.execute_query(Shard::Direct(2)).unwrap(); + let err = tx.execute_query(Shard::Direct(1)).unwrap_err(); assert!(matches!(err, TransactionError::ShardConflict)); } @@ -581,9 +655,11 @@ mod tests { #[test] fn test_set_manual_shard_after_commit_same_ok() { let mut tx = LogicalTransaction::new(); + tx.soft_begin().unwrap(); tx.execute_query(Shard::Direct(0)).unwrap(); tx.commit().unwrap(); + tx.set_manual_shard(Shard::Direct(0)).unwrap(); assert_eq!(tx.manual_shard, Some(Shard::Direct(0))); } @@ -591,31 +667,41 @@ mod tests { #[test] fn test_set_manual_shard_after_commit_different_errors() { let mut tx = LogicalTransaction::new(); + tx.soft_begin().unwrap(); tx.execute_query(Shard::Direct(0)).unwrap(); + tx.execute_query(Shard::Direct(0)).unwrap(); + tx.execute_query(Shard::Direct(0)).unwrap(); tx.commit().unwrap(); - let err = tx.set_manual_shard(Shard::Direct(1)).unwrap_err(); - assert!(matches!(err, TransactionError::ShardConflict)); + + tx.set_manual_shard(Shard::Direct(1)).unwrap(); // should not panic } #[test] fn test_set_manual_shard_after_rollback_same_ok() { let mut tx = LogicalTransaction::new(); + tx.soft_begin().unwrap(); tx.execute_query(Shard::Direct(0)).unwrap(); tx.rollback().unwrap(); - tx.set_manual_shard(Shard::Direct(0)).unwrap(); - assert_eq!(tx.manual_shard, Some(Shard::Direct(0))); + + tx.set_manual_shard(Shard::Direct(88)).unwrap(); + assert_eq!(tx.manual_shard, Some(Shard::Direct(88))); // no panic } #[test] fn test_set_manual_shard_after_rollback_different_errors() { let mut tx = LogicalTransaction::new(); + tx.soft_begin().unwrap(); tx.execute_query(Shard::Direct(0)).unwrap(); + tx.execute_query(Shard::Direct(0)).unwrap(); + tx.execute_query(Shard::Direct(0)).unwrap(); + tx.execute_query(Shard::Direct(0)).unwrap(); + tx.execute_query(Shard::Direct(0)).unwrap(); tx.rollback().unwrap(); - let err = tx.set_manual_shard(Shard::Direct(1)).unwrap_err(); - assert!(matches!(err, TransactionError::ShardConflict)); + + tx.set_manual_shard(Shard::Direct(1)).unwrap(); // should not panic } #[test] @@ -634,10 +720,13 @@ mod tests { #[test] fn test_soft_begin_after_reset_from_finalized() { let mut tx = LogicalTransaction::new(); + tx.soft_begin().unwrap(); tx.execute_query(Shard::Direct(0)).unwrap(); tx.commit().unwrap(); + tx.reset(); + tx.soft_begin().unwrap(); assert_eq!(tx.status, TransactionStatus::BeginPending); } @@ -645,17 +734,12 @@ mod tests { #[test] fn test_active_shard_both_same() { let mut tx = LogicalTransaction::new(); + tx.set_manual_shard(Shard::Direct(3)).unwrap(); tx.soft_begin().unwrap(); tx.execute_query(Shard::Direct(3)).unwrap(); - assert_eq!(tx.active_shard(), Some(Shard::Direct(3))); - } - #[test] - fn test_statements_executed_remains_zero_after_execute() { - let mut tx = LogicalTransaction::new(); - tx.soft_begin().unwrap(); - tx.execute_query(Shard::Direct(0)).unwrap(); + assert_eq!(tx.active_shard(), Some(Shard::Direct(3))); } } diff --git a/pgdog/src/frontend/router/context.rs b/pgdog/src/frontend/router/context.rs index 117a2dbf..30f7ea20 100644 --- a/pgdog/src/frontend/router/context.rs +++ b/pgdog/src/frontend/router/context.rs @@ -1,7 +1,9 @@ use super::Error; use crate::{ backend::Cluster, - frontend::{buffer::BufferedQuery, Buffer, PreparedStatements}, + frontend::{ + buffer::BufferedQuery, logical_transaction::LogicalTransaction, Buffer, PreparedStatements, + }, net::{Bind, Parameters}, }; @@ -17,10 +19,10 @@ pub struct RouterContext<'a> { pub cluster: &'a Cluster, /// Client parameters, e.g. search_path. pub params: &'a Parameters, - /// Client inside transaction, - pub in_transaction: bool, /// Currently executing COPY statement. pub copy_mode: bool, + /// Client's logical_transaction struct, + pub logical_transaction: &'a LogicalTransaction, } impl<'a> RouterContext<'a> { @@ -29,7 +31,7 @@ impl<'a> RouterContext<'a> { cluster: &'a Cluster, stmt: &'a mut PreparedStatements, params: &'a Parameters, - in_transaction: bool, + logical_transaction: &'a LogicalTransaction, ) -> Result { let query = buffer.query()?; let bind = buffer.parameters()?; @@ -40,9 +42,13 @@ impl<'a> RouterContext<'a> { bind, params, prepared_statements: stmt, + logical_transaction, cluster, - in_transaction, copy_mode, }) } + + pub fn in_transaction(&self) -> bool { + self.logical_transaction.in_transaction() + } } diff --git a/pgdog/src/frontend/router/mod.rs b/pgdog/src/frontend/router/mod.rs index 2629139c..4feac537 100644 --- a/pgdog/src/frontend/router/mod.rs +++ b/pgdog/src/frontend/router/mod.rs @@ -98,11 +98,6 @@ impl Router { self.routed } - /// Query parser is inside a transaction. - pub fn in_transaction(&self) -> bool { - self.query_parser.in_transaction() - } - /// Get last commmand computed by the query parser. pub fn command(&self) -> &Command { &self.latest_command diff --git a/pgdog/src/frontend/router/parser/context.rs b/pgdog/src/frontend/router/parser/context.rs index 51d84d66..3c50bbaa 100644 --- a/pgdog/src/frontend/router/parser/context.rs +++ b/pgdog/src/frontend/router/parser/context.rs @@ -60,7 +60,7 @@ impl<'a> QueryParserContext<'a> { /// Write override enabled? pub(super) fn write_override(&self) -> bool { - self.router_context.in_transaction && self.rw_conservative() + self.in_transaction() && self.rw_conservative() } /// Are we using the conservative read/write separation strategy? @@ -92,4 +92,8 @@ impl<'a> QueryParserContext<'a> { pub(super) fn multi_tenant(&self) -> &Option { self.multi_tenant } + + pub(super) fn in_transaction(&self) -> bool { + self.router_context.in_transaction() + } } diff --git a/pgdog/src/frontend/router/parser/query/explain.rs b/pgdog/src/frontend/router/parser/query/explain.rs index a95ff5be..86d15745 100644 --- a/pgdog/src/frontend/router/parser/query/explain.rs +++ b/pgdog/src/frontend/router/parser/query/explain.rs @@ -28,6 +28,7 @@ mod tests { use super::*; use crate::backend::Cluster; + use crate::frontend::logical_transaction::LogicalTransaction; use crate::frontend::{Buffer, PreparedStatements, RouterContext}; use crate::net::messages::{Bind, Parameter, Parse, Query}; use crate::net::Parameters; @@ -39,8 +40,10 @@ mod tests { let cluster = Cluster::new_test(); let mut stmts = PreparedStatements::default(); let params = Parameters::default(); + let logical_transaction = LogicalTransaction::new(); - let ctx = RouterContext::new(&buffer, &cluster, &mut stmts, ¶ms, false).unwrap(); + let ctx = RouterContext::new(&buffer, &cluster, &mut stmts, ¶ms, &logical_transaction) + .unwrap(); match QueryParser::default().parse(ctx).unwrap().clone() { Command::Query(route) => route, @@ -65,8 +68,10 @@ mod tests { let cluster = Cluster::new_test(); let mut stmts = PreparedStatements::default(); let params = Parameters::default(); + let logical_transaction = LogicalTransaction::new(); - let ctx = RouterContext::new(&buffer, &cluster, &mut stmts, ¶ms, false).unwrap(); + let ctx = RouterContext::new(&buffer, &cluster, &mut stmts, ¶ms, &logical_transaction) + .unwrap(); match QueryParser::default().parse(ctx).unwrap().clone() { Command::Query(route) => route, diff --git a/pgdog/src/frontend/router/parser/query/mod.rs b/pgdog/src/frontend/router/parser/query/mod.rs index fae8727c..12be9f16 100644 --- a/pgdog/src/frontend/router/parser/query/mod.rs +++ b/pgdog/src/frontend/router/parser/query/mod.rs @@ -49,8 +49,6 @@ use tracing::{debug, trace}; /// #[derive(Debug)] pub struct QueryParser { - // The statement is executed inside a tranasction. - in_transaction: bool, // No matter what query is executed, we'll send it to the primary. write_override: bool, // Currently calculated shard. @@ -60,7 +58,6 @@ pub struct QueryParser { impl Default for QueryParser { fn default() -> Self { Self { - in_transaction: false, write_override: false, shard: Shard::All, } @@ -68,17 +65,11 @@ impl Default for QueryParser { } impl QueryParser { - /// Indicates we are in a transaction. - pub fn in_transaction(&self) -> bool { - self.in_transaction - } - /// Parse a query and return a command. pub fn parse(&mut self, context: RouterContext) -> Result { let mut qp_context = QueryParserContext::new(context); let mut command = if qp_context.query().is_ok() { - self.in_transaction = qp_context.router_context.in_transaction; self.write_override = qp_context.write_override(); self.query(&mut qp_context)? diff --git a/pgdog/src/frontend/router/parser/query/set.rs b/pgdog/src/frontend/router/parser/query/set.rs index cba3effd..7f053c41 100644 --- a/pgdog/src/frontend/router/parser/query/set.rs +++ b/pgdog/src/frontend/router/parser/query/set.rs @@ -60,7 +60,7 @@ impl QueryParser { // TODO: Handle SET commands for updating client // params without touching the server. name => { - if !self.in_transaction { + if !context.in_transaction() { let mut value = vec![]; for node in &stmt.args { diff --git a/pgdog/src/frontend/router/parser/query/test.rs b/pgdog/src/frontend/router/parser/query/test.rs index 799b251f..1c74ccc8 100644 --- a/pgdog/src/frontend/router/parser/query/test.rs +++ b/pgdog/src/frontend/router/parser/query/test.rs @@ -1,14 +1,16 @@ -use crate::net::{ - messages::{parse::Parse, Parameter}, - Close, Format, Sync, -}; - use super::{super::Shard, *}; + use crate::backend::Cluster; use crate::config::ReadWriteStrategy; -use crate::frontend::{Buffer, PreparedStatements, RouterContext}; +use crate::frontend::{ + logical_transaction::LogicalTransaction, Buffer, PreparedStatements, RouterContext, +}; use crate::net::messages::Query; use crate::net::Parameters; +use crate::net::{ + messages::{parse::Parse, Parameter}, + Close, Format, Sync, +}; macro_rules! command { ($query:expr) => {{ @@ -18,7 +20,10 @@ macro_rules! command { let cluster = Cluster::new_test(); let mut stmt = PreparedStatements::default(); let params = Parameters::default(); - let context = RouterContext::new(&buffer, &cluster, &mut stmt, ¶ms, false).unwrap(); + let logical_transaction = LogicalTransaction::new(); + let context = + RouterContext::new(&buffer, &cluster, &mut stmt, ¶ms, &logical_transaction) + .unwrap(); let command = query_parser.parse(context).unwrap().clone(); (command, query_parser) @@ -44,9 +49,21 @@ macro_rules! query_parser { let mut prep_stmts = PreparedStatements::default(); let params = Parameters::default(); let buffer: Buffer = vec![$query.into()].into(); - let router_context = - RouterContext::new(&buffer, &cluster, &mut prep_stmts, ¶ms, $in_transaction) - .unwrap(); + let mut logical_transaction = LogicalTransaction::new(); + + if $in_transaction { + logical_transaction.soft_begin().unwrap(); + } + + let router_context = RouterContext::new( + &buffer, + &cluster, + &mut prep_stmts, + ¶ms, + &logical_transaction, + ) + .unwrap(); + $qp.parse(router_context).unwrap() }}; @@ -69,6 +86,7 @@ macro_rules! parse { data: p.to_vec(), }) .collect::>(); + let logical_transaction = LogicalTransaction::new(); let bind = Bind::new_params_codes($name, ¶ms, $codes); let route = QueryParser::default() .parse( @@ -77,7 +95,7 @@ macro_rules! parse { &Cluster::new_test(), &mut PreparedStatements::default(), &Parameters::default(), - false, + &logical_transaction, ) .unwrap(), ) @@ -165,23 +183,20 @@ fn test_omni() { let q = "SELECT sharded_omni.* FROM sharded_omni WHERE sharded_omni.id = $1"; let route = query!(q); assert!(matches!(route.shard(), Shard::Direct(_))); - let (_, qp) = command!(q); - assert!(!qp.in_transaction); + let (_, _qp) = command!(q); } #[test] fn test_set() { let route = query!(r#"SET "pgdog.shard" TO 1"#); assert_eq!(route.shard(), &Shard::Direct(1)); - let (_, qp) = command!(r#"SET "pgdog.shard" TO 1"#); - assert!(!qp.in_transaction); + let (_, _qp) = command!(r#"SET "pgdog.shard" TO 1"#); let route = query!(r#"SET "pgdog.sharding_key" TO '11'"#); assert_eq!(route.shard(), &Shard::Direct(1)); - let (_, qp) = command!(r#"SET "pgdog.sharding_key" TO '11'"#); - assert!(!qp.in_transaction); + let (_, _qp) = command!(r#"SET "pgdog.sharding_key" TO '11'"#); - for (command, qp) in [ + for (command, _qp) in [ command!("SET TimeZone TO 'UTC'"), command!("SET TIME ZONE 'UTC'"), ] { @@ -192,10 +207,9 @@ fn test_set() { } _ => panic!("not a set"), }; - assert!(!qp.in_transaction); } - let (command, qp) = command!("SET statement_timeout TO 3000"); + let (command, _qp) = command!("SET statement_timeout TO 3000"); match command { Command::Set { name, value } => { assert_eq!(name, "statement_timeout"); @@ -203,11 +217,10 @@ fn test_set() { } _ => panic!("not a set"), }; - assert!(!qp.in_transaction); // TODO: user shouldn't be able to set these. // The server will report an error on synchronization. - let (command, qp) = command!("SET is_superuser TO true"); + let (command, _qp) = command!("SET is_superuser TO true"); match command { Command::Set { name, value } => { assert_eq!(name, "is_superuser"); @@ -215,7 +228,6 @@ fn test_set() { } _ => panic!("not a set"), }; - assert!(!qp.in_transaction); let (_, mut qp) = command!("BEGIN"); assert!(qp.write_override); @@ -241,15 +253,24 @@ fn test_set() { let cluster = Cluster::new_test(); let mut prep_stmts = PreparedStatements::default(); let params = Parameters::default(); - let router_context = - RouterContext::new(&buffer, &cluster, &mut prep_stmts, ¶ms, true).unwrap(); + + let mut logical_transaction = LogicalTransaction::new(); + logical_transaction.soft_begin().unwrap(); + + let router_context = RouterContext::new( + &buffer, + &cluster, + &mut prep_stmts, + ¶ms, + &logical_transaction, + ) + .unwrap(); let mut context = QueryParserContext::new(router_context); for read_only in [true, false] { context.read_only = read_only; // Overriding context above. let mut qp = QueryParser::default(); - qp.in_transaction = true; let route = qp.query(&mut context).unwrap(); match route { @@ -269,7 +290,6 @@ fn test_transaction() { _ => panic!("not a query"), }; - assert!(qp.in_transaction); assert!(qp.write_override); let route = query_parser!(qp, Parse::named("test", "SELECT $1"), true); @@ -290,9 +310,7 @@ fn test_transaction() { command, Command::StartTransaction(BufferedQuery::Query(_)) )); - assert!(qp.in_transaction); - qp.in_transaction = true; let route = query_parser!( qp, Query::new("SET application_name TO 'test'"), @@ -323,9 +341,8 @@ fn test_begin_extended() { #[test] fn test_show_shards() { - let (cmd, qp) = command!("SHOW pgdog.shards"); + let (cmd, _qp) = command!("SHOW pgdog.shards"); assert!(matches!(cmd, Command::Shards(2))); - assert!(!qp.in_transaction); } #[test] @@ -355,10 +372,13 @@ fn test_cte() { fn test_function_begin() { let (cmd, mut qp) = command!("BEGIN"); assert!(matches!(cmd, Command::StartTransaction(_))); - assert!(qp.in_transaction); let cluster = Cluster::new_test(); let mut prep_stmts = PreparedStatements::default(); let params = Parameters::default(); + + let mut logical_transaction = LogicalTransaction::new(); + logical_transaction.soft_begin().unwrap(); + let buffer: Buffer = vec![Query::new( "SELECT ROW(t1.*) AS tt1, @@ -377,15 +397,22 @@ WHERE t2.account = ( ) .into()] .into(); - let router_context = - RouterContext::new(&buffer, &cluster, &mut prep_stmts, ¶ms, true).unwrap(); + + let router_context = RouterContext::new( + &buffer, + &cluster, + &mut prep_stmts, + ¶ms, + &logical_transaction, + ) + .unwrap(); + let mut context = QueryParserContext::new(router_context); let route = qp.query(&mut context).unwrap(); match route { Command::Query(query) => assert!(query.is_write()), _ => panic!("not a select"), } - assert!(qp.in_transaction); } #[test] @@ -433,8 +460,10 @@ fn test_close_direct_one_shard() { let buf: Buffer = vec![Close::named("test").into(), Sync.into()].into(); let mut pp = PreparedStatements::default(); let params = Parameters::default(); + let logical_transaction = LogicalTransaction::new(); - let context = RouterContext::new(&buf, &cluster, &mut pp, ¶ms, false).unwrap(); + let context = + RouterContext::new(&buf, &cluster, &mut pp, ¶ms, &logical_transaction).unwrap(); let cmd = qp.parse(context).unwrap(); diff --git a/pgdog/src/frontend/router/parser/query/transaction.rs b/pgdog/src/frontend/router/parser/query/transaction.rs index 9554b0f6..85ffaa5a 100644 --- a/pgdog/src/frontend/router/parser/query/transaction.rs +++ b/pgdog/src/frontend/router/parser/query/transaction.rs @@ -25,7 +25,6 @@ impl QueryParser { TransactionStmtKind::TransStmtCommit => return Ok(Command::CommitTransaction), TransactionStmtKind::TransStmtRollback => return Ok(Command::RollbackTransaction), TransactionStmtKind::TransStmtBegin | TransactionStmtKind::TransStmtStart => { - self.in_transaction = true; return Ok(Command::StartTransaction(context.query()?.clone())); } _ => Ok(Command::Query(Route::write(None))), diff --git a/pgdog/src/net/messages/error_response.rs b/pgdog/src/net/messages/error_response.rs index 91109066..39fa0b63 100644 --- a/pgdog/src/net/messages/error_response.rs +++ b/pgdog/src/net/messages/error_response.rs @@ -138,6 +138,16 @@ impl ErrorResponse { ..Default::default() } } + + /// Warning for issuing BEGIN inside an existing transaction. + pub fn already_in_transaction() -> ErrorResponse { + ErrorResponse { + severity: "WARNING".into(), + code: "25001".into(), + message: "there is already a transaction in progress".into(), + ..Default::default() + } + } } impl Display for ErrorResponse {