diff --git a/datafusion-postgres/src/handlers.rs b/datafusion-postgres/src/handlers.rs index ab03bcf..e81db85 100644 --- a/datafusion-postgres/src/handlers.rs +++ b/datafusion-postgres/src/handlers.rs @@ -1,4 +1,5 @@ use std::collections::HashMap; +use std::hash::{Hash, Hasher}; use std::sync::Arc; use crate::auth::{AuthManager, Permission, ResourceType}; @@ -18,7 +19,9 @@ use pgwire::api::stmt::QueryParser; use pgwire::api::stmt::StoredStatement; use pgwire::api::{ClientInfo, PgWireServerHandlers, Type}; use pgwire::error::{PgWireError, PgWireResult}; -use tokio::sync::Mutex; +use std::sync::atomic::{AtomicU64, Ordering}; +use std::time::{Duration, Instant}; +use tokio::sync::{Mutex, RwLock}; use arrow_pg::datatypes::df; use arrow_pg::datatypes::{arrow_schema_to_pg_fields, into_pg_type}; @@ -63,13 +66,26 @@ impl PgWireServerHandlers for HandlerFactory { } } +/// Per-connection transaction state storage +/// We use a hash of both PID and secret key as the connection identifier for better uniqueness +pub type ConnectionId = u64; + +#[derive(Debug, Clone)] +struct ConnectionState { + transaction_state: TransactionState, + last_activity: Instant, +} + +type ConnectionStates = Arc>>; + /// The pgwire handler backed by a datafusion `SessionContext` pub struct DfSessionService { session_context: Arc, parser: Arc, timezone: Arc>, - transaction_state: Arc>, + connection_states: ConnectionStates, auth_manager: Arc, + cleanup_counter: AtomicU64, } impl DfSessionService { @@ -84,11 +100,57 @@ impl DfSessionService { session_context, parser, timezone: Arc::new(Mutex::new("UTC".to_string())), - transaction_state: Arc::new(Mutex::new(TransactionState::None)), + connection_states: Arc::new(RwLock::new(HashMap::new())), auth_manager, + cleanup_counter: AtomicU64::new(0), + } + } + + async fn get_transaction_state(&self, client_id: ConnectionId) -> TransactionState { + self.connection_states + .read() + .await + .get(&client_id) + .map(|s| s.transaction_state) + .unwrap_or(TransactionState::None) + } + + async fn update_transaction_state(&self, client_id: ConnectionId, new_state: TransactionState) { + let mut states = self.connection_states.write().await; + + // Update or insert state using entry API + states + .entry(client_id) + .and_modify(|s| { + s.transaction_state = new_state; + s.last_activity = Instant::now(); + }) + .or_insert(ConnectionState { + transaction_state: new_state, + last_activity: Instant::now(), + }); + + // Inline cleanup every 100 operations + if self.cleanup_counter.fetch_add(1, Ordering::Relaxed) % 100 == 0 { + let cutoff = Instant::now() - Duration::from_secs(3600); + states.retain(|_, state| state.last_activity > cutoff); } } + fn get_client_id(client: &C) -> ConnectionId { + // Use a hash of PID, secret key, and socket address for better uniqueness + let (pid, secret) = client.pid_and_secret_key(); + let socket_addr = client.socket_addr(); + + // Create a hash of all identifying values + let mut hasher = std::collections::hash_map::DefaultHasher::new(); + pid.hash(&mut hasher); + secret.hash(&mut hasher); + socket_addr.hash(&mut hasher); + + hasher.finish() + } + /// Check if the current user has permission to execute a query async fn check_query_permission(&self, client: &C, query: &str) -> PgWireResult<()> where @@ -213,18 +275,24 @@ impl DfSessionService { } } - async fn try_respond_transaction_statements<'a>( + async fn try_respond_transaction_statements<'a, C>( &self, + client: &C, query_lower: &str, - ) -> PgWireResult>> { + ) -> PgWireResult>> + where + C: ClientInfo, + { + let client_id = Self::get_client_id(client); + // Transaction handling based on pgwire example: // https://github.com/sunng87/pgwire/blob/master/examples/transaction.rs#L57 match query_lower.trim() { "begin" | "begin transaction" | "begin work" | "start transaction" => { - let mut state = self.transaction_state.lock().await; - match *state { + match self.get_transaction_state(client_id).await { TransactionState::None => { - *state = TransactionState::Active; + self.update_transaction_state(client_id, TransactionState::Active) + .await; Ok(Some(Response::TransactionStart(Tag::new("BEGIN")))) } TransactionState::Active => { @@ -245,10 +313,10 @@ impl DfSessionService { } } "commit" | "commit transaction" | "commit work" | "end" | "end transaction" => { - let mut state = self.transaction_state.lock().await; - match *state { + match self.get_transaction_state(client_id).await { TransactionState::Active => { - *state = TransactionState::None; + self.update_transaction_state(client_id, TransactionState::None) + .await; Ok(Some(Response::TransactionEnd(Tag::new("COMMIT")))) } TransactionState::None => { @@ -257,14 +325,15 @@ impl DfSessionService { } TransactionState::Failed => { // COMMIT in failed transaction is treated as ROLLBACK - *state = TransactionState::None; + self.update_transaction_state(client_id, TransactionState::None) + .await; Ok(Some(Response::TransactionEnd(Tag::new("ROLLBACK")))) } } } "rollback" | "rollback transaction" | "rollback work" | "abort" => { - let mut state = self.transaction_state.lock().await; - *state = TransactionState::None; + self.update_transaction_state(client_id, TransactionState::None) + .await; Ok(Some(Response::TransactionEnd(Tag::new("ROLLBACK")))) } _ => Ok(None), @@ -343,7 +412,7 @@ impl SimpleQueryHandler for DfSessionService { } if let Some(resp) = self - .try_respond_transaction_statements(&query_lower) + .try_respond_transaction_statements(client, &query_lower) .await? { return Ok(vec![resp]); @@ -354,17 +423,15 @@ impl SimpleQueryHandler for DfSessionService { } // Check if we're in a failed transaction and block non-transaction commands - { - let state = self.transaction_state.lock().await; - if *state == TransactionState::Failed { - return Err(PgWireError::UserError(Box::new( - pgwire::error::ErrorInfo::new( - "ERROR".to_string(), - "25P01".to_string(), - "current transaction is aborted, commands ignored until end of transaction block".to_string(), - ), - ))); - } + let client_id = Self::get_client_id(client); + if self.get_transaction_state(client_id).await == TransactionState::Failed { + return Err(PgWireError::UserError(Box::new( + pgwire::error::ErrorInfo::new( + "ERROR".to_string(), + "25P01".to_string(), + "current transaction is aborted, commands ignored until end of transaction block".to_string(), + ), + ))); } let df_result = self.session_context.sql(query).await; @@ -374,11 +441,10 @@ impl SimpleQueryHandler for DfSessionService { Ok(df) => df, Err(e) => { // If we're in a transaction and a query fails, mark transaction as failed - { - let mut state = self.transaction_state.lock().await; - if *state == TransactionState::Active { - *state = TransactionState::Failed; - } + let client_id = Self::get_client_id(client); + if self.get_transaction_state(client_id).await == TransactionState::Active { + self.update_transaction_state(client_id, TransactionState::Failed) + .await; } return Err(PgWireError::ApiError(Box::new(e))); } @@ -496,10 +562,29 @@ impl ExtendedQueryHandler for DfSessionService { return Ok(resp); } + if let Some(resp) = self + .try_respond_transaction_statements(client, &query) + .await? + { + return Ok(resp); + } + if let Some(resp) = self.try_respond_show_statements(&query).await? { return Ok(resp); } + // Check if we're in a failed transaction and block non-transaction commands + let client_id = Self::get_client_id(client); + if self.get_transaction_state(client_id).await == TransactionState::Failed { + return Err(PgWireError::UserError(Box::new( + pgwire::error::ErrorInfo::new( + "ERROR".to_string(), + "25P01".to_string(), + "current transaction is aborted, commands ignored until end of transaction block".to_string(), + ), + ))); + } + let (_, plan) = &portal.statement.statement; let param_types = plan @@ -510,11 +595,18 @@ impl ExtendedQueryHandler for DfSessionService { .clone() .replace_params_with_values(¶m_values) .map_err(|e| PgWireError::ApiError(Box::new(e)))?; // Fixed: Use ¶m_values - let dataframe = self - .session_context - .execute_logical_plan(plan) - .await - .map_err(|e| PgWireError::ApiError(Box::new(e)))?; + let dataframe = match self.session_context.execute_logical_plan(plan).await { + Ok(df) => df, + Err(e) => { + // If we're in a transaction and a query fails, mark transaction as failed + let client_id = Self::get_client_id(client); + if self.get_transaction_state(client_id).await == TransactionState::Active { + self.update_transaction_state(client_id, TransactionState::Failed) + .await; + } + return Err(PgWireError::ApiError(Box::new(e))); + } + }; let resp = df::encode_dataframe(dataframe, &portal.result_column_format).await?; Ok(Response::Query(resp)) } @@ -555,3 +647,134 @@ fn ordered_param_types(types: &HashMap>) -> Vec