Skip to content

Commit 0b45b9a

Browse files
authored
Improve TableScan with filters pushdown unparsing (joins) (#13132)
* Improve TableScan with filters pushdown unparsing (joins) * Fix formatting * Add test with filters before and after join
1 parent 1fd6116 commit 0b45b9a

File tree

3 files changed

+240
-17
lines changed

3 files changed

+240
-17
lines changed

datafusion/sql/src/unparser/plan.rs

Lines changed: 66 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -27,8 +27,8 @@ use super::{
2727
},
2828
utils::{
2929
find_agg_node_within_select, find_unnest_node_within_select,
30-
find_window_nodes_within_select, unproject_sort_expr, unproject_unnest_expr,
31-
unproject_window_exprs,
30+
find_window_nodes_within_select, try_transform_to_simple_table_scan_with_filters,
31+
unproject_sort_expr, unproject_unnest_expr, unproject_window_exprs,
3232
},
3333
Unparser,
3434
};
@@ -39,8 +39,8 @@ use datafusion_common::{
3939
Column, DataFusionError, Result, TableReference,
4040
};
4141
use datafusion_expr::{
42-
expr::Alias, Distinct, Expr, JoinConstraint, JoinType, LogicalPlan,
43-
LogicalPlanBuilder, Projection, SortExpr, TableScan,
42+
expr::Alias, BinaryExpr, Distinct, Expr, JoinConstraint, JoinType, LogicalPlan,
43+
LogicalPlanBuilder, Operator, Projection, SortExpr, TableScan,
4444
};
4545
use sqlparser::ast::{self, Ident, SetExpr};
4646
use std::sync::Arc;
@@ -468,22 +468,77 @@ impl Unparser<'_> {
468468
self.select_to_sql_recursively(input, query, select, relation)
469469
}
470470
LogicalPlan::Join(join) => {
471-
let join_constraint = self.join_constraint_to_sql(
472-
join.join_constraint,
473-
&join.on,
474-
join.filter.as_ref(),
471+
let mut table_scan_filters = vec![];
472+
473+
let left_plan =
474+
match try_transform_to_simple_table_scan_with_filters(&join.left)? {
475+
Some((plan, filters)) => {
476+
table_scan_filters.extend(filters);
477+
Arc::new(plan)
478+
}
479+
None => Arc::clone(&join.left),
480+
};
481+
482+
self.select_to_sql_recursively(
483+
left_plan.as_ref(),
484+
query,
485+
select,
486+
relation,
475487
)?;
476488

489+
let right_plan =
490+
match try_transform_to_simple_table_scan_with_filters(&join.right)? {
491+
Some((plan, filters)) => {
492+
table_scan_filters.extend(filters);
493+
Arc::new(plan)
494+
}
495+
None => Arc::clone(&join.right),
496+
};
497+
477498
let mut right_relation = RelationBuilder::default();
478499

479500
self.select_to_sql_recursively(
480-
join.left.as_ref(),
501+
right_plan.as_ref(),
481502
query,
482503
select,
483-
relation,
504+
&mut right_relation,
484505
)?;
506+
507+
let join_filters = if table_scan_filters.is_empty() {
508+
join.filter.clone()
509+
} else {
510+
// Combine `table_scan_filters` into a single filter using `AND`
511+
let Some(combined_filters) =
512+
table_scan_filters.into_iter().reduce(|acc, filter| {
513+
Expr::BinaryExpr(BinaryExpr {
514+
left: Box::new(acc),
515+
op: Operator::And,
516+
right: Box::new(filter),
517+
})
518+
})
519+
else {
520+
return internal_err!("Failed to combine TableScan filters");
521+
};
522+
523+
// Combine `join.filter` with `combined_filters` using `AND`
524+
match &join.filter {
525+
Some(filter) => Some(Expr::BinaryExpr(BinaryExpr {
526+
left: Box::new(filter.clone()),
527+
op: Operator::And,
528+
right: Box::new(combined_filters),
529+
})),
530+
None => Some(combined_filters),
531+
}
532+
};
533+
534+
let join_constraint = self.join_constraint_to_sql(
535+
join.join_constraint,
536+
&join.on,
537+
join_filters.as_ref(),
538+
)?;
539+
485540
self.select_to_sql_recursively(
486-
join.right.as_ref(),
541+
right_plan.as_ref(),
487542
query,
488543
select,
489544
&mut right_relation,

datafusion/sql/src/unparser/utils.rs

Lines changed: 87 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -15,20 +15,20 @@
1515
// specific language governing permissions and limitations
1616
// under the License.
1717

18-
use std::cmp::Ordering;
18+
use std::{cmp::Ordering, sync::Arc, vec};
1919

2020
use datafusion_common::{
2121
internal_err,
22-
tree_node::{Transformed, TreeNode},
23-
Column, Result, ScalarValue,
22+
tree_node::{Transformed, TransformedResult, TreeNode},
23+
Column, DataFusionError, Result, ScalarValue,
2424
};
2525
use datafusion_expr::{
26-
expr, utils::grouping_set_to_exprlist, Aggregate, Expr, LogicalPlan, Projection,
27-
SortExpr, Unnest, Window,
26+
expr, utils::grouping_set_to_exprlist, Aggregate, Expr, LogicalPlan,
27+
LogicalPlanBuilder, Projection, SortExpr, Unnest, Window,
2828
};
2929
use sqlparser::ast;
3030

31-
use super::{dialect::DateFieldExtractStyle, Unparser};
31+
use super::{dialect::DateFieldExtractStyle, rewrite::TableAliasRewriter, Unparser};
3232

3333
/// Recursively searches children of [LogicalPlan] to find an Aggregate node if exists
3434
/// prior to encountering a Join, TableScan, or a nested subquery (derived table factor).
@@ -288,6 +288,87 @@ pub(crate) fn unproject_sort_expr(
288288
Ok(sort_expr)
289289
}
290290

291+
/// Iterates through the children of a [LogicalPlan] to find a TableScan node before encountering
292+
/// a Projection or any unexpected node that indicates the presence of a Projection (SELECT) in the plan.
293+
/// If a TableScan node is found, returns the TableScan node without filters, along with the collected filters separately.
294+
/// If the plan contains a Projection, returns None.
295+
///
296+
/// Note: If a table alias is present, TableScan filters are rewritten to reference the alias.
297+
///
298+
/// LogicalPlan example:
299+
/// Filter: ta.j1_id < 5
300+
/// Alias: ta
301+
/// TableScan: j1, j1_id > 10
302+
///
303+
/// Will return LogicalPlan below:
304+
/// Alias: ta
305+
/// TableScan: j1
306+
/// And filters: [ta.j1_id < 5, ta.j1_id > 10]
307+
pub(crate) fn try_transform_to_simple_table_scan_with_filters(
308+
plan: &LogicalPlan,
309+
) -> Result<Option<(LogicalPlan, Vec<Expr>)>> {
310+
let mut filters: Vec<Expr> = vec![];
311+
let mut plan_stack = vec![plan];
312+
let mut table_alias = None;
313+
314+
while let Some(current_plan) = plan_stack.pop() {
315+
match current_plan {
316+
LogicalPlan::SubqueryAlias(alias) => {
317+
table_alias = Some(alias.alias.clone());
318+
plan_stack.push(alias.input.as_ref());
319+
}
320+
LogicalPlan::Filter(filter) => {
321+
filters.push(filter.predicate.clone());
322+
plan_stack.push(filter.input.as_ref());
323+
}
324+
LogicalPlan::TableScan(table_scan) => {
325+
let table_schema = table_scan.source.schema();
326+
// optional rewriter if table has an alias
327+
let mut filter_alias_rewriter =
328+
table_alias.as_ref().map(|alias_name| TableAliasRewriter {
329+
table_schema: &table_schema,
330+
alias_name: alias_name.clone(),
331+
});
332+
333+
// rewrite filters to use table alias if present
334+
let table_scan_filters = table_scan
335+
.filters
336+
.iter()
337+
.cloned()
338+
.map(|expr| {
339+
if let Some(ref mut rewriter) = filter_alias_rewriter {
340+
expr.rewrite(rewriter).data()
341+
} else {
342+
Ok(expr)
343+
}
344+
})
345+
.collect::<Result<Vec<_>, DataFusionError>>()?;
346+
347+
filters.extend(table_scan_filters);
348+
349+
let mut builder = LogicalPlanBuilder::scan(
350+
table_scan.table_name.clone(),
351+
Arc::clone(&table_scan.source),
352+
None,
353+
)?;
354+
355+
if let Some(alias) = table_alias.take() {
356+
builder = builder.alias(alias)?;
357+
}
358+
359+
let plan = builder.build()?;
360+
361+
return Ok(Some((plan, filters)));
362+
}
363+
_ => {
364+
return Ok(None);
365+
}
366+
}
367+
}
368+
369+
Ok(None)
370+
}
371+
291372
/// Converts a date_part function to SQL, tailoring it to the supported date field extraction style.
292373
pub(crate) fn date_part_to_sql(
293374
unparser: &Unparser,

datafusion/sql/tests/cases/plan_to_sql.rs

Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1008,6 +1008,93 @@ fn test_sort_with_push_down_fetch() -> Result<()> {
10081008
Ok(())
10091009
}
10101010

1011+
#[test]
1012+
fn test_join_with_table_scan_filters() -> Result<()> {
1013+
let schema_left = Schema::new(vec![
1014+
Field::new("id", DataType::Utf8, false),
1015+
Field::new("name", DataType::Utf8, false),
1016+
]);
1017+
1018+
let schema_right = Schema::new(vec![
1019+
Field::new("id", DataType::Utf8, false),
1020+
Field::new("age", DataType::Utf8, false),
1021+
]);
1022+
1023+
let left_plan = table_scan_with_filters(
1024+
Some("left_table"),
1025+
&schema_left,
1026+
None,
1027+
vec![col("name").like(lit("some_name"))],
1028+
)?
1029+
.alias("left")?
1030+
.build()?;
1031+
1032+
let right_plan = table_scan_with_filters(
1033+
Some("right_table"),
1034+
&schema_right,
1035+
None,
1036+
vec![col("age").gt(lit(10))],
1037+
)?
1038+
.build()?;
1039+
1040+
let join_plan_with_filter = LogicalPlanBuilder::from(left_plan.clone())
1041+
.join(
1042+
right_plan.clone(),
1043+
datafusion_expr::JoinType::Inner,
1044+
(vec!["left.id"], vec!["right_table.id"]),
1045+
Some(col("left.id").gt(lit(5))),
1046+
)?
1047+
.build()?;
1048+
1049+
let sql = plan_to_sql(&join_plan_with_filter)?;
1050+
1051+
let expected_sql = r#"SELECT * FROM left_table AS "left" JOIN right_table ON "left".id = right_table.id AND (("left".id > 5) AND ("left"."name" LIKE 'some_name' AND (age > 10)))"#;
1052+
1053+
assert_eq!(sql.to_string(), expected_sql);
1054+
1055+
let join_plan_no_filter = LogicalPlanBuilder::from(left_plan.clone())
1056+
.join(
1057+
right_plan,
1058+
datafusion_expr::JoinType::Inner,
1059+
(vec!["left.id"], vec!["right_table.id"]),
1060+
None,
1061+
)?
1062+
.build()?;
1063+
1064+
let sql = plan_to_sql(&join_plan_no_filter)?;
1065+
1066+
let expected_sql = r#"SELECT * FROM left_table AS "left" JOIN right_table ON "left".id = right_table.id AND ("left"."name" LIKE 'some_name' AND (age > 10))"#;
1067+
1068+
assert_eq!(sql.to_string(), expected_sql);
1069+
1070+
let right_plan_with_filter = table_scan_with_filters(
1071+
Some("right_table"),
1072+
&schema_right,
1073+
None,
1074+
vec![col("age").gt(lit(10))],
1075+
)?
1076+
.filter(col("right_table.name").eq(lit("before_join_filter_val")))?
1077+
.build()?;
1078+
1079+
let join_plan_multiple_filters = LogicalPlanBuilder::from(left_plan.clone())
1080+
.join(
1081+
right_plan_with_filter,
1082+
datafusion_expr::JoinType::Inner,
1083+
(vec!["left.id"], vec!["right_table.id"]),
1084+
Some(col("left.id").gt(lit(5))),
1085+
)?
1086+
.filter(col("left.name").eq(lit("after_join_filter_val")))?
1087+
.build()?;
1088+
1089+
let sql = plan_to_sql(&join_plan_multiple_filters)?;
1090+
1091+
let expected_sql = r#"SELECT * FROM left_table AS "left" JOIN right_table ON "left".id = right_table.id AND (("left".id > 5) AND (("left"."name" LIKE 'some_name' AND (right_table."name" = 'before_join_filter_val')) AND (age > 10))) WHERE ("left"."name" = 'after_join_filter_val')"#;
1092+
1093+
assert_eq!(sql.to_string(), expected_sql);
1094+
1095+
Ok(())
1096+
}
1097+
10111098
#[test]
10121099
fn test_interval_lhs_eq() {
10131100
sql_round_trip(

0 commit comments

Comments
 (0)