diff --git a/datafusion-postgres/src/sql.rs b/datafusion-postgres/src/sql.rs index 92a7170..3a7fe77 100644 --- a/datafusion-postgres/src/sql.rs +++ b/datafusion-postgres/src/sql.rs @@ -1,4 +1,5 @@ use std::collections::HashSet; +use std::ops::ControlFlow; use std::sync::Arc; use datafusion::sql::sqlparser::ast::Expr; @@ -13,6 +14,8 @@ use datafusion::sql::sqlparser::ast::Statement; use datafusion::sql::sqlparser::ast::TableFactor; use datafusion::sql::sqlparser::ast::TableWithJoins; use datafusion::sql::sqlparser::ast::Value; +use datafusion::sql::sqlparser::ast::VisitMut; +use datafusion::sql::sqlparser::ast::VisitorMut; use datafusion::sql::sqlparser::dialect::PostgreSqlDialect; use datafusion::sql::sqlparser::parser::Parser; use datafusion::sql::sqlparser::parser::ParserError; @@ -272,8 +275,16 @@ impl RemoveUnsupportedTypes { Self { unsupported_types } } +} + +struct RemoveUnsupportedTypesVisitor<'a> { + unsupported_types: &'a HashSet, +} + +impl<'a> VisitorMut for RemoveUnsupportedTypesVisitor<'a> { + type Break = (); - fn rewrite_expr_unsupported_types(&self, expr: &mut Expr) { + fn pre_visit_expr(&mut self, expr: &mut Expr) -> ControlFlow { match expr { // This is the key part: identify constants with type annotations. Expr::TypedString { value, data_type } => { @@ -297,29 +308,21 @@ impl RemoveUnsupportedTypes { *expr = *value.clone(); } } - // Handle binary operations by recursively rewriting both sides. - Expr::BinaryOp { left, right, .. } => { - self.rewrite_expr_unsupported_types(left); - self.rewrite_expr_unsupported_types(right); - } // Add more match arms for other expression types (e.g., `Function`, `InList`) as needed. _ => {} } + + ControlFlow::Continue(()) } } impl SqlStatementRewriteRule for RemoveUnsupportedTypes { - fn rewrite(&self, mut s: Statement) -> Statement { - // Traverse the AST to find the WHERE clause and rewrite it. - if let Statement::Query(query) = &mut s { - if let SetExpr::Select(select) = query.body.as_mut() { - if let Some(expr) = &mut select.selection { - self.rewrite_expr_unsupported_types(expr); - } - } - } - - s + fn rewrite(&self, mut statement: Statement) -> Statement { + let mut visitor = RemoveUnsupportedTypesVisitor { + unsupported_types: &self.unsupported_types, + }; + let _ = statement.visit(&mut visitor); + statement } } @@ -327,35 +330,36 @@ impl SqlStatementRewriteRule for RemoveUnsupportedTypes { mod tests { use super::*; + macro_rules! assert_rewrite { + ($rules:expr, $orig:expr, $rewt:expr) => { + let sql = $orig; + let statement = parse(sql).expect("Failed to parse").remove(0); + + let statement = rewrite(statement, $rules); + assert_eq!(statement.to_string(), $rewt); + }; + } + #[test] fn test_alias_rewrite() { let rules: Vec> = vec![Arc::new(AliasDuplicatedProjectionRewrite)]; - let sql = "SELECT n.oid, n.* FROM pg_catalog.pg_namespace n"; - let statement = parse(sql).expect("Failed to parse").remove(0); - - let statement = rewrite(statement, &rules); - assert_eq!( - statement.to_string(), + assert_rewrite!( + &rules, + "SELECT n.oid, n.* FROM pg_catalog.pg_namespace n", "SELECT n.oid AS __alias_oid, n.* FROM pg_catalog.pg_namespace AS n" ); - let sql = "SELECT oid, * FROM pg_catalog.pg_namespace"; - let statement = parse(sql).expect("Failed to parse").remove(0); - - let statement = rewrite(statement, &rules); - assert_eq!( - statement.to_string(), + assert_rewrite!( + &rules, + "SELECT oid, * FROM pg_catalog.pg_namespace", "SELECT oid AS __alias_oid, * FROM pg_catalog.pg_namespace" ); - let sql = "SELECT t1.oid, t2.* FROM tbl1 AS t1 JOIN tbl2 AS t2 ON t1.id = t2.id"; - let statement = parse(sql).expect("Failed to parse").remove(0); - - let statement = rewrite(statement, &rules); - assert_eq!( - statement.to_string(), + assert_rewrite!( + &rules, + "SELECT t1.oid, t2.* FROM tbl1 AS t1 JOIN tbl2 AS t2 ON t1.id = t2.id", "SELECT t1.oid, t2.* FROM tbl1 AS t1 JOIN tbl2 AS t2 ON t1.id = t2.id" ); } @@ -365,30 +369,21 @@ mod tests { let rules: Vec> = vec![Arc::new(ResolveUnqualifiedIdentifer)]; - let sql = "SELECT n.* FROM pg_catalog.pg_namespace n WHERE nspname = 'pg_catalog' ORDER BY nspname"; - let statement = parse(sql).expect("Failed to parse").remove(0); - - let statement = rewrite(statement, &rules); - assert_eq!( - statement.to_string(), + assert_rewrite!( + &rules, + "SELECT n.* FROM pg_catalog.pg_namespace n WHERE nspname = 'pg_catalog' ORDER BY nspname", "SELECT n.* FROM pg_catalog.pg_namespace AS n WHERE n.nspname = 'pg_catalog' ORDER BY n.nspname" ); - let sql = "SELECT * FROM pg_catalog.pg_namespace ORDER BY nspname"; - let statement = parse(sql).expect("Failed to parse").remove(0); - - let statement = rewrite(statement, &rules); - assert_eq!( - statement.to_string(), + assert_rewrite!( + &rules, + "SELECT * FROM pg_catalog.pg_namespace ORDER BY nspname", "SELECT * FROM pg_catalog.pg_namespace ORDER BY nspname" ); - let sql = "SELECT n.oid,n.*,d.description FROM pg_catalog.pg_namespace n LEFT OUTER JOIN pg_catalog.pg_description d ON d.objoid=n.oid AND d.objsubid=0 AND d.classoid='pg_namespace' ORDER BY nspsname"; - let statement = parse(sql).expect("Failed to parse").remove(0); - - let statement = rewrite(statement, &rules); - assert_eq!( - statement.to_string(), + assert_rewrite!( + &rules, + "SELECT n.oid,n.*,d.description FROM pg_catalog.pg_namespace n LEFT OUTER JOIN pg_catalog.pg_description d ON d.objoid=n.oid AND d.objsubid=0 AND d.classoid='pg_namespace' ORDER BY nspsname", "SELECT n.oid, n.*, d.description FROM pg_catalog.pg_namespace AS n LEFT OUTER JOIN pg_catalog.pg_description AS d ON d.objoid = n.oid AND d.objsubid = 0 AND d.classoid = 'pg_namespace' ORDER BY n.nspsname" ); } @@ -398,21 +393,27 @@ mod tests { let rules: Vec> = vec![Arc::new(RemoveUnsupportedTypes::new())]; - let sql = "SELECT n.* FROM pg_catalog.pg_namespace n WHERE n.nspname = 'pg_catalog'::regclass ORDER BY n.nspname"; - let statement = parse(sql).expect("Failed to parse").remove(0); - - let statement = rewrite(statement, &rules); - assert_eq!( - statement.to_string(), + assert_rewrite!( + &rules, + "SELECT n.* FROM pg_catalog.pg_namespace n WHERE n.nspname = 'pg_catalog'::regclass ORDER BY n.nspname", "SELECT n.* FROM pg_catalog.pg_namespace AS n WHERE n.nspname = 'pg_catalog' ORDER BY n.nspname" ); - let sql = "SELECT n.* FROM pg_catalog.pg_namespace n WHERE n.nspname = 'pg_catalog' ORDER BY n.nspname"; - let statement = parse(sql).expect("Failed to parse").remove(0); + assert_rewrite!( + &rules, + "SELECT n.* FROM pg_catalog.pg_namespace n WHERE n.oid = 1 AND n.nspname = 'pg_catalog'::regclass ORDER BY n.nspname", + "SELECT n.* FROM pg_catalog.pg_namespace AS n WHERE n.oid = 1 AND n.nspname = 'pg_catalog' ORDER BY n.nspname" + ); + + assert_rewrite!( + &rules, + "SELECT n.oid,n.*,d.description FROM pg_catalog.pg_namespace n LEFT OUTER JOIN pg_catalog.pg_description d ON d.objoid=n.oid AND d.objsubid=0 AND d.classoid='pg_namespace'::regclass ORDER BY nspname", + "SELECT n.oid, n.*, d.description FROM pg_catalog.pg_namespace AS n LEFT OUTER JOIN pg_catalog.pg_description AS d ON d.objoid = n.oid AND d.objsubid = 0 AND d.classoid = 'pg_namespace' ORDER BY nspname" + ); - let statement = rewrite(statement, &rules); - assert_eq!( - statement.to_string(), + assert_rewrite!( + &rules, + "SELECT n.* FROM pg_catalog.pg_namespace n WHERE n.nspname = 'pg_catalog' ORDER BY n.nspname", "SELECT n.* FROM pg_catalog.pg_namespace AS n WHERE n.nspname = 'pg_catalog' ORDER BY n.nspname" ); }