Skip to content

Commit cac222f

Browse files
committed
fix: pass filters pushed earlier to supports_filter_pushdown
Part of apache#19929 Let "t" be a table provider that supports exactly any single filter but not a conjunction. Consider the following optimizer pipeline: 1. Try to push `a = 1, b = 1`. `supports_filters_pushdown` returns [Exact, Inexact] OK: the optimizer records that a = 1 is pushed and creates a filter node for b = 1. ... Another optimization iteration. 2. Try to push `b = 1`. `supports_filters_pushdown` returns [Exact]. Of course, the table provider can't remember all previously pushed filters, so it has no choice but to answer `Exact`. Now, the optimizer thinks the conjunction a = 1 AND b = 1 is supported exactly, but it is not. To prevent this problem, this patch passes filters that were already pushed into the scan earlier to `supports_filters_pushdown`.
1 parent b2c29ac commit cac222f

File tree

2 files changed

+83
-12
lines changed

2 files changed

+83
-12
lines changed

datafusion/optimizer/src/push_down_filter.rs

Lines changed: 82 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1126,7 +1126,12 @@ impl OptimizerRule for PushDownFilter {
11261126
}
11271127
LogicalPlan::Join(join) => push_down_join(join, Some(&filter.predicate)),
11281128
LogicalPlan::TableScan(scan) => {
1129-
let filter_predicates = split_conjunction(&filter.predicate);
1129+
let filter_predicates: Vec<_> = split_conjunction(&filter.predicate)
1130+
.into_iter()
1131+
// Add already pushed filters.
1132+
.chain(scan.filters.iter())
1133+
.unique()
1134+
.collect();
11301135

11311136
let (volatile_filters, non_volatile_filters): (Vec<&Expr>, Vec<&Expr>) =
11321137
filter_predicates
@@ -1154,13 +1159,8 @@ impl OptimizerRule for PushDownFilter {
11541159
.map(|(pred, _)| pred);
11551160

11561161
// Add new scan filters
1157-
let new_scan_filters: Vec<Expr> = scan
1158-
.filters
1159-
.iter()
1160-
.chain(new_scan_filters)
1161-
.unique()
1162-
.cloned()
1163-
.collect();
1162+
let new_scan_filters: Vec<Expr> =
1163+
new_scan_filters.unique().cloned().collect();
11641164

11651165
// Compose predicates to be of `Unsupported` or `Inexact` pushdown type, and also include volatile filters
11661166
let new_predicate: Vec<Expr> = zip
@@ -1438,8 +1438,8 @@ mod tests {
14381438
use datafusion_expr::{
14391439
ColumnarValue, ExprFunctionExt, Extension, LogicalPlanBuilder,
14401440
ScalarFunctionArgs, ScalarUDF, ScalarUDFImpl, Signature, TableSource, TableType,
1441-
UserDefinedLogicalNodeCore, Volatility, WindowFunctionDefinition, col, in_list,
1442-
in_subquery, lit,
1441+
UserDefinedLogicalNodeCore, Volatility, WindowFunctionDefinition, binary_expr,
1442+
col, in_list, in_subquery, lit,
14431443
};
14441444

14451445
use crate::OptimizerContext;
@@ -1459,7 +1459,17 @@ mod tests {
14591459
$plan:expr,
14601460
@ $expected:literal $(,)?
14611461
) => {{
1462-
let optimizer_ctx = OptimizerContext::new().with_max_passes(1);
1462+
assert_optimized_plan_equal_with_pases_num!($plan, 1, @ $expected,)
1463+
}};
1464+
}
1465+
1466+
macro_rules! assert_optimized_plan_equal_with_pases_num {
1467+
(
1468+
$plan:expr,
1469+
$max_pases: expr,
1470+
@ $expected:literal $(,)?
1471+
) => {{
1472+
let optimizer_ctx = OptimizerContext::new().with_max_passes($max_pases);
14631473
let rules: Vec<Arc<dyn crate::OptimizerRule + Send + Sync>> = vec![Arc::new(PushDownFilter::new())];
14641474
assert_optimized_plan_eq_snapshot!(
14651475
optimizer_ctx,
@@ -4223,4 +4233,65 @@ mod tests {
42234233
"
42244234
)
42254235
}
4236+
4237+
struct SingleFilterSupportSource {
4238+
schema: SchemaRef,
4239+
}
4240+
4241+
#[async_trait]
4242+
impl TableSource for SingleFilterSupportSource {
4243+
fn schema(&self) -> SchemaRef {
4244+
Arc::clone(&self.schema)
4245+
}
4246+
4247+
fn table_type(&self) -> TableType {
4248+
TableType::Base
4249+
}
4250+
4251+
fn supports_filters_pushdown(
4252+
&self,
4253+
filters: &[&Expr],
4254+
) -> Result<Vec<TableProviderFilterPushDown>> {
4255+
// Support exactly any single filter.
4256+
let mut res = vec![TableProviderFilterPushDown::Unsupported; filters.len()];
4257+
res[0] = TableProviderFilterPushDown::Exact;
4258+
Ok(res)
4259+
}
4260+
4261+
fn as_any(&self) -> &dyn Any {
4262+
self
4263+
}
4264+
}
4265+
4266+
#[test]
4267+
fn test_pushed_filters_passed_again() -> Result<()> {
4268+
let schema = Arc::new(Schema::new(vec![
4269+
Field::new("a", DataType::Int32, true),
4270+
Field::new("b", DataType::Int32, true),
4271+
]));
4272+
4273+
let plan = LogicalPlanBuilder::scan(
4274+
"t",
4275+
Arc::new(SingleFilterSupportSource {
4276+
schema: Arc::clone(&schema),
4277+
}),
4278+
None,
4279+
)?
4280+
.filter(binary_expr(
4281+
binary_expr(col("a"), Operator::Eq, lit(1)),
4282+
Operator::And,
4283+
binary_expr(col("b"), Operator::Eq, lit(4)),
4284+
))?
4285+
.build()?;
4286+
4287+
// Verify that the only one filter is pushed.
4288+
assert_optimized_plan_equal_with_pases_num!(
4289+
plan,
4290+
5,
4291+
@r"
4292+
Filter: t.b = Int32(4)
4293+
TableScan: t, full_filters=[t.a = Int32(1)]
4294+
"
4295+
)
4296+
}
42264297
}

datafusion/sqllogictest/test_files/predicates.slt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -666,7 +666,7 @@ logical_plan
666666
03)----Filter: lineitem.l_quantity >= Decimal128(Some(100),15,2) AND lineitem.l_quantity <= Decimal128(Some(1100),15,2) OR lineitem.l_quantity >= Decimal128(Some(1000),15,2) AND lineitem.l_quantity <= Decimal128(Some(2000),15,2) OR lineitem.l_quantity >= Decimal128(Some(2000),15,2) AND lineitem.l_quantity <= Decimal128(Some(3000),15,2)
667667
04)------TableScan: lineitem projection=[l_partkey, l_quantity], partial_filters=[lineitem.l_quantity >= Decimal128(Some(100),15,2) AND lineitem.l_quantity <= Decimal128(Some(1100),15,2) OR lineitem.l_quantity >= Decimal128(Some(1000),15,2) AND lineitem.l_quantity <= Decimal128(Some(2000),15,2) OR lineitem.l_quantity >= Decimal128(Some(2000),15,2) AND lineitem.l_quantity <= Decimal128(Some(3000),15,2)]
668668
05)----Filter: (part.p_brand = Utf8View("Brand#12") AND part.p_size <= Int32(5) OR part.p_brand = Utf8View("Brand#23") AND part.p_size <= Int32(10) OR part.p_brand = Utf8View("Brand#34") AND part.p_size <= Int32(15)) AND part.p_size >= Int32(1)
669-
06)------TableScan: part projection=[p_partkey, p_brand, p_size], partial_filters=[part.p_size >= Int32(1), part.p_brand = Utf8View("Brand#12") AND part.p_size <= Int32(5) OR part.p_brand = Utf8View("Brand#23") AND part.p_size <= Int32(10) OR part.p_brand = Utf8View("Brand#34") AND part.p_size <= Int32(15)]
669+
06)------TableScan: part projection=[p_partkey, p_brand, p_size], partial_filters=[part.p_brand = Utf8View("Brand#12") AND part.p_size <= Int32(5) OR part.p_brand = Utf8View("Brand#23") AND part.p_size <= Int32(10) OR part.p_brand = Utf8View("Brand#34") AND part.p_size <= Int32(15), part.p_size >= Int32(1)]
670670
physical_plan
671671
01)HashJoinExec: mode=Partitioned, join_type=Inner, on=[(l_partkey@0, p_partkey@0)], filter=p_brand@1 = Brand#12 AND l_quantity@0 >= Some(100),15,2 AND l_quantity@0 <= Some(1100),15,2 AND p_size@2 <= 5 OR p_brand@1 = Brand#23 AND l_quantity@0 >= Some(1000),15,2 AND l_quantity@0 <= Some(2000),15,2 AND p_size@2 <= 10 OR p_brand@1 = Brand#34 AND l_quantity@0 >= Some(2000),15,2 AND l_quantity@0 <= Some(3000),15,2 AND p_size@2 <= 15, projection=[l_partkey@0]
672672
02)--RepartitionExec: partitioning=Hash([l_partkey@0], 4), input_partitions=4

0 commit comments

Comments
 (0)