Skip to content

Commit 08bdc60

Browse files
committed
feat: add rewrite rule to transform any operation
1 parent 3cb9adc commit 08bdc60

File tree

2 files changed

+99
-1
lines changed

2 files changed

+99
-1
lines changed

datafusion-postgres/src/handlers.rs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ use std::sync::Arc;
44
use crate::auth::{AuthManager, Permission, ResourceType};
55
use crate::sql::{
66
parse, rewrite, AliasDuplicatedProjectionRewrite, RemoveUnsupportedTypes,
7-
ResolveUnqualifiedIdentifer, SqlStatementRewriteRule,
7+
ResolveUnqualifiedIdentifer, RewriteArrayAnyOperation, SqlStatementRewriteRule,
88
};
99
use async_trait::async_trait;
1010
use datafusion::arrow::datatypes::DataType;
@@ -81,6 +81,7 @@ impl DfSessionService {
8181
Arc::new(AliasDuplicatedProjectionRewrite),
8282
Arc::new(ResolveUnqualifiedIdentifer),
8383
Arc::new(RemoveUnsupportedTypes::new()),
84+
Arc::new(RewriteArrayAnyOperation),
8485
];
8586
let parser = Arc::new(Parser {
8687
session_context: session_context.clone(),

datafusion-postgres/src/sql.rs

Lines changed: 97 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,15 @@ use std::collections::HashSet;
22
use std::ops::ControlFlow;
33
use std::sync::Arc;
44

5+
use datafusion::sql::sqlparser::ast::BinaryOperator;
56
use datafusion::sql::sqlparser::ast::Expr;
7+
use datafusion::sql::sqlparser::ast::Function;
8+
use datafusion::sql::sqlparser::ast::FunctionArg;
9+
use datafusion::sql::sqlparser::ast::FunctionArgExpr;
10+
use datafusion::sql::sqlparser::ast::FunctionArgumentList;
11+
use datafusion::sql::sqlparser::ast::FunctionArguments;
612
use datafusion::sql::sqlparser::ast::Ident;
13+
use datafusion::sql::sqlparser::ast::ObjectName;
714
use datafusion::sql::sqlparser::ast::OrderByKind;
815
use datafusion::sql::sqlparser::ast::Query;
916
use datafusion::sql::sqlparser::ast::Select;
@@ -13,6 +20,7 @@ use datafusion::sql::sqlparser::ast::SetExpr;
1320
use datafusion::sql::sqlparser::ast::Statement;
1421
use datafusion::sql::sqlparser::ast::TableFactor;
1522
use datafusion::sql::sqlparser::ast::TableWithJoins;
23+
use datafusion::sql::sqlparser::ast::UnaryOperator;
1624
use datafusion::sql::sqlparser::ast::Value;
1725
use datafusion::sql::sqlparser::ast::VisitMut;
1826
use datafusion::sql::sqlparser::ast::VisitorMut;
@@ -327,6 +335,72 @@ impl SqlStatementRewriteRule for RemoveUnsupportedTypes {
327335
}
328336
}
329337

338+
/// Rewrite Postgres's ANY operator to array_contains
339+
#[derive(Debug)]
340+
pub struct RewriteArrayAnyOperation;
341+
342+
struct RewriteArrayAnyOperationVisitor;
343+
344+
impl RewriteArrayAnyOperationVisitor {
345+
fn any_to_array_cofntains(&self, left: &Expr, right: &Expr) -> Expr {
346+
Expr::Function(Function {
347+
name: ObjectName::from(vec![Ident::new("array_contains")]),
348+
args: FunctionArguments::List(FunctionArgumentList {
349+
args: vec![
350+
FunctionArg::Unnamed(FunctionArgExpr::Expr(right.clone())),
351+
FunctionArg::Unnamed(FunctionArgExpr::Expr(left.clone())),
352+
],
353+
duplicate_treatment: None,
354+
clauses: vec![],
355+
}),
356+
uses_odbc_syntax: false,
357+
parameters: FunctionArguments::None,
358+
filter: None,
359+
null_treatment: None,
360+
over: None,
361+
within_group: vec![],
362+
})
363+
}
364+
}
365+
366+
impl VisitorMut for RewriteArrayAnyOperationVisitor {
367+
type Break = ();
368+
369+
fn pre_visit_expr(&mut self, expr: &mut Expr) -> ControlFlow<Self::Break> {
370+
if let Expr::AnyOp {
371+
left,
372+
compare_op,
373+
right,
374+
..
375+
} = expr
376+
{
377+
match compare_op {
378+
BinaryOperator::Eq => {
379+
*expr = self.any_to_array_cofntains(left.as_ref(), right.as_ref());
380+
}
381+
BinaryOperator::NotEq => {
382+
*expr = Expr::UnaryOp {
383+
op: UnaryOperator::Not,
384+
expr: Box::new(self.any_to_array_cofntains(left.as_ref(), right.as_ref())),
385+
}
386+
}
387+
_ => {}
388+
}
389+
}
390+
ControlFlow::Continue(())
391+
}
392+
}
393+
394+
impl SqlStatementRewriteRule for RewriteArrayAnyOperation {
395+
fn rewrite(&self, mut s: Statement) -> Statement {
396+
let mut visitor = RewriteArrayAnyOperationVisitor;
397+
398+
let _ = s.visit(&mut visitor);
399+
400+
s
401+
}
402+
}
403+
330404
#[cfg(test)]
331405
mod tests {
332406
use super::*;
@@ -427,4 +501,27 @@ mod tests {
427501
"SELECT n.* FROM pg_catalog.pg_namespace AS n WHERE n.nspname = 'pg_catalog' ORDER BY n.nspname"
428502
);
429503
}
504+
505+
#[test]
506+
fn test_any_to_array_contains() {
507+
let rules: Vec<Arc<dyn SqlStatementRewriteRule>> = vec![Arc::new(RewriteArrayAnyOperation)];
508+
509+
assert_rewrite!(
510+
&rules,
511+
"SELECT a = ANY(current_schemas(true))",
512+
"SELECT array_contains(current_schemas(true), a)"
513+
);
514+
515+
assert_rewrite!(
516+
&rules,
517+
"SELECT a != ANY(current_schemas(true))",
518+
"SELECT NOT array_contains(current_schemas(true), a)"
519+
);
520+
521+
assert_rewrite!(
522+
&rules,
523+
"SELECT a FROM tbl WHERE a = ANY(current_schemas(true))",
524+
"SELECT a FROM tbl WHERE array_contains(current_schemas(true), a)"
525+
);
526+
}
430527
}

0 commit comments

Comments
 (0)