diff --git a/datafusion-postgres/src/handlers.rs b/datafusion-postgres/src/handlers.rs index 3ce6148..f5d5d52 100644 --- a/datafusion-postgres/src/handlers.rs +++ b/datafusion-postgres/src/handlers.rs @@ -2,7 +2,10 @@ use std::collections::HashMap; use std::sync::Arc; use crate::auth::{AuthManager, Permission, ResourceType}; -use crate::sql::{parse, rewrite, AliasDuplicatedProjectionRewrite, SqlStatementRewriteRule}; +use crate::sql::{ + parse, rewrite, AliasDuplicatedProjectionRewrite, RemoveUnsupportedTypes, + ResolveUnqualifiedIdentifer, SqlStatementRewriteRule, +}; use async_trait::async_trait; use datafusion::arrow::datatypes::DataType; use datafusion::logical_expr::LogicalPlan; @@ -73,8 +76,11 @@ impl DfSessionService { session_context: Arc, auth_manager: Arc, ) -> DfSessionService { - let sql_rewrite_rules: Vec> = - vec![Arc::new(AliasDuplicatedProjectionRewrite)]; + let sql_rewrite_rules: Vec> = vec![ + Arc::new(AliasDuplicatedProjectionRewrite), + Arc::new(ResolveUnqualifiedIdentifer), + Arc::new(RemoveUnsupportedTypes::new()), + ]; let parser = Arc::new(Parser { session_context: session_context.clone(), sql_rewrite_rules: sql_rewrite_rules.clone(), diff --git a/datafusion-postgres/src/sql.rs b/datafusion-postgres/src/sql.rs index 7261847..92a7170 100644 --- a/datafusion-postgres/src/sql.rs +++ b/datafusion-postgres/src/sql.rs @@ -12,6 +12,7 @@ use datafusion::sql::sqlparser::ast::SetExpr; use datafusion::sql::sqlparser::ast::Statement; use datafusion::sql::sqlparser::ast::TableFactor; use datafusion::sql::sqlparser::ast::TableWithJoins; +use datafusion::sql::sqlparser::ast::Value; use datafusion::sql::sqlparser::dialect::PostgreSqlDialect; use datafusion::sql::sqlparser::parser::Parser; use datafusion::sql::sqlparser::parser::ParserError; @@ -258,6 +259,70 @@ impl SqlStatementRewriteRule for ResolveUnqualifiedIdentifer { } } +/// Remove datafusion unsupported type annotations +#[derive(Debug)] +pub struct RemoveUnsupportedTypes { + unsupported_types: HashSet, +} + +impl RemoveUnsupportedTypes { + pub fn new() -> Self { + let mut unsupported_types = HashSet::new(); + unsupported_types.insert("regclass".to_owned()); + + Self { unsupported_types } + } + + fn rewrite_expr_unsupported_types(&self, expr: &mut Expr) { + match expr { + // This is the key part: identify constants with type annotations. + Expr::TypedString { value, data_type } => { + if self + .unsupported_types + .contains(data_type.to_string().to_lowercase().as_str()) + { + *expr = + Expr::Value(Value::SingleQuotedString(value.to_string()).with_empty_span()); + } + } + Expr::Cast { + data_type, + expr: value, + .. + } => { + if self + .unsupported_types + .contains(data_type.to_string().to_lowercase().as_str()) + { + *expr = *value.clone(); + } + } + // Handle binary operations by recursively rewriting both sides. + Expr::BinaryOp { left, right, .. } => { + self.rewrite_expr_unsupported_types(left); + self.rewrite_expr_unsupported_types(right); + } + // Add more match arms for other expression types (e.g., `Function`, `InList`) as needed. + _ => {} + } + } +} + +impl SqlStatementRewriteRule for RemoveUnsupportedTypes { + fn rewrite(&self, mut s: Statement) -> Statement { + // Traverse the AST to find the WHERE clause and rewrite it. + if let Statement::Query(query) = &mut s { + if let SetExpr::Select(select) = query.body.as_mut() { + if let Some(expr) = &mut select.selection { + self.rewrite_expr_unsupported_types(expr); + } + } + } + + s + } +} + #[cfg(test)] mod tests { use super::*; @@ -327,4 +392,28 @@ mod tests { "SELECT n.oid, n.*, d.description FROM pg_catalog.pg_namespace AS n LEFT OUTER JOIN pg_catalog.pg_description AS d ON d.objoid = n.oid AND d.objsubid = 0 AND d.classoid = 'pg_namespace' ORDER BY n.nspsname" ); } + + #[test] + fn test_remove_unsupported_types() { + let rules: Vec> = + vec![Arc::new(RemoveUnsupportedTypes::new())]; + + let sql = "SELECT n.* FROM pg_catalog.pg_namespace n WHERE n.nspname = 'pg_catalog'::regclass ORDER BY n.nspname"; + let statement = parse(sql).expect("Failed to parse").remove(0); + + let statement = rewrite(statement, &rules); + assert_eq!( + statement.to_string(), + "SELECT n.* FROM pg_catalog.pg_namespace AS n WHERE n.nspname = 'pg_catalog' ORDER BY n.nspname" + ); + + let sql = "SELECT n.* FROM pg_catalog.pg_namespace n WHERE n.nspname = 'pg_catalog' ORDER BY n.nspname"; + let statement = parse(sql).expect("Failed to parse").remove(0); + + let statement = rewrite(statement, &rules); + assert_eq!( + statement.to_string(), + "SELECT n.* FROM pg_catalog.pg_namespace AS n WHERE n.nspname = 'pg_catalog' ORDER BY n.nspname" + ); + } }