diff --git a/datafusion-postgres/src/handlers.rs b/datafusion-postgres/src/handlers.rs index 66e6ebe..3ce6148 100644 --- a/datafusion-postgres/src/handlers.rs +++ b/datafusion-postgres/src/handlers.rs @@ -2,10 +2,12 @@ use std::collections::HashMap; use std::sync::Arc; use crate::auth::{AuthManager, Permission, ResourceType}; +use crate::sql::{parse, rewrite, AliasDuplicatedProjectionRewrite, SqlStatementRewriteRule}; use async_trait::async_trait; use datafusion::arrow::datatypes::DataType; use datafusion::logical_expr::LogicalPlan; use datafusion::prelude::*; +use datafusion::sql::parser::Statement; use pgwire::api::auth::noop::NoopStartupHandler; use pgwire::api::auth::StartupHandler; use pgwire::api::portal::{Format, Portal}; @@ -63,6 +65,7 @@ pub struct DfSessionService { parser: Arc, timezone: Arc>, auth_manager: Arc, + sql_rewrite_rules: Vec>, } impl DfSessionService { @@ -70,14 +73,18 @@ impl DfSessionService { session_context: Arc, auth_manager: Arc, ) -> DfSessionService { + let sql_rewrite_rules: Vec> = + vec![Arc::new(AliasDuplicatedProjectionRewrite)]; let parser = Arc::new(Parser { session_context: session_context.clone(), + sql_rewrite_rules: sql_rewrite_rules.clone(), }); DfSessionService { session_context, parser, timezone: Arc::new(Mutex::new("UTC".to_string())), auth_manager, + sql_rewrite_rules, } } @@ -308,8 +315,17 @@ impl SimpleQueryHandler for DfSessionService { where C: ClientInfo + Unpin + Send + Sync, { - let query_lower = query.to_lowercase().trim().to_string(); log::debug!("Received query: {query}"); // Log the query for debugging + let mut statements = parse(query).map_err(|e| PgWireError::ApiError(Box::new(e)))?; + + // TODO: deal with multiple statements + let mut statement = statements.remove(0); + + // Attempt to rewrite + statement = rewrite(statement, &self.sql_rewrite_rules); + + // TODO: improve statement check by using statement directly + let query_lower = statement.to_string().to_lowercase().trim().to_string(); // Check permissions for the query (skip for SET, transaction, and SHOW statements) if !query_lower.starts_with("set") @@ -526,6 +542,7 @@ impl ExtendedQueryHandler for DfSessionService { pub struct Parser { session_context: Arc, + sql_rewrite_rules: Vec>, } #[async_trait] @@ -538,14 +555,23 @@ impl QueryParser for Parser { sql: &str, _types: &[Type], ) -> PgWireResult { - log::debug!("Received parse extended query: {sql}"); // Log for debugging + log::debug!("Received parse extended query: {sql}"); // Log for + // debugging + let mut statements = parse(sql).map_err(|e| PgWireError::ApiError(Box::new(e)))?; + let mut statement = statements.remove(0); + + // Attempt to rewrite + statement = rewrite(statement, &self.sql_rewrite_rules); + + let query = statement.to_string(); + let context = &self.session_context; let state = context.state(); let logical_plan = state - .create_logical_plan(sql) + .statement_to_plan(Statement::Statement(Box::new(statement))) .await .map_err(|e| PgWireError::ApiError(Box::new(e)))?; - Ok((sql.to_string(), logical_plan)) + Ok((query, logical_plan)) } } diff --git a/datafusion-postgres/src/lib.rs b/datafusion-postgres/src/lib.rs index ba43d00..b08e830 100644 --- a/datafusion-postgres/src/lib.rs +++ b/datafusion-postgres/src/lib.rs @@ -1,5 +1,6 @@ mod handlers; pub mod pg_catalog; +mod sql; use std::fs::File; use std::io::{BufReader, Error as IOError, ErrorKind}; diff --git a/datafusion-postgres/src/sql.rs b/datafusion-postgres/src/sql.rs new file mode 100644 index 0000000..9002811 --- /dev/null +++ b/datafusion-postgres/src/sql.rs @@ -0,0 +1,174 @@ +use std::sync::Arc; + +use datafusion::sql::sqlparser::ast::Expr; +use datafusion::sql::sqlparser::ast::Ident; +use datafusion::sql::sqlparser::ast::Select; +use datafusion::sql::sqlparser::ast::SelectItem; +use datafusion::sql::sqlparser::ast::SelectItemQualifiedWildcardKind; +use datafusion::sql::sqlparser::ast::SetExpr; +use datafusion::sql::sqlparser::ast::Statement; +use datafusion::sql::sqlparser::dialect::PostgreSqlDialect; +use datafusion::sql::sqlparser::parser::Parser; +use datafusion::sql::sqlparser::parser::ParserError; + +pub fn parse(sql: &str) -> Result, ParserError> { + let dialect = PostgreSqlDialect {}; + + Parser::parse_sql(&dialect, sql) +} + +pub fn rewrite(mut s: Statement, rules: &[Arc]) -> Statement { + for rule in rules { + s = rule.rewrite(s); + } + + s +} + +pub trait SqlStatementRewriteRule: Send + Sync { + fn rewrite(&self, s: Statement) -> Statement; +} + +/// Rewrite rule for adding alias to duplicated projection +/// +/// This rule is to deal with sql like `SELECT n.oid, n.* FROM n`, which is a +/// valid statement in postgres. But datafusion treat it as illegal because of +/// duplicated column oid in projection. +/// +/// This rule will add alias to column, when there is a wildcard found in +/// projection. +#[derive(Debug)] +pub struct AliasDuplicatedProjectionRewrite; + +impl AliasDuplicatedProjectionRewrite { + // Rewrites a SELECT statement to alias explicit columns from the same table as a qualified wildcard. + fn rewrite_select_with_alias(select: &mut Box