Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
202 changes: 199 additions & 3 deletions datafusion/core/tests/optimizer/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@
//! Tests for the DataFusion SQL query planner that require functions from the
//! datafusion-functions crate.

use datafusion_expr::simplify::SimplifyContext;
use datafusion_optimizer::simplify_expressions::ExprSimplifier;
use insta::assert_snapshot;
use std::any::Any;
use std::collections::HashMap;
Expand All @@ -28,11 +30,13 @@ use arrow::datatypes::{
};
use datafusion_common::config::ConfigOptions;
use datafusion_common::tree_node::TransformedResult;
use datafusion_common::{DFSchema, Result, ScalarValue, TableReference, plan_err};
use datafusion_common::{
DFSchema, DFSchemaRef, Result, ScalarValue, TableReference, plan_err,
};
use datafusion_expr::interval_arithmetic::{Interval, NullableInterval};
use datafusion_expr::{
AggregateUDF, BinaryExpr, Expr, ExprSchemable, LogicalPlan, Operator, ScalarUDF,
TableSource, WindowUDF, col, lit,
TableSource, WindowUDF, and, col, lit, or,
};
use datafusion_functions::core::expr_ext::FieldAccessor;
use datafusion_optimizer::analyzer::Analyzer;
Expand All @@ -45,7 +49,7 @@ use datafusion_sql::sqlparser::parser::Parser;

use chrono::DateTime;
use datafusion_expr::expr_rewriter::rewrite_with_guarantees;
use datafusion_functions::datetime;
use datafusion_functions::datetime::{self, expr_fn};

#[cfg(test)]
#[ctor::ctor]
Expand Down Expand Up @@ -378,3 +382,195 @@ fn validate_unchanged_cases(guarantees: &[(Expr, NullableInterval)], cases: &[Ex
);
}
}

// DatePart preimage tests
#[test]
fn test_preimage_date_part_date32_eq() {
let schema = expr_test_schema();
// date_part(c1, DatePart::Year) = 2024 -> c1 >= 2024-01-01 AND c1 < 2025-01-01
let expr_lt = expr_fn::date_part(lit("year"), col("date32")).eq(lit(2024i32));
let expected = and(
col("date32").gt_eq(lit(ScalarValue::Date32(Some(19723)))),
col("date32").lt(lit(ScalarValue::Date32(Some(20089)))),
);
assert_eq!(optimize_test(expr_lt, &schema), expected)
}

#[test]
fn test_preimage_date_part_date64_not_eq() {
let schema = expr_test_schema();
// date_part(c1, DatePart::Year) <> 2024 -> c1 < 2024-01-01 AND c1 >= 2025-01-01
let expr_lt = expr_fn::date_part(lit("year"), col("date64")).not_eq(lit(2024i32));
let expected = or(
col("date64").lt(lit(ScalarValue::Date64(Some(19723 * 86_400_000)))),
col("date64").gt_eq(lit(ScalarValue::Date64(Some(20089 * 86_400_000)))),
);
assert_eq!(optimize_test(expr_lt, &schema), expected)
}

#[test]
fn test_preimage_date_part_timestamp_nano_lt() {
let schema = expr_test_schema();
let expr_lt = expr_fn::date_part(lit("year"), col("ts_nano_none")).lt(lit(2024i32));
let expected = col("ts_nano_none").lt(lit(ScalarValue::TimestampNanosecond(
Some(19723 * 86_400_000_000_000),
None,
)));
assert_eq!(optimize_test(expr_lt, &schema), expected)
}

#[test]
fn test_preimage_date_part_timestamp_nano_utc_gt() {
let schema = expr_test_schema();
let expr_lt = expr_fn::date_part(lit("year"), col("ts_nano_utc")).gt(lit(2024i32));
let expected = col("ts_nano_utc").gt_eq(lit(ScalarValue::TimestampNanosecond(
Some(20089 * 86_400_000_000_000),
None,
)));
assert_eq!(optimize_test(expr_lt, &schema), expected)
}

#[test]
fn test_preimage_date_part_timestamp_sec_est_gt_eq() {
let schema = expr_test_schema();
let expr_lt = expr_fn::date_part(lit("year"), col("ts_sec_est")).gt_eq(lit(2024i32));
let expected = col("ts_sec_est").gt_eq(lit(ScalarValue::TimestampSecond(
Some(19723 * 86_400),
None,
)));
assert_eq!(optimize_test(expr_lt, &schema), expected)
}

#[test]
fn test_preimage_date_part_timestamp_sec_est_lt_eq() {
let schema = expr_test_schema();
let expr_lt = expr_fn::date_part(lit("year"), col("ts_mic_pt")).lt_eq(lit(2024i32));
let expected = col("ts_mic_pt").lt(lit(ScalarValue::TimestampMicrosecond(
Some(20089 * 86_400_000_000),
None,
)));
assert_eq!(optimize_test(expr_lt, &schema), expected)
}

#[test]
fn test_preimage_date_part_timestamp_nano_lt_swap() {
let schema = expr_test_schema();
let expr_lt = lit(2024i32).gt(expr_fn::date_part(lit("year"), col("ts_nano_none")));
let expected = col("ts_nano_none").lt(lit(ScalarValue::TimestampNanosecond(
Some(19723 * 86_400_000_000_000),
None,
)));
assert_eq!(optimize_test(expr_lt, &schema), expected)
}

#[test]
fn test_preimage_date_part_date32_is_not_distinct_from() {
let schema = expr_test_schema();
// date_part(c1, DatePart::Year) is not distinct from 2024 -> c1 >= 2024-01-01 AND c1 < 2025-01-01 (the null handling part is dropped since rhs is not null)
let expr_lt = Expr::BinaryExpr(BinaryExpr {
left: Box::new(expr_fn::date_part(lit("year"), col("date32"))),
op: Operator::IsNotDistinctFrom,
right: Box::new(lit(2024i32)),
});
let expected = and(
col("date32").gt_eq(lit(ScalarValue::Date32(Some(19723)))),
col("date32").lt(lit(ScalarValue::Date32(Some(20089)))),
);
assert_eq!(optimize_test(expr_lt, &schema), expected)
}

#[test]
// Should not simplify - interval can't be calculated
fn test_preimage_date_part_date32_is_not_distinct_from_null() {
let schema = expr_test_schema();
// date_part(c1, DatePart::Year) is not distinct from Null -> unchanged
let expr_lt = Expr::BinaryExpr(BinaryExpr {
left: Box::new(expr_fn::date_part(lit("year"), col("date32"))),
op: Operator::IsNotDistinctFrom,
right: Box::new(lit(ScalarValue::Null)),
});
assert_eq!(optimize_test(expr_lt.clone(), &schema), expr_lt)
}

#[test]
fn test_preimage_date_part_date64_is_distinct_from() {
let schema = expr_test_schema();
// date_part(c1, DatePart::Year) is distinct from 2024 -> c1 < 2024-01-01 OR c1 >= 2025-01-01 or c1 is NULL
let expr_lt = Expr::BinaryExpr(BinaryExpr {
left: Box::new(expr_fn::date_part(lit("year"), col("date64"))),
op: Operator::IsDistinctFrom,
right: Box::new(lit(2024i32)),
});
let expected = col("date64")
.lt(lit(ScalarValue::Date64(Some(19723 * 86_400_000))))
.or(col("date64").gt_eq(lit(ScalarValue::Date64(Some(20089 * 86_400_000)))))
.or(col("date64").is_null());
assert_eq!(optimize_test(expr_lt, &schema), expected)
}

#[test]
// Should not simplify - interval can't be calculated
fn test_preimage_date_part_date64_is_distinct_from_null() {
let schema = expr_test_schema();
// date_part(c1, DatePart::Year) is distinct from 2024 -> c1 < 2024-01-01 OR c1 >= unchanged
let expr_lt = Expr::BinaryExpr(BinaryExpr {
left: Box::new(expr_fn::date_part(lit("year"), col("date64"))),
op: Operator::IsDistinctFrom,
right: Box::new(lit(ScalarValue::Null)),
});
assert_eq!(optimize_test(expr_lt.clone(), &schema), expr_lt)
}

#[test]
// Should not simplify
fn test_preimage_date_part_not_year_date32_eq() {
let schema = expr_test_schema();
// date_part(c1, DatePart::Year) = 2024 -> c1 >= 2024-01-01 AND c1 < 2025-01-01
let expr_lt = expr_fn::date_part(lit("month"), col("date32")).eq(lit(1i32));
assert_eq!(optimize_test(expr_lt.clone(), &schema), expr_lt)
}

fn optimize_test(expr: Expr, schema: &DFSchemaRef) -> Expr {
let simplifier =
ExprSimplifier::new(SimplifyContext::default().with_schema(Arc::clone(schema)));

simplifier.simplify(expr).unwrap()
}

fn expr_test_schema() -> DFSchemaRef {
Arc::new(
DFSchema::from_unqualified_fields(
vec![
Field::new("date32", DataType::Date32, true),
Field::new("date64", DataType::Date64, true),
Field::new("ts_nano_none", timestamp_nano_none_type(), true),
Field::new("ts_nano_utc", timestamp_nano_utc_type(), true),
Field::new("ts_sec_est", timestamp_sec_est_type(), true),
Field::new("ts_mic_pt", timestamp_mic_pt_type(), true),
]
.into(),
HashMap::new(),
)
.unwrap(),
)
}

fn timestamp_nano_none_type() -> DataType {
DataType::Timestamp(TimeUnit::Nanosecond, None)
}

// this is the type that now() returns
fn timestamp_nano_utc_type() -> DataType {
let utc = Some("+0:00".into());
DataType::Timestamp(TimeUnit::Nanosecond, utc)
}

fn timestamp_sec_est_type() -> DataType {
let est = Some("-5:00".into());
DataType::Timestamp(TimeUnit::Second, est)
}

fn timestamp_mic_pt_type() -> DataType {
let pt = Some("-8::00".into());
DataType::Timestamp(TimeUnit::Microsecond, pt)
}
62 changes: 62 additions & 0 deletions datafusion/expr/src/udf.rs
Original file line number Diff line number Diff line change
Expand Up @@ -232,6 +232,25 @@ impl ScalarUDF {
self.inner.is_nullable(args, schema)
}

/// Return a preimage
///
/// See [`ScalarUDFImpl::preimage`] for more details.
pub fn preimage(
&self,
args: &[Expr],
lit_expr: &Expr,
info: &SimplifyContext,
) -> Result<Option<Interval>> {
self.inner.preimage(args, lit_expr, info)
}

/// Return inner column from function args
///
/// See [`ScalarUDFImpl::column_expr`]
pub fn column_expr(&self, args: &[Expr]) -> Option<Expr> {
self.inner.column_expr(args)
}

/// Invoke the function on `args`, returning the appropriate result.
///
/// See [`ScalarUDFImpl::invoke_with_args`] for details.
Expand Down Expand Up @@ -696,6 +715,36 @@ pub trait ScalarUDFImpl: Debug + DynEq + DynHash + Send + Sync {
Ok(ExprSimplifyResult::Original(args))
}

/// Returns the [preimage] for this function and the specified scalar value, if any.
///
/// A preimage is a single contiguous [`Interval`] of values where the function
/// will always return `lit_value`
///
/// This rewrite is described in the [ClickHouse Paper] and is particularly
/// useful for simplifying expressions `date_part` or equivalent functions. The
/// idea is that if you have an expression like `date_part(YEAR, k) = 2024` and you
/// can find a [preimage] for `date_part(YEAR, k)`, which is the range of dates
/// covering the entire year of 2024. Thus, you can rewrite the expression to `k
/// >= '2024-01-01' AND k < '2025-01-01' which is often more optimizable.
///
/// This should only return a preimage if the function takes a single argument
///
/// [ClickHouse Paper]: https://www.vldb.org/pvldb/vol17/p3731-schulze.pdf
/// [preimage]: https://en.wikipedia.org/wiki/Image_(mathematics)#Inverse_image
fn preimage(
&self,
_args: &[Expr],
_lit_expr: &Expr,
_info: &SimplifyContext,
) -> Result<Option<Interval>> {
Ok(None)
}

// Return the inner column expression from this function
fn column_expr(&self, _args: &[Expr]) -> Option<Expr> {
None
}

/// Returns true if some of this `exprs` subexpressions may not be evaluated
/// and thus any side effects (like divide by zero) may not be encountered.
///
Expand Down Expand Up @@ -926,6 +975,19 @@ impl ScalarUDFImpl for AliasedScalarUDFImpl {
self.inner.simplify(args, info)
}

fn preimage(
&self,
args: &[Expr],
lit_expr: &Expr,
info: &SimplifyContext,
) -> Result<Option<Interval>> {
self.inner.preimage(args, lit_expr, info)
}

fn column_expr(&self, args: &[Expr]) -> Option<Expr> {
self.inner.column_expr(args)
}

fn conditional_arguments<'a>(
&self,
args: &'a [Expr],
Expand Down
Loading