@@ -31,6 +31,9 @@ use tokio::sync::Mutex;
3131use arrow_pg:: datatypes:: df;
3232use arrow_pg:: datatypes:: { arrow_schema_to_pg_fields, into_pg_type} ;
3333
34+ // Metadata keys for session-level settings
35+ const METADATA_STATEMENT_TIMEOUT : & str = "statement_timeout_ms" ;
36+
3437/// Simple startup handler that does no authentication
3538/// For production, use DfAuthSource with proper pgwire authentication handlers
3639pub struct SimpleStartupHandler ;
@@ -71,7 +74,6 @@ pub struct DfSessionService {
7174 timezone : Arc < Mutex < String > > ,
7275 auth_manager : Arc < AuthManager > ,
7376 sql_rewrite_rules : Vec < Arc < dyn SqlStatementRewriteRule > > ,
74- statement_timeout : Arc < Mutex < Option < std:: time:: Duration > > > ,
7577}
7678
7779impl DfSessionService {
@@ -98,7 +100,31 @@ impl DfSessionService {
98100 timezone : Arc :: new ( Mutex :: new ( "UTC" . to_string ( ) ) ) ,
99101 auth_manager,
100102 sql_rewrite_rules,
101- statement_timeout : Arc :: new ( Mutex :: new ( None ) ) ,
103+ }
104+ }
105+
106+ /// Get statement timeout from client metadata
107+ fn get_statement_timeout < C > ( client : & C ) -> Option < std:: time:: Duration >
108+ where
109+ C : ClientInfo ,
110+ {
111+ client
112+ . metadata ( )
113+ . get ( METADATA_STATEMENT_TIMEOUT )
114+ . and_then ( |s| s. parse :: < u64 > ( ) . ok ( ) )
115+ . map ( std:: time:: Duration :: from_millis)
116+ }
117+
118+ /// Set statement timeout in client metadata
119+ fn set_statement_timeout < C > ( client : & mut C , timeout : Option < std:: time:: Duration > )
120+ where
121+ C : ClientInfo ,
122+ {
123+ let metadata = client. metadata_mut ( ) ;
124+ if let Some ( duration) = timeout {
125+ metadata. insert ( METADATA_STATEMENT_TIMEOUT . to_string ( ) , duration. as_millis ( ) . to_string ( ) ) ;
126+ } else {
127+ metadata. remove ( METADATA_STATEMENT_TIMEOUT ) ;
102128 }
103129 }
104130
@@ -196,10 +222,14 @@ impl DfSessionService {
196222 Ok ( QueryResponse :: new ( Arc :: new ( fields) , Box :: pin ( row_stream) ) )
197223 }
198224
199- async fn try_respond_set_statements < ' a > (
225+ async fn try_respond_set_statements < ' a , C > (
200226 & self ,
227+ client : & mut C ,
201228 query_lower : & str ,
202- ) -> PgWireResult < Option < Response < ' a > > > {
229+ ) -> PgWireResult < Option < Response < ' a > > >
230+ where
231+ C : ClientInfo ,
232+ {
203233 if query_lower. starts_with ( "set" ) {
204234 if query_lower. starts_with ( "set time zone" ) {
205235 let parts: Vec < & str > = query_lower. split_whitespace ( ) . collect ( ) ;
@@ -221,10 +251,9 @@ impl DfSessionService {
221251 let parts: Vec < & str > = query_lower. split_whitespace ( ) . collect ( ) ;
222252 if parts. len ( ) >= 3 {
223253 let timeout_str = parts[ 2 ] . trim_matches ( '"' ) . trim_matches ( '\'' ) ;
224- let mut statement_timeout = self . statement_timeout . lock ( ) . await ;
225254
226- if timeout_str == "0" || timeout_str. is_empty ( ) {
227- * statement_timeout = None ;
255+ let timeout = if timeout_str == "0" || timeout_str. is_empty ( ) {
256+ None
228257 } else {
229258 // Parse timeout value (supports ms, s, min formats)
230259 let timeout_ms = if timeout_str. ends_with ( "ms" ) {
@@ -245,14 +274,12 @@ impl DfSessionService {
245274 } ;
246275
247276 match timeout_ms {
248- Ok ( ms) if ms > 0 => {
249- * statement_timeout = Some ( std:: time:: Duration :: from_millis ( ms) ) ;
250- }
251- _ => {
252- * statement_timeout = None ;
253- }
277+ Ok ( ms) if ms > 0 => Some ( std:: time:: Duration :: from_millis ( ms) ) ,
278+ _ => None ,
254279 }
255- }
280+ } ;
281+
282+ Self :: set_statement_timeout ( client, timeout) ;
256283 Ok ( Some ( Response :: Execution ( Tag :: new ( "SET" ) ) ) )
257284 } else {
258285 Err ( PgWireError :: UserError ( Box :: new (
@@ -322,10 +349,14 @@ impl DfSessionService {
322349 }
323350 }
324351
325- async fn try_respond_show_statements < ' a > (
352+ async fn try_respond_show_statements < ' a , C > (
326353 & self ,
354+ client : & C ,
327355 query_lower : & str ,
328- ) -> PgWireResult < Option < Response < ' a > > > {
356+ ) -> PgWireResult < Option < Response < ' a > > >
357+ where
358+ C : ClientInfo ,
359+ {
329360 if query_lower. starts_with ( "show " ) {
330361 match query_lower. strip_suffix ( ";" ) . unwrap_or ( query_lower) {
331362 "show time zone" => {
@@ -354,7 +385,7 @@ impl DfSessionService {
354385 Ok ( Some ( Response :: Query ( resp) ) )
355386 }
356387 "show statement_timeout" => {
357- let timeout = * self . statement_timeout . lock ( ) . await ;
388+ let timeout = Self :: get_statement_timeout ( client ) ;
358389 let timeout_str = match timeout {
359390 Some ( duration) => format ! ( "{}ms" , duration. as_millis( ) ) ,
360391 None => "0" . to_string ( ) ,
@@ -408,7 +439,7 @@ impl SimpleQueryHandler for DfSessionService {
408439 self . check_query_permission ( client, & query) . await ?;
409440 }
410441
411- if let Some ( resp) = self . try_respond_set_statements ( & query_lower) . await ? {
442+ if let Some ( resp) = self . try_respond_set_statements ( client , & query_lower) . await ? {
412443 return Ok ( vec ! [ resp] ) ;
413444 }
414445
@@ -419,7 +450,7 @@ impl SimpleQueryHandler for DfSessionService {
419450 return Ok ( vec ! [ resp] ) ;
420451 }
421452
422- if let Some ( resp) = self . try_respond_show_statements ( & query_lower) . await ? {
453+ if let Some ( resp) = self . try_respond_show_statements ( client , & query_lower) . await ? {
423454 return Ok ( vec ! [ resp] ) ;
424455 }
425456
@@ -436,7 +467,7 @@ impl SimpleQueryHandler for DfSessionService {
436467 }
437468
438469 let df_result = {
439- let timeout = * self . statement_timeout . lock ( ) . await ;
470+ let timeout = Self :: get_statement_timeout ( client ) ;
440471 if let Some ( timeout_duration) = timeout {
441472 tokio:: time:: timeout ( timeout_duration, self . session_context . sql ( & query) )
442473 . await
@@ -568,7 +599,7 @@ impl ExtendedQueryHandler for DfSessionService {
568599 . await ?;
569600 }
570601
571- if let Some ( resp) = self . try_respond_set_statements ( & query) . await ? {
602+ if let Some ( resp) = self . try_respond_set_statements ( client , & query) . await ? {
572603 return Ok ( resp) ;
573604 }
574605
@@ -579,7 +610,7 @@ impl ExtendedQueryHandler for DfSessionService {
579610 return Ok ( resp) ;
580611 }
581612
582- if let Some ( resp) = self . try_respond_show_statements ( & query) . await ? {
613+ if let Some ( resp) = self . try_respond_show_statements ( client , & query) . await ? {
583614 return Ok ( resp) ;
584615 }
585616
@@ -613,7 +644,7 @@ impl ExtendedQueryHandler for DfSessionService {
613644 . map_err ( |e| PgWireError :: ApiError ( Box :: new ( e) ) ) ?;
614645
615646 let dataframe = {
616- let timeout = * self . statement_timeout . lock ( ) . await ;
647+ let timeout = Self :: get_statement_timeout ( client ) ;
617648 if let Some ( timeout_duration) = timeout {
618649 tokio:: time:: timeout (
619650 timeout_duration,
@@ -690,28 +721,88 @@ mod tests {
690721 use super :: * ;
691722 use crate :: auth:: AuthManager ;
692723 use datafusion:: prelude:: SessionContext ;
724+ use std:: collections:: HashMap ;
693725 use std:: time:: Duration ;
694726
727+ struct MockClient {
728+ metadata : HashMap < String , String > ,
729+ }
730+
731+ impl MockClient {
732+ fn new ( ) -> Self {
733+ Self {
734+ metadata : HashMap :: new ( ) ,
735+ }
736+ }
737+ }
738+
739+ impl ClientInfo for MockClient {
740+ fn socket_addr ( & self ) -> std:: net:: SocketAddr {
741+ "127.0.0.1:5432" . parse ( ) . unwrap ( )
742+ }
743+
744+ fn is_secure ( & self ) -> bool {
745+ false
746+ }
747+
748+ fn protocol_version ( & self ) -> pgwire:: messages:: ProtocolVersion {
749+ pgwire:: messages:: ProtocolVersion :: PROTOCOL3_0
750+ }
751+
752+ fn set_protocol_version ( & mut self , _version : pgwire:: messages:: ProtocolVersion ) { }
753+
754+ fn pid_and_secret_key ( & self ) -> ( i32 , pgwire:: messages:: startup:: SecretKey ) {
755+ ( 0 , pgwire:: messages:: startup:: SecretKey :: I32 ( 0 ) )
756+ }
757+
758+ fn set_pid_and_secret_key ( & mut self , _pid : i32 , _secret_key : pgwire:: messages:: startup:: SecretKey ) { }
759+
760+ fn state ( & self ) -> pgwire:: api:: PgWireConnectionState {
761+ pgwire:: api:: PgWireConnectionState :: ReadyForQuery
762+ }
763+
764+ fn set_state ( & mut self , _new_state : pgwire:: api:: PgWireConnectionState ) { }
765+
766+ fn transaction_status ( & self ) -> pgwire:: messages:: response:: TransactionStatus {
767+ pgwire:: messages:: response:: TransactionStatus :: Idle
768+ }
769+
770+ fn set_transaction_status ( & mut self , _new_status : pgwire:: messages:: response:: TransactionStatus ) { }
771+
772+ fn metadata ( & self ) -> & HashMap < String , String > {
773+ & self . metadata
774+ }
775+
776+ fn metadata_mut ( & mut self ) -> & mut HashMap < String , String > {
777+ & mut self . metadata
778+ }
779+
780+ fn client_certificates < ' a > ( & self ) -> Option < & [ rustls_pki_types:: CertificateDer < ' a > ] > {
781+ None
782+ }
783+ }
784+
695785 #[ tokio:: test]
696786 async fn test_statement_timeout_set_and_show ( ) {
697787 let session_context = Arc :: new ( SessionContext :: new ( ) ) ;
698788 let auth_manager = Arc :: new ( AuthManager :: new ( ) ) ;
699789 let service = DfSessionService :: new ( session_context, auth_manager) ;
790+ let mut client = MockClient :: new ( ) ;
700791
701792 // Test setting timeout to 5000ms
702793 let set_response = service
703- . try_respond_set_statements ( "set statement_timeout '5000ms'" )
794+ . try_respond_set_statements ( & mut client , "set statement_timeout '5000ms'" )
704795 . await
705796 . unwrap ( ) ;
706797 assert ! ( set_response. is_some( ) ) ;
707798
708- // Verify the timeout was set
709- let timeout = * service . statement_timeout . lock ( ) . await ;
799+ // Verify the timeout was set in client metadata
800+ let timeout = DfSessionService :: get_statement_timeout ( & client ) ;
710801 assert_eq ! ( timeout, Some ( Duration :: from_millis( 5000 ) ) ) ;
711802
712803 // Test SHOW statement_timeout
713804 let show_response = service
714- . try_respond_show_statements ( "show statement_timeout" )
805+ . try_respond_show_statements ( & client , "show statement_timeout" )
715806 . await
716807 . unwrap ( ) ;
717808 assert ! ( show_response. is_some( ) ) ;
@@ -722,20 +813,21 @@ mod tests {
722813 let session_context = Arc :: new ( SessionContext :: new ( ) ) ;
723814 let auth_manager = Arc :: new ( AuthManager :: new ( ) ) ;
724815 let service = DfSessionService :: new ( session_context, auth_manager) ;
816+ let mut client = MockClient :: new ( ) ;
725817
726818 // Set timeout first
727819 service
728- . try_respond_set_statements ( "set statement_timeout '1000ms'" )
820+ . try_respond_set_statements ( & mut client , "set statement_timeout '1000ms'" )
729821 . await
730822 . unwrap ( ) ;
731823
732824 // Disable timeout with 0
733825 service
734- . try_respond_set_statements ( "set statement_timeout '0'" )
826+ . try_respond_set_statements ( & mut client , "set statement_timeout '0'" )
735827 . await
736828 . unwrap ( ) ;
737829
738- let timeout = * service . statement_timeout . lock ( ) . await ;
830+ let timeout = DfSessionService :: get_statement_timeout ( & client ) ;
739831 assert_eq ! ( timeout, None ) ;
740832 }
741833}
0 commit comments