Skip to content

Commit e5644ec

Browse files
authored
feat: rewrite to remove unsupported type cast (#145)
1 parent daadf6f commit e5644ec

File tree

2 files changed

+98
-3
lines changed

2 files changed

+98
-3
lines changed

datafusion-postgres/src/handlers.rs

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,10 @@ use std::collections::HashMap;
22
use std::sync::Arc;
33

44
use crate::auth::{AuthManager, Permission, ResourceType};
5-
use crate::sql::{parse, rewrite, AliasDuplicatedProjectionRewrite, SqlStatementRewriteRule};
5+
use crate::sql::{
6+
parse, rewrite, AliasDuplicatedProjectionRewrite, RemoveUnsupportedTypes,
7+
ResolveUnqualifiedIdentifer, SqlStatementRewriteRule,
8+
};
69
use async_trait::async_trait;
710
use datafusion::arrow::datatypes::DataType;
811
use datafusion::logical_expr::LogicalPlan;
@@ -73,8 +76,11 @@ impl DfSessionService {
7376
session_context: Arc<SessionContext>,
7477
auth_manager: Arc<AuthManager>,
7578
) -> DfSessionService {
76-
let sql_rewrite_rules: Vec<Arc<dyn SqlStatementRewriteRule>> =
77-
vec![Arc::new(AliasDuplicatedProjectionRewrite)];
79+
let sql_rewrite_rules: Vec<Arc<dyn SqlStatementRewriteRule>> = vec![
80+
Arc::new(AliasDuplicatedProjectionRewrite),
81+
Arc::new(ResolveUnqualifiedIdentifer),
82+
Arc::new(RemoveUnsupportedTypes::new()),
83+
];
7884
let parser = Arc::new(Parser {
7985
session_context: session_context.clone(),
8086
sql_rewrite_rules: sql_rewrite_rules.clone(),

datafusion-postgres/src/sql.rs

Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ use datafusion::sql::sqlparser::ast::SetExpr;
1212
use datafusion::sql::sqlparser::ast::Statement;
1313
use datafusion::sql::sqlparser::ast::TableFactor;
1414
use datafusion::sql::sqlparser::ast::TableWithJoins;
15+
use datafusion::sql::sqlparser::ast::Value;
1516
use datafusion::sql::sqlparser::dialect::PostgreSqlDialect;
1617
use datafusion::sql::sqlparser::parser::Parser;
1718
use datafusion::sql::sqlparser::parser::ParserError;
@@ -258,6 +259,70 @@ impl SqlStatementRewriteRule for ResolveUnqualifiedIdentifer {
258259
}
259260
}
260261

262+
/// Remove datafusion unsupported type annotations
263+
#[derive(Debug)]
264+
pub struct RemoveUnsupportedTypes {
265+
unsupported_types: HashSet<String>,
266+
}
267+
268+
impl RemoveUnsupportedTypes {
269+
pub fn new() -> Self {
270+
let mut unsupported_types = HashSet::new();
271+
unsupported_types.insert("regclass".to_owned());
272+
273+
Self { unsupported_types }
274+
}
275+
276+
fn rewrite_expr_unsupported_types(&self, expr: &mut Expr) {
277+
match expr {
278+
// This is the key part: identify constants with type annotations.
279+
Expr::TypedString { value, data_type } => {
280+
if self
281+
.unsupported_types
282+
.contains(data_type.to_string().to_lowercase().as_str())
283+
{
284+
*expr =
285+
Expr::Value(Value::SingleQuotedString(value.to_string()).with_empty_span());
286+
}
287+
}
288+
Expr::Cast {
289+
data_type,
290+
expr: value,
291+
..
292+
} => {
293+
if self
294+
.unsupported_types
295+
.contains(data_type.to_string().to_lowercase().as_str())
296+
{
297+
*expr = *value.clone();
298+
}
299+
}
300+
// Handle binary operations by recursively rewriting both sides.
301+
Expr::BinaryOp { left, right, .. } => {
302+
self.rewrite_expr_unsupported_types(left);
303+
self.rewrite_expr_unsupported_types(right);
304+
}
305+
// Add more match arms for other expression types (e.g., `Function`, `InList`) as needed.
306+
_ => {}
307+
}
308+
}
309+
}
310+
311+
impl SqlStatementRewriteRule for RemoveUnsupportedTypes {
312+
fn rewrite(&self, mut s: Statement) -> Statement {
313+
// Traverse the AST to find the WHERE clause and rewrite it.
314+
if let Statement::Query(query) = &mut s {
315+
if let SetExpr::Select(select) = query.body.as_mut() {
316+
if let Some(expr) = &mut select.selection {
317+
self.rewrite_expr_unsupported_types(expr);
318+
}
319+
}
320+
}
321+
322+
s
323+
}
324+
}
325+
261326
#[cfg(test)]
262327
mod tests {
263328
use super::*;
@@ -327,4 +392,28 @@ mod tests {
327392
"SELECT n.oid, n.*, d.description FROM pg_catalog.pg_namespace AS n LEFT OUTER JOIN pg_catalog.pg_description AS d ON d.objoid = n.oid AND d.objsubid = 0 AND d.classoid = 'pg_namespace' ORDER BY n.nspsname"
328393
);
329394
}
395+
396+
#[test]
397+
fn test_remove_unsupported_types() {
398+
let rules: Vec<Arc<dyn SqlStatementRewriteRule>> =
399+
vec![Arc::new(RemoveUnsupportedTypes::new())];
400+
401+
let sql = "SELECT n.* FROM pg_catalog.pg_namespace n WHERE n.nspname = 'pg_catalog'::regclass ORDER BY n.nspname";
402+
let statement = parse(sql).expect("Failed to parse").remove(0);
403+
404+
let statement = rewrite(statement, &rules);
405+
assert_eq!(
406+
statement.to_string(),
407+
"SELECT n.* FROM pg_catalog.pg_namespace AS n WHERE n.nspname = 'pg_catalog' ORDER BY n.nspname"
408+
);
409+
410+
let sql = "SELECT n.* FROM pg_catalog.pg_namespace n WHERE n.nspname = 'pg_catalog' ORDER BY n.nspname";
411+
let statement = parse(sql).expect("Failed to parse").remove(0);
412+
413+
let statement = rewrite(statement, &rules);
414+
assert_eq!(
415+
statement.to_string(),
416+
"SELECT n.* FROM pg_catalog.pg_namespace AS n WHERE n.nspname = 'pg_catalog' ORDER BY n.nspname"
417+
);
418+
}
330419
}

0 commit comments

Comments
 (0)