@@ -376,21 +376,22 @@ impl SimpleQueryHandler for DfSessionService {
376376
377377 // Add query timeout for simple queries
378378 let query_timeout = std:: time:: Duration :: from_secs ( 60 ) ; // 60 seconds
379- let df_result = match tokio:: time:: timeout (
380- query_timeout,
381- self . session_context . sql ( query)
382- ) . await {
383- Ok ( result) => result,
384- Err ( _) => {
385- return Err ( PgWireError :: UserError ( Box :: new (
386- pgwire:: error:: ErrorInfo :: new (
387- "ERROR" . to_string ( ) ,
388- "57014" . to_string ( ) , // PostgreSQL query_canceled error code
389- format ! ( "Query execution timeout after {} seconds" , query_timeout. as_secs( ) ) ,
390- ) ,
391- ) ) ) ;
392- }
393- } ;
379+ let df_result =
380+ match tokio:: time:: timeout ( query_timeout, self . session_context . sql ( query) ) . await {
381+ Ok ( result) => result,
382+ Err ( _) => {
383+ return Err ( PgWireError :: UserError ( Box :: new (
384+ pgwire:: error:: ErrorInfo :: new (
385+ "ERROR" . to_string ( ) ,
386+ "57014" . to_string ( ) , // PostgreSQL query_canceled error code
387+ format ! (
388+ "Query execution timeout after {} seconds" ,
389+ query_timeout. as_secs( )
390+ ) ,
391+ ) ,
392+ ) ) ) ;
393+ }
394+ } ;
394395
395396 // Handle query execution errors and transaction state
396397 let df = match df_result {
@@ -530,40 +531,47 @@ impl ExtendedQueryHandler for DfSessionService {
530531 Ok ( types) => types,
531532 Err ( e) => {
532533 let error_msg = e. to_string ( ) ;
533- if error_msg. contains ( "Cannot get result type for arithmetic operation Null + Null" )
534- || error_msg. contains ( "Invalid arithmetic operation: Null + Null" ) {
534+ if error_msg. contains ( "Cannot get result type for arithmetic operation Null + Null" )
535+ || error_msg. contains ( "Invalid arithmetic operation: Null + Null" )
536+ {
535537 // Fallback: assume all parameters are integers for arithmetic operations
536538 log:: warn!( "DataFusion type inference failed for arithmetic operation, using integer fallback" ) ;
537539 let param_count = portal. statement . parameter_types . len ( ) ;
538- std:: collections:: HashMap :: from_iter (
539- ( 0 ..param_count) . map ( |i| ( format ! ( "${}" , i + 1 ) , Some ( datafusion:: arrow:: datatypes:: DataType :: Int32 ) ) )
540- )
540+ std:: collections:: HashMap :: from_iter ( ( 0 ..param_count) . map ( |i| {
541+ (
542+ format ! ( "${}" , i + 1 ) ,
543+ Some ( datafusion:: arrow:: datatypes:: DataType :: Int32 ) ,
544+ )
545+ } ) )
541546 } else {
542547 return Err ( PgWireError :: ApiError ( Box :: new ( e) ) ) ;
543548 }
544549 }
545550 } ;
546551 let param_values = df:: deserialize_parameters ( portal, & ordered_param_types ( & param_types) ) ?; // Fixed: Use ¶m_types
547-
552+
548553 // Replace parameters with values, with automatic retry for type inference failures
549554 let plan = match plan. clone ( ) . replace_params_with_values ( & param_values) {
550555 Ok ( plan) => plan,
551556 Err ( e) => {
552557 let error_msg = e. to_string ( ) ;
553- if error_msg. contains ( "Cannot get result type for arithmetic operation Null + Null" )
554- || error_msg. contains ( "Invalid arithmetic operation: Null + Null" ) {
555- log:: info!( "Retrying query with enhanced type casting for arithmetic operations" ) ;
556-
558+ if error_msg. contains ( "Cannot get result type for arithmetic operation Null + Null" )
559+ || error_msg. contains ( "Invalid arithmetic operation: Null + Null" )
560+ {
561+ log:: info!(
562+ "Retrying query with enhanced type casting for arithmetic operations"
563+ ) ;
564+
557565 // Attempt to reparse the query with explicit type casting
558566 let original_query = & portal. statement . statement . 0 ;
559567 let enhanced_query = enhance_query_with_type_casting ( original_query) ;
560-
568+
561569 // Try to create a new plan with the enhanced query
562570 match self . session_context . sql ( & enhanced_query) . await {
563571 Ok ( new_plan_df) => {
564572 // Get the logical plan from the new dataframe
565573 let new_plan = new_plan_df. logical_plan ( ) . clone ( ) ;
566-
574+
567575 // Try parameter substitution again with the new plan
568576 match new_plan. replace_params_with_values ( & param_values) {
569577 Ok ( final_plan) => final_plan,
@@ -599,15 +607,20 @@ impl ExtendedQueryHandler for DfSessionService {
599607 let query_timeout = std:: time:: Duration :: from_secs ( 60 ) ; // 60 seconds
600608 let dataframe = match tokio:: time:: timeout (
601609 query_timeout,
602- self . session_context . execute_logical_plan ( plan)
603- ) . await {
610+ self . session_context . execute_logical_plan ( plan) ,
611+ )
612+ . await
613+ {
604614 Ok ( result) => result. map_err ( |e| PgWireError :: ApiError ( Box :: new ( e) ) ) ?,
605615 Err ( _) => {
606616 return Err ( PgWireError :: UserError ( Box :: new (
607617 pgwire:: error:: ErrorInfo :: new (
608618 "ERROR" . to_string ( ) ,
609619 "57014" . to_string ( ) , // PostgreSQL query_canceled error code
610- format ! ( "Query execution timeout after {} seconds" , query_timeout. as_secs( ) ) ,
620+ format ! (
621+ "Query execution timeout after {} seconds" ,
622+ query_timeout. as_secs( )
623+ ) ,
611624 ) ,
612625 ) ) ) ;
613626 }
@@ -649,19 +662,19 @@ impl QueryParser for Parser {
649662/// This helps DataFusion's type inference when it encounters ambiguous parameter types
650663fn enhance_query_with_type_casting ( query : & str ) -> String {
651664 use regex:: Regex ;
652-
665+
653666 // Pattern to match arithmetic operations with parameters: $1 + $2, $1 - $2, etc.
654667 let arithmetic_pattern = Regex :: new ( r"\$(\d+)\s*([+\-*/])\s*\$(\d+)" ) . unwrap ( ) ;
655-
668+
656669 // Replace untyped parameters in arithmetic operations with integer-cast parameters
657670 let enhanced = arithmetic_pattern. replace_all ( query, "$$$1::integer $2 $$$3::integer" ) ;
658-
671+
659672 // Pattern to match single parameters in potentially ambiguous contexts
660673 let single_param_pattern = Regex :: new ( r"\$(\d+)(?!::)(?=\s*[+\-*/=<>]|\s*\))" ) . unwrap ( ) ;
661-
674+
662675 // Add integer casting to remaining untyped parameters in arithmetic contexts
663676 let enhanced = single_param_pattern. replace_all ( & enhanced, "$$$1::integer" ) ;
664-
677+
665678 log:: debug!( "Enhanced query: {} -> {}" , query, enhanced) ;
666679 enhanced. to_string ( )
667680}
0 commit comments