Skip to content

Commit a59d2fd

Browse files
committed
extend a scan projection if some pushed filters become unsupported
Consider the following optimizer-run scenario: 1. `supports_filters_pushdown` returns `Exact` on some filter, e.g. "a = 1", where the column "a" is not required by the query projection. 2. "a" is removed from the table provider projection by "optimize projection" rule. 3. `supports_filters_pushdown` changes a decision and returns `Inexact` on this filter the next time. e.g., input filters are changed and it prefers to use a new one. 4. "a" is not returned to the table provider projection which leads to filter that references a column which is not a part of the input schema. This patch fixes issue introducing the following logic within a filter push-down rule: 1. Collect columns that are not used in the current table provider scan projection, but required for filter expressions. Call it `additional_projection`. 2. If `additional_projection` is empty -- leave logic as is prior the patch. 3. Otherwise extend a table provider projection and wrap a plan with an additional projection node to preserve schema used prior to the rule.
1 parent 653f415 commit a59d2fd

File tree

3 files changed

+132
-21
lines changed

3 files changed

+132
-21
lines changed

datafusion/optimizer/src/push_down_filter.rs

Lines changed: 128 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,8 @@ use datafusion_expr::utils::{
4040
conjunction, expr_to_columns, split_conjunction, split_conjunction_owned,
4141
};
4242
use datafusion_expr::{
43-
BinaryExpr, Expr, Filter, Operator, Projection, TableProviderFilterPushDown, and, or,
43+
BinaryExpr, Expr, Filter, LogicalPlanBuilder, Operator, Projection,
44+
TableProviderFilterPushDown, and, or,
4445
};
4546

4647
use crate::optimizer::ApplyOrder;
@@ -1132,7 +1133,7 @@ impl OptimizerRule for PushDownFilter {
11321133
LogicalPlan::TableScan(scan) => {
11331134
let filter_predicates: Vec<_> = split_conjunction(&filter.predicate)
11341135
.into_iter()
1135-
// Add already pushed filters.
1136+
// Add already pushed filters to ensure that the rule is idempotent.
11361137
.chain(scan.filters.iter())
11371138
.unique()
11381139
.collect();
@@ -1166,26 +1167,106 @@ impl OptimizerRule for PushDownFilter {
11661167
let new_scan_filters: Vec<Expr> =
11671168
new_scan_filters.unique().cloned().collect();
11681169

1170+
let source_schema = scan.source.schema();
1171+
let mut additional_projection = HashSet::new();
1172+
11691173
// Compose predicates to be of `Unsupported` or `Inexact` pushdown type, and also include volatile filters
11701174
let new_predicate: Vec<Expr> = zip
1171-
.filter(|(_, res)| res != &TableProviderFilterPushDown::Exact)
1175+
.filter(|(expr, res)| {
1176+
if *res == TableProviderFilterPushDown::Exact {
1177+
return false;
1178+
}
1179+
// For each not exactly supported filter we must ensure that all columns are projected,
1180+
// so we collect all columns which are not currently projected.
1181+
expr.apply(|expr| {
1182+
if let Expr::Column(column) = expr
1183+
&& let Ok(idx) = source_schema.index_of(column.name())
1184+
&& scan
1185+
.projection
1186+
.as_ref()
1187+
.is_some_and(|p| !p.contains(&idx))
1188+
{
1189+
additional_projection.insert(idx);
1190+
}
1191+
Ok(TreeNodeRecursion::Continue)
1192+
})
1193+
.unwrap();
1194+
true
1195+
})
11721196
.map(|(pred, _)| pred)
11731197
.chain(volatile_filters)
11741198
.cloned()
11751199
.collect();
11761200

1177-
let new_scan = LogicalPlan::TableScan(TableScan {
1178-
filters: new_scan_filters,
1179-
..scan
1180-
});
1181-
1182-
Transformed::yes(new_scan).transform_data(|new_scan| {
1183-
if let Some(predicate) = conjunction(new_predicate) {
1184-
make_filter(predicate, Arc::new(new_scan)).map(Transformed::yes)
1201+
// Wraps with a filter if some filters are not supported exactly.
1202+
let filtered = move |plan| {
1203+
if let Some(new_predicate) = conjunction(new_predicate) {
1204+
Filter::try_new(new_predicate, Arc::new(plan))
1205+
.map(LogicalPlan::Filter)
11851206
} else {
1186-
Ok(Transformed::no(new_scan))
1207+
Ok(plan)
11871208
}
1188-
})
1209+
};
1210+
1211+
if additional_projection.is_empty() {
1212+
// No additional projection is required.
1213+
let new_scan = LogicalPlan::TableScan(TableScan {
1214+
filters: new_scan_filters,
1215+
..scan
1216+
});
1217+
return filtered(new_scan).map(Transformed::yes);
1218+
}
1219+
1220+
let scan_table_name = &scan.table_name;
1221+
let new_scan = filtered(
1222+
LogicalPlanBuilder::scan_with_filters_fetch(
1223+
scan_table_name.clone(),
1224+
Arc::clone(&scan.source),
1225+
scan.projection.clone().map(|mut projection| {
1226+
// Extend a projection.
1227+
projection.extend(additional_projection);
1228+
projection
1229+
}),
1230+
new_scan_filters,
1231+
scan.fetch,
1232+
)?
1233+
.build()?,
1234+
)?;
1235+
1236+
// Project fields required by the initial projection.
1237+
let new_plan = LogicalPlan::Projection(Projection::try_new_with_schema(
1238+
scan.projection
1239+
.as_ref()
1240+
.map(|projection| {
1241+
projection
1242+
.iter()
1243+
.cloned()
1244+
.map(|idx| {
1245+
Expr::Column(Column::new(
1246+
Some(scan_table_name.clone()),
1247+
source_schema.field(idx).name(),
1248+
))
1249+
})
1250+
.collect()
1251+
})
1252+
.unwrap_or_else(|| {
1253+
source_schema
1254+
.fields()
1255+
.iter()
1256+
.map(|field| {
1257+
Expr::Column(Column::new(
1258+
Some(scan_table_name.clone()),
1259+
field.name(),
1260+
))
1261+
})
1262+
.collect()
1263+
}),
1264+
Arc::new(new_scan),
1265+
// Preserve a projected schema metadata.
1266+
scan.projected_schema,
1267+
)?);
1268+
1269+
Ok(Transformed::yes(new_plan))
11891270
}
11901271
LogicalPlan::Extension(extension_plan) => {
11911272
// This check prevents the Filter from being removed when the extension node has no children,
@@ -3235,7 +3316,7 @@ mod tests {
32353316
let plan = table_scan_with_pushdown_provider_builder(
32363317
TableProviderFilterPushDown::Inexact,
32373318
vec![col("a").eq(lit(10i64)), col("b").gt(lit(11i64))],
3238-
Some(vec![0]),
3319+
Some(vec![0, 1]),
32393320
)?
32403321
.filter(and(col("a").eq(lit(10i64)), col("b").gt(lit(11i64))))?
32413322
.project(vec![col("a"), col("b")])?
@@ -3246,7 +3327,7 @@ mod tests {
32463327
@r"
32473328
Projection: a, b
32483329
Filter: a = Int64(10) AND b > Int64(11)
3249-
TableScan: test projection=[a], partial_filters=[a = Int64(10), b > Int64(11)]
3330+
TableScan: test projection=[a, b], partial_filters=[a = Int64(10), b > Int64(11)]
32503331
"
32513332
)
32523333
}
@@ -3256,7 +3337,7 @@ mod tests {
32563337
let plan = table_scan_with_pushdown_provider_builder(
32573338
TableProviderFilterPushDown::Exact,
32583339
vec![],
3259-
Some(vec![0]),
3340+
Some(vec![0, 1]),
32603341
)?
32613342
.filter(and(col("a").eq(lit(10i64)), col("b").gt(lit(11i64))))?
32623343
.project(vec![col("a"), col("b")])?
@@ -3266,7 +3347,7 @@ mod tests {
32663347
plan,
32673348
@r"
32683349
Projection: a, b
3269-
TableScan: test projection=[a], full_filters=[a = Int64(10), b > Int64(11)]
3350+
TableScan: test projection=[a, b], full_filters=[a = Int64(10), b > Int64(11)]
32703351
"
32713352
)
32723353
}
@@ -4386,4 +4467,34 @@ mod tests {
43864467
"
43874468
)
43884469
}
4470+
4471+
#[test]
4472+
fn test_projection_is_updated_when_filter_becomes_unsupported() -> Result<()> {
4473+
let test_provider = PushDownProvider {
4474+
filter_support: TableProviderFilterPushDown::Unsupported,
4475+
};
4476+
4477+
let projected_schema = test_provider.schema().project(&[0])?;
4478+
let table_scan = LogicalPlan::TableScan(TableScan {
4479+
table_name: "test".into(),
4480+
// Emulate that there were pushed filters but now
4481+
// provider cannot support it.
4482+
filters: vec![col("b").eq(lit(1i64))],
4483+
projected_schema: Arc::new(DFSchema::try_from(projected_schema)?),
4484+
projection: Some(vec![0]),
4485+
source: Arc::new(test_provider),
4486+
fetch: None,
4487+
});
4488+
4489+
let plan = LogicalPlanBuilder::from(table_scan)
4490+
.filter(col("a").eq(lit(1i64)))?
4491+
.build()?;
4492+
4493+
assert_optimized_plan_equal!(plan,
4494+
@r"
4495+
Projection: test.a
4496+
Filter: a = Int64(1) AND b = Int64(1)
4497+
TableScan: test projection=[a, b]"
4498+
)
4499+
}
43894500
}

datafusion/sqllogictest/test_files/tpch/plans/q19.slt.part

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -60,9 +60,9 @@ logical_plan
6060
04)------Inner Join: lineitem.l_partkey = part.p_partkey Filter: part.p_brand = Utf8View("Brand#12") AND part.p_container IN ([Utf8View("SM CASE"), Utf8View("SM BOX"), Utf8View("SM PACK"), Utf8View("SM PKG")]) AND lineitem.l_quantity >= Decimal128(Some(100),15,2) AND lineitem.l_quantity <= Decimal128(Some(1100),15,2) AND part.p_size <= Int32(5) OR part.p_brand = Utf8View("Brand#23") AND part.p_container IN ([Utf8View("MED BAG"), Utf8View("MED BOX"), Utf8View("MED PKG"), Utf8View("MED PACK")]) AND lineitem.l_quantity >= Decimal128(Some(1000),15,2) AND lineitem.l_quantity <= Decimal128(Some(2000),15,2) AND part.p_size <= Int32(10) OR part.p_brand = Utf8View("Brand#34") AND part.p_container IN ([Utf8View("LG CASE"), Utf8View("LG BOX"), Utf8View("LG PACK"), Utf8View("LG PKG")]) AND lineitem.l_quantity >= Decimal128(Some(2000),15,2) AND lineitem.l_quantity <= Decimal128(Some(3000),15,2) AND part.p_size <= Int32(15)
6161
05)--------Projection: lineitem.l_partkey, lineitem.l_quantity, lineitem.l_extendedprice, lineitem.l_discount
6262
06)----------Filter: (lineitem.l_quantity >= Decimal128(Some(100),15,2) AND lineitem.l_quantity <= Decimal128(Some(1100),15,2) OR lineitem.l_quantity >= Decimal128(Some(1000),15,2) AND lineitem.l_quantity <= Decimal128(Some(2000),15,2) OR lineitem.l_quantity >= Decimal128(Some(2000),15,2) AND lineitem.l_quantity <= Decimal128(Some(3000),15,2)) AND (lineitem.l_shipmode = Utf8View("AIR") OR lineitem.l_shipmode = Utf8View("AIR REG")) AND lineitem.l_shipinstruct = Utf8View("DELIVER IN PERSON")
63-
07)------------TableScan: lineitem projection=[l_partkey, l_quantity, l_extendedprice, l_discount, l_shipinstruct, l_shipmode], partial_filters=[lineitem.l_shipmode = Utf8View("AIR") OR lineitem.l_shipmode = Utf8View("AIR REG"), lineitem.l_shipinstruct = Utf8View("DELIVER IN PERSON"), lineitem.l_quantity >= Decimal128(Some(100),15,2) AND lineitem.l_quantity <= Decimal128(Some(1100),15,2) OR lineitem.l_quantity >= Decimal128(Some(1000),15,2) AND lineitem.l_quantity <= Decimal128(Some(2000),15,2) OR lineitem.l_quantity >= Decimal128(Some(2000),15,2) AND lineitem.l_quantity <= Decimal128(Some(3000),15,2)]
63+
07)------------TableScan: lineitem projection=[l_partkey, l_quantity, l_extendedprice, l_discount, l_shipinstruct, l_shipmode], partial_filters=[lineitem.l_quantity >= Decimal128(Some(100),15,2) AND lineitem.l_quantity <= Decimal128(Some(1100),15,2) OR lineitem.l_quantity >= Decimal128(Some(1000),15,2) AND lineitem.l_quantity <= Decimal128(Some(2000),15,2) OR lineitem.l_quantity >= Decimal128(Some(2000),15,2) AND lineitem.l_quantity <= Decimal128(Some(3000),15,2), lineitem.l_shipmode = Utf8View("AIR") OR lineitem.l_shipmode = Utf8View("AIR REG"), lineitem.l_shipinstruct = Utf8View("DELIVER IN PERSON")]
6464
08)--------Filter: (part.p_brand = Utf8View("Brand#12") AND part.p_container IN ([Utf8View("SM CASE"), Utf8View("SM BOX"), Utf8View("SM PACK"), Utf8View("SM PKG")]) AND part.p_size <= Int32(5) OR part.p_brand = Utf8View("Brand#23") AND part.p_container IN ([Utf8View("MED BAG"), Utf8View("MED BOX"), Utf8View("MED PKG"), Utf8View("MED PACK")]) AND part.p_size <= Int32(10) OR part.p_brand = Utf8View("Brand#34") AND part.p_container IN ([Utf8View("LG CASE"), Utf8View("LG BOX"), Utf8View("LG PACK"), Utf8View("LG PKG")]) AND part.p_size <= Int32(15)) AND part.p_size >= Int32(1)
65-
09)----------TableScan: part projection=[p_partkey, p_brand, p_size, p_container], partial_filters=[part.p_size >= Int32(1), part.p_brand = Utf8View("Brand#12") AND part.p_container IN ([Utf8View("SM CASE"), Utf8View("SM BOX"), Utf8View("SM PACK"), Utf8View("SM PKG")]) AND part.p_size <= Int32(5) OR part.p_brand = Utf8View("Brand#23") AND part.p_container IN ([Utf8View("MED BAG"), Utf8View("MED BOX"), Utf8View("MED PKG"), Utf8View("MED PACK")]) AND part.p_size <= Int32(10) OR part.p_brand = Utf8View("Brand#34") AND part.p_container IN ([Utf8View("LG CASE"), Utf8View("LG BOX"), Utf8View("LG PACK"), Utf8View("LG PKG")]) AND part.p_size <= Int32(15)]
65+
09)----------TableScan: part projection=[p_partkey, p_brand, p_size, p_container], partial_filters=[part.p_brand = Utf8View("Brand#12") AND part.p_container IN ([Utf8View("SM CASE"), Utf8View("SM BOX"), Utf8View("SM PACK"), Utf8View("SM PKG")]) AND part.p_size <= Int32(5) OR part.p_brand = Utf8View("Brand#23") AND part.p_container IN ([Utf8View("MED BAG"), Utf8View("MED BOX"), Utf8View("MED PKG"), Utf8View("MED PACK")]) AND part.p_size <= Int32(10) OR part.p_brand = Utf8View("Brand#34") AND part.p_container IN ([Utf8View("LG CASE"), Utf8View("LG BOX"), Utf8View("LG PACK"), Utf8View("LG PKG")]) AND part.p_size <= Int32(15), part.p_size >= Int32(1)]
6666
physical_plan
6767
01)ProjectionExec: expr=[sum(lineitem.l_extendedprice * Int64(1) - lineitem.l_discount)@0 as revenue]
6868
02)--AggregateExec: mode=Final, gby=[], aggr=[sum(lineitem.l_extendedprice * Int64(1) - lineitem.l_discount)]

datafusion/sqllogictest/test_files/tpch/plans/q22.slt.part

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@ logical_plan
6464
06)----------Inner Join: Filter: CAST(customer.c_acctbal AS Decimal128(19, 6)) > __scalar_sq_2.avg(customer.c_acctbal)
6565
07)------------Projection: customer.c_phone, customer.c_acctbal
6666
08)--------------LeftAnti Join: customer.c_custkey = __correlated_sq_1.o_custkey
67-
09)----------------Filter: substr(customer.c_phone, Int64(1), Int64(2)) IN ([Utf8View("13"), Utf8View("31"), Utf8View("23"), Utf8View("29"), Utf8View("30"), Utf8View("18"), Utf8View("17")])
67+
09)----------------Filter: substr(customer.c_phone, Int64(1), Int64(2)) IN ([Utf8View("13"), Utf8View("31"), Utf8View("23"), Utf8View("29"), Utf8View("30"), Utf8View("18"), Utf8View("17")]) AND Boolean(true)
6868
10)------------------TableScan: customer projection=[c_custkey, c_phone, c_acctbal], partial_filters=[substr(customer.c_phone, Int64(1), Int64(2)) IN ([Utf8View("13"), Utf8View("31"), Utf8View("23"), Utf8View("29"), Utf8View("30"), Utf8View("18"), Utf8View("17")]), Boolean(true)]
6969
11)----------------SubqueryAlias: __correlated_sq_1
7070
12)------------------TableScan: orders projection=[o_custkey]
@@ -87,7 +87,7 @@ physical_plan
8787
11)--------------------CoalescePartitionsExec
8888
12)----------------------HashJoinExec: mode=Partitioned, join_type=LeftAnti, on=[(c_custkey@0, o_custkey@0)], projection=[c_phone@1, c_acctbal@2]
8989
13)------------------------RepartitionExec: partitioning=Hash([c_custkey@0], 4), input_partitions=4
90-
14)--------------------------FilterExec: substr(c_phone@1, 1, 2) IN (SET) ([13, 31, 23, 29, 30, 18, 17])
90+
14)--------------------------FilterExec: substr(c_phone@1, 1, 2) IN (SET) ([13, 31, 23, 29, 30, 18, 17]) AND true
9191
15)----------------------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1
9292
16)------------------------------DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/customer.tbl]]}, projection=[c_custkey, c_phone, c_acctbal], file_type=csv, has_header=false
9393
17)------------------------RepartitionExec: partitioning=Hash([o_custkey@0], 4), input_partitions=4

0 commit comments

Comments
 (0)