diff --git a/datafusion/expr/src/udf.rs b/datafusion/expr/src/udf.rs index 0654370ac7ebf..55c50f6dd8567 100644 --- a/datafusion/expr/src/udf.rs +++ b/datafusion/expr/src/udf.rs @@ -232,6 +232,25 @@ impl ScalarUDF { self.inner.is_nullable(args, schema) } + /// Return a preimage + /// + /// See [`ScalarUDFImpl::preimage`] for more details. + pub fn preimage( + &self, + args: &[Expr], + lit_expr: &Expr, + info: &SimplifyContext, + ) -> Result> { + self.inner.preimage(args, lit_expr, info) + } + + /// Return inner column from function args + /// + /// See [`ScalarUDFImpl::column_expr`] + pub fn column_expr(&self, args: &[Expr]) -> Option { + self.inner.column_expr(args) + } + /// Invoke the function on `args`, returning the appropriate result. /// /// See [`ScalarUDFImpl::invoke_with_args`] for details. @@ -696,6 +715,36 @@ pub trait ScalarUDFImpl: Debug + DynEq + DynHash + Send + Sync { Ok(ExprSimplifyResult::Original(args)) } + /// Returns the [preimage] for this function and the specified scalar value, if any. + /// + /// A preimage is a single contiguous [`Interval`] of values where the function + /// will always return `lit_value` + /// + /// This rewrite is described in the [ClickHouse Paper] and is particularly + /// useful for simplifying expressions `date_part` or equivalent functions. The + /// idea is that if you have an expression like `date_part(YEAR, k) = 2024` and you + /// can find a [preimage] for `date_part(YEAR, k)`, which is the range of dates + /// covering the entire year of 2024. Thus, you can rewrite the expression to `k + /// >= '2024-01-01' AND k < '2025-01-01' which is often more optimizable. + /// + /// This should only return a preimage if the function takes a single argument + /// + /// [ClickHouse Paper]: https://www.vldb.org/pvldb/vol17/p3731-schulze.pdf + /// [preimage]: https://en.wikipedia.org/wiki/Image_(mathematics)#Inverse_image + fn preimage( + &self, + _args: &[Expr], + _lit_expr: &Expr, + _info: &SimplifyContext, + ) -> Result> { + Ok(None) + } + + // Return the inner column expression from this function + fn column_expr(&self, _args: &[Expr]) -> Option { + None + } + /// Returns true if some of this `exprs` subexpressions may not be evaluated /// and thus any side effects (like divide by zero) may not be encountered. /// @@ -926,6 +975,19 @@ impl ScalarUDFImpl for AliasedScalarUDFImpl { self.inner.simplify(args, info) } + fn preimage( + &self, + args: &[Expr], + lit_expr: &Expr, + info: &SimplifyContext, + ) -> Result> { + self.inner.preimage(args, lit_expr, info) + } + + fn column_expr(&self, args: &[Expr]) -> Option { + self.inner.column_expr(args) + } + fn conditional_arguments<'a>( &self, args: &'a [Expr], diff --git a/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs b/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs index 55bff5849c5cb..c99978a1cc4c6 100644 --- a/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs +++ b/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs @@ -39,7 +39,7 @@ use datafusion_common::{ }; use datafusion_expr::{ BinaryExpr, Case, ColumnarValue, Expr, Like, Operator, Volatility, and, - binary::BinaryTypeCoercer, lit, or, + binary::BinaryTypeCoercer, interval_arithmetic::Interval, lit, or, }; use datafusion_expr::{Cast, TryCast, simplify::ExprSimplifyResult}; use datafusion_expr::{expr::ScalarFunction, interval_arithmetic::NullableInterval}; @@ -51,7 +51,6 @@ use datafusion_physical_expr::{create_physical_expr, execution_props::ExecutionP use super::inlist_simplifier::ShortenInListSimplifier; use super::utils::*; -use crate::analyzer::type_coercion::TypeCoercionRewriter; use crate::simplify_expressions::SimplifyContext; use crate::simplify_expressions::regex::simplify_regex_expr; use crate::simplify_expressions::unwrap_cast::{ @@ -59,6 +58,10 @@ use crate::simplify_expressions::unwrap_cast::{ is_cast_expr_and_support_unwrap_cast_in_comparison_for_inlist, unwrap_cast_in_comparison_for_binary, }; +use crate::{ + analyzer::type_coercion::TypeCoercionRewriter, + simplify_expressions::udf_preimage::rewrite_with_preimage, +}; use datafusion_expr::expr_rewriter::rewrite_with_guarantees_map; use datafusion_expr_common::casts::try_cast_literal_to_type; use indexmap::IndexSet; @@ -1952,12 +1955,98 @@ impl TreeNodeRewriter for Simplifier<'_> { })) } + // ======================================= + // preimage_in_comparison + // ======================================= + // + // For case: + // date_part(expr as 'YEAR') op literal + // + // Background: + // Datasources such as Parquet can prune partitions using simple predicates, + // but they cannot do so for complex expressions. + // For a complex predicate like `date_part('YEAR', c1) < 2000`, pruning is not possible. + // After rewriting it to `c1 < 2000-01-01`, pruning becomes feasible. + // NOTE: we only consider immutable UDFs with literal RHS values + Expr::BinaryExpr(BinaryExpr { left, op, right }) => { + use datafusion_expr::Operator::*; + let is_preimage_op = matches!( + op, + Eq | NotEq + | Lt + | LtEq + | Gt + | GtEq + | IsDistinctFrom + | IsNotDistinctFrom + ); + if !is_preimage_op { + return Ok(Transformed::no(Expr::BinaryExpr(BinaryExpr { + left, + op, + right, + }))); + } + + if let (Some(interval), Some(col_expr)) = + get_preimage(left.as_ref(), right.as_ref(), info)? + { + rewrite_with_preimage(info, interval, op, Box::new(col_expr))? + } else if let Some(swapped) = op.swap() { + if let (Some(interval), Some(col_expr)) = + get_preimage(right.as_ref(), left.as_ref(), info)? + { + rewrite_with_preimage( + info, + interval, + swapped, + Box::new(col_expr), + )? + } else { + Transformed::no(Expr::BinaryExpr(BinaryExpr { left, op, right })) + } + } else { + Transformed::no(Expr::BinaryExpr(BinaryExpr { left, op, right })) + } + } + // no additional rewrites possible expr => Transformed::no(expr), }) } } +fn get_preimage( + left_expr: &Expr, + right_expr: &Expr, + info: &SimplifyContext, +) -> Result<(Option, Option)> { + let Expr::ScalarFunction(ScalarFunction { func, args }) = left_expr else { + return Ok((None, None)); + }; + if !is_literal_or_literal_cast(right_expr) { + return Ok((None, None)); + } + if func.signature().volatility != Volatility::Immutable { + return Ok((None, None)); + } + Ok(( + func.preimage(args, right_expr, info)?, + func.column_expr(args), + )) +} + +fn is_literal_or_literal_cast(expr: &Expr) -> bool { + match expr { + Expr::Literal(_, _) => true, + Expr::Cast(Cast { expr, .. }) => matches!(expr.as_ref(), Expr::Literal(_, _)), + Expr::TryCast(TryCast { expr, .. }) => { + matches!(expr.as_ref(), Expr::Literal(_, _)) + } + _ => false, + } +} + fn as_string_scalar(expr: &Expr) -> Option<(DataType, &Option)> { match expr { Expr::Literal(ScalarValue::Utf8(s), _) => Some((DataType::Utf8, s)), diff --git a/datafusion/optimizer/src/simplify_expressions/mod.rs b/datafusion/optimizer/src/simplify_expressions/mod.rs index 3ab76119cca84..b85b000821ad8 100644 --- a/datafusion/optimizer/src/simplify_expressions/mod.rs +++ b/datafusion/optimizer/src/simplify_expressions/mod.rs @@ -24,6 +24,7 @@ mod regex; pub mod simplify_exprs; pub mod simplify_literal; mod simplify_predicates; +mod udf_preimage; mod unwrap_cast; mod utils; diff --git a/datafusion/optimizer/src/simplify_expressions/udf_preimage.rs b/datafusion/optimizer/src/simplify_expressions/udf_preimage.rs new file mode 100644 index 0000000000000..960c8df322d15 --- /dev/null +++ b/datafusion/optimizer/src/simplify_expressions/udf_preimage.rs @@ -0,0 +1,114 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use datafusion_common::{Result, internal_err, tree_node::Transformed}; +use datafusion_expr::{ + BinaryExpr, Expr, Operator, and, lit, or, simplify::SimplifyContext, +}; +use datafusion_expr_common::interval_arithmetic::Interval; + +/// Rewrites a binary expression using its "preimage" +/// +/// Specifically it rewrites expressions of the form ` OP x` (e.g. ` = +/// x`) where `` is known to have a pre-image (aka the entire single +/// range for which it is valid) +/// +/// This rewrite is described in the [ClickHouse Paper] and is particularly +/// useful for simplifying expressions `date_part` or equivalent functions. The +/// idea is that if you have an expression like `date_part(YEAR, k) = 2024` and you +/// can find a [preimage] for `date_part(YEAR, k)`, which is the range of dates +/// covering the entire year of 2024. Thus, you can rewrite the expression to `k +/// >= '2024-01-01' AND k < '2025-01-01' which is often more optimizable. +/// +/// [ClickHouse Paper]: https://www.vldb.org/pvldb/vol17/p3731-schulze.pdf +/// [preimage]: https://en.wikipedia.org/wiki/Image_(mathematics)#Inverse_image +/// +pub(super) fn rewrite_with_preimage( + _info: &SimplifyContext, + preimage_interval: Interval, + op: Operator, + expr: Box, +) -> Result> { + let (lower, upper) = preimage_interval.into_bounds(); + let (lower, upper) = (lit(lower), lit(upper)); + + let rewritten_expr = match op { + // < x ==> < lower + // >= x ==> >= lower + Operator::Lt | Operator::GtEq => Expr::BinaryExpr(BinaryExpr { + left: expr, + op, + right: Box::new(lower), + }), + // > x ==> >= upper + Operator::Gt => Expr::BinaryExpr(BinaryExpr { + left: expr, + op: Operator::GtEq, + right: Box::new(upper), + }), + // <= x ==> < upper + Operator::LtEq => Expr::BinaryExpr(BinaryExpr { + left: expr, + op: Operator::Lt, + right: Box::new(upper), + }), + // = x ==> ( >= lower) and ( < upper) + // + // is not distinct from x ==> ( is NULL and x is NULL) or (( >= lower) and ( < upper)) + // but since x is always not NULL => ( >= lower) and ( < upper) + Operator::Eq | Operator::IsNotDistinctFrom => and( + Expr::BinaryExpr(BinaryExpr { + left: expr.clone(), + op: Operator::GtEq, + right: Box::new(lower), + }), + Expr::BinaryExpr(BinaryExpr { + left: expr, + op: Operator::Lt, + right: Box::new(upper), + }), + ), + // != x ==> ( < lower) or ( >= upper) + Operator::NotEq => or( + Expr::BinaryExpr(BinaryExpr { + left: expr.clone(), + op: Operator::Lt, + right: Box::new(lower), + }), + Expr::BinaryExpr(BinaryExpr { + left: expr, + op: Operator::GtEq, + right: Box::new(upper), + }), + ), + // is distinct from x ==> ( < lower) or ( >= upper) or ( is NULL and x is not NULL) or ( is not NULL and x is NULL) + // but given that x is always not NULL => ( < lower) or ( >= upper) or ( is NULL) + Operator::IsDistinctFrom => Expr::BinaryExpr(BinaryExpr { + left: expr.clone(), + op: Operator::Lt, + right: Box::new(lower.clone()), + }) + .or(Expr::BinaryExpr(BinaryExpr { + left: expr.clone(), + op: Operator::GtEq, + right: Box::new(upper), + })) + .or(expr.is_null()), + _ => return internal_err!("Expect comparison operators"), + }; + Ok(Transformed::yes(rewritten_expr)) +}