Skip to content

Commit 0b6490f

Browse files
committed
Merge branch 'master' into feat/pg-catalog-sql-dbeaver
2 parents 95f0487 + c01cf50 commit 0b6490f

File tree

1 file changed

+65
-64
lines changed

1 file changed

+65
-64
lines changed

datafusion-postgres/src/sql.rs

Lines changed: 65 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
use std::collections::HashSet;
2+
use std::ops::ControlFlow;
23
use std::sync::Arc;
34

45
use datafusion::sql::sqlparser::ast::Expr;
@@ -13,6 +14,8 @@ use datafusion::sql::sqlparser::ast::Statement;
1314
use datafusion::sql::sqlparser::ast::TableFactor;
1415
use datafusion::sql::sqlparser::ast::TableWithJoins;
1516
use datafusion::sql::sqlparser::ast::Value;
17+
use datafusion::sql::sqlparser::ast::VisitMut;
18+
use datafusion::sql::sqlparser::ast::VisitorMut;
1619
use datafusion::sql::sqlparser::dialect::PostgreSqlDialect;
1720
use datafusion::sql::sqlparser::parser::Parser;
1821
use datafusion::sql::sqlparser::parser::ParserError;
@@ -272,8 +275,16 @@ impl RemoveUnsupportedTypes {
272275

273276
Self { unsupported_types }
274277
}
278+
}
279+
280+
struct RemoveUnsupportedTypesVisitor<'a> {
281+
unsupported_types: &'a HashSet<String>,
282+
}
283+
284+
impl<'a> VisitorMut for RemoveUnsupportedTypesVisitor<'a> {
285+
type Break = ();
275286

276-
fn rewrite_expr_unsupported_types(&self, expr: &mut Expr) {
287+
fn pre_visit_expr(&mut self, expr: &mut Expr) -> ControlFlow<Self::Break> {
277288
match expr {
278289
// This is the key part: identify constants with type annotations.
279290
Expr::TypedString { value, data_type } => {
@@ -297,65 +308,58 @@ impl RemoveUnsupportedTypes {
297308
*expr = *value.clone();
298309
}
299310
}
300-
// Handle binary operations by recursively rewriting both sides.
301-
Expr::BinaryOp { left, right, .. } => {
302-
self.rewrite_expr_unsupported_types(left);
303-
self.rewrite_expr_unsupported_types(right);
304-
}
305311
// Add more match arms for other expression types (e.g., `Function`, `InList`) as needed.
306312
_ => {}
307313
}
314+
315+
ControlFlow::Continue(())
308316
}
309317
}
310318

311319
impl SqlStatementRewriteRule for RemoveUnsupportedTypes {
312-
fn rewrite(&self, mut s: Statement) -> Statement {
313-
// Traverse the AST to find the WHERE clause and rewrite it.
314-
if let Statement::Query(query) = &mut s {
315-
if let SetExpr::Select(select) = query.body.as_mut() {
316-
if let Some(expr) = &mut select.selection {
317-
self.rewrite_expr_unsupported_types(expr);
318-
}
319-
}
320-
}
321-
322-
s
320+
fn rewrite(&self, mut statement: Statement) -> Statement {
321+
let mut visitor = RemoveUnsupportedTypesVisitor {
322+
unsupported_types: &self.unsupported_types,
323+
};
324+
let _ = statement.visit(&mut visitor);
325+
statement
323326
}
324327
}
325328

326329
#[cfg(test)]
327330
mod tests {
328331
use super::*;
329332

333+
macro_rules! assert_rewrite {
334+
($rules:expr, $orig:expr, $rewt:expr) => {
335+
let sql = $orig;
336+
let statement = parse(sql).expect("Failed to parse").remove(0);
337+
338+
let statement = rewrite(statement, $rules);
339+
assert_eq!(statement.to_string(), $rewt);
340+
};
341+
}
342+
330343
#[test]
331344
fn test_alias_rewrite() {
332345
let rules: Vec<Arc<dyn SqlStatementRewriteRule>> =
333346
vec![Arc::new(AliasDuplicatedProjectionRewrite)];
334347

335-
let sql = "SELECT n.oid, n.* FROM pg_catalog.pg_namespace n";
336-
let statement = parse(sql).expect("Failed to parse").remove(0);
337-
338-
let statement = rewrite(statement, &rules);
339-
assert_eq!(
340-
statement.to_string(),
348+
assert_rewrite!(
349+
&rules,
350+
"SELECT n.oid, n.* FROM pg_catalog.pg_namespace n",
341351
"SELECT n.oid AS __alias_oid, n.* FROM pg_catalog.pg_namespace AS n"
342352
);
343353

344-
let sql = "SELECT oid, * FROM pg_catalog.pg_namespace";
345-
let statement = parse(sql).expect("Failed to parse").remove(0);
346-
347-
let statement = rewrite(statement, &rules);
348-
assert_eq!(
349-
statement.to_string(),
354+
assert_rewrite!(
355+
&rules,
356+
"SELECT oid, * FROM pg_catalog.pg_namespace",
350357
"SELECT oid AS __alias_oid, * FROM pg_catalog.pg_namespace"
351358
);
352359

353-
let sql = "SELECT t1.oid, t2.* FROM tbl1 AS t1 JOIN tbl2 AS t2 ON t1.id = t2.id";
354-
let statement = parse(sql).expect("Failed to parse").remove(0);
355-
356-
let statement = rewrite(statement, &rules);
357-
assert_eq!(
358-
statement.to_string(),
360+
assert_rewrite!(
361+
&rules,
362+
"SELECT t1.oid, t2.* FROM tbl1 AS t1 JOIN tbl2 AS t2 ON t1.id = t2.id",
359363
"SELECT t1.oid, t2.* FROM tbl1 AS t1 JOIN tbl2 AS t2 ON t1.id = t2.id"
360364
);
361365

@@ -374,30 +378,21 @@ mod tests {
374378
let rules: Vec<Arc<dyn SqlStatementRewriteRule>> =
375379
vec![Arc::new(ResolveUnqualifiedIdentifer)];
376380

377-
let sql = "SELECT n.* FROM pg_catalog.pg_namespace n WHERE nspname = 'pg_catalog' ORDER BY nspname";
378-
let statement = parse(sql).expect("Failed to parse").remove(0);
379-
380-
let statement = rewrite(statement, &rules);
381-
assert_eq!(
382-
statement.to_string(),
381+
assert_rewrite!(
382+
&rules,
383+
"SELECT n.* FROM pg_catalog.pg_namespace n WHERE nspname = 'pg_catalog' ORDER BY nspname",
383384
"SELECT n.* FROM pg_catalog.pg_namespace AS n WHERE n.nspname = 'pg_catalog' ORDER BY n.nspname"
384385
);
385386

386-
let sql = "SELECT * FROM pg_catalog.pg_namespace ORDER BY nspname";
387-
let statement = parse(sql).expect("Failed to parse").remove(0);
388-
389-
let statement = rewrite(statement, &rules);
390-
assert_eq!(
391-
statement.to_string(),
387+
assert_rewrite!(
388+
&rules,
389+
"SELECT * FROM pg_catalog.pg_namespace ORDER BY nspname",
392390
"SELECT * FROM pg_catalog.pg_namespace ORDER BY nspname"
393391
);
394392

395-
let sql = "SELECT n.oid,n.*,d.description FROM pg_catalog.pg_namespace n LEFT OUTER JOIN pg_catalog.pg_description d ON d.objoid=n.oid AND d.objsubid=0 AND d.classoid='pg_namespace' ORDER BY nspsname";
396-
let statement = parse(sql).expect("Failed to parse").remove(0);
397-
398-
let statement = rewrite(statement, &rules);
399-
assert_eq!(
400-
statement.to_string(),
393+
assert_rewrite!(
394+
&rules,
395+
"SELECT n.oid,n.*,d.description FROM pg_catalog.pg_namespace n LEFT OUTER JOIN pg_catalog.pg_description d ON d.objoid=n.oid AND d.objsubid=0 AND d.classoid='pg_namespace' ORDER BY nspsname",
401396
"SELECT n.oid, n.*, d.description FROM pg_catalog.pg_namespace AS n LEFT OUTER JOIN pg_catalog.pg_description AS d ON d.objoid = n.oid AND d.objsubid = 0 AND d.classoid = 'pg_namespace' ORDER BY n.nspsname"
402397
);
403398
}
@@ -407,21 +402,27 @@ mod tests {
407402
let rules: Vec<Arc<dyn SqlStatementRewriteRule>> =
408403
vec![Arc::new(RemoveUnsupportedTypes::new())];
409404

410-
let sql = "SELECT n.* FROM pg_catalog.pg_namespace n WHERE n.nspname = 'pg_catalog'::regclass ORDER BY n.nspname";
411-
let statement = parse(sql).expect("Failed to parse").remove(0);
412-
413-
let statement = rewrite(statement, &rules);
414-
assert_eq!(
415-
statement.to_string(),
405+
assert_rewrite!(
406+
&rules,
407+
"SELECT n.* FROM pg_catalog.pg_namespace n WHERE n.nspname = 'pg_catalog'::regclass ORDER BY n.nspname",
416408
"SELECT n.* FROM pg_catalog.pg_namespace AS n WHERE n.nspname = 'pg_catalog' ORDER BY n.nspname"
417409
);
418410

419-
let sql = "SELECT n.* FROM pg_catalog.pg_namespace n WHERE n.nspname = 'pg_catalog' ORDER BY n.nspname";
420-
let statement = parse(sql).expect("Failed to parse").remove(0);
411+
assert_rewrite!(
412+
&rules,
413+
"SELECT n.* FROM pg_catalog.pg_namespace n WHERE n.oid = 1 AND n.nspname = 'pg_catalog'::regclass ORDER BY n.nspname",
414+
"SELECT n.* FROM pg_catalog.pg_namespace AS n WHERE n.oid = 1 AND n.nspname = 'pg_catalog' ORDER BY n.nspname"
415+
);
421416

422-
let statement = rewrite(statement, &rules);
423-
assert_eq!(
424-
statement.to_string(),
417+
assert_rewrite!(
418+
&rules,
419+
"SELECT n.oid,n.*,d.description FROM pg_catalog.pg_namespace n LEFT OUTER JOIN pg_catalog.pg_description d ON d.objoid=n.oid AND d.objsubid=0 AND d.classoid='pg_namespace'::regclass ORDER BY nspname",
420+
"SELECT n.oid, n.*, d.description FROM pg_catalog.pg_namespace AS n LEFT OUTER JOIN pg_catalog.pg_description AS d ON d.objoid = n.oid AND d.objsubid = 0 AND d.classoid = 'pg_namespace' ORDER BY nspname"
421+
);
422+
423+
assert_rewrite!(
424+
&rules,
425+
"SELECT n.* FROM pg_catalog.pg_namespace n WHERE n.nspname = 'pg_catalog' ORDER BY n.nspname",
425426
"SELECT n.* FROM pg_catalog.pg_namespace AS n WHERE n.nspname = 'pg_catalog' ORDER BY n.nspname"
426427
);
427428
}

0 commit comments

Comments
 (0)