1616// under the License.
1717
1818use datafusion_common:: { Result , internal_err, tree_node:: Transformed } ;
19- use datafusion_expr:: {
20- Expr , Operator , and, binary_expr, lit, or, simplify:: SimplifyContext ,
21- } ;
19+ use datafusion_expr:: { Expr , Operator , and, lit, or, simplify:: SimplifyContext } ;
2220use datafusion_expr_common:: interval_arithmetic:: Interval ;
2321
2422/// Rewrites a binary expression using its "preimage"
@@ -46,32 +44,32 @@ pub(super) fn rewrite_with_preimage(
4644) -> Result < Transformed < Expr > > {
4745 let ( lower, upper) = preimage_interval. into_bounds ( ) ;
4846 let ( lower, upper) = ( lit ( lower) , lit ( upper) ) ;
47+ let expr = * expr;
4948
5049 let rewritten_expr = match op {
5150 // <expr> < x ==> <expr> < lower
5251 // <expr> >= x ==> <expr> >= lower
53- Operator :: Lt | Operator :: GtEq => binary_expr ( * expr, op, lower) ,
52+ Operator :: Lt => expr. lt ( lower) ,
53+ Operator :: GtEq => expr. gt_eq ( lower) ,
5454 // <expr> > x ==> <expr> >= upper
55- Operator :: Gt => binary_expr ( * expr, Operator :: GtEq , upper) ,
55+ Operator :: Gt => expr. gt_eq ( upper) ,
5656 // <expr> <= x ==> <expr> < upper
57- Operator :: LtEq => binary_expr ( * expr, Operator :: Lt , upper) ,
57+ Operator :: LtEq => expr. lt ( upper) ,
5858 // <expr> = x ==> (<expr> >= lower) and (<expr> < upper)
5959 //
6060 // <expr> is not distinct from x ==> (<expr> is NULL and x is NULL) or ((<expr> >= lower) and (<expr> < upper))
6161 // but since x is always not NULL => (<expr> >= lower) and (<expr> < upper)
62- Operator :: Eq | Operator :: IsNotDistinctFrom => and (
63- binary_expr ( * expr. clone ( ) , Operator :: GtEq , lower) ,
64- binary_expr ( * expr, Operator :: Lt , upper) ,
65- ) ,
62+ Operator :: Eq | Operator :: IsNotDistinctFrom => {
63+ and ( expr. clone ( ) . gt_eq ( lower) , expr. lt ( upper) )
64+ }
6665 // <expr> != x ==> (<expr> < lower) or (<expr> >= upper)
67- Operator :: NotEq => or (
68- binary_expr ( * expr. clone ( ) , Operator :: Lt , lower) ,
69- binary_expr ( * expr, Operator :: GtEq , upper) ,
70- ) ,
66+ Operator :: NotEq => or ( expr. clone ( ) . lt ( lower) , expr. gt_eq ( upper) ) ,
7167 // <expr> is distinct from x ==> (<expr> < lower) or (<expr> >= upper) or (<expr> is NULL and x is not NULL) or (<expr> is not NULL and x is NULL)
7268 // but given that x is always not NULL => (<expr> < lower) or (<expr> >= upper) or (<expr> is NULL)
73- Operator :: IsDistinctFrom => binary_expr ( * expr. clone ( ) , Operator :: Lt , lower)
74- . or ( binary_expr ( * expr. clone ( ) , Operator :: GtEq , upper) )
69+ Operator :: IsDistinctFrom => expr
70+ . clone ( )
71+ . lt ( lower)
72+ . or ( expr. clone ( ) . gt_eq ( upper) )
7573 . or ( expr. is_null ( ) ) ,
7674 _ => return internal_err ! ( "Expect comparison operators" ) ,
7775 } ;
@@ -86,17 +84,56 @@ mod test {
8684 use arrow:: datatypes:: { DataType , Field } ;
8785 use datafusion_common:: { DFSchema , DFSchemaRef , Result , ScalarValue } ;
8886 use datafusion_expr:: {
89- ColumnarValue , Expr , Operator , ScalarFunctionArgs , ScalarUDF , ScalarUDFImpl ,
90- Signature , Volatility , and, binary_expr, col, expr:: ScalarFunction , lit,
91- simplify:: SimplifyContext ,
87+ BinaryExpr , ColumnarValue , Expr , Operator , ScalarFunctionArgs , ScalarUDF ,
88+ ScalarUDFImpl , Signature , Volatility , and, col, lit, simplify:: SimplifyContext ,
9289 } ;
9390
9491 use super :: Interval ;
9592 use crate :: simplify_expressions:: ExprSimplifier ;
9693
94+ fn is_distinct_from ( left : Expr , right : Expr ) -> Expr {
95+ Expr :: BinaryExpr ( BinaryExpr {
96+ left : Box :: new ( left) ,
97+ op : Operator :: IsDistinctFrom ,
98+ right : Box :: new ( right) ,
99+ } )
100+ }
101+
102+ fn is_not_distinct_from ( left : Expr , right : Expr ) -> Expr {
103+ Expr :: BinaryExpr ( BinaryExpr {
104+ left : Box :: new ( left) ,
105+ op : Operator :: IsNotDistinctFrom ,
106+ right : Box :: new ( right) ,
107+ } )
108+ }
109+
97110 #[ derive( Debug , PartialEq , Eq , Hash ) ]
98111 struct PreimageUdf {
112+ /// Defaults to an exact signature with one Int32 argument and Immutable volatility
99113 signature : Signature ,
114+ /// If true, returns a preimage; otherwise, returns None
115+ enabled : bool ,
116+ }
117+
118+ impl PreimageUdf {
119+ fn new ( ) -> Self {
120+ Self {
121+ signature : Signature :: exact ( vec ! [ DataType :: Int32 ] , Volatility :: Immutable ) ,
122+ enabled : true ,
123+ }
124+ }
125+
126+ /// Set the enabled flag
127+ fn with_enabled ( mut self , enabled : bool ) -> Self {
128+ self . enabled = enabled;
129+ self
130+ }
131+
132+ /// Set the volatility
133+ fn with_volatility ( mut self , volatility : Volatility ) -> Self {
134+ self . signature . volatility = volatility;
135+ self
136+ }
100137 }
101138
102139 impl ScalarUDFImpl for PreimageUdf {
@@ -126,6 +163,9 @@ mod test {
126163 lit_expr : & Expr ,
127164 _info : & SimplifyContext ,
128165 ) -> Result < Option < Interval > > {
166+ if !self . enabled {
167+ return Ok ( None ) ;
168+ }
129169 if args. len ( ) != 1 {
130170 return Ok ( None ) ;
131171 }
@@ -146,19 +186,24 @@ mod test {
146186 }
147187
148188 fn optimize_test ( expr : Expr , schema : & DFSchemaRef ) -> Expr {
149- let simplifier = ExprSimplifier :: new (
150- SimplifyContext :: default ( ) . with_schema ( Arc :: clone ( schema) ) ,
151- ) ;
152-
153- simplifier. simplify ( expr) . unwrap ( )
189+ let simplify_context = SimplifyContext :: default ( ) . with_schema ( Arc :: clone ( schema) ) ;
190+ ExprSimplifier :: new ( simplify_context)
191+ . simplify ( expr)
192+ . unwrap ( )
154193 }
155194
156195 fn preimage_udf_expr ( ) -> Expr {
157- let udf = ScalarUDF :: new_from_impl ( PreimageUdf {
158- signature : Signature :: exact ( vec ! [ DataType :: Int32 ] , Volatility :: Immutable ) ,
159- } ) ;
196+ ScalarUDF :: new_from_impl ( PreimageUdf :: new ( ) ) . call ( vec ! [ col( "x" ) ] )
197+ }
160198
161- Expr :: ScalarFunction ( ScalarFunction :: new_udf ( Arc :: new ( udf) , vec ! [ col( "x" ) ] ) )
199+ fn non_immutable_udf_expr ( ) -> Expr {
200+ ScalarUDF :: new_from_impl ( PreimageUdf :: new ( ) . with_volatility ( Volatility :: Volatile ) )
201+ . call ( vec ! [ col( "x" ) ] )
202+ }
203+
204+ fn no_preimage_udf_expr ( ) -> Expr {
205+ ScalarUDF :: new_from_impl ( PreimageUdf :: new ( ) . with_enabled ( false ) )
206+ . call ( vec ! [ col( "x" ) ] )
162207 }
163208
164209 fn test_schema ( ) -> DFSchemaRef {
@@ -171,100 +216,150 @@ mod test {
171216 )
172217 }
173218
219+ fn test_schema_xy ( ) -> DFSchemaRef {
220+ Arc :: new (
221+ DFSchema :: from_unqualified_fields (
222+ vec ! [
223+ Field :: new( "x" , DataType :: Int32 , false ) ,
224+ Field :: new( "y" , DataType :: Int32 , false ) ,
225+ ]
226+ . into ( ) ,
227+ Default :: default ( ) ,
228+ )
229+ . unwrap ( ) ,
230+ )
231+ }
232+
174233 #[ test]
175234 fn test_preimage_eq_rewrite ( ) {
235+ // Equality rewrite when preimage and column expression are available.
176236 let schema = test_schema ( ) ;
177- let expr = binary_expr ( preimage_udf_expr ( ) , Operator :: Eq , lit ( 500 ) ) ;
178- let expected = and (
179- binary_expr ( col ( "x" ) , Operator :: GtEq , lit ( 100 ) ) ,
180- binary_expr ( col ( "x" ) , Operator :: Lt , lit ( 200 ) ) ,
181- ) ;
237+ let expr = preimage_udf_expr ( ) . eq ( lit ( 500 ) ) ;
238+ let expected = and ( col ( "x" ) . gt_eq ( lit ( 100 ) ) , col ( "x" ) . lt ( lit ( 200 ) ) ) ;
182239
183240 assert_eq ! ( optimize_test( expr, & schema) , expected) ;
184241 }
185242
186243 #[ test]
187244 fn test_preimage_noteq_rewrite ( ) {
245+ // Inequality rewrite expands to disjoint ranges.
188246 let schema = test_schema ( ) ;
189- let expr = binary_expr ( preimage_udf_expr ( ) , Operator :: NotEq , lit ( 500 ) ) ;
190- let expected = binary_expr ( col ( "x" ) , Operator :: Lt , lit ( 100 ) ) . or ( binary_expr (
191- col ( "x" ) ,
192- Operator :: GtEq ,
193- lit ( 200 ) ,
194- ) ) ;
247+ let expr = preimage_udf_expr ( ) . not_eq ( lit ( 500 ) ) ;
248+ let expected = col ( "x" ) . lt ( lit ( 100 ) ) . or ( col ( "x" ) . gt_eq ( lit ( 200 ) ) ) ;
195249
196250 assert_eq ! ( optimize_test( expr, & schema) , expected) ;
197251 }
198252
199253 #[ test]
200254 fn test_preimage_eq_rewrite_swapped ( ) {
255+ // Equality rewrite works when the literal appears on the left.
201256 let schema = test_schema ( ) ;
202- let expr = binary_expr ( lit ( 500 ) , Operator :: Eq , preimage_udf_expr ( ) ) ;
203- let expected = and (
204- binary_expr ( col ( "x" ) , Operator :: GtEq , lit ( 100 ) ) ,
205- binary_expr ( col ( "x" ) , Operator :: Lt , lit ( 200 ) ) ,
206- ) ;
257+ let expr = lit ( 500 ) . eq ( preimage_udf_expr ( ) ) ;
258+ let expected = and ( col ( "x" ) . gt_eq ( lit ( 100 ) ) , col ( "x" ) . lt ( lit ( 200 ) ) ) ;
207259
208260 assert_eq ! ( optimize_test( expr, & schema) , expected) ;
209261 }
210262
211263 #[ test]
212264 fn test_preimage_lt_rewrite ( ) {
265+ // Less-than comparison rewrites to the lower bound.
213266 let schema = test_schema ( ) ;
214- let expr = binary_expr ( preimage_udf_expr ( ) , Operator :: Lt , lit ( 500 ) ) ;
215- let expected = binary_expr ( col ( "x" ) , Operator :: Lt , lit ( 100 ) ) ;
267+ let expr = preimage_udf_expr ( ) . lt ( lit ( 500 ) ) ;
268+ let expected = col ( "x" ) . lt ( lit ( 100 ) ) ;
216269
217270 assert_eq ! ( optimize_test( expr, & schema) , expected) ;
218271 }
219272
220273 #[ test]
221274 fn test_preimage_lteq_rewrite ( ) {
275+ // Less-than-or-equal comparison rewrites to the upper bound.
222276 let schema = test_schema ( ) ;
223- let expr = binary_expr ( preimage_udf_expr ( ) , Operator :: LtEq , lit ( 500 ) ) ;
224- let expected = binary_expr ( col ( "x" ) , Operator :: Lt , lit ( 200 ) ) ;
277+ let expr = preimage_udf_expr ( ) . lt_eq ( lit ( 500 ) ) ;
278+ let expected = col ( "x" ) . lt ( lit ( 200 ) ) ;
225279
226280 assert_eq ! ( optimize_test( expr, & schema) , expected) ;
227281 }
228282
229283 #[ test]
230284 fn test_preimage_gt_rewrite ( ) {
285+ // Greater-than comparison rewrites to the upper bound (inclusive).
231286 let schema = test_schema ( ) ;
232- let expr = binary_expr ( preimage_udf_expr ( ) , Operator :: Gt , lit ( 500 ) ) ;
233- let expected = binary_expr ( col ( "x" ) , Operator :: GtEq , lit ( 200 ) ) ;
287+ let expr = preimage_udf_expr ( ) . gt ( lit ( 500 ) ) ;
288+ let expected = col ( "x" ) . gt_eq ( lit ( 200 ) ) ;
234289
235290 assert_eq ! ( optimize_test( expr, & schema) , expected) ;
236291 }
237292
238293 #[ test]
239294 fn test_preimage_gteq_rewrite ( ) {
295+ // Greater-than-or-equal comparison rewrites to the lower bound.
240296 let schema = test_schema ( ) ;
241- let expr = binary_expr ( preimage_udf_expr ( ) , Operator :: GtEq , lit ( 500 ) ) ;
242- let expected = binary_expr ( col ( "x" ) , Operator :: GtEq , lit ( 100 ) ) ;
297+ let expr = preimage_udf_expr ( ) . gt_eq ( lit ( 500 ) ) ;
298+ let expected = col ( "x" ) . gt_eq ( lit ( 100 ) ) ;
243299
244300 assert_eq ! ( optimize_test( expr, & schema) , expected) ;
245301 }
246302
247303 #[ test]
248304 fn test_preimage_is_not_distinct_from_rewrite ( ) {
305+ // IS NOT DISTINCT FROM is treated like equality for non-null literal RHS.
249306 let schema = test_schema ( ) ;
250- let expr =
251- binary_expr ( preimage_udf_expr ( ) , Operator :: IsNotDistinctFrom , lit ( 500 ) ) ;
252- let expected = and (
253- binary_expr ( col ( "x" ) , Operator :: GtEq , lit ( 100 ) ) ,
254- binary_expr ( col ( "x" ) , Operator :: Lt , lit ( 200 ) ) ,
255- ) ;
307+ let expr = is_not_distinct_from ( preimage_udf_expr ( ) , lit ( 500 ) ) ;
308+ let expected = and ( col ( "x" ) . gt_eq ( lit ( 100 ) ) , col ( "x" ) . lt ( lit ( 200 ) ) ) ;
256309
257310 assert_eq ! ( optimize_test( expr, & schema) , expected) ;
258311 }
259312
260313 #[ test]
261314 fn test_preimage_is_distinct_from_rewrite ( ) {
315+ // IS DISTINCT FROM adds an explicit NULL branch for the column.
262316 let schema = test_schema ( ) ;
263- let expr = binary_expr ( preimage_udf_expr ( ) , Operator :: IsDistinctFrom , lit ( 500 ) ) ;
264- let expected = binary_expr ( col ( "x" ) , Operator :: Lt , lit ( 100 ) )
265- . or ( binary_expr ( col ( "x" ) , Operator :: GtEq , lit ( 200 ) ) )
317+ let expr = is_distinct_from ( preimage_udf_expr ( ) , lit ( 500 ) ) ;
318+ let expected = col ( "x" )
319+ . lt ( lit ( 100 ) )
320+ . or ( col ( "x" ) . gt_eq ( lit ( 200 ) ) )
266321 . or ( col ( "x" ) . is_null ( ) ) ;
267322
268323 assert_eq ! ( optimize_test( expr, & schema) , expected) ;
269324 }
325+
326+ #[ test]
327+ fn test_preimage_non_literal_rhs_no_rewrite ( ) {
328+ // Non-literal RHS should not be rewritten.
329+ let schema = test_schema_xy ( ) ;
330+ let expr = preimage_udf_expr ( ) . eq ( col ( "y" ) ) ;
331+ let expected = expr. clone ( ) ;
332+
333+ assert_eq ! ( optimize_test( expr, & schema) , expected) ;
334+ }
335+
336+ #[ test]
337+ fn test_preimage_null_literal_no_rewrite ( ) {
338+ // NULL literal RHS should not be rewritten.
339+ let schema = test_schema ( ) ;
340+ let expr = preimage_udf_expr ( ) . eq ( lit ( ScalarValue :: Int32 ( None ) ) ) ;
341+ let expected = expr. clone ( ) ;
342+
343+ assert_eq ! ( optimize_test( expr, & schema) , expected) ;
344+ }
345+
346+ #[ test]
347+ fn test_preimage_non_immutable_no_rewrite ( ) {
348+ // Non-immutable UDFs should not participate in preimage rewrites.
349+ let schema = test_schema ( ) ;
350+ let expr = non_immutable_udf_expr ( ) . eq ( lit ( 500 ) ) ;
351+ let expected = expr. clone ( ) ;
352+
353+ assert_eq ! ( optimize_test( expr, & schema) , expected) ;
354+ }
355+
356+ #[ test]
357+ fn test_preimage_no_preimage_no_rewrite ( ) {
358+ // If the UDF provides no preimage, the expression should remain unchanged.
359+ let schema = test_schema ( ) ;
360+ let expr = no_preimage_udf_expr ( ) . eq ( lit ( 500 ) ) ;
361+ let expected = expr. clone ( ) ;
362+
363+ assert_eq ! ( optimize_test( expr, & schema) , expected) ;
364+ }
270365}
0 commit comments