Skip to content

Commit daadf6f

Browse files
authored
feat: add sql rewrite rule to resolve unqualified identifier (#144)
* feat: add sql rewrite rule to resolve unqualified identifier * chore: lint fix
1 parent 5890a8a commit daadf6f

File tree

1 file changed

+156
-0
lines changed

1 file changed

+156
-0
lines changed

datafusion-postgres/src/sql.rs

Lines changed: 156 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,17 @@
1+
use std::collections::HashSet;
12
use std::sync::Arc;
23

34
use datafusion::sql::sqlparser::ast::Expr;
45
use datafusion::sql::sqlparser::ast::Ident;
6+
use datafusion::sql::sqlparser::ast::OrderByKind;
7+
use datafusion::sql::sqlparser::ast::Query;
58
use datafusion::sql::sqlparser::ast::Select;
69
use datafusion::sql::sqlparser::ast::SelectItem;
710
use datafusion::sql::sqlparser::ast::SelectItemQualifiedWildcardKind;
811
use datafusion::sql::sqlparser::ast::SetExpr;
912
use datafusion::sql::sqlparser::ast::Statement;
13+
use datafusion::sql::sqlparser::ast::TableFactor;
14+
use datafusion::sql::sqlparser::ast::TableWithJoins;
1015
use datafusion::sql::sqlparser::dialect::PostgreSqlDialect;
1116
use datafusion::sql::sqlparser::parser::Parser;
1217
use datafusion::sql::sqlparser::parser::ParserError;
@@ -135,6 +140,124 @@ impl SqlStatementRewriteRule for AliasDuplicatedProjectionRewrite {
135140
}
136141
}
137142

143+
/// Prepend qualifier for order by or filter when there is qualified wildcard
144+
///
145+
/// Postgres allows unqualified identifier in ORDER BY and FILTER but it's not
146+
/// accepted by datafusion.
147+
#[derive(Debug)]
148+
pub struct ResolveUnqualifiedIdentifer;
149+
150+
impl ResolveUnqualifiedIdentifer {
151+
fn rewrite_unqualified_identifiers(query: &mut Box<Query>) {
152+
if let SetExpr::Select(select) = query.body.as_mut() {
153+
// Step 1: Find all table aliases from FROM and JOIN clauses.
154+
let table_aliases = Self::get_table_aliases(&select.from);
155+
156+
// Step 2: Check for a single qualified wildcard in the projection.
157+
let qualified_wildcard_alias = Self::get_qualified_wildcard_alias(&select.projection);
158+
if qualified_wildcard_alias.is_none() || table_aliases.is_empty() {
159+
return; // Conditions not met.
160+
}
161+
162+
let wildcard_alias = qualified_wildcard_alias.unwrap();
163+
164+
// Step 3: Rewrite expressions in the WHERE and ORDER BY clauses.
165+
if let Some(selection) = &mut select.selection {
166+
Self::rewrite_expr(selection, &wildcard_alias, &table_aliases);
167+
}
168+
169+
if let Some(OrderByKind::Expressions(order_by_exprs)) =
170+
query.order_by.as_mut().map(|o| &mut o.kind)
171+
{
172+
for order_by_expr in order_by_exprs {
173+
Self::rewrite_expr(&mut order_by_expr.expr, &wildcard_alias, &table_aliases);
174+
}
175+
}
176+
}
177+
}
178+
179+
fn get_table_aliases(tables: &[TableWithJoins]) -> HashSet<String> {
180+
let mut aliases = HashSet::new();
181+
for table_with_joins in tables {
182+
if let TableFactor::Table {
183+
alias: Some(alias), ..
184+
} = &table_with_joins.relation
185+
{
186+
aliases.insert(alias.name.value.clone());
187+
}
188+
for join in &table_with_joins.joins {
189+
if let TableFactor::Table {
190+
alias: Some(alias), ..
191+
} = &join.relation
192+
{
193+
aliases.insert(alias.name.value.clone());
194+
}
195+
}
196+
}
197+
aliases
198+
}
199+
200+
fn get_qualified_wildcard_alias(projection: &[SelectItem]) -> Option<String> {
201+
let mut qualified_wildcards = projection
202+
.iter()
203+
.filter_map(|item| {
204+
if let SelectItem::QualifiedWildcard(
205+
SelectItemQualifiedWildcardKind::ObjectName(objname),
206+
_,
207+
) = item
208+
{
209+
Some(
210+
objname
211+
.0
212+
.iter()
213+
.map(|v| v.as_ident().unwrap().value.clone())
214+
.collect::<Vec<_>>()
215+
.join("."),
216+
)
217+
} else {
218+
None
219+
}
220+
})
221+
.collect::<Vec<_>>();
222+
223+
if qualified_wildcards.len() == 1 {
224+
Some(qualified_wildcards.remove(0))
225+
} else {
226+
None
227+
}
228+
}
229+
230+
fn rewrite_expr(expr: &mut Expr, wildcard_alias: &str, table_aliases: &HashSet<String>) {
231+
match expr {
232+
Expr::Identifier(ident) => {
233+
// If the identifier is not a table alias itself, rewrite it.
234+
if !table_aliases.contains(&ident.value) {
235+
*expr = Expr::CompoundIdentifier(vec![
236+
Ident::new(wildcard_alias.to_string()),
237+
ident.clone(),
238+
]);
239+
}
240+
}
241+
Expr::BinaryOp { left, right, .. } => {
242+
Self::rewrite_expr(left, wildcard_alias, table_aliases);
243+
Self::rewrite_expr(right, wildcard_alias, table_aliases);
244+
}
245+
// Add more cases for other expression types as needed (e.g., `InList`, `Between`, etc.)
246+
_ => {}
247+
}
248+
}
249+
}
250+
251+
impl SqlStatementRewriteRule for ResolveUnqualifiedIdentifer {
252+
fn rewrite(&self, mut statement: Statement) -> Statement {
253+
if let Statement::Query(query) = &mut statement {
254+
Self::rewrite_unqualified_identifiers(query);
255+
}
256+
257+
statement
258+
}
259+
}
260+
138261
#[cfg(test)]
139262
mod tests {
140263
use super::*;
@@ -171,4 +294,37 @@ mod tests {
171294
"SELECT t1.oid, t2.* FROM tbl1 AS t1 JOIN tbl2 AS t2 ON t1.id = t2.id"
172295
);
173296
}
297+
298+
#[test]
299+
fn test_qualifier_prepend() {
300+
let rules: Vec<Arc<dyn SqlStatementRewriteRule>> =
301+
vec![Arc::new(ResolveUnqualifiedIdentifer)];
302+
303+
let sql = "SELECT n.* FROM pg_catalog.pg_namespace n WHERE nspname = 'pg_catalog' ORDER BY nspname";
304+
let statement = parse(sql).expect("Failed to parse").remove(0);
305+
306+
let statement = rewrite(statement, &rules);
307+
assert_eq!(
308+
statement.to_string(),
309+
"SELECT n.* FROM pg_catalog.pg_namespace AS n WHERE n.nspname = 'pg_catalog' ORDER BY n.nspname"
310+
);
311+
312+
let sql = "SELECT * FROM pg_catalog.pg_namespace ORDER BY nspname";
313+
let statement = parse(sql).expect("Failed to parse").remove(0);
314+
315+
let statement = rewrite(statement, &rules);
316+
assert_eq!(
317+
statement.to_string(),
318+
"SELECT * FROM pg_catalog.pg_namespace ORDER BY nspname"
319+
);
320+
321+
let sql = "SELECT n.oid,n.*,d.description FROM pg_catalog.pg_namespace n LEFT OUTER JOIN pg_catalog.pg_description d ON d.objoid=n.oid AND d.objsubid=0 AND d.classoid='pg_namespace' ORDER BY nspsname";
322+
let statement = parse(sql).expect("Failed to parse").remove(0);
323+
324+
let statement = rewrite(statement, &rules);
325+
assert_eq!(
326+
statement.to_string(),
327+
"SELECT n.oid, n.*, d.description FROM pg_catalog.pg_namespace AS n LEFT OUTER JOIN pg_catalog.pg_description AS d ON d.objoid = n.oid AND d.objsubid = 0 AND d.classoid = 'pg_namespace' ORDER BY n.nspsname"
328+
);
329+
}
174330
}

0 commit comments

Comments
 (0)