@@ -1009,4 +1009,46 @@ mod tests {
10091009 let timeout = DfSessionService :: get_statement_timeout ( & client) ;
10101010 assert_eq ! ( timeout, None ) ;
10111011 }
1012+
1013+ struct TestHook ;
1014+
1015+ #[ async_trait]
1016+ impl QueryHook for TestHook {
1017+ async fn handle_query (
1018+ & self ,
1019+ statement : & Statement ,
1020+ _ctx : & SessionContext ,
1021+ _client : & dyn ClientInfo ,
1022+ ) -> Option < PgWireResult < Vec < Response > > > {
1023+ if statement. to_string ( ) . contains ( "magic" ) {
1024+ Some ( Ok ( vec ! [ Response :: EmptyQuery ] ) )
1025+ } else {
1026+ None
1027+ }
1028+ }
1029+ }
1030+
1031+ #[ tokio:: test]
1032+ async fn test_query_hooks ( ) {
1033+ let hook = TestHook ;
1034+ let ctx = SessionContext :: new ( ) ;
1035+ let client = MockClient :: new ( ) ;
1036+
1037+ // Parse a statement that contains "magic"
1038+ let parser = PostgresCompatibilityParser :: new ( ) ;
1039+ let statements = parser. parse ( "SELECT magic" ) . unwrap ( ) ;
1040+ let stmt = Statement :: Statement ( Box :: new ( statements[ 0 ] . clone ( ) ) ) ;
1041+
1042+ // Hook should intercept
1043+ let result = hook. handle_query ( & stmt, & ctx, & client) . await ;
1044+ assert ! ( result. is_some( ) ) ;
1045+
1046+ // Parse a normal statement
1047+ let statements = parser. parse ( "SELECT 1" ) . unwrap ( ) ;
1048+ let stmt = Statement :: Statement ( Box :: new ( statements[ 0 ] . clone ( ) ) ) ;
1049+
1050+ // Hook should not intercept
1051+ let result = hook. handle_query ( & stmt, & ctx, & client) . await ;
1052+ assert ! ( result. is_none( ) ) ;
1053+ }
10121054}
0 commit comments