Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 30 additions & 4 deletions datafusion-postgres/src/handlers.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,12 @@ use std::collections::HashMap;
use std::sync::Arc;

use crate::auth::{AuthManager, Permission, ResourceType};
use crate::sql::{parse, rewrite, AliasDuplicatedProjectionRewrite, SqlStatementRewriteRule};
use async_trait::async_trait;
use datafusion::arrow::datatypes::DataType;
use datafusion::logical_expr::LogicalPlan;
use datafusion::prelude::*;
use datafusion::sql::parser::Statement;
use pgwire::api::auth::noop::NoopStartupHandler;
use pgwire::api::auth::StartupHandler;
use pgwire::api::portal::{Format, Portal};
Expand Down Expand Up @@ -63,21 +65,26 @@ pub struct DfSessionService {
parser: Arc<Parser>,
timezone: Arc<Mutex<String>>,
auth_manager: Arc<AuthManager>,
sql_rewrite_rules: Vec<Arc<dyn SqlStatementRewriteRule>>,
}

impl DfSessionService {
pub fn new(
session_context: Arc<SessionContext>,
auth_manager: Arc<AuthManager>,
) -> DfSessionService {
let sql_rewrite_rules: Vec<Arc<dyn SqlStatementRewriteRule>> =
vec![Arc::new(AliasDuplicatedProjectionRewrite)];
let parser = Arc::new(Parser {
session_context: session_context.clone(),
sql_rewrite_rules: sql_rewrite_rules.clone(),
});
DfSessionService {
session_context,
parser,
timezone: Arc::new(Mutex::new("UTC".to_string())),
auth_manager,
sql_rewrite_rules,
}
}

Expand Down Expand Up @@ -308,8 +315,17 @@ impl SimpleQueryHandler for DfSessionService {
where
C: ClientInfo + Unpin + Send + Sync,
{
let query_lower = query.to_lowercase().trim().to_string();
log::debug!("Received query: {query}"); // Log the query for debugging
let mut statements = parse(query).map_err(|e| PgWireError::ApiError(Box::new(e)))?;

// TODO: deal with multiple statements
let mut statement = statements.remove(0);

// Attempt to rewrite
statement = rewrite(statement, &self.sql_rewrite_rules);

// TODO: improve statement check by using statement directly
let query_lower = statement.to_string().to_lowercase().trim().to_string();

// Check permissions for the query (skip for SET, transaction, and SHOW statements)
if !query_lower.starts_with("set")
Expand Down Expand Up @@ -526,6 +542,7 @@ impl ExtendedQueryHandler for DfSessionService {

pub struct Parser {
session_context: Arc<SessionContext>,
sql_rewrite_rules: Vec<Arc<dyn SqlStatementRewriteRule>>,
}

#[async_trait]
Expand All @@ -538,14 +555,23 @@ impl QueryParser for Parser {
sql: &str,
_types: &[Type],
) -> PgWireResult<Self::Statement> {
log::debug!("Received parse extended query: {sql}"); // Log for debugging
log::debug!("Received parse extended query: {sql}"); // Log for
// debugging
let mut statements = parse(sql).map_err(|e| PgWireError::ApiError(Box::new(e)))?;
let mut statement = statements.remove(0);

// Attempt to rewrite
statement = rewrite(statement, &self.sql_rewrite_rules);

let query = statement.to_string();

let context = &self.session_context;
let state = context.state();
let logical_plan = state
.create_logical_plan(sql)
.statement_to_plan(Statement::Statement(Box::new(statement)))
.await
.map_err(|e| PgWireError::ApiError(Box::new(e)))?;
Ok((sql.to_string(), logical_plan))
Ok((query, logical_plan))
}
}

Expand Down
1 change: 1 addition & 0 deletions datafusion-postgres/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
mod handlers;
pub mod pg_catalog;
mod sql;

use std::fs::File;
use std::io::{BufReader, Error as IOError, ErrorKind};
Expand Down
174 changes: 174 additions & 0 deletions datafusion-postgres/src/sql.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,174 @@
use std::sync::Arc;

use datafusion::sql::sqlparser::ast::Expr;
use datafusion::sql::sqlparser::ast::Ident;
use datafusion::sql::sqlparser::ast::Select;
use datafusion::sql::sqlparser::ast::SelectItem;
use datafusion::sql::sqlparser::ast::SelectItemQualifiedWildcardKind;
use datafusion::sql::sqlparser::ast::SetExpr;
use datafusion::sql::sqlparser::ast::Statement;
use datafusion::sql::sqlparser::dialect::PostgreSqlDialect;
use datafusion::sql::sqlparser::parser::Parser;
use datafusion::sql::sqlparser::parser::ParserError;

pub fn parse(sql: &str) -> Result<Vec<Statement>, ParserError> {
let dialect = PostgreSqlDialect {};

Parser::parse_sql(&dialect, sql)
}

pub fn rewrite(mut s: Statement, rules: &[Arc<dyn SqlStatementRewriteRule>]) -> Statement {
for rule in rules {
s = rule.rewrite(s);
}

s
}

pub trait SqlStatementRewriteRule: Send + Sync {
fn rewrite(&self, s: Statement) -> Statement;
}

/// Rewrite rule for adding alias to duplicated projection
///
/// This rule is to deal with sql like `SELECT n.oid, n.* FROM n`, which is a
/// valid statement in postgres. But datafusion treat it as illegal because of
/// duplicated column oid in projection.
///
/// This rule will add alias to column, when there is a wildcard found in
/// projection.
#[derive(Debug)]
pub struct AliasDuplicatedProjectionRewrite;

impl AliasDuplicatedProjectionRewrite {
// Rewrites a SELECT statement to alias explicit columns from the same table as a qualified wildcard.
fn rewrite_select_with_alias(select: &mut Box<Select>) {
// 1. Collect all table aliases from qualified wildcards.
let mut wildcard_tables = Vec::new();
let mut has_simple_wildcard = false;
for p in &select.projection {
match p {
SelectItem::QualifiedWildcard(name, _) => match name {
SelectItemQualifiedWildcardKind::ObjectName(objname) => {
// for n.oid,
let idents = objname
.0
.iter()
.map(|v| v.as_ident().unwrap().value.clone())
.collect::<Vec<_>>()
.join(".");

wildcard_tables.push(idents);
}
SelectItemQualifiedWildcardKind::Expr(_expr) => {
// FIXME:
}
},
SelectItem::Wildcard(_) => {
has_simple_wildcard = true;
}
_ => {}
}
}

// If there are no qualified wildcards, there's nothing to do.
if wildcard_tables.is_empty() && !has_simple_wildcard {
return;
}

// 2. Rewrite the projection, adding aliases to matching columns.
let mut new_projection = vec![];
for p in select.projection.drain(..) {
match p {
SelectItem::UnnamedExpr(expr) => {
let alias_partial = match &expr {
// Case for `oid` (unqualified identifier)
Expr::Identifier(ident) => Some(ident.clone()),
// Case for `n.oid` (compound identifier)
Expr::CompoundIdentifier(idents) => {
// compare every ident but the last
if idents.len() > 1 {
let table_name = &idents[..idents.len() - 1]
.iter()
.map(|i| i.value.clone())
.collect::<Vec<_>>()
.join(".");
if wildcard_tables.iter().any(|name| name == table_name) {
Some(idents[idents.len() - 1].clone())
} else {
None
}
} else {
None
}
}
_ => None,
};

if let Some(name) = alias_partial {
let alias = format!("__alias_{name}");
new_projection.push(SelectItem::ExprWithAlias {
expr,
alias: Ident::new(alias),
});
} else {
new_projection.push(SelectItem::UnnamedExpr(expr));
}
}
// Preserve existing aliases and wildcards.
_ => new_projection.push(p),
}
}
select.projection = new_projection;
}
}

impl SqlStatementRewriteRule for AliasDuplicatedProjectionRewrite {
fn rewrite(&self, mut statement: Statement) -> Statement {
if let Statement::Query(query) = &mut statement {
if let SetExpr::Select(select) = query.body.as_mut() {
Self::rewrite_select_with_alias(select);
}
}

statement
}
}

#[cfg(test)]
mod tests {
use super::*;

#[test]
fn test_alias_rewrite() {
let rules: Vec<Arc<dyn SqlStatementRewriteRule>> =
vec![Arc::new(AliasDuplicatedProjectionRewrite)];

let sql = "SELECT n.oid, n.* FROM pg_catalog.pg_namespace n";
let statement = parse(sql).expect("Failed to parse").remove(0);

let statement = rewrite(statement, &rules);
assert_eq!(
statement.to_string(),
"SELECT n.oid AS __alias_oid, n.* FROM pg_catalog.pg_namespace AS n"
);

let sql = "SELECT oid, * FROM pg_catalog.pg_namespace";
let statement = parse(sql).expect("Failed to parse").remove(0);

let statement = rewrite(statement, &rules);
assert_eq!(
statement.to_string(),
"SELECT oid AS __alias_oid, * FROM pg_catalog.pg_namespace"
);

let sql = "SELECT t1.oid, t2.* FROM tbl1 AS t1 JOIN tbl2 AS t2 ON t1.id = t2.id";
let statement = parse(sql).expect("Failed to parse").remove(0);

let statement = rewrite(statement, &rules);
assert_eq!(
statement.to_string(),
"SELECT t1.oid, t2.* FROM tbl1 AS t1 JOIN tbl2 AS t2 ON t1.id = t2.id"
);
}
}
Loading