Skip to content

Commit 18807d7

Browse files
committed
Add extended query support to QueryHook
Add extended query support to QueryHook,. Also re-use the parsed statement instead of parsing it again because it is was wasted work, and it also made the code less clean. (We still potentially handle the params more than once, which can be fixed later) Finally, ensure the client is Send + Sync. (Should this be enforced in pgwire?)
1 parent 259eb55 commit 18807d7

File tree

1 file changed

+53
-23
lines changed

1 file changed

+53
-23
lines changed

datafusion-postgres/src/handlers.rs

Lines changed: 53 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ use std::sync::Arc;
33

44
use async_trait::async_trait;
55
use datafusion::arrow::datatypes::{DataType, Field, Schema};
6-
use datafusion::common::ToDFSchema;
6+
use datafusion::common::{ParamValues, ToDFSchema};
77
use datafusion::error::DataFusionError;
88
use datafusion::logical_expr::LogicalPlan;
99
use datafusion::prelude::*;
@@ -33,11 +33,22 @@ use datafusion_pg_catalog::sql::PostgresCompatibilityParser;
3333

3434
#[async_trait]
3535
pub trait QueryHook: Send + Sync {
36-
async fn handle_query(
36+
/// called in simple query handler to return response directly
37+
async fn handle_simple_query(
3738
&self,
3839
statement: &sqlparser::ast::Statement,
3940
session_context: &SessionContext,
40-
client: &dyn ClientInfo,
41+
client: &(dyn ClientInfo + Send + Sync),
42+
) -> Option<PgWireResult<Response>>;
43+
44+
/// called at extended query execute phase, for query execution
45+
async fn handle_extended_query(
46+
&self,
47+
statement: &sqlparser::ast::Statement,
48+
logical_plan: &LogicalPlan,
49+
params: &ParamValues,
50+
session_context: &SessionContext,
51+
client: &(dyn ClientInfo + Send + Sync),
4152
) -> Option<PgWireResult<Response>>;
4253
}
4354

@@ -492,7 +503,7 @@ impl SimpleQueryHandler for DfSessionService {
492503
// Call query hooks with the parsed statement
493504
for hook in &self.query_hooks {
494505
if let Some(result) = hook
495-
.handle_query(&statement, &self.session_context, client)
506+
.handle_simple_query(&statement, &self.session_context, client)
496507
.await
497508
{
498509
results.push(result?);
@@ -643,21 +654,29 @@ impl ExtendedQueryHandler for DfSessionService {
643654
log::debug!("Received execute extended query: {query}"); // Log for debugging
644655

645656
// Check query hooks first
646-
for hook in &self.query_hooks {
647-
// Parse the SQL to get the Statement for the hook
648-
let sql = &portal.statement.statement.0;
649-
let statements = self
650-
.parser
651-
.sql_parser
652-
.parse(sql)
653-
.map_err(|e| PgWireError::ApiError(Box::new(e)))?;
654-
655-
if let Some(statement) = statements.into_iter().next() {
656-
if let Some(result) = hook
657-
.handle_query(&statement, &self.session_context, client)
658-
.await
659-
{
660-
return result;
657+
if !self.query_hooks.is_empty() {
658+
if let (_, Some((statement, plan))) = &portal.statement.statement {
659+
// TODO: in the case where query hooks all return None, we do the param handling again later.
660+
let param_types = plan
661+
.get_parameter_types()
662+
.map_err(|e| PgWireError::ApiError(Box::new(e)))?;
663+
664+
let param_values: ParamValues =
665+
df::deserialize_parameters(portal, &ordered_param_types(&param_types))?;
666+
667+
for hook in &self.query_hooks {
668+
if let Some(result) = hook
669+
.handle_extended_query(
670+
statement,
671+
plan,
672+
&param_values,
673+
&self.session_context,
674+
client,
675+
)
676+
.await
677+
{
678+
return result;
679+
}
661680
}
662681
}
663682
}
@@ -1010,18 +1029,29 @@ mod tests {
10101029

10111030
#[async_trait]
10121031
impl QueryHook for TestHook {
1013-
async fn handle_query(
1032+
async fn handle_simple_query(
10141033
&self,
10151034
statement: &sqlparser::ast::Statement,
10161035
_ctx: &SessionContext,
1017-
_client: &dyn ClientInfo,
1036+
_client: &(dyn ClientInfo + Sync + Send),
10181037
) -> Option<PgWireResult<Response>> {
10191038
if statement.to_string().contains("magic") {
10201039
Some(Ok(Response::EmptyQuery))
10211040
} else {
10221041
None
10231042
}
10241043
}
1044+
1045+
async fn handle_extended_query(
1046+
&self,
1047+
_statement: &sqlparser::ast::Statement,
1048+
_logical_plan: &LogicalPlan,
1049+
_params: &ParamValues,
1050+
_session_context: &SessionContext,
1051+
_client: &(dyn ClientInfo + Send + Sync),
1052+
) -> Option<PgWireResult<Response>> {
1053+
todo!();
1054+
}
10251055
}
10261056

10271057
#[tokio::test]
@@ -1036,15 +1066,15 @@ mod tests {
10361066
let stmt = &statements[0];
10371067

10381068
// Hook should intercept
1039-
let result = hook.handle_query(stmt, &ctx, &client).await;
1069+
let result = hook.handle_simple_query(stmt, &ctx, &client).await;
10401070
assert!(result.is_some());
10411071

10421072
// Parse a normal statement
10431073
let statements = parser.parse("SELECT 1").unwrap();
10441074
let stmt = &statements[0];
10451075

10461076
// Hook should not intercept
1047-
let result = hook.handle_query(stmt, &ctx, &client).await;
1077+
let result = hook.handle_simple_query(stmt, &ctx, &client).await;
10481078
assert!(result.is_none());
10491079
}
10501080
}

0 commit comments

Comments
 (0)