@@ -3,7 +3,7 @@ use std::sync::Arc;
33
44use async_trait:: async_trait;
55use datafusion:: arrow:: datatypes:: { DataType , Field , Schema } ;
6- use datafusion:: common:: ToDFSchema ;
6+ use datafusion:: common:: { ParamValues , ToDFSchema } ;
77use datafusion:: error:: DataFusionError ;
88use datafusion:: logical_expr:: LogicalPlan ;
99use datafusion:: prelude:: * ;
@@ -33,11 +33,22 @@ use datafusion_pg_catalog::sql::PostgresCompatibilityParser;
3333
3434#[ async_trait]
3535pub trait QueryHook : Send + Sync {
36- async fn handle_query (
36+ /// called in simple query handler to return response directly
37+ async fn handle_simple_query (
3738 & self ,
3839 statement : & sqlparser:: ast:: Statement ,
3940 session_context : & SessionContext ,
40- client : & dyn ClientInfo ,
41+ client : & ( dyn ClientInfo + Send + Sync ) ,
42+ ) -> Option < PgWireResult < Response > > ;
43+
44+ /// called at extended query execute phase, for query execution
45+ async fn handle_extended_query (
46+ & self ,
47+ statement : & sqlparser:: ast:: Statement ,
48+ logical_plan : & LogicalPlan ,
49+ params : & ParamValues ,
50+ session_context : & SessionContext ,
51+ client : & ( dyn ClientInfo + Send + Sync ) ,
4152 ) -> Option < PgWireResult < Response > > ;
4253}
4354
@@ -492,7 +503,7 @@ impl SimpleQueryHandler for DfSessionService {
492503 // Call query hooks with the parsed statement
493504 for hook in & self . query_hooks {
494505 if let Some ( result) = hook
495- . handle_query ( & statement, & self . session_context , client)
506+ . handle_simple_query ( & statement, & self . session_context , client)
496507 . await
497508 {
498509 results. push ( result?) ;
@@ -643,21 +654,29 @@ impl ExtendedQueryHandler for DfSessionService {
643654 log:: debug!( "Received execute extended query: {query}" ) ; // Log for debugging
644655
645656 // Check query hooks first
646- for hook in & self . query_hooks {
647- // Parse the SQL to get the Statement for the hook
648- let sql = & portal. statement . statement . 0 ;
649- let statements = self
650- . parser
651- . sql_parser
652- . parse ( sql)
653- . map_err ( |e| PgWireError :: ApiError ( Box :: new ( e) ) ) ?;
654-
655- if let Some ( statement) = statements. into_iter ( ) . next ( ) {
656- if let Some ( result) = hook
657- . handle_query ( & statement, & self . session_context , client)
658- . await
659- {
660- return result;
657+ if !self . query_hooks . is_empty ( ) {
658+ if let ( _, Some ( ( statement, plan) ) ) = & portal. statement . statement {
659+ // TODO: in the case where query hooks all return None, we do the param handling again later.
660+ let param_types = plan
661+ . get_parameter_types ( )
662+ . map_err ( |e| PgWireError :: ApiError ( Box :: new ( e) ) ) ?;
663+
664+ let param_values: ParamValues =
665+ df:: deserialize_parameters ( portal, & ordered_param_types ( & param_types) ) ?;
666+
667+ for hook in & self . query_hooks {
668+ if let Some ( result) = hook
669+ . handle_extended_query (
670+ statement,
671+ plan,
672+ & param_values,
673+ & self . session_context ,
674+ client,
675+ )
676+ . await
677+ {
678+ return result;
679+ }
661680 }
662681 }
663682 }
@@ -1010,18 +1029,29 @@ mod tests {
10101029
10111030 #[ async_trait]
10121031 impl QueryHook for TestHook {
1013- async fn handle_query (
1032+ async fn handle_simple_query (
10141033 & self ,
10151034 statement : & sqlparser:: ast:: Statement ,
10161035 _ctx : & SessionContext ,
1017- _client : & dyn ClientInfo ,
1036+ _client : & ( dyn ClientInfo + Sync + Send ) ,
10181037 ) -> Option < PgWireResult < Response > > {
10191038 if statement. to_string ( ) . contains ( "magic" ) {
10201039 Some ( Ok ( Response :: EmptyQuery ) )
10211040 } else {
10221041 None
10231042 }
10241043 }
1044+
1045+ async fn handle_extended_query (
1046+ & self ,
1047+ _statement : & sqlparser:: ast:: Statement ,
1048+ _logical_plan : & LogicalPlan ,
1049+ _params : & ParamValues ,
1050+ _session_context : & SessionContext ,
1051+ _client : & ( dyn ClientInfo + Send + Sync ) ,
1052+ ) -> Option < PgWireResult < Response > > {
1053+ todo ! ( ) ;
1054+ }
10251055 }
10261056
10271057 #[ tokio:: test]
@@ -1036,15 +1066,15 @@ mod tests {
10361066 let stmt = & statements[ 0 ] ;
10371067
10381068 // Hook should intercept
1039- let result = hook. handle_query ( stmt, & ctx, & client) . await ;
1069+ let result = hook. handle_simple_query ( stmt, & ctx, & client) . await ;
10401070 assert ! ( result. is_some( ) ) ;
10411071
10421072 // Parse a normal statement
10431073 let statements = parser. parse ( "SELECT 1" ) . unwrap ( ) ;
10441074 let stmt = & statements[ 0 ] ;
10451075
10461076 // Hook should not intercept
1047- let result = hook. handle_query ( stmt, & ctx, & client) . await ;
1077+ let result = hook. handle_simple_query ( stmt, & ctx, & client) . await ;
10481078 assert ! ( result. is_none( ) ) ;
10491079 }
10501080}
0 commit comments