@@ -246,3 +246,119 @@ fn calculate_selectivity(
246246 acc * cardinality_ratio ( & initial. interval , & target. interval )
247247 } )
248248}
249+
250+ #[ cfg( test) ]
251+ mod tests {
252+ use std:: sync:: Arc ;
253+
254+ use arrow_schema:: { DataType , Field , Schema } ;
255+ use datafusion_common:: { assert_contains, DFSchema } ;
256+ use datafusion_expr:: {
257+ col, execution_props:: ExecutionProps , interval_arithmetic:: Interval , lit, Expr ,
258+ } ;
259+
260+ use crate :: { create_physical_expr, AnalysisContext } ;
261+
262+ use super :: { analyze, ExprBoundaries } ;
263+
264+ fn make_field ( name : & str , data_type : DataType ) -> Field {
265+ let nullable = false ;
266+ Field :: new ( name, data_type, nullable)
267+ }
268+
269+ #[ test]
270+ fn test_analyze_boundary_exprs ( ) {
271+ let schema = Arc :: new ( Schema :: new ( vec ! [ make_field( "a" , DataType :: Int32 ) ] ) ) ;
272+
273+ /// Test case containing (expression tree, lower bound, upper bound)
274+ type TestCase = ( Expr , Option < i32 > , Option < i32 > ) ;
275+
276+ let test_cases: Vec < TestCase > = vec ! [
277+ // a > 10
278+ ( col( "a" ) . gt( lit( 10 ) ) , Some ( 11 ) , None ) ,
279+ // a < 20
280+ ( col( "a" ) . lt( lit( 20 ) ) , None , Some ( 19 ) ) ,
281+ // a > 10 AND a < 20
282+ (
283+ col( "a" ) . gt( lit( 10 ) ) . and( col( "a" ) . lt( lit( 20 ) ) ) ,
284+ Some ( 11 ) ,
285+ Some ( 19 ) ,
286+ ) ,
287+ // a >= 10
288+ ( col( "a" ) . gt_eq( lit( 10 ) ) , Some ( 10 ) , None ) ,
289+ // a <= 20
290+ ( col( "a" ) . lt_eq( lit( 20 ) ) , None , Some ( 20 ) ) ,
291+ // a >= 10 AND a <= 20
292+ (
293+ col( "a" ) . gt_eq( lit( 10 ) ) . and( col( "a" ) . lt_eq( lit( 20 ) ) ) ,
294+ Some ( 10 ) ,
295+ Some ( 20 ) ,
296+ ) ,
297+ // a > 10 AND a < 20 AND a < 15
298+ (
299+ col( "a" )
300+ . gt( lit( 10 ) )
301+ . and( col( "a" ) . lt( lit( 20 ) ) )
302+ . and( col( "a" ) . lt( lit( 15 ) ) ) ,
303+ Some ( 11 ) ,
304+ Some ( 14 ) ,
305+ ) ,
306+ // (a > 10 AND a < 20) AND (a > 15 AND a < 25)
307+ (
308+ col( "a" )
309+ . gt( lit( 10 ) )
310+ . and( col( "a" ) . lt( lit( 20 ) ) )
311+ . and( col( "a" ) . gt( lit( 15 ) ) )
312+ . and( col( "a" ) . lt( lit( 25 ) ) ) ,
313+ Some ( 16 ) ,
314+ Some ( 19 ) ,
315+ ) ,
316+ // (a > 10 AND a < 20) AND (a > 20 AND a < 30)
317+ (
318+ col( "a" )
319+ . gt( lit( 10 ) )
320+ . and( col( "a" ) . lt( lit( 20 ) ) )
321+ . and( col( "a" ) . gt( lit( 20 ) ) )
322+ . and( col( "a" ) . lt( lit( 30 ) ) ) ,
323+ None ,
324+ None ,
325+ ) ,
326+ ] ;
327+ for ( expr, lower, upper) in test_cases {
328+ let boundaries = ExprBoundaries :: try_new_unbounded ( & schema) . unwrap ( ) ;
329+ let df_schema = DFSchema :: try_from ( Arc :: clone ( & schema) ) . unwrap ( ) ;
330+ let physical_expr =
331+ create_physical_expr ( & expr, & df_schema, & ExecutionProps :: new ( ) ) . unwrap ( ) ;
332+ let analysis_result = analyze (
333+ & physical_expr,
334+ AnalysisContext :: new ( boundaries) ,
335+ df_schema. as_ref ( ) ,
336+ )
337+ . unwrap ( ) ;
338+ let actual = & analysis_result. boundaries [ 0 ] . interval ;
339+ let expected = Interval :: make ( lower, upper) . unwrap ( ) ;
340+ assert_eq ! (
341+ & expected, actual,
342+ "did not get correct interval for SQL expression: {expr:?}"
343+ ) ;
344+ }
345+ }
346+
347+ #[ test]
348+ fn test_analyze_invalid_boundary_exprs ( ) {
349+ let schema = Arc :: new ( Schema :: new ( vec ! [ make_field( "a" , DataType :: Int32 ) ] ) ) ;
350+ let expr = col ( "a" ) . lt ( lit ( 10 ) ) . or ( col ( "a" ) . gt ( lit ( 20 ) ) ) ;
351+ let expected_error = "Interval arithmetic does not support the operator OR" ;
352+ let boundaries = ExprBoundaries :: try_new_unbounded ( & schema) . unwrap ( ) ;
353+ let df_schema = DFSchema :: try_from ( Arc :: clone ( & schema) ) . unwrap ( ) ;
354+ let physical_expr =
355+ create_physical_expr ( & expr, & df_schema, & ExecutionProps :: new ( ) ) . unwrap ( ) ;
356+ let analysis_error = analyze (
357+ & physical_expr,
358+ AnalysisContext :: new ( boundaries) ,
359+ df_schema. as_ref ( ) ,
360+ )
361+ . unwrap_err ( ) ;
362+ assert_contains ! ( analysis_error. to_string( ) , expected_error) ;
363+ }
364+ }
0 commit comments