Skip to content

Commit 81d0f42

Browse files
committed
switch to accepting datafusion statements
1 parent 2142e40 commit 81d0f42

File tree

2 files changed

+36
-32
lines changed

2 files changed

+36
-32
lines changed

datafusion-postgres/src/handlers.rs

Lines changed: 35 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -30,17 +30,14 @@ use arrow_pg::datatypes::{arrow_schema_to_pg_fields, into_pg_type};
3030
use datafusion_pg_catalog::pg_catalog::context::{Permission, ResourceType};
3131
use datafusion_pg_catalog::sql::PostgresCompatibilityParser;
3232

33-
/// Statement type represents a parsed SQL query with its logical plan
34-
pub type Statement = (String, Option<LogicalPlan>);
35-
3633
#[async_trait]
3734
pub trait QueryHook: Send + Sync {
3835
async fn handle_query(
3936
&self,
4037
statement: &Statement,
4138
session_context: &SessionContext,
4239
client: &dyn ClientInfo,
43-
) -> Option<PgWireResult<Vec<Response<'static>>>>;
40+
) -> Option<PgWireResult<Vec<Response>>>;
4441
}
4542

4643
#[derive(Debug, Clone, Copy, PartialEq)]
@@ -66,8 +63,11 @@ pub struct HandlerFactory {
6663

6764
impl HandlerFactory {
6865
pub fn new(session_context: Arc<SessionContext>, auth_manager: Arc<AuthManager>) -> Self {
69-
let session_service =
70-
Arc::new(DfSessionService::new(session_context, auth_manager.clone(), None));
66+
let session_service = Arc::new(DfSessionService::new(
67+
session_context,
68+
auth_manager.clone(),
69+
None,
70+
));
7171
HandlerFactory { session_service }
7272
}
7373
}
@@ -491,26 +491,16 @@ impl SimpleQueryHandler for DfSessionService {
491491
self.check_query_permission(client, &query).await?;
492492
}
493493

494-
// Parse query into logical plan for hook
495-
if let Some(hook) = &self.query_hook {
496-
// Create logical plan from query
497-
let state = self.session_context.state();
498-
let logical_plan_result = state.create_logical_plan(query).await;
499-
500-
if let Ok(logical_plan) = logical_plan_result {
501-
// Optimize the logical plan
502-
let optimized_result = state.optimize(&logical_plan);
503-
504-
if let Ok(optimized) = optimized_result {
505-
// Create Statement tuple and call hook
506-
let statement = (query.to_string(), optimized);
507-
if let Some(result) = hook.handle_query(&statement, &self.session_context, client).await {
508-
return result;
509-
}
494+
// Call query hook with the parsed statement
495+
if let Some(hook) = &self.query_hook {
496+
let wrapped_statement = Statement::Statement(Box::new(statement.clone()));
497+
if let Some(result) = hook
498+
.handle_query(&wrapped_statement, &self.session_context, client)
499+
.await
500+
{
501+
return result;
510502
}
511503
}
512-
// If parsing or optimization fails, we'll continue with normal processing
513-
}
514504

515505
if let Some(resp) = self
516506
.try_respond_set_statements(client, &query_lower)
@@ -578,7 +568,7 @@ impl SimpleQueryHandler for DfSessionService {
578568

579569
#[async_trait]
580570
impl ExtendedQueryHandler for DfSessionService {
581-
type Statement = Statement;
571+
type Statement = (String, Option<LogicalPlan>);
582572
type QueryParser = Parser;
583573

584574
fn query_parser(&self) -> Arc<Self::QueryParser> {
@@ -656,11 +646,25 @@ impl ExtendedQueryHandler for DfSessionService {
656646

657647
// Check query hook first
658648
if let Some(hook) = &self.query_hook {
659-
if let Some(result) = hook.handle_query(&portal.statement.statement, &self.session_context, client).await {
660-
// Convert Vec<Response> to single Response
661-
// For extended query, we expect a single response
662-
if let Some(response) = result?.into_iter().next() {
663-
return Ok(response);
649+
// Parse the SQL to get the Statement for the hook
650+
let sql = &portal.statement.statement.0;
651+
let statements = self
652+
.parser
653+
.sql_parser
654+
.parse(sql)
655+
.map_err(|e| PgWireError::ApiError(Box::new(e)))?;
656+
657+
if let Some(statement) = statements.into_iter().next() {
658+
let wrapped_statement = Statement::Statement(Box::new(statement));
659+
if let Some(result) = hook
660+
.handle_query(&wrapped_statement, &self.session_context, client)
661+
.await
662+
{
663+
// Convert Vec<Response> to single Response
664+
// For extended query, we expect a single response
665+
if let Some(response) = result?.into_iter().next() {
666+
return Ok(response);
667+
}
664668
}
665669
}
666670
}
@@ -837,7 +841,7 @@ impl Parser {
837841

838842
#[async_trait]
839843
impl QueryParser for Parser {
840-
type Statement = Statement;
844+
type Statement = (String, Option<LogicalPlan>);
841845

842846
async fn parse_sql<C>(
843847
&self,

datafusion-postgres/src/lib.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ use tokio_rustls::TlsAcceptor;
2020

2121
use crate::auth::AuthManager;
2222
use handlers::HandlerFactory;
23-
pub use handlers::{DfSessionService, Parser, QueryHook, Statement};
23+
pub use handlers::{DfSessionService, Parser, QueryHook};
2424

2525
/// re-exports
2626
pub use arrow_pg;

0 commit comments

Comments
 (0)