Skip to content
Merged
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
142 changes: 102 additions & 40 deletions datafusion-postgres/src/handlers.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ use std::sync::Arc;

use async_trait::async_trait;
use datafusion::arrow::datatypes::{DataType, Field, Schema};
use datafusion::common::ToDFSchema;
use datafusion::common::{ParamValues, ToDFSchema};
use datafusion::error::DataFusionError;
use datafusion::logical_expr::LogicalPlan;
use datafusion::prelude::*;
Expand Down Expand Up @@ -33,11 +33,30 @@ use datafusion_pg_catalog::sql::PostgresCompatibilityParser;

#[async_trait]
pub trait QueryHook: Send + Sync {
async fn handle_query(
/// called in simple query handler to return response directly
async fn handle_simple_query(
&self,
statement: &sqlparser::ast::Statement,
session_context: &SessionContext,
client: &dyn ClientInfo,
client: &(dyn ClientInfo + Send + Sync),
) -> Option<PgWireResult<Response>>;

/// called at extended query parse phase, for generating `LogicalPlan`from statement
async fn handle_extended_parse_query(
&self,
statement: &sqlparser::ast::Statement,
session_context: &SessionContext,
client: &(dyn ClientInfo + Send + Sync),
) -> Option<PgWireResult<LogicalPlan>>;

/// called at extended query execute phase, for query execution
async fn handle_extended_query(
&self,
statement: &sqlparser::ast::Statement,
logical_plan: &LogicalPlan,
params: &ParamValues,
session_context: &SessionContext,
client: &(dyn ClientInfo + Send + Sync),
) -> Option<PgWireResult<Response>>;
}

Expand Down Expand Up @@ -117,6 +136,7 @@ impl DfSessionService {
let parser = Arc::new(Parser {
session_context: session_context.clone(),
sql_parser: PostgresCompatibilityParser::new(),
query_hooks: query_hooks.clone(),
});
DfSessionService {
session_context,
Expand Down Expand Up @@ -492,7 +512,7 @@ impl SimpleQueryHandler for DfSessionService {
// Call query hooks with the parsed statement
for hook in &self.query_hooks {
if let Some(result) = hook
.handle_query(&statement, &self.session_context, client)
.handle_simple_query(&statement, &self.session_context, client)
.await
{
results.push(result?);
Expand Down Expand Up @@ -566,7 +586,7 @@ impl SimpleQueryHandler for DfSessionService {

#[async_trait]
impl ExtendedQueryHandler for DfSessionService {
type Statement = (String, Option<LogicalPlan>);
type Statement = (String, Option<(sqlparser::ast::Statement, LogicalPlan)>);
type QueryParser = Parser;

fn query_parser(&self) -> Arc<Self::QueryParser> {
Expand All @@ -581,7 +601,7 @@ impl ExtendedQueryHandler for DfSessionService {
where
C: ClientInfo + Unpin + Send + Sync,
{
if let (_, Some(plan)) = &target.statement {
if let (_, Some((_, plan))) = &target.statement {
let schema = plan.schema();
let fields = arrow_schema_to_pg_fields(schema.as_arrow(), &Format::UnifiedBinary)?;
let params = plan
Expand Down Expand Up @@ -613,7 +633,7 @@ impl ExtendedQueryHandler for DfSessionService {
where
C: ClientInfo + Unpin + Send + Sync,
{
if let (_, Some(plan)) = &target.statement.statement {
if let (_, Some((_, plan))) = &target.statement.statement {
let format = &target.result_column_format;
let schema = plan.schema();
let fields = arrow_schema_to_pg_fields(schema.as_arrow(), format)?;
Expand Down Expand Up @@ -643,21 +663,29 @@ impl ExtendedQueryHandler for DfSessionService {
log::debug!("Received execute extended query: {query}"); // Log for debugging

// Check query hooks first
for hook in &self.query_hooks {
// Parse the SQL to get the Statement for the hook
let sql = &portal.statement.statement.0;
let statements = self
.parser
.sql_parser
.parse(sql)
.map_err(|e| PgWireError::ApiError(Box::new(e)))?;

if let Some(statement) = statements.into_iter().next() {
if let Some(result) = hook
.handle_query(&statement, &self.session_context, client)
.await
{
return result;
if !self.query_hooks.is_empty() {
if let (_, Some((statement, plan))) = &portal.statement.statement {
// TODO: in the case where query hooks all return None, we do the param handling again later.
let param_types = plan
.get_parameter_types()
.map_err(|e| PgWireError::ApiError(Box::new(e)))?;

let param_values: ParamValues =
df::deserialize_parameters(portal, &ordered_param_types(&param_types))?;

for hook in &self.query_hooks {
if let Some(result) = hook
.handle_extended_query(
statement,
plan,
&param_values,
&self.session_context,
client,
)
.await
{
return result;
}
}
}
}
Expand Down Expand Up @@ -695,7 +723,7 @@ impl ExtendedQueryHandler for DfSessionService {
)));
}

if let (_, Some(plan)) = &portal.statement.statement {
if let (_, Some((_, plan))) = &portal.statement.statement {
let param_types = plan
.get_parameter_types()
.map_err(|e| PgWireError::ApiError(Box::new(e)))?;
Expand Down Expand Up @@ -780,6 +808,7 @@ async fn map_rows_affected_for_insert(df: &DataFrame) -> PgWireResult<Response>
pub struct Parser {
session_context: Arc<SessionContext>,
sql_parser: PostgresCompatibilityParser,
query_hooks: Vec<Arc<dyn QueryHook>>,
}

impl Parser {
Expand Down Expand Up @@ -834,24 +863,19 @@ impl Parser {

#[async_trait]
impl QueryParser for Parser {
type Statement = (String, Option<LogicalPlan>);
type Statement = (String, Option<(sqlparser::ast::Statement, LogicalPlan)>);

async fn parse_sql<C>(
&self,
_client: &C,
client: &C,
sql: &str,
_types: &[Type],
) -> PgWireResult<Self::Statement> {
) -> PgWireResult<Self::Statement>
where
C: ClientInfo + Unpin + Send + Sync,
{
log::debug!("Received parse extended query: {sql}"); // Log for debugging

// Check for transaction commands that shouldn't be parsed by DataFusion
if let Some(plan) = self
.try_shortcut_parse_plan(sql)
.map_err(|e| PgWireError::ApiError(Box::new(e)))?
{
return Ok((sql.to_string(), Some(plan)));
}

let mut statements = self
.sql_parser
.parse(sql)
Expand All @@ -862,15 +886,33 @@ impl QueryParser for Parser {

let statement = statements.remove(0);

// Check for transaction commands that shouldn't be parsed by DataFusion
if let Some(plan) = self
.try_shortcut_parse_plan(sql)
.map_err(|e| PgWireError::ApiError(Box::new(e)))?
{
return Ok((sql.to_string(), Some((statement, plan))));
}

let query = statement.to_string();

let context = &self.session_context;
let state = context.state();

for hook in &self.query_hooks {
if let Some(logical_plan) = hook
.handle_extended_parse_query(&statement, context, client)
.await
{
return Ok((query, Some((statement, logical_plan?))));
}
}

let logical_plan = state
.statement_to_plan(Statement::Statement(Box::new(statement)))
.statement_to_plan(Statement::Statement(Box::new(statement.clone())))
.await
.map_err(|e| PgWireError::ApiError(Box::new(e)))?;
Ok((query, Some(logical_plan)))
Ok((query, Some((statement, logical_plan))))
}
}

Expand Down Expand Up @@ -1010,18 +1052,38 @@ mod tests {

#[async_trait]
impl QueryHook for TestHook {
async fn handle_query(
async fn handle_simple_query(
&self,
statement: &sqlparser::ast::Statement,
_ctx: &SessionContext,
_client: &dyn ClientInfo,
_client: &(dyn ClientInfo + Sync + Send),
) -> Option<PgWireResult<Response>> {
if statement.to_string().contains("magic") {
Some(Ok(Response::EmptyQuery))
} else {
None
}
}

async fn handle_extended_parse_query(
&self,
_statement: &sqlparser::ast::Statement,
_session_context: &SessionContext,
_client: &(dyn ClientInfo + Send + Sync),
) -> Option<PgWireResult<LogicalPlan>> {
todo!()
}

async fn handle_extended_query(
&self,
_statement: &sqlparser::ast::Statement,
_logical_plan: &LogicalPlan,
_params: &ParamValues,
_session_context: &SessionContext,
_client: &(dyn ClientInfo + Send + Sync),
) -> Option<PgWireResult<Response>> {
todo!();
}
}

#[tokio::test]
Expand All @@ -1036,15 +1098,15 @@ mod tests {
let stmt = &statements[0];

// Hook should intercept
let result = hook.handle_query(stmt, &ctx, &client).await;
let result = hook.handle_simple_query(stmt, &ctx, &client).await;
assert!(result.is_some());

// Parse a normal statement
let statements = parser.parse("SELECT 1").unwrap();
let stmt = &statements[0];

// Hook should not intercept
let result = hook.handle_query(stmt, &ctx, &client).await;
let result = hook.handle_simple_query(stmt, &ctx, &client).await;
assert!(result.is_none());
}
}
Loading