@@ -30,17 +30,14 @@ use arrow_pg::datatypes::{arrow_schema_to_pg_fields, into_pg_type};
3030use datafusion_pg_catalog:: pg_catalog:: context:: { Permission , ResourceType } ;
3131use datafusion_pg_catalog:: sql:: PostgresCompatibilityParser ;
3232
33- /// Statement type represents a parsed SQL query with its logical plan
34- pub type Statement = ( String , Option < LogicalPlan > ) ;
35-
3633#[ async_trait]
3734pub trait QueryHook : Send + Sync {
3835 async fn handle_query (
3936 & self ,
4037 statement : & Statement ,
4138 session_context : & SessionContext ,
4239 client : & dyn ClientInfo ,
43- ) -> Option < PgWireResult < Vec < Response < ' static > > > > ;
40+ ) -> Option < PgWireResult < Vec < Response > > > ;
4441}
4542
4643#[ derive( Debug , Clone , Copy , PartialEq ) ]
@@ -66,8 +63,11 @@ pub struct HandlerFactory {
6663
6764impl HandlerFactory {
6865 pub fn new ( session_context : Arc < SessionContext > , auth_manager : Arc < AuthManager > ) -> Self {
69- let session_service =
70- Arc :: new ( DfSessionService :: new ( session_context, auth_manager. clone ( ) , None ) ) ;
66+ let session_service = Arc :: new ( DfSessionService :: new (
67+ session_context,
68+ auth_manager. clone ( ) ,
69+ None ,
70+ ) ) ;
7171 HandlerFactory { session_service }
7272 }
7373}
@@ -491,26 +491,16 @@ impl SimpleQueryHandler for DfSessionService {
491491 self . check_query_permission ( client, & query) . await ?;
492492 }
493493
494- // Parse query into logical plan for hook
495- if let Some ( hook) = & self . query_hook {
496- // Create logical plan from query
497- let state = self . session_context . state ( ) ;
498- let logical_plan_result = state. create_logical_plan ( query) . await ;
499-
500- if let Ok ( logical_plan) = logical_plan_result {
501- // Optimize the logical plan
502- let optimized_result = state. optimize ( & logical_plan) ;
503-
504- if let Ok ( optimized) = optimized_result {
505- // Create Statement tuple and call hook
506- let statement = ( query. to_string ( ) , optimized) ;
507- if let Some ( result) = hook. handle_query ( & statement, & self . session_context , client) . await {
508- return result;
509- }
494+ // Call query hook with the parsed statement
495+ if let Some ( hook) = & self . query_hook {
496+ let wrapped_statement = Statement :: Statement ( Box :: new ( statement. clone ( ) ) ) ;
497+ if let Some ( result) = hook
498+ . handle_query ( & wrapped_statement, & self . session_context , client)
499+ . await
500+ {
501+ return result;
510502 }
511503 }
512- // If parsing or optimization fails, we'll continue with normal processing
513- }
514504
515505 if let Some ( resp) = self
516506 . try_respond_set_statements ( client, & query_lower)
@@ -578,7 +568,7 @@ impl SimpleQueryHandler for DfSessionService {
578568
579569#[ async_trait]
580570impl ExtendedQueryHandler for DfSessionService {
581- type Statement = Statement ;
571+ type Statement = ( String , Option < LogicalPlan > ) ;
582572 type QueryParser = Parser ;
583573
584574 fn query_parser ( & self ) -> Arc < Self :: QueryParser > {
@@ -656,11 +646,25 @@ impl ExtendedQueryHandler for DfSessionService {
656646
657647 // Check query hook first
658648 if let Some ( hook) = & self . query_hook {
659- if let Some ( result) = hook. handle_query ( & portal. statement . statement , & self . session_context , client) . await {
660- // Convert Vec<Response> to single Response
661- // For extended query, we expect a single response
662- if let Some ( response) = result?. into_iter ( ) . next ( ) {
663- return Ok ( response) ;
649+ // Parse the SQL to get the Statement for the hook
650+ let sql = & portal. statement . statement . 0 ;
651+ let statements = self
652+ . parser
653+ . sql_parser
654+ . parse ( sql)
655+ . map_err ( |e| PgWireError :: ApiError ( Box :: new ( e) ) ) ?;
656+
657+ if let Some ( statement) = statements. into_iter ( ) . next ( ) {
658+ let wrapped_statement = Statement :: Statement ( Box :: new ( statement) ) ;
659+ if let Some ( result) = hook
660+ . handle_query ( & wrapped_statement, & self . session_context , client)
661+ . await
662+ {
663+ // Convert Vec<Response> to single Response
664+ // For extended query, we expect a single response
665+ if let Some ( response) = result?. into_iter ( ) . next ( ) {
666+ return Ok ( response) ;
667+ }
664668 }
665669 }
666670 }
@@ -837,7 +841,7 @@ impl Parser {
837841
838842#[ async_trait]
839843impl QueryParser for Parser {
840- type Statement = Statement ;
844+ type Statement = ( String , Option < LogicalPlan > ) ;
841845
842846 async fn parse_sql < C > (
843847 & self ,
0 commit comments