@@ -43,9 +43,9 @@ pub struct HandlerFactory {
4343}
4444
4545impl HandlerFactory {
46- pub fn new ( session_context : Arc < SessionContext > , auth_manager : Arc < AuthManager > ) -> Self {
46+ pub fn new ( session_context : Arc < SessionContext > , auth_manager : Arc < AuthManager > , query_timeout : Option < std :: time :: Duration > ) -> Self {
4747 let session_service =
48- Arc :: new ( DfSessionService :: new ( session_context, auth_manager. clone ( ) ) ) ;
48+ Arc :: new ( DfSessionService :: new ( session_context, auth_manager. clone ( ) , query_timeout ) ) ;
4949 HandlerFactory { session_service }
5050 }
5151}
@@ -71,12 +71,14 @@ pub struct DfSessionService {
7171 timezone : Arc < Mutex < String > > ,
7272 auth_manager : Arc < AuthManager > ,
7373 sql_rewrite_rules : Vec < Arc < dyn SqlStatementRewriteRule > > ,
74+ query_timeout : Option < std:: time:: Duration > ,
7475}
7576
7677impl DfSessionService {
7778 pub fn new (
7879 session_context : Arc < SessionContext > ,
7980 auth_manager : Arc < AuthManager > ,
81+ query_timeout : Option < std:: time:: Duration > ,
8082 ) -> DfSessionService {
8183 let sql_rewrite_rules: Vec < Arc < dyn SqlStatementRewriteRule > > = vec ! [
8284 Arc :: new( AliasDuplicatedProjectionRewrite ) ,
@@ -97,9 +99,12 @@ impl DfSessionService {
9799 timezone : Arc :: new ( Mutex :: new ( "UTC" . to_string ( ) ) ) ,
98100 auth_manager,
99101 sql_rewrite_rules,
102+ query_timeout,
100103 }
101104 }
102105
106+
107+
103108 /// Check if the current user has permission to execute a query
104109 async fn check_query_permission < C > ( & self , client : & C , query : & str ) -> PgWireResult < ( ) >
105110 where
@@ -378,7 +383,19 @@ impl SimpleQueryHandler for DfSessionService {
378383 ) ) ) ;
379384 }
380385
381- let df_result = self . session_context . sql ( & query) . await ;
386+ let df_result = if let Some ( timeout) = self . query_timeout {
387+ tokio:: time:: timeout ( timeout, self . session_context . sql ( & query) )
388+ . await
389+ . map_err ( |_| {
390+ PgWireError :: UserError ( Box :: new ( pgwire:: error:: ErrorInfo :: new (
391+ "ERROR" . to_string ( ) ,
392+ "57014" . to_string ( ) , // query_canceled error code
393+ "canceling statement due to query timeout" . to_string ( ) ,
394+ ) ) )
395+ } ) ?
396+ } else {
397+ self . session_context . sql ( & query) . await
398+ } ;
382399
383400 // Handle query execution errors and transaction state
384401 let df = match df_result {
@@ -540,10 +557,23 @@ impl ExtendedQueryHandler for DfSessionService {
540557 . optimize ( & plan)
541558 . map_err ( |e| PgWireError :: ApiError ( Box :: new ( e) ) ) ?;
542559
543- let dataframe = match self . session_context . execute_logical_plan ( optimised) . await {
544- Ok ( df) => df,
545- Err ( e) => {
546- return Err ( PgWireError :: ApiError ( Box :: new ( e) ) ) ;
560+ let dataframe = if let Some ( timeout) = self . query_timeout {
561+ tokio:: time:: timeout ( timeout, self . session_context . execute_logical_plan ( optimised) )
562+ . await
563+ . map_err ( |_| {
564+ PgWireError :: UserError ( Box :: new ( pgwire:: error:: ErrorInfo :: new (
565+ "ERROR" . to_string ( ) ,
566+ "57014" . to_string ( ) , // query_canceled error code
567+ "canceling statement due to query timeout" . to_string ( ) ,
568+ ) ) )
569+ } ) ?
570+ . map_err ( |e| PgWireError :: ApiError ( Box :: new ( e) ) ) ?
571+ } else {
572+ match self . session_context . execute_logical_plan ( optimised) . await {
573+ Ok ( df) => df,
574+ Err ( e) => {
575+ return Err ( PgWireError :: ApiError ( Box :: new ( e) ) ) ;
576+ }
547577 }
548578 } ;
549579 let resp = df:: encode_dataframe ( dataframe, & portal. result_column_format ) . await ?;
@@ -593,3 +623,76 @@ fn ordered_param_types(types: &HashMap<String, Option<DataType>>) -> Vec<Option<
593623 types. sort_by ( |a, b| a. 0 . cmp ( b. 0 ) ) ;
594624 types. into_iter ( ) . map ( |pt| pt. 1 . as_ref ( ) ) . collect ( )
595625}
626+
627+ #[ cfg( test) ]
628+ mod tests {
629+ use super :: * ;
630+ use crate :: { auth:: AuthManager , ServerOptions } ;
631+ use datafusion:: prelude:: SessionContext ;
632+ use std:: time:: Duration ;
633+
634+ #[ test]
635+ fn test_server_options_default_timeout ( ) {
636+ let opts = ServerOptions :: default ( ) ;
637+ assert_eq ! ( opts. query_timeout, Some ( Duration :: from_secs( 30 ) ) ) ;
638+ }
639+
640+ #[ test]
641+ fn test_server_options_no_timeout ( ) {
642+ let mut opts = ServerOptions :: new ( ) ;
643+ opts. query_timeout = None ;
644+ assert_eq ! ( opts. query_timeout, None ) ;
645+ }
646+
647+ #[ test]
648+ fn test_handler_factory_with_timeout ( ) {
649+ let session_context = Arc :: new ( SessionContext :: new ( ) ) ;
650+ let auth_manager = Arc :: new ( AuthManager :: new ( ) ) ;
651+ let timeout = Some ( Duration :: from_secs ( 60 ) ) ;
652+
653+ let factory = HandlerFactory :: new ( session_context, auth_manager, timeout) ;
654+ assert_eq ! ( factory. session_service. query_timeout, timeout) ;
655+ }
656+
657+ #[ test]
658+ fn test_session_service_timeout_configuration ( ) {
659+ let session_context = Arc :: new ( SessionContext :: new ( ) ) ;
660+ let auth_manager = Arc :: new ( AuthManager :: new ( ) ) ;
661+
662+ // Test with timeout
663+ let service_with_timeout = DfSessionService :: new (
664+ session_context. clone ( ) ,
665+ auth_manager. clone ( ) ,
666+ Some ( Duration :: from_secs ( 45 ) )
667+ ) ;
668+ assert_eq ! ( service_with_timeout. query_timeout, Some ( Duration :: from_secs( 45 ) ) ) ;
669+
670+ // Test without timeout (None)
671+ let service_no_timeout = DfSessionService :: new (
672+ session_context,
673+ auth_manager,
674+ None
675+ ) ;
676+ assert_eq ! ( service_no_timeout. query_timeout, None ) ;
677+ }
678+
679+ #[ test]
680+ fn test_timeout_configuration_from_seconds ( ) {
681+ // Test 0 seconds = no timeout
682+ let opts_no_timeout = ServerOptions :: new ( ) . with_query_timeout_secs ( 0 ) ;
683+ assert_eq ! ( opts_no_timeout. query_timeout, None ) ;
684+
685+ // Test positive seconds = Some(Duration)
686+ let opts_with_timeout = ServerOptions :: new ( ) . with_query_timeout_secs ( 60 ) ;
687+ assert_eq ! ( opts_with_timeout. query_timeout, Some ( Duration :: from_secs( 60 ) ) ) ;
688+ }
689+
690+ #[ test]
691+ fn test_max_connections_configuration ( ) {
692+ let opts = ServerOptions :: new ( ) . with_max_connections ( 500 ) ;
693+ assert_eq ! ( opts. max_connections, 500 ) ;
694+
695+ let opts_default = ServerOptions :: default ( ) ;
696+ assert_eq ! ( opts_default. max_connections, 1000 ) ;
697+ }
698+ }
0 commit comments