@@ -27,9 +27,20 @@ use tokio::sync::Mutex;
2727use crate :: auth:: AuthManager ;
2828use arrow_pg:: datatypes:: df;
2929use arrow_pg:: datatypes:: { arrow_schema_to_pg_fields, into_pg_type} ;
30+ use datafusion:: sql:: sqlparser;
3031use datafusion_pg_catalog:: pg_catalog:: context:: { Permission , ResourceType } ;
3132use datafusion_pg_catalog:: sql:: PostgresCompatibilityParser ;
3233
34+ #[ async_trait]
35+ pub trait QueryHook : Send + Sync {
36+ async fn handle_query (
37+ & self ,
38+ statement : & sqlparser:: ast:: Statement ,
39+ session_context : & SessionContext ,
40+ client : & dyn ClientInfo ,
41+ ) -> Option < PgWireResult < Response > > ;
42+ }
43+
3344// Metadata keys for session-level settings
3445const METADATA_STATEMENT_TIMEOUT : & str = "statement_timeout_ms" ;
3546
@@ -45,9 +56,16 @@ pub struct HandlerFactory {
4556}
4657
4758impl HandlerFactory {
48- pub fn new ( session_context : Arc < SessionContext > , auth_manager : Arc < AuthManager > ) -> Self {
49- let session_service =
50- Arc :: new ( DfSessionService :: new ( session_context, auth_manager. clone ( ) ) ) ;
59+ pub fn new (
60+ session_context : Arc < SessionContext > ,
61+ auth_manager : Arc < AuthManager > ,
62+ query_hooks : Vec < Arc < dyn QueryHook > > ,
63+ ) -> Self {
64+ let session_service = Arc :: new ( DfSessionService :: new (
65+ session_context,
66+ auth_manager. clone ( ) ,
67+ query_hooks,
68+ ) ) ;
5169 HandlerFactory { session_service }
5270 }
5371}
@@ -87,12 +105,14 @@ pub struct DfSessionService {
87105 parser : Arc < Parser > ,
88106 timezone : Arc < Mutex < String > > ,
89107 auth_manager : Arc < AuthManager > ,
108+ query_hooks : Vec < Arc < dyn QueryHook > > ,
90109}
91110
92111impl DfSessionService {
93112 pub fn new (
94113 session_context : Arc < SessionContext > ,
95114 auth_manager : Arc < AuthManager > ,
115+ query_hooks : Vec < Arc < dyn QueryHook > > ,
96116 ) -> DfSessionService {
97117 let parser = Arc :: new ( Parser {
98118 session_context : session_context. clone ( ) ,
@@ -103,6 +123,7 @@ impl DfSessionService {
103123 parser,
104124 timezone : Arc :: new ( Mutex :: new ( "UTC" . to_string ( ) ) ) ,
105125 auth_manager,
126+ query_hooks,
106127 }
107128 }
108129
@@ -468,6 +489,16 @@ impl SimpleQueryHandler for DfSessionService {
468489 self . check_query_permission ( client, & query) . await ?;
469490 }
470491
492+ // Call query hooks with the parsed statement
493+ for hook in & self . query_hooks {
494+ if let Some ( result) = hook
495+ . handle_query ( & statement, & self . session_context , client)
496+ . await
497+ {
498+ return result. map ( |response| vec ! [ response] ) ;
499+ }
500+ }
501+
471502 if let Some ( resp) = self
472503 . try_respond_set_statements ( client, & query_lower)
473504 . await ?
@@ -610,6 +641,26 @@ impl ExtendedQueryHandler for DfSessionService {
610641 . to_string ( ) ;
611642 log:: debug!( "Received execute extended query: {query}" ) ; // Log for debugging
612643
644+ // Check query hooks first
645+ for hook in & self . query_hooks {
646+ // Parse the SQL to get the Statement for the hook
647+ let sql = & portal. statement . statement . 0 ;
648+ let statements = self
649+ . parser
650+ . sql_parser
651+ . parse ( sql)
652+ . map_err ( |e| PgWireError :: ApiError ( Box :: new ( e) ) ) ?;
653+
654+ if let Some ( statement) = statements. into_iter ( ) . next ( ) {
655+ if let Some ( result) = hook
656+ . handle_query ( & statement, & self . session_context , client)
657+ . await
658+ {
659+ return result;
660+ }
661+ }
662+ }
663+
613664 // Check permissions for the query (skip for SET and SHOW statements)
614665 if !query. starts_with ( "set" ) && !query. starts_with ( "show" ) {
615666 self . check_query_permission ( client, & portal. statement . statement . 0 )
@@ -909,7 +960,7 @@ mod tests {
909960 async fn test_statement_timeout_set_and_show ( ) {
910961 let session_context = Arc :: new ( SessionContext :: new ( ) ) ;
911962 let auth_manager = Arc :: new ( AuthManager :: new ( ) ) ;
912- let service = DfSessionService :: new ( session_context, auth_manager) ;
963+ let service = DfSessionService :: new ( session_context, auth_manager, vec ! [ ] ) ;
913964 let mut client = MockClient :: new ( ) ;
914965
915966 // Test setting timeout to 5000ms
@@ -935,7 +986,7 @@ mod tests {
935986 async fn test_statement_timeout_disable ( ) {
936987 let session_context = Arc :: new ( SessionContext :: new ( ) ) ;
937988 let auth_manager = Arc :: new ( AuthManager :: new ( ) ) ;
938- let service = DfSessionService :: new ( session_context, auth_manager) ;
989+ let service = DfSessionService :: new ( session_context, auth_manager, vec ! [ ] ) ;
939990 let mut client = MockClient :: new ( ) ;
940991
941992 // Set timeout first
@@ -953,4 +1004,46 @@ mod tests {
9531004 let timeout = DfSessionService :: get_statement_timeout ( & client) ;
9541005 assert_eq ! ( timeout, None ) ;
9551006 }
1007+
1008+ struct TestHook ;
1009+
1010+ #[ async_trait]
1011+ impl QueryHook for TestHook {
1012+ async fn handle_query (
1013+ & self ,
1014+ statement : & sqlparser:: ast:: Statement ,
1015+ _ctx : & SessionContext ,
1016+ _client : & dyn ClientInfo ,
1017+ ) -> Option < PgWireResult < Response > > {
1018+ if statement. to_string ( ) . contains ( "magic" ) {
1019+ Some ( Ok ( Response :: EmptyQuery ) )
1020+ } else {
1021+ None
1022+ }
1023+ }
1024+ }
1025+
1026+ #[ tokio:: test]
1027+ async fn test_query_hooks ( ) {
1028+ let hook = TestHook ;
1029+ let ctx = SessionContext :: new ( ) ;
1030+ let client = MockClient :: new ( ) ;
1031+
1032+ // Parse a statement that contains "magic"
1033+ let parser = PostgresCompatibilityParser :: new ( ) ;
1034+ let statements = parser. parse ( "SELECT magic" ) . unwrap ( ) ;
1035+ let stmt = & statements[ 0 ] ;
1036+
1037+ // Hook should intercept
1038+ let result = hook. handle_query ( stmt, & ctx, & client) . await ;
1039+ assert ! ( result. is_some( ) ) ;
1040+
1041+ // Parse a normal statement
1042+ let statements = parser. parse ( "SELECT 1" ) . unwrap ( ) ;
1043+ let stmt = & statements[ 0 ] ;
1044+
1045+ // Hook should not intercept
1046+ let result = hook. handle_query ( stmt, & ctx, & client) . await ;
1047+ assert ! ( result. is_none( ) ) ;
1048+ }
9561049}
0 commit comments