@@ -566,7 +566,7 @@ impl SimpleQueryHandler for DfSessionService {
566566
567567#[ async_trait]
568568impl ExtendedQueryHandler for DfSessionService {
569- type Statement = ( String , Option < LogicalPlan > ) ;
569+ type Statement = ( String , Option < ( sqlparser :: ast :: Statement , LogicalPlan ) > ) ;
570570 type QueryParser = Parser ;
571571
572572 fn query_parser ( & self ) -> Arc < Self :: QueryParser > {
@@ -581,7 +581,7 @@ impl ExtendedQueryHandler for DfSessionService {
581581 where
582582 C : ClientInfo + Unpin + Send + Sync ,
583583 {
584- if let ( _, Some ( plan) ) = & target. statement {
584+ if let ( _, Some ( ( _ , plan) ) ) = & target. statement {
585585 let schema = plan. schema ( ) ;
586586 let fields = arrow_schema_to_pg_fields ( schema. as_arrow ( ) , & Format :: UnifiedBinary ) ?;
587587 let params = plan
@@ -613,7 +613,7 @@ impl ExtendedQueryHandler for DfSessionService {
613613 where
614614 C : ClientInfo + Unpin + Send + Sync ,
615615 {
616- if let ( _, Some ( plan) ) = & target. statement . statement {
616+ if let ( _, Some ( ( _ , plan) ) ) = & target. statement . statement {
617617 let format = & target. result_column_format ;
618618 let schema = plan. schema ( ) ;
619619 let fields = arrow_schema_to_pg_fields ( schema. as_arrow ( ) , format) ?;
@@ -695,7 +695,7 @@ impl ExtendedQueryHandler for DfSessionService {
695695 ) ) ) ;
696696 }
697697
698- if let ( _, Some ( plan) ) = & portal. statement . statement {
698+ if let ( _, Some ( ( _ , plan) ) ) = & portal. statement . statement {
699699 let param_types = plan
700700 . get_parameter_types ( )
701701 . map_err ( |e| PgWireError :: ApiError ( Box :: new ( e) ) ) ?;
@@ -834,7 +834,7 @@ impl Parser {
834834
835835#[ async_trait]
836836impl QueryParser for Parser {
837- type Statement = ( String , Option < LogicalPlan > ) ;
837+ type Statement = ( String , Option < ( sqlparser :: ast :: Statement , LogicalPlan ) > ) ;
838838
839839 async fn parse_sql < C > (
840840 & self ,
@@ -844,14 +844,6 @@ impl QueryParser for Parser {
844844 ) -> PgWireResult < Self :: Statement > {
845845 log:: debug!( "Received parse extended query: {sql}" ) ; // Log for debugging
846846
847- // Check for transaction commands that shouldn't be parsed by DataFusion
848- if let Some ( plan) = self
849- . try_shortcut_parse_plan ( sql)
850- . map_err ( |e| PgWireError :: ApiError ( Box :: new ( e) ) ) ?
851- {
852- return Ok ( ( sql. to_string ( ) , Some ( plan) ) ) ;
853- }
854-
855847 let mut statements = self
856848 . sql_parser
857849 . parse ( sql)
@@ -862,15 +854,23 @@ impl QueryParser for Parser {
862854
863855 let statement = statements. remove ( 0 ) ;
864856
857+ // Check for transaction commands that shouldn't be parsed by DataFusion
858+ if let Some ( plan) = self
859+ . try_shortcut_parse_plan ( sql)
860+ . map_err ( |e| PgWireError :: ApiError ( Box :: new ( e) ) ) ?
861+ {
862+ return Ok ( ( sql. to_string ( ) , Some ( ( statement, plan) ) ) ) ;
863+ }
864+
865865 let query = statement. to_string ( ) ;
866866
867867 let context = & self . session_context ;
868868 let state = context. state ( ) ;
869869 let logical_plan = state
870- . statement_to_plan ( Statement :: Statement ( Box :: new ( statement) ) )
870+ . statement_to_plan ( Statement :: Statement ( Box :: new ( statement. clone ( ) ) ) )
871871 . await
872872 . map_err ( |e| PgWireError :: ApiError ( Box :: new ( e) ) ) ?;
873- Ok ( ( query, Some ( logical_plan) ) )
873+ Ok ( ( query, Some ( ( statement , logical_plan) ) ) )
874874 }
875875}
876876
0 commit comments