@@ -374,7 +374,23 @@ impl SimpleQueryHandler for DfSessionService {
374374 }
375375 }
376376
377- let df_result = self . session_context . sql ( query) . await ;
377+ // Add query timeout for simple queries
378+ let query_timeout = std:: time:: Duration :: from_secs ( 60 ) ; // 60 seconds
379+ let df_result = match tokio:: time:: timeout (
380+ query_timeout,
381+ self . session_context . sql ( query)
382+ ) . await {
383+ Ok ( result) => result,
384+ Err ( _) => {
385+ return Err ( PgWireError :: UserError ( Box :: new (
386+ pgwire:: error:: ErrorInfo :: new (
387+ "ERROR" . to_string ( ) ,
388+ "57014" . to_string ( ) , // PostgreSQL query_canceled error code
389+ format ! ( "Query execution timeout after {} seconds" , query_timeout. as_secs( ) ) ,
390+ ) ,
391+ ) ) ) ;
392+ }
393+ } ;
378394
379395 // Handle query execution errors and transaction state
380396 let df = match df_result {
@@ -509,19 +525,93 @@ impl ExtendedQueryHandler for DfSessionService {
509525
510526 let ( _, plan) = & portal. statement . statement ;
511527
512- let param_types = plan
513- . get_parameter_types ( )
514- . map_err ( |e| PgWireError :: ApiError ( Box :: new ( e) ) ) ?;
528+ // Enhanced parameter type inference with fallback for NULL + NULL scenarios
529+ let param_types = match plan. get_parameter_types ( ) {
530+ Ok ( types) => types,
531+ Err ( e) => {
532+ let error_msg = e. to_string ( ) ;
533+ if error_msg. contains ( "Cannot get result type for arithmetic operation Null + Null" )
534+ || error_msg. contains ( "Invalid arithmetic operation: Null + Null" ) {
535+ // Fallback: assume all parameters are integers for arithmetic operations
536+ log:: warn!( "DataFusion type inference failed for arithmetic operation, using integer fallback" ) ;
537+ let param_count = portal. statement . parameter_types . len ( ) ;
538+ std:: collections:: HashMap :: from_iter (
539+ ( 0 ..param_count) . map ( |i| ( format ! ( "${}" , i + 1 ) , Some ( datafusion:: arrow:: datatypes:: DataType :: Int32 ) ) )
540+ )
541+ } else {
542+ return Err ( PgWireError :: ApiError ( Box :: new ( e) ) ) ;
543+ }
544+ }
545+ } ;
515546 let param_values = df:: deserialize_parameters ( portal, & ordered_param_types ( & param_types) ) ?; // Fixed: Use ¶m_types
516- let plan = plan
517- . clone ( )
518- . replace_params_with_values ( & param_values)
519- . map_err ( |e| PgWireError :: ApiError ( Box :: new ( e) ) ) ?; // Fixed: Use ¶m_values
520- let dataframe = self
521- . session_context
522- . execute_logical_plan ( plan)
523- . await
524- . map_err ( |e| PgWireError :: ApiError ( Box :: new ( e) ) ) ?;
547+
548+ // Replace parameters with values, with automatic retry for type inference failures
549+ let plan = match plan. clone ( ) . replace_params_with_values ( & param_values) {
550+ Ok ( plan) => plan,
551+ Err ( e) => {
552+ let error_msg = e. to_string ( ) ;
553+ if error_msg. contains ( "Cannot get result type for arithmetic operation Null + Null" )
554+ || error_msg. contains ( "Invalid arithmetic operation: Null + Null" ) {
555+ log:: info!( "Retrying query with enhanced type casting for arithmetic operations" ) ;
556+
557+ // Attempt to reparse the query with explicit type casting
558+ let original_query = & portal. statement . statement . 0 ;
559+ let enhanced_query = enhance_query_with_type_casting ( original_query) ;
560+
561+ // Try to create a new plan with the enhanced query
562+ match self . session_context . sql ( & enhanced_query) . await {
563+ Ok ( new_plan_df) => {
564+ // Get the logical plan from the new dataframe
565+ let new_plan = new_plan_df. logical_plan ( ) . clone ( ) ;
566+
567+ // Try parameter substitution again with the new plan
568+ match new_plan. replace_params_with_values ( & param_values) {
569+ Ok ( final_plan) => final_plan,
570+ Err ( _) => {
571+ // If it still fails, return helpful error message
572+ return Err ( PgWireError :: UserError ( Box :: new (
573+ pgwire:: error:: ErrorInfo :: new (
574+ "ERROR" . to_string ( ) ,
575+ "42804" . to_string ( ) ,
576+ "Cannot infer parameter types for arithmetic operation. Please use explicit type casting like $1::integer + $2::integer" . to_string ( ) ,
577+ ) ,
578+ ) ) ) ;
579+ }
580+ }
581+ }
582+ Err ( _) => {
583+ // If enhanced query fails, return helpful error message
584+ return Err ( PgWireError :: UserError ( Box :: new (
585+ pgwire:: error:: ErrorInfo :: new (
586+ "ERROR" . to_string ( ) ,
587+ "42804" . to_string ( ) ,
588+ "Cannot infer parameter types for arithmetic operation. Please use explicit type casting like $1::integer + $2::integer" . to_string ( ) ,
589+ ) ,
590+ ) ) ) ;
591+ }
592+ }
593+ } else {
594+ return Err ( PgWireError :: ApiError ( Box :: new ( e) ) ) ;
595+ }
596+ }
597+ } ;
598+ // Add query timeout to prevent long-running queries from hanging connections
599+ let query_timeout = std:: time:: Duration :: from_secs ( 60 ) ; // 60 seconds
600+ let dataframe = match tokio:: time:: timeout (
601+ query_timeout,
602+ self . session_context . execute_logical_plan ( plan)
603+ ) . await {
604+ Ok ( result) => result. map_err ( |e| PgWireError :: ApiError ( Box :: new ( e) ) ) ?,
605+ Err ( _) => {
606+ return Err ( PgWireError :: UserError ( Box :: new (
607+ pgwire:: error:: ErrorInfo :: new (
608+ "ERROR" . to_string ( ) ,
609+ "57014" . to_string ( ) , // PostgreSQL query_canceled error code
610+ format ! ( "Query execution timeout after {} seconds" , query_timeout. as_secs( ) ) ,
611+ ) ,
612+ ) ) ) ;
613+ }
614+ } ;
525615 let resp = df:: encode_dataframe ( dataframe, & portal. result_column_format ) . await ?;
526616 Ok ( Response :: Query ( resp) )
527617 }
@@ -555,6 +645,27 @@ impl QueryParser for Parser {
555645 }
556646}
557647
648+ /// Enhance a SQL query by adding type casting to parameters in arithmetic operations
649+ /// This helps DataFusion's type inference when it encounters ambiguous parameter types
650+ fn enhance_query_with_type_casting ( query : & str ) -> String {
651+ use regex:: Regex ;
652+
653+ // Pattern to match arithmetic operations with parameters: $1 + $2, $1 - $2, etc.
654+ let arithmetic_pattern = Regex :: new ( r"\$(\d+)\s*([+\-*/])\s*\$(\d+)" ) . unwrap ( ) ;
655+
656+ // Replace untyped parameters in arithmetic operations with integer-cast parameters
657+ let enhanced = arithmetic_pattern. replace_all ( query, "$$$1::integer $2 $$$3::integer" ) ;
658+
659+ // Pattern to match single parameters in potentially ambiguous contexts
660+ let single_param_pattern = Regex :: new ( r"\$(\d+)(?!::)(?=\s*[+\-*/=<>]|\s*\))" ) . unwrap ( ) ;
661+
662+ // Add integer casting to remaining untyped parameters in arithmetic contexts
663+ let enhanced = single_param_pattern. replace_all ( & enhanced, "$$$1::integer" ) ;
664+
665+ log:: debug!( "Enhanced query: {} -> {}" , query, enhanced) ;
666+ enhanced. to_string ( )
667+ }
668+
558669fn ordered_param_types ( types : & HashMap < String , Option < DataType > > ) -> Vec < Option < & DataType > > {
559670 // Datafusion stores the parameters as a map. In our case, the keys will be
560671 // `$1`, `$2` etc. The values will be the parameter types.
0 commit comments