@@ -2,8 +2,15 @@ use std::collections::HashSet;
22use std:: ops:: ControlFlow ;
33use std:: sync:: Arc ;
44
5+ use datafusion:: sql:: sqlparser:: ast:: BinaryOperator ;
56use datafusion:: sql:: sqlparser:: ast:: Expr ;
7+ use datafusion:: sql:: sqlparser:: ast:: Function ;
8+ use datafusion:: sql:: sqlparser:: ast:: FunctionArg ;
9+ use datafusion:: sql:: sqlparser:: ast:: FunctionArgExpr ;
10+ use datafusion:: sql:: sqlparser:: ast:: FunctionArgumentList ;
11+ use datafusion:: sql:: sqlparser:: ast:: FunctionArguments ;
612use datafusion:: sql:: sqlparser:: ast:: Ident ;
13+ use datafusion:: sql:: sqlparser:: ast:: ObjectName ;
714use datafusion:: sql:: sqlparser:: ast:: OrderByKind ;
815use datafusion:: sql:: sqlparser:: ast:: Query ;
916use datafusion:: sql:: sqlparser:: ast:: Select ;
@@ -13,6 +20,7 @@ use datafusion::sql::sqlparser::ast::SetExpr;
1320use datafusion:: sql:: sqlparser:: ast:: Statement ;
1421use datafusion:: sql:: sqlparser:: ast:: TableFactor ;
1522use datafusion:: sql:: sqlparser:: ast:: TableWithJoins ;
23+ use datafusion:: sql:: sqlparser:: ast:: UnaryOperator ;
1624use datafusion:: sql:: sqlparser:: ast:: Value ;
1725use datafusion:: sql:: sqlparser:: ast:: VisitMut ;
1826use datafusion:: sql:: sqlparser:: ast:: VisitorMut ;
@@ -327,6 +335,72 @@ impl SqlStatementRewriteRule for RemoveUnsupportedTypes {
327335 }
328336}
329337
338+ /// Rewrite Postgres's ANY operator to array_contains
339+ #[ derive( Debug ) ]
340+ pub struct RewriteArrayAnyOperation ;
341+
342+ struct RewriteArrayAnyOperationVisitor ;
343+
344+ impl RewriteArrayAnyOperationVisitor {
345+ fn any_to_array_cofntains ( & self , left : & Expr , right : & Expr ) -> Expr {
346+ Expr :: Function ( Function {
347+ name : ObjectName :: from ( vec ! [ Ident :: new( "array_contains" ) ] ) ,
348+ args : FunctionArguments :: List ( FunctionArgumentList {
349+ args : vec ! [
350+ FunctionArg :: Unnamed ( FunctionArgExpr :: Expr ( right. clone( ) ) ) ,
351+ FunctionArg :: Unnamed ( FunctionArgExpr :: Expr ( left. clone( ) ) ) ,
352+ ] ,
353+ duplicate_treatment : None ,
354+ clauses : vec ! [ ] ,
355+ } ) ,
356+ uses_odbc_syntax : false ,
357+ parameters : FunctionArguments :: None ,
358+ filter : None ,
359+ null_treatment : None ,
360+ over : None ,
361+ within_group : vec ! [ ] ,
362+ } )
363+ }
364+ }
365+
366+ impl VisitorMut for RewriteArrayAnyOperationVisitor {
367+ type Break = ( ) ;
368+
369+ fn pre_visit_expr ( & mut self , expr : & mut Expr ) -> ControlFlow < Self :: Break > {
370+ if let Expr :: AnyOp {
371+ left,
372+ compare_op,
373+ right,
374+ ..
375+ } = expr
376+ {
377+ match compare_op {
378+ BinaryOperator :: Eq => {
379+ * expr = self . any_to_array_cofntains ( left. as_ref ( ) , right. as_ref ( ) ) ;
380+ }
381+ BinaryOperator :: NotEq => {
382+ * expr = Expr :: UnaryOp {
383+ op : UnaryOperator :: Not ,
384+ expr : Box :: new ( self . any_to_array_cofntains ( left. as_ref ( ) , right. as_ref ( ) ) ) ,
385+ }
386+ }
387+ _ => { }
388+ }
389+ }
390+ ControlFlow :: Continue ( ( ) )
391+ }
392+ }
393+
394+ impl SqlStatementRewriteRule for RewriteArrayAnyOperation {
395+ fn rewrite ( & self , mut s : Statement ) -> Statement {
396+ let mut visitor = RewriteArrayAnyOperationVisitor ;
397+
398+ let _ = s. visit ( & mut visitor) ;
399+
400+ s
401+ }
402+ }
403+
330404#[ cfg( test) ]
331405mod tests {
332406 use super :: * ;
@@ -427,4 +501,27 @@ mod tests {
427501 "SELECT n.* FROM pg_catalog.pg_namespace AS n WHERE n.nspname = 'pg_catalog' ORDER BY n.nspname"
428502 ) ;
429503 }
504+
505+ #[ test]
506+ fn test_any_to_array_contains ( ) {
507+ let rules: Vec < Arc < dyn SqlStatementRewriteRule > > = vec ! [ Arc :: new( RewriteArrayAnyOperation ) ] ;
508+
509+ assert_rewrite ! (
510+ & rules,
511+ "SELECT a = ANY(current_schemas(true))" ,
512+ "SELECT array_contains(current_schemas(true), a)"
513+ ) ;
514+
515+ assert_rewrite ! (
516+ & rules,
517+ "SELECT a != ANY(current_schemas(true))" ,
518+ "SELECT NOT array_contains(current_schemas(true), a)"
519+ ) ;
520+
521+ assert_rewrite ! (
522+ & rules,
523+ "SELECT a FROM tbl WHERE a = ANY(current_schemas(true))" ,
524+ "SELECT a FROM tbl WHERE array_contains(current_schemas(true), a)"
525+ ) ;
526+ }
430527}
0 commit comments