Skip to content

Commit 90eba67

Browse files
authored
Merge branch 'master' into feat/pg-catalog-sql-dbeaver
2 parents 07f2193 + 73c126e commit 90eba67

File tree

3 files changed

+205
-4
lines changed

3 files changed

+205
-4
lines changed

datafusion-postgres/src/handlers.rs

Lines changed: 30 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,12 @@ use std::collections::HashMap;
22
use std::sync::Arc;
33

44
use crate::auth::{AuthManager, Permission, ResourceType};
5+
use crate::sql::{parse, rewrite, AliasDuplicatedProjectionRewrite, SqlStatementRewriteRule};
56
use async_trait::async_trait;
67
use datafusion::arrow::datatypes::DataType;
78
use datafusion::logical_expr::LogicalPlan;
89
use datafusion::prelude::*;
10+
use datafusion::sql::parser::Statement;
911
use log::warn;
1012
use pgwire::api::auth::noop::NoopStartupHandler;
1113
use pgwire::api::auth::StartupHandler;
@@ -64,21 +66,26 @@ pub struct DfSessionService {
6466
parser: Arc<Parser>,
6567
timezone: Arc<Mutex<String>>,
6668
auth_manager: Arc<AuthManager>,
69+
sql_rewrite_rules: Vec<Arc<dyn SqlStatementRewriteRule>>,
6770
}
6871

6972
impl DfSessionService {
7073
pub fn new(
7174
session_context: Arc<SessionContext>,
7275
auth_manager: Arc<AuthManager>,
7376
) -> DfSessionService {
77+
let sql_rewrite_rules: Vec<Arc<dyn SqlStatementRewriteRule>> =
78+
vec![Arc::new(AliasDuplicatedProjectionRewrite)];
7479
let parser = Arc::new(Parser {
7580
session_context: session_context.clone(),
81+
sql_rewrite_rules: sql_rewrite_rules.clone(),
7682
});
7783
DfSessionService {
7884
session_context,
7985
parser,
8086
timezone: Arc::new(Mutex::new("UTC".to_string())),
8187
auth_manager,
88+
sql_rewrite_rules,
8289
}
8390
}
8491

@@ -307,8 +314,17 @@ impl SimpleQueryHandler for DfSessionService {
307314
where
308315
C: ClientInfo + Unpin + Send + Sync,
309316
{
310-
let query_lower = query.to_lowercase().trim().to_string();
311317
log::debug!("Received query: {query}"); // Log the query for debugging
318+
let mut statements = parse(query).map_err(|e| PgWireError::ApiError(Box::new(e)))?;
319+
320+
// TODO: deal with multiple statements
321+
let mut statement = statements.remove(0);
322+
323+
// Attempt to rewrite
324+
statement = rewrite(statement, &self.sql_rewrite_rules);
325+
326+
// TODO: improve statement check by using statement directly
327+
let query_lower = statement.to_string().to_lowercase().trim().to_string();
312328

313329
// Check permissions for the query (skip for SET, transaction, and SHOW statements)
314330
if !query_lower.starts_with("set")
@@ -525,6 +541,7 @@ impl ExtendedQueryHandler for DfSessionService {
525541

526542
pub struct Parser {
527543
session_context: Arc<SessionContext>,
544+
sql_rewrite_rules: Vec<Arc<dyn SqlStatementRewriteRule>>,
528545
}
529546

530547
#[async_trait]
@@ -537,14 +554,23 @@ impl QueryParser for Parser {
537554
sql: &str,
538555
_types: &[Type],
539556
) -> PgWireResult<Self::Statement> {
540-
log::debug!("Received parse extended query: {sql}"); // Log for debugging
557+
log::debug!("Received parse extended query: {sql}"); // Log for
558+
// debugging
559+
let mut statements = parse(sql).map_err(|e| PgWireError::ApiError(Box::new(e)))?;
560+
let mut statement = statements.remove(0);
561+
562+
// Attempt to rewrite
563+
statement = rewrite(statement, &self.sql_rewrite_rules);
564+
565+
let query = statement.to_string();
566+
541567
let context = &self.session_context;
542568
let state = context.state();
543569
let logical_plan = state
544-
.create_logical_plan(sql)
570+
.statement_to_plan(Statement::Statement(Box::new(statement)))
545571
.await
546572
.map_err(|e| PgWireError::ApiError(Box::new(e)))?;
547-
Ok((sql.to_string(), logical_plan))
573+
Ok((query, logical_plan))
548574
}
549575
}
550576

datafusion-postgres/src/lib.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
mod handlers;
22
pub mod pg_catalog;
3+
mod sql;
34

45
use std::fs::File;
56
use std::io::{BufReader, Error as IOError, ErrorKind};

datafusion-postgres/src/sql.rs

Lines changed: 174 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,174 @@
1+
use std::sync::Arc;
2+
3+
use datafusion::sql::sqlparser::ast::Expr;
4+
use datafusion::sql::sqlparser::ast::Ident;
5+
use datafusion::sql::sqlparser::ast::Select;
6+
use datafusion::sql::sqlparser::ast::SelectItem;
7+
use datafusion::sql::sqlparser::ast::SelectItemQualifiedWildcardKind;
8+
use datafusion::sql::sqlparser::ast::SetExpr;
9+
use datafusion::sql::sqlparser::ast::Statement;
10+
use datafusion::sql::sqlparser::dialect::PostgreSqlDialect;
11+
use datafusion::sql::sqlparser::parser::Parser;
12+
use datafusion::sql::sqlparser::parser::ParserError;
13+
14+
pub fn parse(sql: &str) -> Result<Vec<Statement>, ParserError> {
15+
let dialect = PostgreSqlDialect {};
16+
17+
Parser::parse_sql(&dialect, sql)
18+
}
19+
20+
pub fn rewrite(mut s: Statement, rules: &[Arc<dyn SqlStatementRewriteRule>]) -> Statement {
21+
for rule in rules {
22+
s = rule.rewrite(s);
23+
}
24+
25+
s
26+
}
27+
28+
pub trait SqlStatementRewriteRule: Send + Sync {
29+
fn rewrite(&self, s: Statement) -> Statement;
30+
}
31+
32+
/// Rewrite rule for adding alias to duplicated projection
33+
///
34+
/// This rule is to deal with sql like `SELECT n.oid, n.* FROM n`, which is a
35+
/// valid statement in postgres. But datafusion treat it as illegal because of
36+
/// duplicated column oid in projection.
37+
///
38+
/// This rule will add alias to column, when there is a wildcard found in
39+
/// projection.
40+
#[derive(Debug)]
41+
pub struct AliasDuplicatedProjectionRewrite;
42+
43+
impl AliasDuplicatedProjectionRewrite {
44+
// Rewrites a SELECT statement to alias explicit columns from the same table as a qualified wildcard.
45+
fn rewrite_select_with_alias(select: &mut Box<Select>) {
46+
// 1. Collect all table aliases from qualified wildcards.
47+
let mut wildcard_tables = Vec::new();
48+
let mut has_simple_wildcard = false;
49+
for p in &select.projection {
50+
match p {
51+
SelectItem::QualifiedWildcard(name, _) => match name {
52+
SelectItemQualifiedWildcardKind::ObjectName(objname) => {
53+
// for n.oid,
54+
let idents = objname
55+
.0
56+
.iter()
57+
.map(|v| v.as_ident().unwrap().value.clone())
58+
.collect::<Vec<_>>()
59+
.join(".");
60+
61+
wildcard_tables.push(idents);
62+
}
63+
SelectItemQualifiedWildcardKind::Expr(_expr) => {
64+
// FIXME:
65+
}
66+
},
67+
SelectItem::Wildcard(_) => {
68+
has_simple_wildcard = true;
69+
}
70+
_ => {}
71+
}
72+
}
73+
74+
// If there are no qualified wildcards, there's nothing to do.
75+
if wildcard_tables.is_empty() && !has_simple_wildcard {
76+
return;
77+
}
78+
79+
// 2. Rewrite the projection, adding aliases to matching columns.
80+
let mut new_projection = vec![];
81+
for p in select.projection.drain(..) {
82+
match p {
83+
SelectItem::UnnamedExpr(expr) => {
84+
let alias_partial = match &expr {
85+
// Case for `oid` (unqualified identifier)
86+
Expr::Identifier(ident) => Some(ident.clone()),
87+
// Case for `n.oid` (compound identifier)
88+
Expr::CompoundIdentifier(idents) => {
89+
// compare every ident but the last
90+
if idents.len() > 1 {
91+
let table_name = &idents[..idents.len() - 1]
92+
.iter()
93+
.map(|i| i.value.clone())
94+
.collect::<Vec<_>>()
95+
.join(".");
96+
if wildcard_tables.iter().any(|name| name == table_name) {
97+
Some(idents[idents.len() - 1].clone())
98+
} else {
99+
None
100+
}
101+
} else {
102+
None
103+
}
104+
}
105+
_ => None,
106+
};
107+
108+
if let Some(name) = alias_partial {
109+
let alias = format!("__alias_{name}");
110+
new_projection.push(SelectItem::ExprWithAlias {
111+
expr,
112+
alias: Ident::new(alias),
113+
});
114+
} else {
115+
new_projection.push(SelectItem::UnnamedExpr(expr));
116+
}
117+
}
118+
// Preserve existing aliases and wildcards.
119+
_ => new_projection.push(p),
120+
}
121+
}
122+
select.projection = new_projection;
123+
}
124+
}
125+
126+
impl SqlStatementRewriteRule for AliasDuplicatedProjectionRewrite {
127+
fn rewrite(&self, mut statement: Statement) -> Statement {
128+
if let Statement::Query(query) = &mut statement {
129+
if let SetExpr::Select(select) = query.body.as_mut() {
130+
Self::rewrite_select_with_alias(select);
131+
}
132+
}
133+
134+
statement
135+
}
136+
}
137+
138+
#[cfg(test)]
139+
mod tests {
140+
use super::*;
141+
142+
#[test]
143+
fn test_alias_rewrite() {
144+
let rules: Vec<Arc<dyn SqlStatementRewriteRule>> =
145+
vec![Arc::new(AliasDuplicatedProjectionRewrite)];
146+
147+
let sql = "SELECT n.oid, n.* FROM pg_catalog.pg_namespace n";
148+
let statement = parse(sql).expect("Failed to parse").remove(0);
149+
150+
let statement = rewrite(statement, &rules);
151+
assert_eq!(
152+
statement.to_string(),
153+
"SELECT n.oid AS __alias_oid, n.* FROM pg_catalog.pg_namespace AS n"
154+
);
155+
156+
let sql = "SELECT oid, * FROM pg_catalog.pg_namespace";
157+
let statement = parse(sql).expect("Failed to parse").remove(0);
158+
159+
let statement = rewrite(statement, &rules);
160+
assert_eq!(
161+
statement.to_string(),
162+
"SELECT oid AS __alias_oid, * FROM pg_catalog.pg_namespace"
163+
);
164+
165+
let sql = "SELECT t1.oid, t2.* FROM tbl1 AS t1 JOIN tbl2 AS t2 ON t1.id = t2.id";
166+
let statement = parse(sql).expect("Failed to parse").remove(0);
167+
168+
let statement = rewrite(statement, &rules);
169+
assert_eq!(
170+
statement.to_string(),
171+
"SELECT t1.oid, t2.* FROM tbl1 AS t1 JOIN tbl2 AS t2 ON t1.id = t2.id"
172+
);
173+
}
174+
}

0 commit comments

Comments
 (0)