@@ -38,7 +38,7 @@ pub trait QueryHook: Send + Sync {
3838 statement : & sqlparser:: ast:: Statement ,
3939 session_context : & SessionContext ,
4040 client : & dyn ClientInfo ,
41- ) -> Option < PgWireResult < Vec < Response > > > ;
41+ ) -> Option < PgWireResult < Response > > ;
4242}
4343
4444// Metadata keys for session-level settings
@@ -495,7 +495,7 @@ impl SimpleQueryHandler for DfSessionService {
495495 . handle_query ( & statement, & self . session_context , client)
496496 . await
497497 {
498- return result;
498+ return result. map ( |response| vec ! [ response ] ) ;
499499 }
500500 }
501501
@@ -656,11 +656,7 @@ impl ExtendedQueryHandler for DfSessionService {
656656 . handle_query ( & statement, & self . session_context , client)
657657 . await
658658 {
659- // Convert Vec<Response> to single Response
660- // For extended query, we expect a single response
661- if let Some ( response) = result?. into_iter ( ) . next ( ) {
662- return Ok ( response) ;
663- }
659+ return result;
664660 }
665661 }
666662 }
@@ -1015,12 +1011,12 @@ mod tests {
10151011 impl QueryHook for TestHook {
10161012 async fn handle_query (
10171013 & self ,
1018- statement : & Statement ,
1014+ statement : & sqlparser :: ast :: Statement ,
10191015 _ctx : & SessionContext ,
10201016 _client : & dyn ClientInfo ,
1021- ) -> Option < PgWireResult < Vec < Response > > > {
1017+ ) -> Option < PgWireResult < Response > > {
10221018 if statement. to_string ( ) . contains ( "magic" ) {
1023- Some ( Ok ( vec ! [ Response :: EmptyQuery ] ) )
1019+ Some ( Ok ( Response :: EmptyQuery ) )
10241020 } else {
10251021 None
10261022 }
@@ -1036,18 +1032,18 @@ mod tests {
10361032 // Parse a statement that contains "magic"
10371033 let parser = PostgresCompatibilityParser :: new ( ) ;
10381034 let statements = parser. parse ( "SELECT magic" ) . unwrap ( ) ;
1039- let stmt = Statement :: Statement ( Box :: new ( statements[ 0 ] . clone ( ) ) ) ;
1035+ let stmt = & statements[ 0 ] ;
10401036
10411037 // Hook should intercept
1042- let result = hook. handle_query ( & stmt, & ctx, & client) . await ;
1038+ let result = hook. handle_query ( stmt, & ctx, & client) . await ;
10431039 assert ! ( result. is_some( ) ) ;
10441040
10451041 // Parse a normal statement
10461042 let statements = parser. parse ( "SELECT 1" ) . unwrap ( ) ;
1047- let stmt = Statement :: Statement ( Box :: new ( statements[ 0 ] . clone ( ) ) ) ;
1043+ let stmt = & statements[ 0 ] ;
10481044
10491045 // Hook should not intercept
1050- let result = hook. handle_query ( & stmt, & ctx, & client) . await ;
1046+ let result = hook. handle_query ( stmt, & ctx, & client) . await ;
10511047 assert ! ( result. is_none( ) ) ;
10521048 }
10531049}
0 commit comments