@@ -23,6 +23,16 @@ use tokio::sync::Mutex;
2323use arrow_pg:: datatypes:: df;
2424use arrow_pg:: datatypes:: { arrow_schema_to_pg_fields, into_pg_type} ;
2525
26+ #[ async_trait]
27+ pub trait QueryHook : Send + Sync {
28+ async fn handle_query (
29+ & self ,
30+ query : & str ,
31+ session_context : & SessionContext ,
32+ client : & dyn ClientInfo ,
33+ ) -> Option < PgWireResult < Vec < Response < ' static > > > > ;
34+ }
35+
2636#[ derive( Debug , Clone , Copy , PartialEq ) ]
2737pub enum TransactionState {
2838 None ,
@@ -44,7 +54,7 @@ pub struct HandlerFactory {
4454impl HandlerFactory {
4555 pub fn new ( session_context : Arc < SessionContext > , auth_manager : Arc < AuthManager > ) -> Self {
4656 let session_service =
47- Arc :: new ( DfSessionService :: new ( session_context, auth_manager. clone ( ) ) ) ;
57+ Arc :: new ( DfSessionService :: new ( session_context, auth_manager. clone ( ) , None ) ) ;
4858 HandlerFactory { session_service }
4959 }
5060}
@@ -70,12 +80,14 @@ pub struct DfSessionService {
7080 timezone : Arc < Mutex < String > > ,
7181 transaction_state : Arc < Mutex < TransactionState > > ,
7282 auth_manager : Arc < AuthManager > ,
83+ query_hook : Option < Arc < dyn QueryHook > > ,
7384}
7485
7586impl DfSessionService {
7687 pub fn new (
7788 session_context : Arc < SessionContext > ,
7889 auth_manager : Arc < AuthManager > ,
90+ query_hook : Option < Arc < dyn QueryHook > > ,
7991 ) -> DfSessionService {
8092 let parser = Arc :: new ( Parser {
8193 session_context : session_context. clone ( ) ,
@@ -86,6 +98,7 @@ impl DfSessionService {
8698 timezone : Arc :: new ( Mutex :: new ( "UTC" . to_string ( ) ) ) ,
8799 transaction_state : Arc :: new ( Mutex :: new ( TransactionState :: None ) ) ,
88100 auth_manager,
101+ query_hook,
89102 }
90103 }
91104
@@ -374,6 +387,13 @@ impl SimpleQueryHandler for DfSessionService {
374387 }
375388 }
376389
390+ // Check query hook first
391+ if let Some ( hook) = & self . query_hook {
392+ if let Some ( result) = hook. handle_query ( query, & self . session_context , client) . await {
393+ return result;
394+ }
395+ }
396+
377397 let df_result = self . session_context . sql ( query) . await ;
378398
379399 // Handle query execution errors and transaction state
0 commit comments