11use std:: collections:: HashMap ;
2- use std:: hash:: { Hash , Hasher } ;
32use std:: sync:: Arc ;
43
54use crate :: auth:: { AuthManager , Permission , ResourceType } ;
@@ -19,20 +18,12 @@ use pgwire::api::stmt::QueryParser;
1918use pgwire:: api:: stmt:: StoredStatement ;
2019use pgwire:: api:: { ClientInfo , PgWireServerHandlers , Type } ;
2120use pgwire:: error:: { PgWireError , PgWireResult } ;
22- use std:: sync:: atomic:: { AtomicU64 , Ordering } ;
23- use std:: time:: { Duration , Instant } ;
24- use tokio:: sync:: { Mutex , RwLock } ;
21+ use pgwire:: messages:: response:: TransactionStatus ;
22+ use tokio:: sync:: Mutex ;
2523
2624use arrow_pg:: datatypes:: df;
2725use arrow_pg:: datatypes:: { arrow_schema_to_pg_fields, into_pg_type} ;
2826
29- #[ derive( Debug , Clone , Copy , PartialEq ) ]
30- pub enum TransactionState {
31- None ,
32- Active ,
33- Failed ,
34- }
35-
3627/// Simple startup handler that does no authentication
3728/// For production, use DfAuthSource with proper pgwire authentication handlers
3829pub struct SimpleStartupHandler ;
@@ -66,26 +57,12 @@ impl PgWireServerHandlers for HandlerFactory {
6657 }
6758}
6859
69- /// Per-connection transaction state storage
70- /// We use a hash of both PID and secret key as the connection identifier for better uniqueness
71- pub type ConnectionId = u64 ;
72-
73- #[ derive( Debug , Clone ) ]
74- struct ConnectionState {
75- transaction_state : TransactionState ,
76- last_activity : Instant ,
77- }
78-
79- type ConnectionStates = Arc < RwLock < HashMap < ConnectionId , ConnectionState > > > ;
80-
8160/// The pgwire handler backed by a datafusion `SessionContext`
8261pub struct DfSessionService {
8362 session_context : Arc < SessionContext > ,
8463 parser : Arc < Parser > ,
8564 timezone : Arc < Mutex < String > > ,
86- connection_states : ConnectionStates ,
8765 auth_manager : Arc < AuthManager > ,
88- cleanup_counter : AtomicU64 ,
8966}
9067
9168impl DfSessionService {
@@ -100,57 +77,10 @@ impl DfSessionService {
10077 session_context,
10178 parser,
10279 timezone : Arc :: new ( Mutex :: new ( "UTC" . to_string ( ) ) ) ,
103- connection_states : Arc :: new ( RwLock :: new ( HashMap :: new ( ) ) ) ,
10480 auth_manager,
105- cleanup_counter : AtomicU64 :: new ( 0 ) ,
106- }
107- }
108-
109- async fn get_transaction_state ( & self , client_id : ConnectionId ) -> TransactionState {
110- self . connection_states
111- . read ( )
112- . await
113- . get ( & client_id)
114- . map ( |s| s. transaction_state )
115- . unwrap_or ( TransactionState :: None )
116- }
117-
118- async fn update_transaction_state ( & self , client_id : ConnectionId , new_state : TransactionState ) {
119- let mut states = self . connection_states . write ( ) . await ;
120-
121- // Update or insert state using entry API
122- states
123- . entry ( client_id)
124- . and_modify ( |s| {
125- s. transaction_state = new_state;
126- s. last_activity = Instant :: now ( ) ;
127- } )
128- . or_insert ( ConnectionState {
129- transaction_state : new_state,
130- last_activity : Instant :: now ( ) ,
131- } ) ;
132-
133- // Inline cleanup every 100 operations
134- if self . cleanup_counter . fetch_add ( 1 , Ordering :: Relaxed ) % 100 == 0 {
135- let cutoff = Instant :: now ( ) - Duration :: from_secs ( 3600 ) ;
136- states. retain ( |_, state| state. last_activity > cutoff) ;
13781 }
13882 }
13983
140- fn get_client_id < C : ClientInfo > ( client : & C ) -> ConnectionId {
141- // Use a hash of PID, secret key, and socket address for better uniqueness
142- let ( pid, secret) = client. pid_and_secret_key ( ) ;
143- let socket_addr = client. socket_addr ( ) ;
144-
145- // Create a hash of all identifying values
146- let mut hasher = std:: collections:: hash_map:: DefaultHasher :: new ( ) ;
147- pid. hash ( & mut hasher) ;
148- secret. hash ( & mut hasher) ;
149- socket_addr. hash ( & mut hasher) ;
150-
151- hasher. finish ( )
152- }
153-
15484 /// Check if the current user has permission to execute a query
15585 async fn check_query_permission < C > ( & self , client : & C , query : & str ) -> PgWireResult < ( ) >
15686 where
@@ -290,24 +220,15 @@ impl DfSessionService {
290220 where
291221 C : ClientInfo ,
292222 {
293- let client_id = Self :: get_client_id ( client) ;
294-
295223 // Transaction handling based on pgwire example:
296224 // https://github.com/sunng87/pgwire/blob/master/examples/transaction.rs#L57
297225 match query_lower. trim ( ) {
298226 "begin" | "begin transaction" | "begin work" | "start transaction" => {
299- match self . get_transaction_state ( client_id) . await {
300- TransactionState :: None => {
301- self . update_transaction_state ( client_id, TransactionState :: Active )
302- . await ;
303- Ok ( Some ( Response :: TransactionStart ( Tag :: new ( "BEGIN" ) ) ) )
304- }
305- TransactionState :: Active => {
306- // Already in transaction, PostgreSQL allows this but issues a warning
307- // For simplicity, we'll just return BEGIN again
227+ match client. transaction_status ( ) {
228+ TransactionStatus :: Idle | TransactionStatus :: Transaction => {
308229 Ok ( Some ( Response :: TransactionStart ( Tag :: new ( "BEGIN" ) ) ) )
309230 }
310- TransactionState :: Failed => {
231+ TransactionStatus :: Error => {
311232 // Can't start new transaction from failed state
312233 Err ( PgWireError :: UserError ( Box :: new (
313234 pgwire:: error:: ErrorInfo :: new (
@@ -320,27 +241,16 @@ impl DfSessionService {
320241 }
321242 }
322243 "commit" | "commit transaction" | "commit work" | "end" | "end transaction" => {
323- match self . get_transaction_state ( client_id) . await {
324- TransactionState :: Active => {
325- self . update_transaction_state ( client_id, TransactionState :: None )
326- . await ;
327- Ok ( Some ( Response :: TransactionEnd ( Tag :: new ( "COMMIT" ) ) ) )
328- }
329- TransactionState :: None => {
330- // PostgreSQL allows COMMIT outside transaction with warning
244+ match client. transaction_status ( ) {
245+ TransactionStatus :: Idle | TransactionStatus :: Transaction => {
331246 Ok ( Some ( Response :: TransactionEnd ( Tag :: new ( "COMMIT" ) ) ) )
332247 }
333- TransactionState :: Failed => {
334- // COMMIT in failed transaction is treated as ROLLBACK
335- self . update_transaction_state ( client_id, TransactionState :: None )
336- . await ;
248+ TransactionStatus :: Error => {
337249 Ok ( Some ( Response :: TransactionEnd ( Tag :: new ( "ROLLBACK" ) ) ) )
338250 }
339251 }
340252 }
341253 "rollback" | "rollback transaction" | "rollback work" | "abort" => {
342- self . update_transaction_state ( client_id, TransactionState :: None )
343- . await ;
344254 Ok ( Some ( Response :: TransactionEnd ( Tag :: new ( "ROLLBACK" ) ) ) )
345255 }
346256 _ => Ok ( None ) ,
@@ -399,7 +309,7 @@ impl SimpleQueryHandler for DfSessionService {
399309 C : ClientInfo + Unpin + Send + Sync ,
400310 {
401311 let query_lower = query. to_lowercase ( ) . trim ( ) . to_string ( ) ;
402- log:: debug!( "Received query: {}" , query ) ; // Log the query for debugging
312+ log:: debug!( "Received query: {query}" ) ; // Log the query for debugging
403313
404314 // Check permissions for the query (skip for SET, transaction, and SHOW statements)
405315 if !query_lower. starts_with ( "set" )
@@ -429,9 +339,9 @@ impl SimpleQueryHandler for DfSessionService {
429339 return Ok ( vec ! [ resp] ) ;
430340 }
431341
432- // Check if we're in a failed transaction and block non-transaction commands
433- let client_id = Self :: get_client_id ( client ) ;
434- if self . get_transaction_state ( client_id ) . await == TransactionState :: Failed {
342+ // Check if we're in a failed transaction and block non-transaction
343+ // commands
344+ if client . transaction_status ( ) == TransactionStatus :: Error {
435345 return Err ( PgWireError :: UserError ( Box :: new (
436346 pgwire:: error:: ErrorInfo :: new (
437347 "ERROR" . to_string ( ) ,
@@ -447,12 +357,6 @@ impl SimpleQueryHandler for DfSessionService {
447357 let df = match df_result {
448358 Ok ( df) => df,
449359 Err ( e) => {
450- // If we're in a transaction and a query fails, mark transaction as failed
451- let client_id = Self :: get_client_id ( client) ;
452- if self . get_transaction_state ( client_id) . await == TransactionState :: Active {
453- self . update_transaction_state ( client_id, TransactionState :: Failed )
454- . await ;
455- }
456360 return Err ( PgWireError :: ApiError ( Box :: new ( e) ) ) ;
457361 }
458362 } ;
@@ -557,7 +461,7 @@ impl ExtendedQueryHandler for DfSessionService {
557461 . to_lowercase ( )
558462 . trim ( )
559463 . to_string ( ) ;
560- log:: debug!( "Received execute extended query: {}" , query ) ; // Log for debugging
464+ log:: debug!( "Received execute extended query: {query}" ) ; // Log for debugging
561465
562466 // Check permissions for the query (skip for SET and SHOW statements)
563467 if !query. starts_with ( "set" ) && !query. starts_with ( "show" ) {
@@ -580,9 +484,9 @@ impl ExtendedQueryHandler for DfSessionService {
580484 return Ok ( resp) ;
581485 }
582486
583- // Check if we're in a failed transaction and block non-transaction commands
584- let client_id = Self :: get_client_id ( client ) ;
585- if self . get_transaction_state ( client_id ) . await == TransactionState :: Failed {
487+ // Check if we're in a failed transaction and block non-transaction
488+ // commands
489+ if client . transaction_status ( ) == TransactionStatus :: Error {
586490 return Err ( PgWireError :: UserError ( Box :: new (
587491 pgwire:: error:: ErrorInfo :: new (
588492 "ERROR" . to_string ( ) ,
@@ -605,12 +509,6 @@ impl ExtendedQueryHandler for DfSessionService {
605509 let dataframe = match self . session_context . execute_logical_plan ( plan) . await {
606510 Ok ( df) => df,
607511 Err ( e) => {
608- // If we're in a transaction and a query fails, mark transaction as failed
609- let client_id = Self :: get_client_id ( client) ;
610- if self . get_transaction_state ( client_id) . await == TransactionState :: Active {
611- self . update_transaction_state ( client_id, TransactionState :: Failed )
612- . await ;
613- }
614512 return Err ( PgWireError :: ApiError ( Box :: new ( e) ) ) ;
615513 }
616514 } ;
@@ -633,7 +531,7 @@ impl QueryParser for Parser {
633531 sql : & str ,
634532 _types : & [ Type ] ,
635533 ) -> PgWireResult < Self :: Statement > {
636- log:: debug!( "Received parse extended query: {}" , sql ) ; // Log for debugging
534+ log:: debug!( "Received parse extended query: {sql}" ) ; // Log for debugging
637535 let context = & self . session_context ;
638536 let state = context. state ( ) ;
639537 let logical_plan = state
@@ -654,134 +552,3 @@ fn ordered_param_types(types: &HashMap<String, Option<DataType>>) -> Vec<Option<
654552 types. sort_by ( |a, b| a. 0 . cmp ( b. 0 ) ) ;
655553 types. into_iter ( ) . map ( |pt| pt. 1 . as_ref ( ) ) . collect ( )
656554}
657-
658- #[ cfg( test) ]
659- mod tests {
660- use super :: * ;
661- use datafusion:: prelude:: SessionContext ;
662-
663- #[ tokio:: test]
664- async fn test_transaction_isolation ( ) {
665- let session_context = Arc :: new ( SessionContext :: new ( ) ) ;
666- let auth_manager = Arc :: new ( AuthManager :: new ( ) ) ;
667- let service = DfSessionService :: new ( session_context, auth_manager) ;
668-
669- // Simulate two different connection IDs
670- let client_id_1 = 1001 ;
671- let client_id_2 = 1002 ;
672-
673- // Client 1 starts a transaction
674- service
675- . update_transaction_state ( client_id_1, TransactionState :: Active )
676- . await ;
677-
678- // Client 2 starts a transaction
679- service
680- . update_transaction_state ( client_id_2, TransactionState :: Active )
681- . await ;
682-
683- // Verify both have active transactions independently
684- {
685- let states = service. connection_states . read ( ) . await ;
686- assert_eq ! (
687- states. get( & client_id_1) . map( |s| s. transaction_state) ,
688- Some ( TransactionState :: Active )
689- ) ;
690- assert_eq ! (
691- states. get( & client_id_2) . map( |s| s. transaction_state) ,
692- Some ( TransactionState :: Active )
693- ) ;
694- }
695-
696- // Client 1 fails a transaction
697- service
698- . update_transaction_state ( client_id_1, TransactionState :: Failed )
699- . await ;
700-
701- // Verify client 1 is failed but client 2 is still active
702- {
703- let states = service. connection_states . read ( ) . await ;
704- assert_eq ! (
705- states. get( & client_id_1) . map( |s| s. transaction_state) ,
706- Some ( TransactionState :: Failed )
707- ) ;
708- assert_eq ! (
709- states. get( & client_id_2) . map( |s| s. transaction_state) ,
710- Some ( TransactionState :: Active )
711- ) ;
712- }
713-
714- // Client 1 rollback
715- service
716- . update_transaction_state ( client_id_1, TransactionState :: None )
717- . await ;
718-
719- // Client 2 commit
720- service
721- . update_transaction_state ( client_id_2, TransactionState :: None )
722- . await ;
723-
724- // Verify both are back to None state
725- {
726- let states = service. connection_states . read ( ) . await ;
727- assert_eq ! (
728- states. get( & client_id_1) . map( |s| s. transaction_state) ,
729- Some ( TransactionState :: None )
730- ) ;
731- assert_eq ! (
732- states. get( & client_id_2) . map( |s| s. transaction_state) ,
733- Some ( TransactionState :: None )
734- ) ;
735- }
736- }
737-
738- #[ tokio:: test]
739- async fn test_opportunistic_cleanup ( ) {
740- let session_context = Arc :: new ( SessionContext :: new ( ) ) ;
741- let auth_manager = Arc :: new ( AuthManager :: new ( ) ) ;
742- let service = DfSessionService :: new ( session_context, auth_manager) ;
743-
744- // Add some connection states
745- service
746- . update_transaction_state ( 2001 , TransactionState :: Active )
747- . await ;
748- service
749- . update_transaction_state ( 2002 , TransactionState :: Failed )
750- . await ;
751-
752- // Manually create an old connection
753- {
754- let mut states = service. connection_states . write ( ) . await ;
755- states. insert (
756- 2003 ,
757- ConnectionState {
758- transaction_state : TransactionState :: Active ,
759- last_activity : Instant :: now ( ) - Duration :: from_secs ( 7200 ) , // 2 hours old
760- } ,
761- ) ;
762- }
763-
764- // Set cleanup counter to trigger cleanup on next update (fetch_add returns old value)
765- service. cleanup_counter . store ( 99 , Ordering :: Relaxed ) ;
766-
767- // First update sets counter to 100 (99 + 1)
768- service
769- . update_transaction_state ( 2004 , TransactionState :: Active )
770- . await ;
771-
772- // This should trigger cleanup (counter becomes 100, 100 % 100 == 0)
773- service
774- . update_transaction_state ( 2005 , TransactionState :: Active )
775- . await ;
776-
777- // Verify only the old connection was removed (cleanup is now inline, no wait needed)
778- {
779- let states = service. connection_states . read ( ) . await ;
780- assert ! ( states. contains_key( & 2001 ) ) ;
781- assert ! ( states. contains_key( & 2002 ) ) ;
782- assert ! ( !states. contains_key( & 2003 ) ) ; // Old connection should be removed
783- assert ! ( states. contains_key( & 2004 ) ) ;
784- assert ! ( states. contains_key( & 2005 ) ) ;
785- }
786- }
787- }
0 commit comments