Skip to content

Commit 969ed5e

Browse files
authored
Simplify predicates in PushDownFilter optimizer rule (#16362)
* Simplify predicates in filter * add slt test * Use BtreeMap to make tests stable * process edge coner * add doc for simplify_predicates.rs * add as_literal to make code neat * reorgnize file * reduce clone call
1 parent 334d449 commit 969ed5e

File tree

6 files changed

+513
-8
lines changed

6 files changed

+513
-8
lines changed

Cargo.lock

Lines changed: 8 additions & 8 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

datafusion/expr/src/expr.rs

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2069,6 +2069,15 @@ impl Expr {
20692069
_ => None,
20702070
}
20712071
}
2072+
2073+
/// Check if the Expr is literal and get the literal value if it is.
2074+
pub fn as_literal(&self) -> Option<&ScalarValue> {
2075+
if let Expr::Literal(lit, _) = self {
2076+
Some(lit)
2077+
} else {
2078+
None
2079+
}
2080+
}
20722081
}
20732082

20742083
impl Normalizeable for Expr {

datafusion/optimizer/src/push_down_filter.rs

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@ use datafusion_expr::{
4040
};
4141

4242
use crate::optimizer::ApplyOrder;
43+
use crate::simplify_expressions::simplify_predicates;
4344
use crate::utils::{has_all_column_refs, is_restrict_null_predicate};
4445
use crate::{OptimizerConfig, OptimizerRule};
4546

@@ -779,6 +780,18 @@ impl OptimizerRule for PushDownFilter {
779780
return Ok(Transformed::no(plan));
780781
};
781782

783+
let predicate = split_conjunction_owned(filter.predicate.clone());
784+
let old_predicate_len = predicate.len();
785+
let new_predicates = simplify_predicates(predicate)?;
786+
if old_predicate_len != new_predicates.len() {
787+
let Some(new_predicate) = conjunction(new_predicates) else {
788+
// new_predicates is empty - remove the filter entirely
789+
// Return the child plan without the filter
790+
return Ok(Transformed::yes(Arc::unwrap_or_clone(filter.input)));
791+
};
792+
filter.predicate = new_predicate;
793+
}
794+
782795
match Arc::unwrap_or_clone(filter.input) {
783796
LogicalPlan::Filter(child_filter) => {
784797
let parents_predicates = split_conjunction_owned(filter.predicate);

datafusion/optimizer/src/simplify_expressions/mod.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ mod guarantees;
2323
mod inlist_simplifier;
2424
mod regex;
2525
pub mod simplify_exprs;
26+
mod simplify_predicates;
2627
mod unwrap_cast;
2728
mod utils;
2829

@@ -31,6 +32,7 @@ pub use datafusion_expr::simplify::{SimplifyContext, SimplifyInfo};
3132

3233
pub use expr_simplifier::*;
3334
pub use simplify_exprs::*;
35+
pub use simplify_predicates::simplify_predicates;
3436

3537
// Export for test in datafusion/core/tests/optimizer_integration.rs
3638
pub use guarantees::GuaranteeRewriter;
Lines changed: 247 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,247 @@
1+
// Licensed to the Apache Software Foundation (ASF) under one
2+
// or more contributor license agreements. See the NOTICE file
3+
// distributed with this work for additional information
4+
// regarding copyright ownership. The ASF licenses this file
5+
// to you under the Apache License, Version 2.0 (the
6+
// "License"); you may not use this file except in compliance
7+
// with the License. You may obtain a copy of the License at
8+
//
9+
// http://www.apache.org/licenses/LICENSE-2.0
10+
//
11+
// Unless required by applicable law or agreed to in writing,
12+
// software distributed under the License is distributed on an
13+
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
// KIND, either express or implied. See the License for the
15+
// specific language governing permissions and limitations
16+
// under the License.
17+
18+
//! Simplifies predicates by reducing redundant or overlapping conditions.
19+
//!
20+
//! This module provides functionality to optimize logical predicates used in query planning
21+
//! by eliminating redundant conditions, thus reducing the number of predicates to evaluate.
22+
//! Unlike the simplifier in `simplify_expressions/simplify_exprs.rs`, which focuses on
23+
//! general expression simplification (e.g., constant folding and algebraic simplifications),
24+
//! this module specifically targets predicate optimization by handling containment relationships.
25+
//! For example, it can simplify `x > 5 AND x > 6` to just `x > 6`, as the latter condition
26+
//! encompasses the former, resulting in fewer checks during query execution.
27+
28+
use datafusion_common::{Column, Result, ScalarValue};
29+
use datafusion_expr::{BinaryExpr, Cast, Expr, Operator};
30+
use std::collections::BTreeMap;
31+
32+
/// Simplifies a list of predicates by removing redundancies.
33+
///
34+
/// This function takes a vector of predicate expressions and groups them by the column they reference.
35+
/// Predicates that reference a single column and are comparison operations (e.g., >, >=, <, <=, =)
36+
/// are analyzed to remove redundant conditions. For instance, `x > 5 AND x > 6` is simplified to
37+
/// `x > 6`. Other predicates that do not fit this pattern are retained as-is.
38+
///
39+
/// # Arguments
40+
/// * `predicates` - A vector of `Expr` representing the predicates to simplify.
41+
///
42+
/// # Returns
43+
/// A `Result` containing a vector of simplified `Expr` predicates.
44+
pub fn simplify_predicates(predicates: Vec<Expr>) -> Result<Vec<Expr>> {
45+
// Early return for simple cases
46+
if predicates.len() <= 1 {
47+
return Ok(predicates);
48+
}
49+
50+
// Group predicates by their column reference
51+
let mut column_predicates: BTreeMap<Column, Vec<Expr>> = BTreeMap::new();
52+
let mut other_predicates = Vec::new();
53+
54+
for pred in predicates {
55+
match &pred {
56+
Expr::BinaryExpr(BinaryExpr {
57+
left,
58+
op:
59+
Operator::Gt
60+
| Operator::GtEq
61+
| Operator::Lt
62+
| Operator::LtEq
63+
| Operator::Eq,
64+
right,
65+
}) => {
66+
let left_col = extract_column_from_expr(left);
67+
let right_col = extract_column_from_expr(right);
68+
if let (Some(col), Some(_)) = (&left_col, right.as_literal()) {
69+
column_predicates.entry(col.clone()).or_default().push(pred);
70+
} else if let (Some(_), Some(col)) = (left.as_literal(), &right_col) {
71+
column_predicates.entry(col.clone()).or_default().push(pred);
72+
} else {
73+
other_predicates.push(pred);
74+
}
75+
}
76+
_ => other_predicates.push(pred),
77+
}
78+
}
79+
80+
// Process each column's predicates to remove redundancies
81+
let mut result = other_predicates;
82+
for (_, preds) in column_predicates {
83+
let simplified = simplify_column_predicates(preds)?;
84+
result.extend(simplified);
85+
}
86+
87+
Ok(result)
88+
}
89+
90+
/// Simplifies predicates related to a single column.
91+
///
92+
/// This function processes a list of predicates that all reference the same column and
93+
/// simplifies them based on their operators. It groups predicates into greater-than (>, >=),
94+
/// less-than (<, <=), and equality (=) categories, then selects the most restrictive condition
95+
/// in each category to reduce redundancy. For example, among `x > 5` and `x > 6`, only `x > 6`
96+
/// is retained as it is more restrictive.
97+
///
98+
/// # Arguments
99+
/// * `predicates` - A vector of `Expr` representing predicates for a single column.
100+
///
101+
/// # Returns
102+
/// A `Result` containing a vector of simplified `Expr` predicates for the column.
103+
fn simplify_column_predicates(predicates: Vec<Expr>) -> Result<Vec<Expr>> {
104+
if predicates.len() <= 1 {
105+
return Ok(predicates);
106+
}
107+
108+
// Group by operator type, but combining similar operators
109+
let mut greater_predicates = Vec::new(); // Combines > and >=
110+
let mut less_predicates = Vec::new(); // Combines < and <=
111+
let mut eq_predicates = Vec::new();
112+
113+
for pred in predicates {
114+
match &pred {
115+
Expr::BinaryExpr(BinaryExpr { left: _, op, right }) => {
116+
match (op, right.as_literal().is_some()) {
117+
(Operator::Gt, true)
118+
| (Operator::Lt, false)
119+
| (Operator::GtEq, true)
120+
| (Operator::LtEq, false) => greater_predicates.push(pred),
121+
(Operator::Lt, true)
122+
| (Operator::Gt, false)
123+
| (Operator::LtEq, true)
124+
| (Operator::GtEq, false) => less_predicates.push(pred),
125+
(Operator::Eq, _) => eq_predicates.push(pred),
126+
_ => unreachable!("Unexpected operator: {}", op),
127+
}
128+
}
129+
_ => unreachable!("Unexpected predicate {}", pred.to_string()),
130+
}
131+
}
132+
133+
let mut result = Vec::new();
134+
135+
if !eq_predicates.is_empty() {
136+
// If there are many equality predicates, we can only keep one if they are all the same
137+
if eq_predicates.len() == 1
138+
|| eq_predicates.iter().all(|e| e == &eq_predicates[0])
139+
{
140+
result.push(eq_predicates.pop().unwrap());
141+
} else {
142+
// If they are not the same, add a false predicate
143+
result.push(Expr::Literal(ScalarValue::Boolean(Some(false)), None));
144+
}
145+
}
146+
147+
// Handle all greater-than-style predicates (keep the most restrictive - highest value)
148+
if !greater_predicates.is_empty() {
149+
if let Some(most_restrictive) =
150+
find_most_restrictive_predicate(&greater_predicates, true)?
151+
{
152+
result.push(most_restrictive);
153+
} else {
154+
result.extend(greater_predicates);
155+
}
156+
}
157+
158+
// Handle all less-than-style predicates (keep the most restrictive - lowest value)
159+
if !less_predicates.is_empty() {
160+
if let Some(most_restrictive) =
161+
find_most_restrictive_predicate(&less_predicates, false)?
162+
{
163+
result.push(most_restrictive);
164+
} else {
165+
result.extend(less_predicates);
166+
}
167+
}
168+
169+
Ok(result)
170+
}
171+
172+
/// Finds the most restrictive predicate from a list based on literal values.
173+
///
174+
/// This function iterates through a list of predicates to identify the most restrictive one
175+
/// by comparing their literal values. For greater-than predicates, the highest value is most
176+
/// restrictive, while for less-than predicates, the lowest value is most restrictive.
177+
///
178+
/// # Arguments
179+
/// * `predicates` - A slice of `Expr` representing predicates to compare.
180+
/// * `find_greater` - A boolean indicating whether to find the highest value (true for >, >=)
181+
/// or the lowest value (false for <, <=).
182+
///
183+
/// # Returns
184+
/// A `Result` containing an `Option<Expr>` with the most restrictive predicate, if any.
185+
fn find_most_restrictive_predicate(
186+
predicates: &[Expr],
187+
find_greater: bool,
188+
) -> Result<Option<Expr>> {
189+
if predicates.is_empty() {
190+
return Ok(None);
191+
}
192+
193+
let mut most_restrictive_idx = 0;
194+
let mut best_value: Option<&ScalarValue> = None;
195+
196+
for (idx, pred) in predicates.iter().enumerate() {
197+
if let Expr::BinaryExpr(BinaryExpr { left, op: _, right }) = pred {
198+
// Extract the literal value based on which side has it
199+
let scalar_value = match (right.as_literal(), left.as_literal()) {
200+
(Some(scalar), _) => Some(scalar),
201+
(_, Some(scalar)) => Some(scalar),
202+
_ => None,
203+
};
204+
205+
if let Some(scalar) = scalar_value {
206+
if let Some(current_best) = best_value {
207+
if let Some(comparison) = scalar.partial_cmp(current_best) {
208+
let is_better = if find_greater {
209+
comparison == std::cmp::Ordering::Greater
210+
} else {
211+
comparison == std::cmp::Ordering::Less
212+
};
213+
214+
if is_better {
215+
best_value = Some(scalar);
216+
most_restrictive_idx = idx;
217+
}
218+
}
219+
} else {
220+
best_value = Some(scalar);
221+
most_restrictive_idx = idx;
222+
}
223+
}
224+
}
225+
}
226+
227+
Ok(Some(predicates[most_restrictive_idx].clone()))
228+
}
229+
230+
/// Extracts a column reference from an expression, if present.
231+
///
232+
/// This function checks if the given expression is a column reference or contains one,
233+
/// such as within a cast operation. It returns the `Column` if found.
234+
///
235+
/// # Arguments
236+
/// * `expr` - A reference to an `Expr` to inspect for a column reference.
237+
///
238+
/// # Returns
239+
/// An `Option<Column>` containing the column reference if found, otherwise `None`.
240+
fn extract_column_from_expr(expr: &Expr) -> Option<Column> {
241+
match expr {
242+
Expr::Column(col) => Some(col.clone()),
243+
// Handle cases where the column might be wrapped in a cast or other operation
244+
Expr::Cast(Cast { expr, .. }) => extract_column_from_expr(expr),
245+
_ => None,
246+
}
247+
}

0 commit comments

Comments
 (0)