Skip to content

Commit d46851c

Browse files
committed
feat: add sql rewrite rule to resolve unqualified identifier
1 parent 5890a8a commit d46851c

File tree

1 file changed

+154
-0
lines changed

1 file changed

+154
-0
lines changed

datafusion-postgres/src/sql.rs

Lines changed: 154 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,17 @@
1+
use std::collections::HashSet;
12
use std::sync::Arc;
23

34
use datafusion::sql::sqlparser::ast::Expr;
45
use datafusion::sql::sqlparser::ast::Ident;
6+
use datafusion::sql::sqlparser::ast::OrderByKind;
7+
use datafusion::sql::sqlparser::ast::Query;
58
use datafusion::sql::sqlparser::ast::Select;
69
use datafusion::sql::sqlparser::ast::SelectItem;
710
use datafusion::sql::sqlparser::ast::SelectItemQualifiedWildcardKind;
811
use datafusion::sql::sqlparser::ast::SetExpr;
912
use datafusion::sql::sqlparser::ast::Statement;
13+
use datafusion::sql::sqlparser::ast::TableFactor;
14+
use datafusion::sql::sqlparser::ast::TableWithJoins;
1015
use datafusion::sql::sqlparser::dialect::PostgreSqlDialect;
1116
use datafusion::sql::sqlparser::parser::Parser;
1217
use datafusion::sql::sqlparser::parser::ParserError;
@@ -135,6 +140,122 @@ impl SqlStatementRewriteRule for AliasDuplicatedProjectionRewrite {
135140
}
136141
}
137142

143+
/// Prepend qualifier for order by or filter when there is qualified wildcard
144+
///
145+
/// Postgres allows unqualified identifier in ORDER BY and FILTER but it's not
146+
/// accepted by datafusion.
147+
#[derive(Debug)]
148+
pub struct ResolveUnqualifiedIdentifer;
149+
150+
impl ResolveUnqualifiedIdentifer {
151+
fn rewrite_unqualified_identifiers(query: &mut Box<Query>) {
152+
if let SetExpr::Select(select) = query.body.as_mut() {
153+
// Step 1: Find all table aliases from FROM and JOIN clauses.
154+
let table_aliases = Self::get_table_aliases(&select.from);
155+
156+
// Step 2: Check for a single qualified wildcard in the projection.
157+
let qualified_wildcard_alias = Self::get_qualified_wildcard_alias(&select.projection);
158+
if qualified_wildcard_alias.is_none() || table_aliases.is_empty() {
159+
return; // Conditions not met.
160+
}
161+
162+
let wildcard_alias = qualified_wildcard_alias.unwrap();
163+
164+
// Step 3: Rewrite expressions in the WHERE and ORDER BY clauses.
165+
if let Some(selection) = &mut select.selection {
166+
Self::rewrite_expr(selection, &wildcard_alias, &table_aliases);
167+
}
168+
169+
if let Some(OrderByKind::Expressions(order_by_exprs)) =
170+
query.order_by.as_mut().map(|o| &mut o.kind)
171+
{
172+
for order_by_expr in order_by_exprs {
173+
Self::rewrite_expr(&mut order_by_expr.expr, &wildcard_alias, &table_aliases);
174+
}
175+
}
176+
}
177+
}
178+
179+
fn get_table_aliases(tables: &[TableWithJoins]) -> HashSet<String> {
180+
let mut aliases = HashSet::new();
181+
for table_with_joins in tables {
182+
if let TableFactor::Table { alias, .. } = &table_with_joins.relation {
183+
if let Some(alias) = &alias {
184+
aliases.insert(alias.name.value.clone());
185+
}
186+
}
187+
for join in &table_with_joins.joins {
188+
if let TableFactor::Table { alias, .. } = &join.relation {
189+
if let Some(alias) = &alias {
190+
aliases.insert(alias.name.value.clone());
191+
}
192+
}
193+
}
194+
}
195+
aliases
196+
}
197+
198+
fn get_qualified_wildcard_alias(projection: &[SelectItem]) -> Option<String> {
199+
let mut qualified_wildcards = projection
200+
.iter()
201+
.filter_map(|item| {
202+
if let SelectItem::QualifiedWildcard(
203+
SelectItemQualifiedWildcardKind::ObjectName(objname),
204+
_,
205+
) = item
206+
{
207+
Some(
208+
objname
209+
.0
210+
.iter()
211+
.map(|v| v.as_ident().unwrap().value.clone())
212+
.collect::<Vec<_>>()
213+
.join("."),
214+
)
215+
} else {
216+
None
217+
}
218+
})
219+
.collect::<Vec<_>>();
220+
221+
if qualified_wildcards.len() == 1 {
222+
Some(qualified_wildcards.remove(0))
223+
} else {
224+
None
225+
}
226+
}
227+
228+
fn rewrite_expr(expr: &mut Expr, wildcard_alias: &str, table_aliases: &HashSet<String>) {
229+
match expr {
230+
Expr::Identifier(ident) => {
231+
// If the identifier is not a table alias itself, rewrite it.
232+
if !table_aliases.contains(&ident.value) {
233+
*expr = Expr::CompoundIdentifier(vec![
234+
Ident::new(wildcard_alias.to_string()),
235+
ident.clone(),
236+
]);
237+
}
238+
}
239+
Expr::BinaryOp { left, right, .. } => {
240+
Self::rewrite_expr(left, wildcard_alias, table_aliases);
241+
Self::rewrite_expr(right, wildcard_alias, table_aliases);
242+
}
243+
// Add more cases for other expression types as needed (e.g., `InList`, `Between`, etc.)
244+
_ => {}
245+
}
246+
}
247+
}
248+
249+
impl SqlStatementRewriteRule for ResolveUnqualifiedIdentifer {
250+
fn rewrite(&self, mut statement: Statement) -> Statement {
251+
if let Statement::Query(query) = &mut statement {
252+
Self::rewrite_unqualified_identifiers(query);
253+
}
254+
255+
statement
256+
}
257+
}
258+
138259
#[cfg(test)]
139260
mod tests {
140261
use super::*;
@@ -171,4 +292,37 @@ mod tests {
171292
"SELECT t1.oid, t2.* FROM tbl1 AS t1 JOIN tbl2 AS t2 ON t1.id = t2.id"
172293
);
173294
}
295+
296+
#[test]
297+
fn test_qualifier_prepend() {
298+
let rules: Vec<Arc<dyn SqlStatementRewriteRule>> =
299+
vec![Arc::new(ResolveUnqualifiedIdentifer)];
300+
301+
let sql = "SELECT n.* FROM pg_catalog.pg_namespace n WHERE nspname = 'pg_catalog' ORDER BY nspname";
302+
let statement = parse(sql).expect("Failed to parse").remove(0);
303+
304+
let statement = rewrite(statement, &rules);
305+
assert_eq!(
306+
statement.to_string(),
307+
"SELECT n.* FROM pg_catalog.pg_namespace AS n WHERE n.nspname = 'pg_catalog' ORDER BY n.nspname"
308+
);
309+
310+
let sql = "SELECT * FROM pg_catalog.pg_namespace ORDER BY nspname";
311+
let statement = parse(sql).expect("Failed to parse").remove(0);
312+
313+
let statement = rewrite(statement, &rules);
314+
assert_eq!(
315+
statement.to_string(),
316+
"SELECT * FROM pg_catalog.pg_namespace ORDER BY nspname"
317+
);
318+
319+
let sql = "SELECT n.oid,n.*,d.description FROM pg_catalog.pg_namespace n LEFT OUTER JOIN pg_catalog.pg_description d ON d.objoid=n.oid AND d.objsubid=0 AND d.classoid='pg_namespace' ORDER BY nspsname";
320+
let statement = parse(sql).expect("Failed to parse").remove(0);
321+
322+
let statement = rewrite(statement, &rules);
323+
assert_eq!(
324+
statement.to_string(),
325+
"SELECT n.oid, n.*, d.description FROM pg_catalog.pg_namespace AS n LEFT OUTER JOIN pg_catalog.pg_description AS d ON d.objoid = n.oid AND d.objsubid = 0 AND d.classoid = 'pg_namespace' ORDER BY n.nspsname"
326+
);
327+
}
174328
}

0 commit comments

Comments
 (0)