Skip to content

Commit 4e5e765

Browse files
authored
Avoid pushdown of volatile functions to tablescan (#13475)
* Avoid pushdown of volatile functions to tablescan * Apply formatting and linting * Do not check for volatility for unsupported predicates * Check volatility before passing expressions to TableSource * Document non-volatile expresions contract for supports_filters_pushdown * Add test for Unsupported pushdown type * Refactor tests using supports_filters_pushdown
1 parent 207e855 commit 4e5e765

File tree

2 files changed

+134
-51
lines changed

2 files changed

+134
-51
lines changed

datafusion/expr/src/table_source.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,7 @@ pub trait TableSource: Sync + Send {
9999
}
100100

101101
/// Tests whether the table provider can make use of any or all filter expressions
102-
/// to optimise data retrieval.
102+
/// to optimise data retrieval. Only non-volatile expressions are passed to this function.
103103
fn supports_filters_pushdown(
104104
&self,
105105
filters: &[&Expr],

datafusion/optimizer/src/push_down_filter.rs

Lines changed: 133 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -988,32 +988,46 @@ impl OptimizerRule for PushDownFilter {
988988
LogicalPlan::Join(join) => push_down_join(join, Some(&filter.predicate)),
989989
LogicalPlan::TableScan(scan) => {
990990
let filter_predicates = split_conjunction(&filter.predicate);
991-
let results = scan
991+
992+
let (volatile_filters, non_volatile_filters): (Vec<&Expr>, Vec<&Expr>) =
993+
filter_predicates
994+
.into_iter()
995+
.partition(|pred| pred.is_volatile());
996+
997+
// Check which non-volatile filters are supported by source
998+
let supported_filters = scan
992999
.source
993-
.supports_filters_pushdown(filter_predicates.as_slice())?;
994-
if filter_predicates.len() != results.len() {
1000+
.supports_filters_pushdown(non_volatile_filters.as_slice())?;
1001+
if non_volatile_filters.len() != supported_filters.len() {
9951002
return internal_err!(
9961003
"Vec returned length: {} from supports_filters_pushdown is not the same size as the filters passed, which length is: {}",
997-
results.len(),
998-
filter_predicates.len());
1004+
supported_filters.len(),
1005+
non_volatile_filters.len());
9991006
}
10001007

1001-
let zip = filter_predicates.into_iter().zip(results);
1008+
// Compose scan filters from non-volatile filters of `Exact` or `Inexact` pushdown type
1009+
let zip = non_volatile_filters.into_iter().zip(supported_filters);
10021010

10031011
let new_scan_filters = zip
10041012
.clone()
10051013
.filter(|(_, res)| res != &TableProviderFilterPushDown::Unsupported)
10061014
.map(|(pred, _)| pred);
1015+
1016+
// Add new scan filters
10071017
let new_scan_filters: Vec<Expr> = scan
10081018
.filters
10091019
.iter()
10101020
.chain(new_scan_filters)
10111021
.unique()
10121022
.cloned()
10131023
.collect();
1024+
1025+
// Compose predicates to be of `Unsupported` or `Inexact` pushdown type, and also include volatile filters
10141026
let new_predicate: Vec<Expr> = zip
10151027
.filter(|(_, res)| res != &TableProviderFilterPushDown::Exact)
1016-
.map(|(pred, _)| pred.clone())
1028+
.map(|(pred, _)| pred)
1029+
.chain(volatile_filters)
1030+
.cloned()
10171031
.collect();
10181032

10191033
let new_scan = LogicalPlan::TableScan(TableScan {
@@ -2515,23 +2529,31 @@ mod tests {
25152529
}
25162530
}
25172531

2518-
fn table_scan_with_pushdown_provider(
2532+
fn table_scan_with_pushdown_provider_builder(
25192533
filter_support: TableProviderFilterPushDown,
2520-
) -> Result<LogicalPlan> {
2534+
filters: Vec<Expr>,
2535+
projection: Option<Vec<usize>>,
2536+
) -> Result<LogicalPlanBuilder> {
25212537
let test_provider = PushDownProvider { filter_support };
25222538

25232539
let table_scan = LogicalPlan::TableScan(TableScan {
25242540
table_name: "test".into(),
2525-
filters: vec![],
2541+
filters,
25262542
projected_schema: Arc::new(DFSchema::try_from(
25272543
(*test_provider.schema()).clone(),
25282544
)?),
2529-
projection: None,
2545+
projection,
25302546
source: Arc::new(test_provider),
25312547
fetch: None,
25322548
});
25332549

2534-
LogicalPlanBuilder::from(table_scan)
2550+
Ok(LogicalPlanBuilder::from(table_scan))
2551+
}
2552+
2553+
fn table_scan_with_pushdown_provider(
2554+
filter_support: TableProviderFilterPushDown,
2555+
) -> Result<LogicalPlan> {
2556+
table_scan_with_pushdown_provider_builder(filter_support, vec![], None)?
25352557
.filter(col("a").eq(lit(1i64)))?
25362558
.build()
25372559
}
@@ -2588,25 +2610,14 @@ mod tests {
25882610

25892611
#[test]
25902612
fn multi_combined_filter() -> Result<()> {
2591-
let test_provider = PushDownProvider {
2592-
filter_support: TableProviderFilterPushDown::Inexact,
2593-
};
2594-
2595-
let table_scan = LogicalPlan::TableScan(TableScan {
2596-
table_name: "test".into(),
2597-
filters: vec![col("a").eq(lit(10i64)), col("b").gt(lit(11i64))],
2598-
projected_schema: Arc::new(DFSchema::try_from(
2599-
(*test_provider.schema()).clone(),
2600-
)?),
2601-
projection: Some(vec![0]),
2602-
source: Arc::new(test_provider),
2603-
fetch: None,
2604-
});
2605-
2606-
let plan = LogicalPlanBuilder::from(table_scan)
2607-
.filter(and(col("a").eq(lit(10i64)), col("b").gt(lit(11i64))))?
2608-
.project(vec![col("a"), col("b")])?
2609-
.build()?;
2613+
let plan = table_scan_with_pushdown_provider_builder(
2614+
TableProviderFilterPushDown::Inexact,
2615+
vec![col("a").eq(lit(10i64)), col("b").gt(lit(11i64))],
2616+
Some(vec![0]),
2617+
)?
2618+
.filter(and(col("a").eq(lit(10i64)), col("b").gt(lit(11i64))))?
2619+
.project(vec![col("a"), col("b")])?
2620+
.build()?;
26102621

26112622
let expected = "Projection: a, b\
26122623
\n Filter: a = Int64(10) AND b > Int64(11)\
@@ -2617,25 +2628,14 @@ mod tests {
26172628

26182629
#[test]
26192630
fn multi_combined_filter_exact() -> Result<()> {
2620-
let test_provider = PushDownProvider {
2621-
filter_support: TableProviderFilterPushDown::Exact,
2622-
};
2623-
2624-
let table_scan = LogicalPlan::TableScan(TableScan {
2625-
table_name: "test".into(),
2626-
filters: vec![],
2627-
projected_schema: Arc::new(DFSchema::try_from(
2628-
(*test_provider.schema()).clone(),
2629-
)?),
2630-
projection: Some(vec![0]),
2631-
source: Arc::new(test_provider),
2632-
fetch: None,
2633-
});
2634-
2635-
let plan = LogicalPlanBuilder::from(table_scan)
2636-
.filter(and(col("a").eq(lit(10i64)), col("b").gt(lit(11i64))))?
2637-
.project(vec![col("a"), col("b")])?
2638-
.build()?;
2631+
let plan = table_scan_with_pushdown_provider_builder(
2632+
TableProviderFilterPushDown::Exact,
2633+
vec![],
2634+
Some(vec![0]),
2635+
)?
2636+
.filter(and(col("a").eq(lit(10i64)), col("b").gt(lit(11i64))))?
2637+
.project(vec![col("a"), col("b")])?
2638+
.build()?;
26392639

26402640
let expected = r#"
26412641
Projection: a, b
@@ -3385,4 +3385,87 @@ Projection: a, b
33853385
\n TableScan: test2";
33863386
assert_optimized_plan_eq(plan, expected)
33873387
}
3388+
3389+
#[test]
3390+
fn test_push_down_volatile_table_scan() -> Result<()> {
3391+
// SELECT test.a, test.b FROM test as t WHERE TestScalarUDF() > 0.1;
3392+
let table_scan = test_table_scan()?;
3393+
let fun = ScalarUDF::new_from_impl(TestScalarUDF {
3394+
signature: Signature::exact(vec![], Volatility::Volatile),
3395+
});
3396+
let expr = Expr::ScalarFunction(ScalarFunction::new_udf(Arc::new(fun), vec![]));
3397+
let plan = LogicalPlanBuilder::from(table_scan)
3398+
.project(vec![col("a"), col("b")])?
3399+
.filter(expr.gt(lit(0.1)))?
3400+
.build()?;
3401+
3402+
let expected_before = "Filter: TestScalarUDF() > Float64(0.1)\
3403+
\n Projection: test.a, test.b\
3404+
\n TableScan: test";
3405+
assert_eq!(format!("{plan}"), expected_before);
3406+
3407+
let expected_after = "Projection: test.a, test.b\
3408+
\n Filter: TestScalarUDF() > Float64(0.1)\
3409+
\n TableScan: test";
3410+
assert_optimized_plan_eq(plan, expected_after)
3411+
}
3412+
3413+
#[test]
3414+
fn test_push_down_volatile_mixed_table_scan() -> Result<()> {
3415+
// SELECT test.a, test.b FROM test as t WHERE TestScalarUDF() > 0.1 and test.a > 5 and test.b > 10;
3416+
let table_scan = test_table_scan()?;
3417+
let fun = ScalarUDF::new_from_impl(TestScalarUDF {
3418+
signature: Signature::exact(vec![], Volatility::Volatile),
3419+
});
3420+
let expr = Expr::ScalarFunction(ScalarFunction::new_udf(Arc::new(fun), vec![]));
3421+
let plan = LogicalPlanBuilder::from(table_scan)
3422+
.project(vec![col("a"), col("b")])?
3423+
.filter(
3424+
expr.gt(lit(0.1))
3425+
.and(col("t.a").gt(lit(5)))
3426+
.and(col("t.b").gt(lit(10))),
3427+
)?
3428+
.build()?;
3429+
3430+
let expected_before = "Filter: TestScalarUDF() > Float64(0.1) AND t.a > Int32(5) AND t.b > Int32(10)\
3431+
\n Projection: test.a, test.b\
3432+
\n TableScan: test";
3433+
assert_eq!(format!("{plan}"), expected_before);
3434+
3435+
let expected_after = "Projection: test.a, test.b\
3436+
\n Filter: TestScalarUDF() > Float64(0.1)\
3437+
\n TableScan: test, full_filters=[t.a > Int32(5), t.b > Int32(10)]";
3438+
assert_optimized_plan_eq(plan, expected_after)
3439+
}
3440+
3441+
#[test]
3442+
fn test_push_down_volatile_mixed_unsupported_table_scan() -> Result<()> {
3443+
// SELECT test.a, test.b FROM test as t WHERE TestScalarUDF() > 0.1 and test.a > 5 and test.b > 10;
3444+
let fun = ScalarUDF::new_from_impl(TestScalarUDF {
3445+
signature: Signature::exact(vec![], Volatility::Volatile),
3446+
});
3447+
let expr = Expr::ScalarFunction(ScalarFunction::new_udf(Arc::new(fun), vec![]));
3448+
let plan = table_scan_with_pushdown_provider_builder(
3449+
TableProviderFilterPushDown::Unsupported,
3450+
vec![],
3451+
None,
3452+
)?
3453+
.project(vec![col("a"), col("b")])?
3454+
.filter(
3455+
expr.gt(lit(0.1))
3456+
.and(col("t.a").gt(lit(5)))
3457+
.and(col("t.b").gt(lit(10))),
3458+
)?
3459+
.build()?;
3460+
3461+
let expected_before = "Filter: TestScalarUDF() > Float64(0.1) AND t.a > Int32(5) AND t.b > Int32(10)\
3462+
\n Projection: a, b\
3463+
\n TableScan: test";
3464+
assert_eq!(format!("{plan}"), expected_before);
3465+
3466+
let expected_after = "Projection: a, b\
3467+
\n Filter: t.a > Int32(5) AND t.b > Int32(10) AND TestScalarUDF() > Float64(0.1)\
3468+
\n TableScan: test";
3469+
assert_optimized_plan_eq(plan, expected_after)
3470+
}
33883471
}

0 commit comments

Comments
 (0)