Skip to content
Merged
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
106 changes: 68 additions & 38 deletions datafusion-postgres/src/handlers.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ use std::sync::Arc;

use async_trait::async_trait;
use datafusion::arrow::datatypes::{DataType, Field, Schema};
use datafusion::common::ToDFSchema;
use datafusion::common::{ParamValues, ToDFSchema};
use datafusion::error::DataFusionError;
use datafusion::logical_expr::LogicalPlan;
use datafusion::prelude::*;
Expand Down Expand Up @@ -33,11 +33,22 @@ use datafusion_pg_catalog::sql::PostgresCompatibilityParser;

#[async_trait]
pub trait QueryHook: Send + Sync {
async fn handle_query(
/// called in simple query handler to return response directly
async fn handle_simple_query(
&self,
statement: &sqlparser::ast::Statement,
session_context: &SessionContext,
client: &dyn ClientInfo,
client: &(dyn ClientInfo + Send + Sync),
) -> Option<PgWireResult<Response>>;

/// called at extended query execute phase, for query execution
async fn handle_extended_query(
&self,
statement: &sqlparser::ast::Statement,
logical_plan: &LogicalPlan,
params: &ParamValues,
session_context: &SessionContext,
client: &(dyn ClientInfo + Send + Sync),
) -> Option<PgWireResult<Response>>;
}

Expand Down Expand Up @@ -492,7 +503,7 @@ impl SimpleQueryHandler for DfSessionService {
// Call query hooks with the parsed statement
for hook in &self.query_hooks {
if let Some(result) = hook
.handle_query(&statement, &self.session_context, client)
.handle_simple_query(&statement, &self.session_context, client)
.await
{
results.push(result?);
Expand Down Expand Up @@ -566,7 +577,7 @@ impl SimpleQueryHandler for DfSessionService {

#[async_trait]
impl ExtendedQueryHandler for DfSessionService {
type Statement = (String, Option<LogicalPlan>);
type Statement = (String, Option<(sqlparser::ast::Statement, LogicalPlan)>);
type QueryParser = Parser;

fn query_parser(&self) -> Arc<Self::QueryParser> {
Expand All @@ -581,7 +592,7 @@ impl ExtendedQueryHandler for DfSessionService {
where
C: ClientInfo + Unpin + Send + Sync,
{
if let (_, Some(plan)) = &target.statement {
if let (_, Some((_, plan))) = &target.statement {
let schema = plan.schema();
let fields = arrow_schema_to_pg_fields(schema.as_arrow(), &Format::UnifiedBinary)?;
let params = plan
Expand Down Expand Up @@ -613,7 +624,7 @@ impl ExtendedQueryHandler for DfSessionService {
where
C: ClientInfo + Unpin + Send + Sync,
{
if let (_, Some(plan)) = &target.statement.statement {
if let (_, Some((_, plan))) = &target.statement.statement {
let format = &target.result_column_format;
let schema = plan.schema();
let fields = arrow_schema_to_pg_fields(schema.as_arrow(), format)?;
Expand Down Expand Up @@ -643,21 +654,29 @@ impl ExtendedQueryHandler for DfSessionService {
log::debug!("Received execute extended query: {query}"); // Log for debugging

// Check query hooks first
for hook in &self.query_hooks {
// Parse the SQL to get the Statement for the hook
let sql = &portal.statement.statement.0;
let statements = self
.parser
.sql_parser
.parse(sql)
.map_err(|e| PgWireError::ApiError(Box::new(e)))?;

if let Some(statement) = statements.into_iter().next() {
if let Some(result) = hook
.handle_query(&statement, &self.session_context, client)
.await
{
return result;
if !self.query_hooks.is_empty() {
if let (_, Some((statement, plan))) = &portal.statement.statement {
// TODO: in the case where query hooks all return None, we do the param handling again later.
let param_types = plan
.get_parameter_types()
.map_err(|e| PgWireError::ApiError(Box::new(e)))?;

let param_values: ParamValues =
df::deserialize_parameters(portal, &ordered_param_types(&param_types))?;

for hook in &self.query_hooks {
if let Some(result) = hook
.handle_extended_query(
statement,
plan,
&param_values,
&self.session_context,
client,
)
.await
{
return result;
}
}
}
}
Expand Down Expand Up @@ -695,7 +714,7 @@ impl ExtendedQueryHandler for DfSessionService {
)));
}

if let (_, Some(plan)) = &portal.statement.statement {
if let (_, Some((_, plan))) = &portal.statement.statement {
let param_types = plan
.get_parameter_types()
.map_err(|e| PgWireError::ApiError(Box::new(e)))?;
Expand Down Expand Up @@ -834,7 +853,7 @@ impl Parser {

#[async_trait]
impl QueryParser for Parser {
type Statement = (String, Option<LogicalPlan>);
type Statement = (String, Option<(sqlparser::ast::Statement, LogicalPlan)>);

async fn parse_sql<C>(
&self,
Expand All @@ -844,14 +863,6 @@ impl QueryParser for Parser {
) -> PgWireResult<Self::Statement> {
log::debug!("Received parse extended query: {sql}"); // Log for debugging

// Check for transaction commands that shouldn't be parsed by DataFusion
if let Some(plan) = self
.try_shortcut_parse_plan(sql)
.map_err(|e| PgWireError::ApiError(Box::new(e)))?
{
return Ok((sql.to_string(), Some(plan)));
}

let mut statements = self
.sql_parser
.parse(sql)
Expand All @@ -862,15 +873,23 @@ impl QueryParser for Parser {

let statement = statements.remove(0);

// Check for transaction commands that shouldn't be parsed by DataFusion
if let Some(plan) = self
.try_shortcut_parse_plan(sql)
.map_err(|e| PgWireError::ApiError(Box::new(e)))?
{
return Ok((sql.to_string(), Some((statement, plan))));
}

let query = statement.to_string();

let context = &self.session_context;
let state = context.state();
let logical_plan = state
.statement_to_plan(Statement::Statement(Box::new(statement)))
.statement_to_plan(Statement::Statement(Box::new(statement.clone())))
.await
.map_err(|e| PgWireError::ApiError(Box::new(e)))?;
Ok((query, Some(logical_plan)))
Ok((query, Some((statement, logical_plan))))
}
}

Expand Down Expand Up @@ -1010,18 +1029,29 @@ mod tests {

#[async_trait]
impl QueryHook for TestHook {
async fn handle_query(
async fn handle_simple_query(
&self,
statement: &sqlparser::ast::Statement,
_ctx: &SessionContext,
_client: &dyn ClientInfo,
_client: &(dyn ClientInfo + Sync + Send),
) -> Option<PgWireResult<Response>> {
if statement.to_string().contains("magic") {
Some(Ok(Response::EmptyQuery))
} else {
None
}
}

async fn handle_extended_query(
&self,
_statement: &sqlparser::ast::Statement,
_logical_plan: &LogicalPlan,
_params: &ParamValues,
_session_context: &SessionContext,
_client: &(dyn ClientInfo + Send + Sync),
) -> Option<PgWireResult<Response>> {
todo!();
}
}

#[tokio::test]
Expand All @@ -1036,15 +1066,15 @@ mod tests {
let stmt = &statements[0];

// Hook should intercept
let result = hook.handle_query(stmt, &ctx, &client).await;
let result = hook.handle_simple_query(stmt, &ctx, &client).await;
assert!(result.is_some());

// Parse a normal statement
let statements = parser.parse("SELECT 1").unwrap();
let stmt = &statements[0];

// Hook should not intercept
let result = hook.handle_query(stmt, &ctx, &client).await;
let result = hook.handle_simple_query(stmt, &ctx, &client).await;
assert!(result.is_none());
}
}
Loading