Skip to content
242 changes: 240 additions & 2 deletions datafusion/functions/src/math/floor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,12 +25,15 @@ use arrow::datatypes::{
};
use datafusion_common::{Result, ScalarValue, exec_err};
use datafusion_expr::interval_arithmetic::Interval;
use datafusion_expr::preimage::PreimageResult;
use datafusion_expr::simplify::SimplifyContext;
use datafusion_expr::sort_properties::{ExprProperties, SortProperties};
use datafusion_expr::{
Coercion, ColumnarValue, Documentation, ScalarFunctionArgs, ScalarUDFImpl, Signature,
TypeSignature, TypeSignatureClass, Volatility,
Coercion, ColumnarValue, Documentation, Expr, ScalarFunctionArgs, ScalarUDFImpl,
Signature, TypeSignature, TypeSignatureClass, Volatility,
};
use datafusion_macros::user_doc;
use num_traits::{CheckedAdd, Float, One};

use super::decimal::{apply_decimal_op, floor_decimal_value};

Expand Down Expand Up @@ -200,7 +203,242 @@ impl ScalarUDFImpl for FloorFunc {
Interval::make_unbounded(&data_type)
}

/// Compute the preimage for floor function.
///
/// For `floor(x) = N`, the preimage is `x >= N AND x < N + 1`
/// because floor(x) = N for all x in [N, N+1).
///
/// This enables predicate pushdown optimizations, transforming:
/// `floor(col) = 100` into `col >= 100 AND col < 101`
fn preimage(
&self,
args: &[Expr],
lit_expr: &Expr,
_info: &SimplifyContext,
) -> Result<PreimageResult> {
// floor takes exactly one argument
if args.len() != 1 {
return Ok(PreimageResult::None);
}

let arg = args[0].clone();

// Extract the literal value being compared to
let Expr::Literal(lit_value, _) = lit_expr else {
return Ok(PreimageResult::None);
};

// Compute lower bound (N) and upper bound (N + 1) using helper functions
let Some((lower, upper)) = (match lit_value {
// Floating-point types
ScalarValue::Float64(Some(n)) => float_preimage_bounds(*n).map(|(lo, hi)| {
(
ScalarValue::Float64(Some(lo)),
ScalarValue::Float64(Some(hi)),
)
}),
ScalarValue::Float32(Some(n)) => float_preimage_bounds(*n).map(|(lo, hi)| {
(
ScalarValue::Float32(Some(lo)),
ScalarValue::Float32(Some(hi)),
)
}),

// Integer types
ScalarValue::Int8(Some(n)) => int_preimage_bounds(*n).map(|(lo, hi)| {
(ScalarValue::Int8(Some(lo)), ScalarValue::Int8(Some(hi)))
}),
ScalarValue::Int16(Some(n)) => int_preimage_bounds(*n).map(|(lo, hi)| {
(ScalarValue::Int16(Some(lo)), ScalarValue::Int16(Some(hi)))
}),
ScalarValue::Int32(Some(n)) => int_preimage_bounds(*n).map(|(lo, hi)| {
(ScalarValue::Int32(Some(lo)), ScalarValue::Int32(Some(hi)))
}),
ScalarValue::Int64(Some(n)) => int_preimage_bounds(*n).map(|(lo, hi)| {
(ScalarValue::Int64(Some(lo)), ScalarValue::Int64(Some(hi)))
}),

// Unsupported types
_ => None,
}) else {
return Ok(PreimageResult::None);
};

Ok(PreimageResult::Range {
expr: arg,
interval: Box::new(Interval::try_new(lower, upper)?),
})
}

fn documentation(&self) -> Option<&Documentation> {
self.doc()
}
}

// ============ Helper functions for preimage bounds ============

/// Compute preimage bounds for floor function on floating-point types.
/// For floor(x) = n, the preimage is [n, n+1).
/// Returns None if the value is non-finite or would lose precision.
fn float_preimage_bounds<F: Float>(n: F) -> Option<(F, F)> {
let one = F::one();
// Check for non-finite values (infinity, NaN) or precision loss at extreme values
if !n.is_finite() || n + one <= n {
return None;
}
Some((n, n + one))
}

/// Compute preimage bounds for floor function on integer types.
/// For floor(x) = n, the preimage is [n, n+1).
/// Returns None if adding 1 would overflow.
fn int_preimage_bounds<I: CheckedAdd + One + Copy>(n: I) -> Option<(I, I)> {
let upper = n.checked_add(&I::one())?;
Some((n, upper))
}

#[cfg(test)]
mod tests {
use super::*;
use datafusion_expr::col;

/// Helper to test valid preimage cases that should return a Range
fn assert_preimage_range(
input: ScalarValue,
expected_lower: ScalarValue,
expected_upper: ScalarValue,
) {
let floor_func = FloorFunc::new();
let args = vec![col("x")];
let lit_expr = Expr::Literal(input.clone(), None);
let info = SimplifyContext::default();

let result = floor_func.preimage(&args, &lit_expr, &info).unwrap();

match result {
PreimageResult::Range { expr, interval } => {
assert_eq!(expr, col("x"));
assert_eq!(interval.lower().clone(), expected_lower);
assert_eq!(interval.upper().clone(), expected_upper);
}
PreimageResult::None => {
panic!("Expected Range, got None for input {input:?}")
}
}
}

/// Helper to test cases that should return None
fn assert_preimage_none(input: ScalarValue) {
let floor_func = FloorFunc::new();
let args = vec![col("x")];
let lit_expr = Expr::Literal(input.clone(), None);
let info = SimplifyContext::default();

let result = floor_func.preimage(&args, &lit_expr, &info).unwrap();
assert!(
matches!(result, PreimageResult::None),
"Expected None for input {input:?}"
);
}

#[test]
fn test_floor_preimage_valid_cases() {
// Float64
assert_preimage_range(
ScalarValue::Float64(Some(100.0)),
ScalarValue::Float64(Some(100.0)),
ScalarValue::Float64(Some(101.0)),
);
// Float32
assert_preimage_range(
ScalarValue::Float32(Some(50.0)),
ScalarValue::Float32(Some(50.0)),
ScalarValue::Float32(Some(51.0)),
);
// Int64
assert_preimage_range(
ScalarValue::Int64(Some(42)),
ScalarValue::Int64(Some(42)),
ScalarValue::Int64(Some(43)),
);
// Int32
assert_preimage_range(
ScalarValue::Int32(Some(100)),
ScalarValue::Int32(Some(100)),
ScalarValue::Int32(Some(101)),
);
// Negative values
assert_preimage_range(
ScalarValue::Float64(Some(-5.0)),
ScalarValue::Float64(Some(-5.0)),
ScalarValue::Float64(Some(-4.0)),
);
// Zero
assert_preimage_range(
ScalarValue::Float64(Some(0.0)),
ScalarValue::Float64(Some(0.0)),
ScalarValue::Float64(Some(1.0)),
);
}

#[test]
fn test_floor_preimage_integer_overflow() {
// All integer types at MAX value should return None
assert_preimage_none(ScalarValue::Int64(Some(i64::MAX)));
assert_preimage_none(ScalarValue::Int32(Some(i32::MAX)));
assert_preimage_none(ScalarValue::Int16(Some(i16::MAX)));
assert_preimage_none(ScalarValue::Int8(Some(i8::MAX)));
}

#[test]
fn test_floor_preimage_float_edge_cases() {
// Float64 edge cases
assert_preimage_none(ScalarValue::Float64(Some(f64::INFINITY)));
assert_preimage_none(ScalarValue::Float64(Some(f64::NEG_INFINITY)));
assert_preimage_none(ScalarValue::Float64(Some(f64::NAN)));
assert_preimage_none(ScalarValue::Float64(Some(f64::MAX))); // precision loss

// Float32 edge cases
assert_preimage_none(ScalarValue::Float32(Some(f32::INFINITY)));
assert_preimage_none(ScalarValue::Float32(Some(f32::NEG_INFINITY)));
assert_preimage_none(ScalarValue::Float32(Some(f32::NAN)));
assert_preimage_none(ScalarValue::Float32(Some(f32::MAX))); // precision loss
}

#[test]
fn test_floor_preimage_null_values() {
assert_preimage_none(ScalarValue::Float64(None));
assert_preimage_none(ScalarValue::Float32(None));
assert_preimage_none(ScalarValue::Int64(None));
}

#[test]
fn test_floor_preimage_invalid_inputs() {
let floor_func = FloorFunc::new();
let info = SimplifyContext::default();

// Non-literal comparison value
let result = floor_func.preimage(&[col("x")], &col("y"), &info).unwrap();
assert!(
matches!(result, PreimageResult::None),
"Expected None for non-literal"
);

// Wrong argument count (too many)
let lit = Expr::Literal(ScalarValue::Float64(Some(100.0)), None);
let result = floor_func
.preimage(&[col("x"), col("y")], &lit, &info)
.unwrap();
assert!(
matches!(result, PreimageResult::None),
"Expected None for wrong arg count"
);

// Wrong argument count (zero)
let result = floor_func.preimage(&[], &lit, &info).unwrap();
assert!(
matches!(result, PreimageResult::None),
"Expected None for zero args"
);
}
}