Skip to content

Commit b366428

Browse files
committed
since handle_query only accepts a single statement, only a single response is expected.
1 parent 822d624 commit b366428

File tree

1 file changed

+10
-14
lines changed

1 file changed

+10
-14
lines changed

datafusion-postgres/src/handlers.rs

Lines changed: 10 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ pub trait QueryHook: Send + Sync {
3838
statement: &sqlparser::ast::Statement,
3939
session_context: &SessionContext,
4040
client: &dyn ClientInfo,
41-
) -> Option<PgWireResult<Vec<Response>>>;
41+
) -> Option<PgWireResult<Response>>;
4242
}
4343

4444
// Metadata keys for session-level settings
@@ -495,7 +495,7 @@ impl SimpleQueryHandler for DfSessionService {
495495
.handle_query(&statement, &self.session_context, client)
496496
.await
497497
{
498-
return result;
498+
return result.map(|response| vec![response]);
499499
}
500500
}
501501

@@ -656,11 +656,7 @@ impl ExtendedQueryHandler for DfSessionService {
656656
.handle_query(&statement, &self.session_context, client)
657657
.await
658658
{
659-
// Convert Vec<Response> to single Response
660-
// For extended query, we expect a single response
661-
if let Some(response) = result?.into_iter().next() {
662-
return Ok(response);
663-
}
659+
return result;
664660
}
665661
}
666662
}
@@ -1015,12 +1011,12 @@ mod tests {
10151011
impl QueryHook for TestHook {
10161012
async fn handle_query(
10171013
&self,
1018-
statement: &Statement,
1014+
statement: &sqlparser::ast::Statement,
10191015
_ctx: &SessionContext,
10201016
_client: &dyn ClientInfo,
1021-
) -> Option<PgWireResult<Vec<Response>>> {
1017+
) -> Option<PgWireResult<Response>> {
10221018
if statement.to_string().contains("magic") {
1023-
Some(Ok(vec![Response::EmptyQuery]))
1019+
Some(Ok(Response::EmptyQuery))
10241020
} else {
10251021
None
10261022
}
@@ -1036,18 +1032,18 @@ mod tests {
10361032
// Parse a statement that contains "magic"
10371033
let parser = PostgresCompatibilityParser::new();
10381034
let statements = parser.parse("SELECT magic").unwrap();
1039-
let stmt = Statement::Statement(Box::new(statements[0].clone()));
1035+
let stmt = &statements[0];
10401036

10411037
// Hook should intercept
1042-
let result = hook.handle_query(&stmt, &ctx, &client).await;
1038+
let result = hook.handle_query(stmt, &ctx, &client).await;
10431039
assert!(result.is_some());
10441040

10451041
// Parse a normal statement
10461042
let statements = parser.parse("SELECT 1").unwrap();
1047-
let stmt = Statement::Statement(Box::new(statements[0].clone()));
1043+
let stmt = &statements[0];
10481044

10491045
// Hook should not intercept
1050-
let result = hook.handle_query(&stmt, &ctx, &client).await;
1046+
let result = hook.handle_query(stmt, &ctx, &client).await;
10511047
assert!(result.is_none());
10521048
}
10531049
}

0 commit comments

Comments
 (0)