@@ -18,7 +18,9 @@ use pgwire::api::stmt::QueryParser;
1818use pgwire:: api:: stmt:: StoredStatement ;
1919use pgwire:: api:: { ClientInfo , PgWireServerHandlers , Type } ;
2020use pgwire:: error:: { PgWireError , PgWireResult } ;
21- use tokio:: sync:: Mutex ;
21+ use std:: sync:: atomic:: { AtomicU64 , Ordering } ;
22+ use std:: time:: { Duration , Instant } ;
23+ use tokio:: sync:: { Mutex , RwLock } ;
2224
2325use arrow_pg:: datatypes:: df;
2426use arrow_pg:: datatypes:: { arrow_schema_to_pg_fields, into_pg_type} ;
@@ -63,13 +65,26 @@ impl PgWireServerHandlers for HandlerFactory {
6365 }
6466}
6567
68+ /// Per-connection transaction state storage
69+ /// We use the process ID as the connection identifier since it's unique per connection
70+ pub type ConnectionId = i32 ;
71+
72+ #[ derive( Debug , Clone ) ]
73+ struct ConnectionState {
74+ transaction_state : TransactionState ,
75+ last_activity : Instant ,
76+ }
77+
78+ type ConnectionStates = Arc < RwLock < HashMap < ConnectionId , ConnectionState > > > ;
79+
6680/// The pgwire handler backed by a datafusion `SessionContext`
6781pub struct DfSessionService {
6882 session_context : Arc < SessionContext > ,
6983 parser : Arc < Parser > ,
7084 timezone : Arc < Mutex < String > > ,
71- transaction_state : Arc < Mutex < TransactionState > > ,
85+ connection_states : ConnectionStates ,
7286 auth_manager : Arc < AuthManager > ,
87+ cleanup_counter : AtomicU64 ,
7388}
7489
7590impl DfSessionService {
@@ -84,11 +99,48 @@ impl DfSessionService {
8499 session_context,
85100 parser,
86101 timezone : Arc :: new ( Mutex :: new ( "UTC" . to_string ( ) ) ) ,
87- transaction_state : Arc :: new ( Mutex :: new ( TransactionState :: None ) ) ,
102+ connection_states : Arc :: new ( RwLock :: new ( HashMap :: new ( ) ) ) ,
88103 auth_manager,
104+ cleanup_counter : AtomicU64 :: new ( 0 ) ,
105+ }
106+ }
107+
108+ async fn get_transaction_state ( & self , client_id : ConnectionId ) -> TransactionState {
109+ self . connection_states
110+ . read ( )
111+ . await
112+ . get ( & client_id)
113+ . map ( |s| s. transaction_state )
114+ . unwrap_or ( TransactionState :: None )
115+ }
116+
117+ async fn update_transaction_state ( & self , client_id : ConnectionId , new_state : TransactionState {
118+ let mut states = self . connection_states . write ( ) . await ;
119+
120+ // Update or insert state using entry API
121+ states
122+ . entry ( client_id)
123+ . and_modify ( |s| {
124+ s. transaction_state = new_state;
125+ s. last_activity = Instant :: now ( ) ;
126+ } )
127+ . or_insert ( ConnectionState {
128+ transaction_state : new_state,
129+ last_activity : Instant :: now ( ) ,
130+ } ) ;
131+
132+ // Inline cleanup every 100 operations
133+ if self . cleanup_counter . fetch_add ( 1 , Ordering :: Relaxed ) % 100 == 0 {
134+ let cutoff = Instant :: now ( ) - Duration :: from_secs ( 3600 ) ;
135+ states. retain ( |_, state| state. last_activity > cutoff) ;
89136 }
90137 }
91138
139+ fn get_client_id < C : ClientInfo > ( client : & C ) -> ConnectionId {
140+ // Use the process ID which is unique per connection
141+ client. pid_and_secret_key ( ) . 0
142+ }
143+
92144 /// Check if the current user has permission to execute a query
93145 async fn check_query_permission < C > ( & self , client : & C , query : & str ) -> PgWireResult < ( ) >
94146 where
@@ -213,18 +265,24 @@ impl DfSessionService {
213265 }
214266 }
215267
216- async fn try_respond_transaction_statements < ' a > (
268+ async fn try_respond_transaction_statements < ' a , C > (
217269 & self ,
270+ client : & C ,
218271 query_lower : & str ,
219- ) -> PgWireResult < Option < Response < ' a > > > {
272+ ) -> PgWireResult < Option < Response < ' a > > >
273+ where
274+ C : ClientInfo ,
275+ {
276+ let client_id = Self :: get_client_id ( client) ;
277+
220278 // Transaction handling based on pgwire example:
221279 // https://github.com/sunng87/pgwire/blob/master/examples/transaction.rs#L57
222280 match query_lower. trim ( ) {
223281 "begin" | "begin transaction" | "begin work" | "start transaction" => {
224- let mut state = self . transaction_state . lock ( ) . await ;
225- match * state {
282+ match self . get_transaction_state ( client_id) . await {
226283 TransactionState :: None => {
227- * state = TransactionState :: Active ;
284+ self . update_transaction_state ( client_id, TransactionState :: Active )
285+ . await ;
228286 Ok ( Some ( Response :: TransactionStart ( Tag :: new ( "BEGIN" ) ) ) )
229287 }
230288 TransactionState :: Active => {
@@ -245,10 +303,10 @@ impl DfSessionService {
245303 }
246304 }
247305 "commit" | "commit transaction" | "commit work" | "end" | "end transaction" => {
248- let mut state = self . transaction_state . lock ( ) . await ;
249- match * state {
306+ match self . get_transaction_state ( client_id) . await {
250307 TransactionState :: Active => {
251- * state = TransactionState :: None ;
308+ self . update_transaction_state ( client_id, TransactionState :: None )
309+ . await ;
252310 Ok ( Some ( Response :: TransactionEnd ( Tag :: new ( "COMMIT" ) ) ) )
253311 }
254312 TransactionState :: None => {
@@ -257,14 +315,15 @@ impl DfSessionService {
257315 }
258316 TransactionState :: Failed => {
259317 // COMMIT in failed transaction is treated as ROLLBACK
260- * state = TransactionState :: None ;
318+ self . update_transaction_state ( client_id, TransactionState :: None )
319+ . await ;
261320 Ok ( Some ( Response :: TransactionEnd ( Tag :: new ( "ROLLBACK" ) ) ) )
262321 }
263322 }
264323 }
265324 "rollback" | "rollback transaction" | "rollback work" | "abort" => {
266- let mut state = self . transaction_state . lock ( ) . await ;
267- * state = TransactionState :: None ;
325+ self . update_transaction_state ( client_id , TransactionState :: None )
326+ . await ;
268327 Ok ( Some ( Response :: TransactionEnd ( Tag :: new ( "ROLLBACK" ) ) ) )
269328 }
270329 _ => Ok ( None ) ,
@@ -343,7 +402,7 @@ impl SimpleQueryHandler for DfSessionService {
343402 }
344403
345404 if let Some ( resp) = self
346- . try_respond_transaction_statements ( & query_lower)
405+ . try_respond_transaction_statements ( client , & query_lower)
347406 . await ?
348407 {
349408 return Ok ( vec ! [ resp] ) ;
@@ -354,17 +413,15 @@ impl SimpleQueryHandler for DfSessionService {
354413 }
355414
356415 // Check if we're in a failed transaction and block non-transaction commands
357- {
358- let state = self . transaction_state . lock ( ) . await ;
359- if * state == TransactionState :: Failed {
360- return Err ( PgWireError :: UserError ( Box :: new (
361- pgwire:: error:: ErrorInfo :: new (
362- "ERROR" . to_string ( ) ,
363- "25P01" . to_string ( ) ,
364- "current transaction is aborted, commands ignored until end of transaction block" . to_string ( ) ,
365- ) ,
366- ) ) ) ;
367- }
416+ let client_id = Self :: get_client_id ( client) ;
417+ if self . get_transaction_state ( client_id) . await == TransactionState :: Failed {
418+ return Err ( PgWireError :: UserError ( Box :: new (
419+ pgwire:: error:: ErrorInfo :: new (
420+ "ERROR" . to_string ( ) ,
421+ "25P01" . to_string ( ) ,
422+ "current transaction is aborted, commands ignored until end of transaction block" . to_string ( ) ,
423+ ) ,
424+ ) ) ) ;
368425 }
369426
370427 let df_result = self . session_context . sql ( query) . await ;
@@ -374,11 +431,10 @@ impl SimpleQueryHandler for DfSessionService {
374431 Ok ( df) => df,
375432 Err ( e) => {
376433 // If we're in a transaction and a query fails, mark transaction as failed
377- {
378- let mut state = self . transaction_state . lock ( ) . await ;
379- if * state == TransactionState :: Active {
380- * state = TransactionState :: Failed ;
381- }
434+ let client_id = Self :: get_client_id ( client) ;
435+ if self . get_transaction_state ( client_id) . await == TransactionState :: Active {
436+ self . update_transaction_state ( client_id, TransactionState :: Failed )
437+ . await ;
382438 }
383439 return Err ( PgWireError :: ApiError ( Box :: new ( e) ) ) ;
384440 }
@@ -496,10 +552,29 @@ impl ExtendedQueryHandler for DfSessionService {
496552 return Ok ( resp) ;
497553 }
498554
555+ if let Some ( resp) = self
556+ . try_respond_transaction_statements ( client, & query)
557+ . await ?
558+ {
559+ return Ok ( resp) ;
560+ }
561+
499562 if let Some ( resp) = self . try_respond_show_statements ( & query) . await ? {
500563 return Ok ( resp) ;
501564 }
502565
566+ // Check if we're in a failed transaction and block non-transaction commands
567+ let client_id = Self :: get_client_id ( client) ;
568+ if self . get_transaction_state ( client_id) . await == TransactionState :: Failed {
569+ return Err ( PgWireError :: UserError ( Box :: new (
570+ pgwire:: error:: ErrorInfo :: new (
571+ "ERROR" . to_string ( ) ,
572+ "25P01" . to_string ( ) ,
573+ "current transaction is aborted, commands ignored until end of transaction block" . to_string ( ) ,
574+ ) ,
575+ ) ) ) ;
576+ }
577+
503578 let ( _, plan) = & portal. statement . statement ;
504579
505580 let param_types = plan
@@ -510,11 +585,18 @@ impl ExtendedQueryHandler for DfSessionService {
510585 . clone ( )
511586 . replace_params_with_values ( & param_values)
512587 . map_err ( |e| PgWireError :: ApiError ( Box :: new ( e) ) ) ?; // Fixed: Use ¶m_values
513- let dataframe = self
514- . session_context
515- . execute_logical_plan ( plan)
516- . await
517- . map_err ( |e| PgWireError :: ApiError ( Box :: new ( e) ) ) ?;
588+ let dataframe = match self . session_context . execute_logical_plan ( plan) . await {
589+ Ok ( df) => df,
590+ Err ( e) => {
591+ // If we're in a transaction and a query fails, mark transaction as failed
592+ let client_id = Self :: get_client_id ( client) ;
593+ if self . get_transaction_state ( client_id) . await == TransactionState :: Active {
594+ self . update_transaction_state ( client_id, TransactionState :: Failed )
595+ . await ;
596+ }
597+ return Err ( PgWireError :: ApiError ( Box :: new ( e) ) ) ;
598+ }
599+ } ;
518600 let resp = df:: encode_dataframe ( dataframe, & portal. result_column_format ) . await ?;
519601 Ok ( Response :: Query ( resp) )
520602 }
@@ -555,3 +637,134 @@ fn ordered_param_types(types: &HashMap<String, Option<DataType>>) -> Vec<Option<
555637 types. sort_by ( |a, b| a. 0 . cmp ( b. 0 ) ) ;
556638 types. into_iter ( ) . map ( |pt| pt. 1 . as_ref ( ) ) . collect ( )
557639}
640+
641+ #[ cfg( test) ]
642+ mod tests {
643+ use super :: * ;
644+ use datafusion:: prelude:: SessionContext ;
645+
646+ #[ tokio:: test]
647+ async fn test_transaction_isolation ( ) {
648+ let session_context = Arc :: new ( SessionContext :: new ( ) ) ;
649+ let auth_manager = Arc :: new ( AuthManager :: new ( ) ) ;
650+ let service = DfSessionService :: new ( session_context, auth_manager) ;
651+
652+ // Simulate two different connection IDs
653+ let client_id_1 = 1001 ;
654+ let client_id_2 = 1002 ;
655+
656+ // Client 1 starts a transaction
657+ service
658+ . update_transaction_state ( client_id_1, TransactionState :: Active )
659+ . await ;
660+
661+ // Client 2 starts a transaction
662+ service
663+ . update_transaction_state ( client_id_2, TransactionState :: Active )
664+ . await ;
665+
666+ // Verify both have active transactions independently
667+ {
668+ let states = service. connection_states . read ( ) . await ;
669+ assert_eq ! (
670+ states. get( & client_id_1) . map( |s| s. transaction_state) ,
671+ Some ( TransactionState :: Active )
672+ ) ;
673+ assert_eq ! (
674+ states. get( & client_id_2) . map( |s| s. transaction_state) ,
675+ Some ( TransactionState :: Active )
676+ ) ;
677+ }
678+
679+ // Client 1 fails a transaction
680+ service
681+ . update_transaction_state ( client_id_1, TransactionState :: Failed )
682+ . await ;
683+
684+ // Verify client 1 is failed but client 2 is still active
685+ {
686+ let states = service. connection_states . read ( ) . await ;
687+ assert_eq ! (
688+ states. get( & client_id_1) . map( |s| s. transaction_state) ,
689+ Some ( TransactionState :: Failed )
690+ ) ;
691+ assert_eq ! (
692+ states. get( & client_id_2) . map( |s| s. transaction_state) ,
693+ Some ( TransactionState :: Active )
694+ ) ;
695+ }
696+
697+ // Client 1 rollback
698+ service
699+ . update_transaction_state ( client_id_1, TransactionState :: None )
700+ . await ;
701+
702+ // Client 2 commit
703+ service
704+ . update_transaction_state ( client_id_2, TransactionState :: None )
705+ . await ;
706+
707+ // Verify both are back to None state
708+ {
709+ let states = service. connection_states . read ( ) . await ;
710+ assert_eq ! (
711+ states. get( & client_id_1) . map( |s| s. transaction_state) ,
712+ Some ( TransactionState :: None )
713+ ) ;
714+ assert_eq ! (
715+ states. get( & client_id_2) . map( |s| s. transaction_state) ,
716+ Some ( TransactionState :: None )
717+ ) ;
718+ }
719+ }
720+
721+ #[ tokio:: test]
722+ async fn test_opportunistic_cleanup ( ) {
723+ let session_context = Arc :: new ( SessionContext :: new ( ) ) ;
724+ let auth_manager = Arc :: new ( AuthManager :: new ( ) ) ;
725+ let service = DfSessionService :: new ( session_context, auth_manager) ;
726+
727+ // Add some connection states
728+ service
729+ . update_transaction_state ( 2001 , TransactionState :: Active )
730+ . await ;
731+ service
732+ . update_transaction_state ( 2002 , TransactionState :: Failed )
733+ . await ;
734+
735+ // Manually create an old connection
736+ {
737+ let mut states = service. connection_states . write ( ) . await ;
738+ states. insert (
739+ 2003 ,
740+ ConnectionState {
741+ transaction_state : TransactionState :: Active ,
742+ last_activity : Instant :: now ( ) - Duration :: from_secs ( 7200 ) , // 2 hours old
743+ } ,
744+ ) ;
745+ }
746+
747+ // Set cleanup counter to trigger cleanup on next update (fetch_add returns old value)
748+ service. cleanup_counter . store ( 99 , Ordering :: Relaxed ) ;
749+
750+ // First update sets counter to 100 (99 + 1)
751+ service
752+ . update_transaction_state ( 2004 , TransactionState :: Active )
753+ . await ;
754+
755+ // This should trigger cleanup (counter becomes 100, 100 % 100 == 0)
756+ service
757+ . update_transaction_state ( 2005 , TransactionState :: Active )
758+ . await ;
759+
760+ // Verify only the old connection was removed (cleanup is now inline, no wait needed)
761+ {
762+ let states = service. connection_states . read ( ) . await ;
763+ assert ! ( states. contains_key( & 2001 ) ) ;
764+ assert ! ( states. contains_key( & 2002 ) ) ;
765+ assert ! ( !states. contains_key( & 2003 ) ) ; // Old connection should be removed
766+ assert ! ( states. contains_key( & 2004 ) ) ;
767+ assert ! ( states. contains_key( & 2005 ) ) ;
768+ }
769+ }
770+ }
0 commit comments