@@ -25,12 +25,15 @@ use arrow::datatypes::{
2525} ;
2626use datafusion_common:: { Result , ScalarValue , exec_err} ;
2727use datafusion_expr:: interval_arithmetic:: Interval ;
28+ use datafusion_expr:: preimage:: PreimageResult ;
29+ use datafusion_expr:: simplify:: SimplifyContext ;
2830use datafusion_expr:: sort_properties:: { ExprProperties , SortProperties } ;
2931use datafusion_expr:: {
30- Coercion , ColumnarValue , Documentation , ScalarFunctionArgs , ScalarUDFImpl , Signature ,
31- TypeSignature , TypeSignatureClass , Volatility ,
32+ Coercion , ColumnarValue , Documentation , Expr , ScalarFunctionArgs , ScalarUDFImpl ,
33+ Signature , TypeSignature , TypeSignatureClass , Volatility ,
3234} ;
3335use datafusion_macros:: user_doc;
36+ use num_traits:: { CheckedAdd , Float , One } ;
3437
3538use super :: decimal:: { apply_decimal_op, floor_decimal_value} ;
3639
@@ -200,7 +203,231 @@ impl ScalarUDFImpl for FloorFunc {
200203 Interval :: make_unbounded ( & data_type)
201204 }
202205
206+ /// Compute the preimage for floor function.
207+ ///
208+ /// For `floor(x) = N`, the preimage is `x >= N AND x < N + 1`
209+ /// because floor(x) = N for all x in [N, N+1).
210+ ///
211+ /// This enables predicate pushdown optimizations, transforming:
212+ /// `floor(col) = 100` into `col >= 100 AND col < 101`
213+ fn preimage (
214+ & self ,
215+ args : & [ Expr ] ,
216+ lit_expr : & Expr ,
217+ _info : & SimplifyContext ,
218+ ) -> Result < PreimageResult > {
219+ // floor takes exactly one argument
220+ if args. len ( ) != 1 {
221+ return Ok ( PreimageResult :: None ) ;
222+ }
223+
224+ let arg = args[ 0 ] . clone ( ) ;
225+
226+ // Extract the literal value being compared to
227+ let Expr :: Literal ( lit_value, _) = lit_expr else {
228+ return Ok ( PreimageResult :: None ) ;
229+ } ;
230+
231+ // Compute lower bound (N) and upper bound (N + 1) using helper functions
232+ let Some ( ( lower, upper) ) = ( match lit_value {
233+ // Floating-point types
234+ ScalarValue :: Float64 ( Some ( n) ) => float_preimage_bounds ( * n)
235+ . map ( |( lo, hi) | ( ScalarValue :: Float64 ( Some ( lo) ) , ScalarValue :: Float64 ( Some ( hi) ) ) ) ,
236+ ScalarValue :: Float32 ( Some ( n) ) => float_preimage_bounds ( * n)
237+ . map ( |( lo, hi) | ( ScalarValue :: Float32 ( Some ( lo) ) , ScalarValue :: Float32 ( Some ( hi) ) ) ) ,
238+
239+ // Integer types
240+ ScalarValue :: Int8 ( Some ( n) ) => int_preimage_bounds ( * n)
241+ . map ( |( lo, hi) | ( ScalarValue :: Int8 ( Some ( lo) ) , ScalarValue :: Int8 ( Some ( hi) ) ) ) ,
242+ ScalarValue :: Int16 ( Some ( n) ) => int_preimage_bounds ( * n)
243+ . map ( |( lo, hi) | ( ScalarValue :: Int16 ( Some ( lo) ) , ScalarValue :: Int16 ( Some ( hi) ) ) ) ,
244+ ScalarValue :: Int32 ( Some ( n) ) => int_preimage_bounds ( * n)
245+ . map ( |( lo, hi) | ( ScalarValue :: Int32 ( Some ( lo) ) , ScalarValue :: Int32 ( Some ( hi) ) ) ) ,
246+ ScalarValue :: Int64 ( Some ( n) ) => int_preimage_bounds ( * n)
247+ . map ( |( lo, hi) | ( ScalarValue :: Int64 ( Some ( lo) ) , ScalarValue :: Int64 ( Some ( hi) ) ) ) ,
248+
249+ // Unsupported types
250+ _ => None ,
251+ } ) else {
252+ return Ok ( PreimageResult :: None ) ;
253+ } ;
254+
255+ Ok ( PreimageResult :: Range {
256+ expr : arg,
257+ interval : Box :: new ( Interval :: try_new ( lower, upper) ?) ,
258+ } )
259+ }
260+
203261 fn documentation ( & self ) -> Option < & Documentation > {
204262 self . doc ( )
205263 }
206264}
265+
266+ // ============ Helper functions for preimage bounds ============
267+
268+ /// Compute preimage bounds for floor function on floating-point types.
269+ /// For floor(x) = n, the preimage is [n, n+1).
270+ /// Returns None if the value is non-finite or would lose precision.
271+ fn float_preimage_bounds < F : Float > ( n : F ) -> Option < ( F , F ) > {
272+ let one = F :: one ( ) ;
273+ // Check for non-finite values (infinity, NaN) or precision loss at extreme values
274+ if !n. is_finite ( ) || n + one <= n {
275+ return None ;
276+ }
277+ Some ( ( n, n + one) )
278+ }
279+
280+ /// Compute preimage bounds for floor function on integer types.
281+ /// For floor(x) = n, the preimage is [n, n+1).
282+ /// Returns None if adding 1 would overflow.
283+ fn int_preimage_bounds < I : CheckedAdd + One + Copy > ( n : I ) -> Option < ( I , I ) > {
284+ let upper = n. checked_add ( & I :: one ( ) ) ?;
285+ Some ( ( n, upper) )
286+ }
287+
288+ #[ cfg( test) ]
289+ mod tests {
290+ use super :: * ;
291+ use datafusion_expr:: col;
292+
293+ /// Helper to test valid preimage cases that should return a Range
294+ fn assert_preimage_range (
295+ input : ScalarValue ,
296+ expected_lower : ScalarValue ,
297+ expected_upper : ScalarValue ,
298+ ) {
299+ let floor_func = FloorFunc :: new ( ) ;
300+ let args = vec ! [ col( "x" ) ] ;
301+ let lit_expr = Expr :: Literal ( input. clone ( ) , None ) ;
302+ let info = SimplifyContext :: default ( ) ;
303+
304+ let result = floor_func. preimage ( & args, & lit_expr, & info) . unwrap ( ) ;
305+
306+ match result {
307+ PreimageResult :: Range { expr, interval } => {
308+ assert_eq ! ( expr, col( "x" ) ) ;
309+ assert_eq ! ( interval. lower( ) . clone( ) , expected_lower) ;
310+ assert_eq ! ( interval. upper( ) . clone( ) , expected_upper) ;
311+ }
312+ PreimageResult :: None => {
313+ panic ! ( "Expected Range, got None for input {:?}" , input)
314+ }
315+ }
316+ }
317+
318+ /// Helper to test cases that should return None
319+ fn assert_preimage_none ( input : ScalarValue ) {
320+ let floor_func = FloorFunc :: new ( ) ;
321+ let args = vec ! [ col( "x" ) ] ;
322+ let lit_expr = Expr :: Literal ( input. clone ( ) , None ) ;
323+ let info = SimplifyContext :: default ( ) ;
324+
325+ let result = floor_func. preimage ( & args, & lit_expr, & info) . unwrap ( ) ;
326+ assert ! (
327+ matches!( result, PreimageResult :: None ) ,
328+ "Expected None for input {:?}" ,
329+ input
330+ ) ;
331+ }
332+
333+ #[ test]
334+ fn test_floor_preimage_valid_cases ( ) {
335+ // Float64
336+ assert_preimage_range (
337+ ScalarValue :: Float64 ( Some ( 100.0 ) ) ,
338+ ScalarValue :: Float64 ( Some ( 100.0 ) ) ,
339+ ScalarValue :: Float64 ( Some ( 101.0 ) ) ,
340+ ) ;
341+ // Float32
342+ assert_preimage_range (
343+ ScalarValue :: Float32 ( Some ( 50.0 ) ) ,
344+ ScalarValue :: Float32 ( Some ( 50.0 ) ) ,
345+ ScalarValue :: Float32 ( Some ( 51.0 ) ) ,
346+ ) ;
347+ // Int64
348+ assert_preimage_range (
349+ ScalarValue :: Int64 ( Some ( 42 ) ) ,
350+ ScalarValue :: Int64 ( Some ( 42 ) ) ,
351+ ScalarValue :: Int64 ( Some ( 43 ) ) ,
352+ ) ;
353+ // Int32
354+ assert_preimage_range (
355+ ScalarValue :: Int32 ( Some ( 100 ) ) ,
356+ ScalarValue :: Int32 ( Some ( 100 ) ) ,
357+ ScalarValue :: Int32 ( Some ( 101 ) ) ,
358+ ) ;
359+ // Negative values
360+ assert_preimage_range (
361+ ScalarValue :: Float64 ( Some ( -5.0 ) ) ,
362+ ScalarValue :: Float64 ( Some ( -5.0 ) ) ,
363+ ScalarValue :: Float64 ( Some ( -4.0 ) ) ,
364+ ) ;
365+ // Zero
366+ assert_preimage_range (
367+ ScalarValue :: Float64 ( Some ( 0.0 ) ) ,
368+ ScalarValue :: Float64 ( Some ( 0.0 ) ) ,
369+ ScalarValue :: Float64 ( Some ( 1.0 ) ) ,
370+ ) ;
371+ }
372+
373+ #[ test]
374+ fn test_floor_preimage_integer_overflow ( ) {
375+ // All integer types at MAX value should return None
376+ assert_preimage_none ( ScalarValue :: Int64 ( Some ( i64:: MAX ) ) ) ;
377+ assert_preimage_none ( ScalarValue :: Int32 ( Some ( i32:: MAX ) ) ) ;
378+ assert_preimage_none ( ScalarValue :: Int16 ( Some ( i16:: MAX ) ) ) ;
379+ assert_preimage_none ( ScalarValue :: Int8 ( Some ( i8:: MAX ) ) ) ;
380+ }
381+
382+ #[ test]
383+ fn test_floor_preimage_float_edge_cases ( ) {
384+ // Float64 edge cases
385+ assert_preimage_none ( ScalarValue :: Float64 ( Some ( f64:: INFINITY ) ) ) ;
386+ assert_preimage_none ( ScalarValue :: Float64 ( Some ( f64:: NEG_INFINITY ) ) ) ;
387+ assert_preimage_none ( ScalarValue :: Float64 ( Some ( f64:: NAN ) ) ) ;
388+ assert_preimage_none ( ScalarValue :: Float64 ( Some ( f64:: MAX ) ) ) ; // precision loss
389+
390+ // Float32 edge cases
391+ assert_preimage_none ( ScalarValue :: Float32 ( Some ( f32:: INFINITY ) ) ) ;
392+ assert_preimage_none ( ScalarValue :: Float32 ( Some ( f32:: NEG_INFINITY ) ) ) ;
393+ assert_preimage_none ( ScalarValue :: Float32 ( Some ( f32:: NAN ) ) ) ;
394+ assert_preimage_none ( ScalarValue :: Float32 ( Some ( f32:: MAX ) ) ) ; // precision loss
395+ }
396+
397+ #[ test]
398+ fn test_floor_preimage_null_values ( ) {
399+ assert_preimage_none ( ScalarValue :: Float64 ( None ) ) ;
400+ assert_preimage_none ( ScalarValue :: Float32 ( None ) ) ;
401+ assert_preimage_none ( ScalarValue :: Int64 ( None ) ) ;
402+ }
403+
404+ #[ test]
405+ fn test_floor_preimage_invalid_inputs ( ) {
406+ let floor_func = FloorFunc :: new ( ) ;
407+ let info = SimplifyContext :: default ( ) ;
408+
409+ // Non-literal comparison value
410+ let result = floor_func. preimage ( & [ col ( "x" ) ] , & col ( "y" ) , & info) . unwrap ( ) ;
411+ assert ! (
412+ matches!( result, PreimageResult :: None ) ,
413+ "Expected None for non-literal"
414+ ) ;
415+
416+ // Wrong argument count (too many)
417+ let lit = Expr :: Literal ( ScalarValue :: Float64 ( Some ( 100.0 ) ) , None ) ;
418+ let result = floor_func
419+ . preimage ( & [ col ( "x" ) , col ( "y" ) ] , & lit, & info)
420+ . unwrap ( ) ;
421+ assert ! (
422+ matches!( result, PreimageResult :: None ) ,
423+ "Expected None for wrong arg count"
424+ ) ;
425+
426+ // Wrong argument count (zero)
427+ let result = floor_func. preimage ( & [ ] , & lit, & info) . unwrap ( ) ;
428+ assert ! (
429+ matches!( result, PreimageResult :: None ) ,
430+ "Expected None for zero args"
431+ ) ;
432+ }
433+ }
0 commit comments