Skip to content

Commit 466892e

Browse files
committed
Per connection trasnaction tracking and cleanup
1 parent 663247d commit 466892e

File tree

1 file changed

+249
-36
lines changed

1 file changed

+249
-36
lines changed

datafusion-postgres/src/handlers.rs

Lines changed: 249 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,9 @@ use pgwire::api::stmt::QueryParser;
1818
use pgwire::api::stmt::StoredStatement;
1919
use pgwire::api::{ClientInfo, PgWireServerHandlers, Type};
2020
use pgwire::error::{PgWireError, PgWireResult};
21-
use tokio::sync::Mutex;
21+
use std::sync::atomic::{AtomicU64, Ordering};
22+
use std::time::{Duration, Instant};
23+
use tokio::sync::{Mutex, RwLock};
2224

2325
use arrow_pg::datatypes::df;
2426
use arrow_pg::datatypes::{arrow_schema_to_pg_fields, into_pg_type};
@@ -63,13 +65,26 @@ impl PgWireServerHandlers for HandlerFactory {
6365
}
6466
}
6567

68+
/// Per-connection transaction state storage
69+
/// We use the process ID as the connection identifier since it's unique per connection
70+
pub type ConnectionId = i32;
71+
72+
#[derive(Debug, Clone)]
73+
struct ConnectionState {
74+
transaction_state: TransactionState,
75+
last_activity: Instant,
76+
}
77+
78+
type ConnectionStates = Arc<RwLock<HashMap<ConnectionId, ConnectionState>>>;
79+
6680
/// The pgwire handler backed by a datafusion `SessionContext`
6781
pub struct DfSessionService {
6882
session_context: Arc<SessionContext>,
6983
parser: Arc<Parser>,
7084
timezone: Arc<Mutex<String>>,
71-
transaction_state: Arc<Mutex<TransactionState>>,
85+
connection_states: ConnectionStates,
7286
auth_manager: Arc<AuthManager>,
87+
cleanup_counter: AtomicU64,
7388
}
7489

7590
impl DfSessionService {
@@ -84,11 +99,48 @@ impl DfSessionService {
8499
session_context,
85100
parser,
86101
timezone: Arc::new(Mutex::new("UTC".to_string())),
87-
transaction_state: Arc::new(Mutex::new(TransactionState::None)),
102+
connection_states: Arc::new(RwLock::new(HashMap::new())),
88103
auth_manager,
104+
cleanup_counter: AtomicU64::new(0),
105+
}
106+
}
107+
108+
async fn get_transaction_state(&self, client_id: ConnectionId) -> TransactionState {
109+
self.connection_states
110+
.read()
111+
.await
112+
.get(&client_id)
113+
.map(|s| s.transaction_state)
114+
.unwrap_or(TransactionState::None)
115+
}
116+
117+
async fn update_transaction_state(&self, client_id: ConnectionId, new_state: TransactionState {
118+
let mut states = self.connection_states.write().await;
119+
120+
// Update or insert state using entry API
121+
states
122+
.entry(client_id)
123+
.and_modify(|s| {
124+
s.transaction_state = new_state;
125+
s.last_activity = Instant::now();
126+
})
127+
.or_insert(ConnectionState {
128+
transaction_state: new_state,
129+
last_activity: Instant::now(),
130+
});
131+
132+
// Inline cleanup every 100 operations
133+
if self.cleanup_counter.fetch_add(1, Ordering::Relaxed) % 100 == 0 {
134+
let cutoff = Instant::now() - Duration::from_secs(3600);
135+
states.retain(|_, state| state.last_activity > cutoff);
89136
}
90137
}
91138

139+
fn get_client_id<C: ClientInfo>(client: &C) -> ConnectionId {
140+
// Use the process ID which is unique per connection
141+
client.pid_and_secret_key().0
142+
}
143+
92144
/// Check if the current user has permission to execute a query
93145
async fn check_query_permission<C>(&self, client: &C, query: &str) -> PgWireResult<()>
94146
where
@@ -213,18 +265,24 @@ impl DfSessionService {
213265
}
214266
}
215267

216-
async fn try_respond_transaction_statements<'a>(
268+
async fn try_respond_transaction_statements<'a, C>(
217269
&self,
270+
client: &C,
218271
query_lower: &str,
219-
) -> PgWireResult<Option<Response<'a>>> {
272+
) -> PgWireResult<Option<Response<'a>>>
273+
where
274+
C: ClientInfo,
275+
{
276+
let client_id = Self::get_client_id(client);
277+
220278
// Transaction handling based on pgwire example:
221279
// https://github.com/sunng87/pgwire/blob/master/examples/transaction.rs#L57
222280
match query_lower.trim() {
223281
"begin" | "begin transaction" | "begin work" | "start transaction" => {
224-
let mut state = self.transaction_state.lock().await;
225-
match *state {
282+
match self.get_transaction_state(client_id).await {
226283
TransactionState::None => {
227-
*state = TransactionState::Active;
284+
self.update_transaction_state(client_id, TransactionState::Active)
285+
.await;
228286
Ok(Some(Response::TransactionStart(Tag::new("BEGIN"))))
229287
}
230288
TransactionState::Active => {
@@ -245,10 +303,10 @@ impl DfSessionService {
245303
}
246304
}
247305
"commit" | "commit transaction" | "commit work" | "end" | "end transaction" => {
248-
let mut state = self.transaction_state.lock().await;
249-
match *state {
306+
match self.get_transaction_state(client_id).await {
250307
TransactionState::Active => {
251-
*state = TransactionState::None;
308+
self.update_transaction_state(client_id, TransactionState::None)
309+
.await;
252310
Ok(Some(Response::TransactionEnd(Tag::new("COMMIT"))))
253311
}
254312
TransactionState::None => {
@@ -257,14 +315,15 @@ impl DfSessionService {
257315
}
258316
TransactionState::Failed => {
259317
// COMMIT in failed transaction is treated as ROLLBACK
260-
*state = TransactionState::None;
318+
self.update_transaction_state(client_id, TransactionState::None)
319+
.await;
261320
Ok(Some(Response::TransactionEnd(Tag::new("ROLLBACK"))))
262321
}
263322
}
264323
}
265324
"rollback" | "rollback transaction" | "rollback work" | "abort" => {
266-
let mut state = self.transaction_state.lock().await;
267-
*state = TransactionState::None;
325+
self.update_transaction_state(client_id, TransactionState::None)
326+
.await;
268327
Ok(Some(Response::TransactionEnd(Tag::new("ROLLBACK"))))
269328
}
270329
_ => Ok(None),
@@ -343,7 +402,7 @@ impl SimpleQueryHandler for DfSessionService {
343402
}
344403

345404
if let Some(resp) = self
346-
.try_respond_transaction_statements(&query_lower)
405+
.try_respond_transaction_statements(client, &query_lower)
347406
.await?
348407
{
349408
return Ok(vec![resp]);
@@ -354,17 +413,15 @@ impl SimpleQueryHandler for DfSessionService {
354413
}
355414

356415
// Check if we're in a failed transaction and block non-transaction commands
357-
{
358-
let state = self.transaction_state.lock().await;
359-
if *state == TransactionState::Failed {
360-
return Err(PgWireError::UserError(Box::new(
361-
pgwire::error::ErrorInfo::new(
362-
"ERROR".to_string(),
363-
"25P01".to_string(),
364-
"current transaction is aborted, commands ignored until end of transaction block".to_string(),
365-
),
366-
)));
367-
}
416+
let client_id = Self::get_client_id(client);
417+
if self.get_transaction_state(client_id).await == TransactionState::Failed {
418+
return Err(PgWireError::UserError(Box::new(
419+
pgwire::error::ErrorInfo::new(
420+
"ERROR".to_string(),
421+
"25P01".to_string(),
422+
"current transaction is aborted, commands ignored until end of transaction block".to_string(),
423+
),
424+
)));
368425
}
369426

370427
let df_result = self.session_context.sql(query).await;
@@ -374,11 +431,10 @@ impl SimpleQueryHandler for DfSessionService {
374431
Ok(df) => df,
375432
Err(e) => {
376433
// If we're in a transaction and a query fails, mark transaction as failed
377-
{
378-
let mut state = self.transaction_state.lock().await;
379-
if *state == TransactionState::Active {
380-
*state = TransactionState::Failed;
381-
}
434+
let client_id = Self::get_client_id(client);
435+
if self.get_transaction_state(client_id).await == TransactionState::Active {
436+
self.update_transaction_state(client_id, TransactionState::Failed)
437+
.await;
382438
}
383439
return Err(PgWireError::ApiError(Box::new(e)));
384440
}
@@ -496,10 +552,29 @@ impl ExtendedQueryHandler for DfSessionService {
496552
return Ok(resp);
497553
}
498554

555+
if let Some(resp) = self
556+
.try_respond_transaction_statements(client, &query)
557+
.await?
558+
{
559+
return Ok(resp);
560+
}
561+
499562
if let Some(resp) = self.try_respond_show_statements(&query).await? {
500563
return Ok(resp);
501564
}
502565

566+
// Check if we're in a failed transaction and block non-transaction commands
567+
let client_id = Self::get_client_id(client);
568+
if self.get_transaction_state(client_id).await == TransactionState::Failed {
569+
return Err(PgWireError::UserError(Box::new(
570+
pgwire::error::ErrorInfo::new(
571+
"ERROR".to_string(),
572+
"25P01".to_string(),
573+
"current transaction is aborted, commands ignored until end of transaction block".to_string(),
574+
),
575+
)));
576+
}
577+
503578
let (_, plan) = &portal.statement.statement;
504579

505580
let param_types = plan
@@ -510,11 +585,18 @@ impl ExtendedQueryHandler for DfSessionService {
510585
.clone()
511586
.replace_params_with_values(&param_values)
512587
.map_err(|e| PgWireError::ApiError(Box::new(e)))?; // Fixed: Use &param_values
513-
let dataframe = self
514-
.session_context
515-
.execute_logical_plan(plan)
516-
.await
517-
.map_err(|e| PgWireError::ApiError(Box::new(e)))?;
588+
let dataframe = match self.session_context.execute_logical_plan(plan).await {
589+
Ok(df) => df,
590+
Err(e) => {
591+
// If we're in a transaction and a query fails, mark transaction as failed
592+
let client_id = Self::get_client_id(client);
593+
if self.get_transaction_state(client_id).await == TransactionState::Active {
594+
self.update_transaction_state(client_id, TransactionState::Failed)
595+
.await;
596+
}
597+
return Err(PgWireError::ApiError(Box::new(e)));
598+
}
599+
};
518600
let resp = df::encode_dataframe(dataframe, &portal.result_column_format).await?;
519601
Ok(Response::Query(resp))
520602
}
@@ -555,3 +637,134 @@ fn ordered_param_types(types: &HashMap<String, Option<DataType>>) -> Vec<Option<
555637
types.sort_by(|a, b| a.0.cmp(b.0));
556638
types.into_iter().map(|pt| pt.1.as_ref()).collect()
557639
}
640+
641+
#[cfg(test)]
642+
mod tests {
643+
use super::*;
644+
use datafusion::prelude::SessionContext;
645+
646+
#[tokio::test]
647+
async fn test_transaction_isolation() {
648+
let session_context = Arc::new(SessionContext::new());
649+
let auth_manager = Arc::new(AuthManager::new());
650+
let service = DfSessionService::new(session_context, auth_manager);
651+
652+
// Simulate two different connection IDs
653+
let client_id_1 = 1001;
654+
let client_id_2 = 1002;
655+
656+
// Client 1 starts a transaction
657+
service
658+
.update_transaction_state(client_id_1, TransactionState::Active)
659+
.await;
660+
661+
// Client 2 starts a transaction
662+
service
663+
.update_transaction_state(client_id_2, TransactionState::Active)
664+
.await;
665+
666+
// Verify both have active transactions independently
667+
{
668+
let states = service.connection_states.read().await;
669+
assert_eq!(
670+
states.get(&client_id_1).map(|s| s.transaction_state),
671+
Some(TransactionState::Active)
672+
);
673+
assert_eq!(
674+
states.get(&client_id_2).map(|s| s.transaction_state),
675+
Some(TransactionState::Active)
676+
);
677+
}
678+
679+
// Client 1 fails a transaction
680+
service
681+
.update_transaction_state(client_id_1, TransactionState::Failed)
682+
.await;
683+
684+
// Verify client 1 is failed but client 2 is still active
685+
{
686+
let states = service.connection_states.read().await;
687+
assert_eq!(
688+
states.get(&client_id_1).map(|s| s.transaction_state),
689+
Some(TransactionState::Failed)
690+
);
691+
assert_eq!(
692+
states.get(&client_id_2).map(|s| s.transaction_state),
693+
Some(TransactionState::Active)
694+
);
695+
}
696+
697+
// Client 1 rollback
698+
service
699+
.update_transaction_state(client_id_1, TransactionState::None)
700+
.await;
701+
702+
// Client 2 commit
703+
service
704+
.update_transaction_state(client_id_2, TransactionState::None)
705+
.await;
706+
707+
// Verify both are back to None state
708+
{
709+
let states = service.connection_states.read().await;
710+
assert_eq!(
711+
states.get(&client_id_1).map(|s| s.transaction_state),
712+
Some(TransactionState::None)
713+
);
714+
assert_eq!(
715+
states.get(&client_id_2).map(|s| s.transaction_state),
716+
Some(TransactionState::None)
717+
);
718+
}
719+
}
720+
721+
#[tokio::test]
722+
async fn test_opportunistic_cleanup() {
723+
let session_context = Arc::new(SessionContext::new());
724+
let auth_manager = Arc::new(AuthManager::new());
725+
let service = DfSessionService::new(session_context, auth_manager);
726+
727+
// Add some connection states
728+
service
729+
.update_transaction_state(2001, TransactionState::Active)
730+
.await;
731+
service
732+
.update_transaction_state(2002, TransactionState::Failed)
733+
.await;
734+
735+
// Manually create an old connection
736+
{
737+
let mut states = service.connection_states.write().await;
738+
states.insert(
739+
2003,
740+
ConnectionState {
741+
transaction_state: TransactionState::Active,
742+
last_activity: Instant::now() - Duration::from_secs(7200), // 2 hours old
743+
},
744+
);
745+
}
746+
747+
// Set cleanup counter to trigger cleanup on next update (fetch_add returns old value)
748+
service.cleanup_counter.store(99, Ordering::Relaxed);
749+
750+
// First update sets counter to 100 (99 + 1)
751+
service
752+
.update_transaction_state(2004, TransactionState::Active)
753+
.await;
754+
755+
// This should trigger cleanup (counter becomes 100, 100 % 100 == 0)
756+
service
757+
.update_transaction_state(2005, TransactionState::Active)
758+
.await;
759+
760+
// Verify only the old connection was removed (cleanup is now inline, no wait needed)
761+
{
762+
let states = service.connection_states.read().await;
763+
assert!(states.contains_key(&2001));
764+
assert!(states.contains_key(&2002));
765+
assert!(!states.contains_key(&2003)); // Old connection should be removed
766+
assert!(states.contains_key(&2004));
767+
assert!(states.contains_key(&2005));
768+
}
769+
}
770+
}

0 commit comments

Comments
 (0)