Skip to content

Commit 4eacb60

Browse files
kosiewJefffrey
andauthored
Enable Projection Pushdown Optimization for Recursive CTEs (#16696)
* Add column pruning support for RecursiveQuery operator in optimizer - Extend optimize_projections to handle LogicalPlan::RecursiveQuery by applying projection pushdown to its inputs, improving query performance. - Add integration test `recursive_query_column_pruning` verifying plan shows correct projection pruning for recursive CTEs. - Implement `create_cte_work_table` in test context provider to support CTE tests. - Add .github/copilot-instructions.md and AGENTS.md docs with Rust coding, linting, formatting, and contribution guidelines to maintain quality and consistency in generated and user code. * Amend cte results * Add recursive_cte.rs integration test and improve recursive CTE projection pushdown in cte.slt - Add a new async test `recursive_cte_alias_instability` covering a complex recursive CTE query to the DataFusion core SQL tests, validating recursive CTE alias handling and query stability. - Enhance existing sqllogictest file `cte.slt` by adding column projections to recursive CTE TableScans, improving plan efficiency and consistency. - Fix projection pushdown for recursive CTEs in logical and physical plans for nodes, balances, recursive_cte, and numbers test cases. - This addresses alias instability and projection inefficiencies previously observed in recursive CTE handling and improves test coverage for recursive SQL features. * Remove redundant handling of RecursiveQuery in optimize_projections function * main cte.slt * Add tests for recursive CTE projection pushdown scenarios * Remove unused create_cte_work_table function from MyContextProvider implementation * Refactor TableScan projections in recursive query tests for clarity * consolidate recursive cte tests * Remove test that is included in slt * Enhance optimization for recursive queries by adding checks for problematic structures * Refactor recursive CTE tests to improve projection pushdown validation and add new test for alias instability * refactor: rename function to clarify purpose of subquery alias detection in recursive queries * test: add comments to clarify purpose of subquery alias handling in recursive CTE tests * Refactor `plan_contains_subquery_alias` to count subquery aliases instead of returning a bool - Changed `plan_contains_subquery_alias` to count occurrences of `SubqueryAlias` nodes in the plan. - Added helper function `count_subquery_aliases` to recursively update the count. - Updated logic to return true if there are two or more subquery aliases, preserving original behavior in `recursive_cte_alias_instability` test. * Optimize subquery alias counting by early termination - Updated `plan_contains_subquery_alias` to short-circuit counting once threshold (2) is reached. - Modified `count_subquery_aliases` to accept a threshold and return early when count meets or exceeds it. - Improves performance by avoiding unnecessary traversal of the entire plan. - Added integration test snapshot for recursive query column pruning reflecting these changes. * Refactor count_subquery_aliases to return count instead of using mutable reference - Changed count_subquery_aliases to take count usize and return updated count, removing mutable reference. - Added early-exit when count reaches threshold to avoid unnecessary traversal. - Updated is_projection_unnecessary to use new count_subquery_aliases signature. - Removed obsolete integration test snapshot file. - Fixed recursive CTE plan and physical plan in optimizer integration and sqllogictest to include projection pruning on TableScans inside recursive queries. This improves clarity, readability, and efficiency of subquery alias counting in logical plans, and fixes recursive query projections pruning as per related issue. * docs(tests): add note referencing similar SQL in cte.slt for recursive CTE alias test Add a comment in the recursive_cte_alias_instability test highlighting the similarity of the SQL query to one in datafusion/sqllogictest/test_files/cte.slt. This improves test clarity and cross-reference for maintainers. * fix: clarify comment regarding alias ambiguity in recursive CTEs feat: add initial example for explain_memory * refactor: remove recursive_cte test module * fix: add comment for clarity on recursive query handling This commit adds a comment to the `optimize_projections` function in the `mod.rs` file to clarify the handling of recursive queries. The comment references a discussion on GitHub related to the implementation. * remove stray file * Revert "refactor: remove recursive_cte test module" This reverts commit d8f6dd1. * remove RecursiveQuery bypass * refactor: improve handling of non-CTE subqueries in recursive queries * refactor: remove unused recursive_cte module from SQL tests * refactor: enhance projection optimization by handling non-CTE subqueries in recursive queries * refactor: simplify TableScan projections in recursive query tests * refactor: reorder filter and projection in recursive query logical plan * refactor: restrict projection pushdown in recursive queries to only allow CTE references * refactor: optimize TableScan projection in recursive query logical plan * fix: correct dynamic filter predicate order in hash join test * Update datafusion/optimizer/src/optimize_projections/mod.rs Co-authored-by: Jeffrey Vo <[email protected]> * Update datafusion/optimizer/src/optimize_projections/mod.rs Co-authored-by: Jeffrey Vo <[email protected]> * Enhance test description for recursive CTE projection pushdown * Rename plan_contains_non_cte_subquery to plan_contains_other_subqueries for clarity * Add test for recursive CTE with nested subquery to validate projection pushdown behavior * Remove recursive_query_column_pruning --------- Co-authored-by: Jeffrey Vo <[email protected]>
1 parent 52565a7 commit 4eacb60

File tree

3 files changed

+210
-11
lines changed

3 files changed

+210
-11
lines changed

datafusion/optimizer/src/optimize_projections/mod.rs

Lines changed: 64 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -356,12 +356,35 @@ fn optimize_projections(
356356
.collect::<Result<Vec<_>>>()?
357357
}
358358
LogicalPlan::EmptyRelation(_)
359-
| LogicalPlan::RecursiveQuery(_)
360359
| LogicalPlan::Values(_)
361360
| LogicalPlan::DescribeTable(_) => {
362361
// These operators have no inputs, so stop the optimization process.
363362
return Ok(Transformed::no(plan));
364363
}
364+
LogicalPlan::RecursiveQuery(recursive) => {
365+
// Only allow subqueries that reference the current CTE; nested subqueries are not yet
366+
// supported for projection pushdown for simplicity.
367+
// TODO: be able to do projection pushdown on recursive CTEs with subqueries
368+
if plan_contains_other_subqueries(
369+
recursive.static_term.as_ref(),
370+
&recursive.name,
371+
) || plan_contains_other_subqueries(
372+
recursive.recursive_term.as_ref(),
373+
&recursive.name,
374+
) {
375+
return Ok(Transformed::no(plan));
376+
}
377+
378+
plan.inputs()
379+
.into_iter()
380+
.map(|input| {
381+
indices
382+
.clone()
383+
.with_projection_beneficial()
384+
.with_plan_exprs(&plan, input.schema())
385+
})
386+
.collect::<Result<Vec<_>>>()?
387+
}
365388
LogicalPlan::Join(join) => {
366389
let left_len = join.left.schema().fields().len();
367390
let (left_req_indices, right_req_indices) =
@@ -850,6 +873,46 @@ pub fn is_projection_unnecessary(
850873
))
851874
}
852875

876+
/// Returns true if the plan subtree contains any subqueries that are not the
877+
/// CTE reference itself. This treats any non-CTE [`LogicalPlan::SubqueryAlias`]
878+
/// node (including aliased relations) as a blocker, along with expression-level
879+
/// subqueries like scalar, EXISTS, or IN. These cases prevent projection
880+
/// pushdown for now because we cannot safely reason about their column usage.
881+
fn plan_contains_other_subqueries(plan: &LogicalPlan, cte_name: &str) -> bool {
882+
if let LogicalPlan::SubqueryAlias(alias) = plan {
883+
if alias.alias.table() != cte_name {
884+
return true;
885+
}
886+
}
887+
888+
let mut found = false;
889+
plan.apply_expressions(|expr| {
890+
if expr_contains_subquery(expr) {
891+
found = true;
892+
Ok(TreeNodeRecursion::Stop)
893+
} else {
894+
Ok(TreeNodeRecursion::Continue)
895+
}
896+
})
897+
.expect("expression traversal never fails");
898+
if found {
899+
return true;
900+
}
901+
902+
plan.inputs()
903+
.into_iter()
904+
.any(|child| plan_contains_other_subqueries(child, cte_name))
905+
}
906+
907+
fn expr_contains_subquery(expr: &Expr) -> bool {
908+
expr.exists(|e| match e {
909+
Expr::ScalarSubquery(_) | Expr::Exists(_) | Expr::InSubquery(_) => Ok(true),
910+
_ => Ok(false),
911+
})
912+
// Safe unwrap since we are doing a simple boolean check
913+
.unwrap()
914+
}
915+
853916
#[cfg(test)]
854917
mod tests {
855918
use std::cmp::Ordering;

datafusion/optimizer/tests/optimizer_integration.rs

Lines changed: 138 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,48 @@ fn init() {
4646
let _ = env_logger::try_init();
4747
}
4848

49+
#[test]
50+
fn recursive_cte_with_nested_subquery() -> Result<()> {
51+
// Covers bailout path in `plan_contains_other_subqueries`, ensuring nested subqueries
52+
// within recursive CTE branches prevent projection pushdown.
53+
let sql = r#"
54+
WITH RECURSIVE numbers(id, level) AS (
55+
SELECT sub.id, sub.level FROM (
56+
SELECT col_int32 AS id, 1 AS level FROM test
57+
) sub
58+
UNION ALL
59+
SELECT t.col_int32, numbers.level + 1
60+
FROM test t
61+
JOIN numbers ON t.col_int32 = numbers.id + 1
62+
)
63+
SELECT id, level FROM numbers
64+
"#;
65+
66+
let plan = test_sql(sql)?;
67+
68+
assert_snapshot!(
69+
format!("{plan}"),
70+
@r#"
71+
SubqueryAlias: numbers
72+
Projection: sub.id AS id, sub.level AS level
73+
RecursiveQuery: is_distinct=false
74+
Projection: sub.id, sub.level
75+
SubqueryAlias: sub
76+
Projection: test.col_int32 AS id, Int64(1) AS level
77+
TableScan: test
78+
Projection: t.col_int32, numbers.level + Int64(1)
79+
Inner Join: CAST(t.col_int32 AS Int64) = CAST(numbers.id AS Int64) + Int64(1)
80+
SubqueryAlias: t
81+
Filter: CAST(test.col_int32 AS Int64) IS NOT NULL
82+
TableScan: test
83+
Filter: CAST(numbers.id AS Int64) + Int64(1) IS NOT NULL
84+
TableScan: numbers
85+
"#
86+
);
87+
88+
Ok(())
89+
}
90+
4991
#[test]
5092
fn case_when() -> Result<()> {
5193
let sql = "SELECT CASE WHEN col_int32 > 0 THEN 1 ELSE 0 END FROM test";
@@ -478,6 +520,94 @@ fn select_correlated_predicate_subquery_with_uppercase_ident() {
478520
);
479521
}
480522

523+
#[test]
524+
fn recursive_cte_projection_pushdown() -> Result<()> {
525+
// Test that projection pushdown works with recursive CTEs by ensuring
526+
// only the required columns are projected from the base table, even when
527+
// the CTE definition includes unused columns
528+
let sql = "WITH RECURSIVE nodes AS (\
529+
SELECT col_int32 AS id, col_utf8 AS name, col_uint32 AS extra FROM test \
530+
UNION ALL \
531+
SELECT id + 1, name, extra FROM nodes WHERE id < 3\
532+
) SELECT id FROM nodes";
533+
let plan = test_sql(sql)?;
534+
535+
// The optimizer successfully performs projection pushdown by only selecting the needed
536+
// columns from the base table and recursive table, eliminating unused columns
537+
assert_snapshot!(
538+
format!("{plan}"),
539+
@r#"SubqueryAlias: nodes
540+
RecursiveQuery: is_distinct=false
541+
Projection: test.col_int32 AS id
542+
TableScan: test projection=[col_int32]
543+
Projection: CAST(CAST(nodes.id AS Int64) + Int64(1) AS Int32)
544+
Filter: nodes.id < Int32(3)
545+
TableScan: nodes projection=[id]
546+
"#
547+
);
548+
Ok(())
549+
}
550+
551+
#[test]
552+
fn recursive_cte_with_unused_columns() -> Result<()> {
553+
// Test projection pushdown with a recursive CTE where the base case
554+
// includes columns that are never used in the recursive part or final result
555+
let sql = "WITH RECURSIVE series AS (\
556+
SELECT 1 AS n, col_utf8, col_uint32, col_date32 FROM test WHERE col_int32 = 1 \
557+
UNION ALL \
558+
SELECT n + 1, col_utf8, col_uint32, col_date32 FROM series WHERE n < 3\
559+
) SELECT n FROM series";
560+
let plan = test_sql(sql)?;
561+
562+
// The optimizer successfully performs projection pushdown by eliminating unused columns
563+
// even when they're defined in the CTE but not actually needed
564+
assert_snapshot!(
565+
format!("{plan}"),
566+
@r#"SubqueryAlias: series
567+
RecursiveQuery: is_distinct=false
568+
Projection: Int64(1) AS n
569+
Filter: test.col_int32 = Int32(1)
570+
TableScan: test projection=[col_int32]
571+
Projection: series.n + Int64(1)
572+
Filter: series.n < Int64(3)
573+
TableScan: series projection=[n]
574+
"#
575+
);
576+
Ok(())
577+
}
578+
579+
#[test]
580+
/// Asserts the minimal plan shape once projection pushdown succeeds for a recursive CTE.
581+
/// Unlike the previous two tests that retain extra columns in either the base or recursive
582+
/// branches, this baseline shows the optimizer trimming everything down to the single
583+
/// column required by the final projection.
584+
fn recursive_cte_projection_pushdown_baseline() -> Result<()> {
585+
// Test case that truly demonstrates projection pushdown working:
586+
// The base case only selects needed columns
587+
let sql = "WITH RECURSIVE countdown AS (\
588+
SELECT col_int32 AS n FROM test WHERE col_int32 = 5 \
589+
UNION ALL \
590+
SELECT n - 1 FROM countdown WHERE n > 1\
591+
) SELECT n FROM countdown";
592+
let plan = test_sql(sql)?;
593+
594+
// This demonstrates optimal projection pushdown where only col_int32 is projected from the base table,
595+
// and only the needed column is selected from the recursive table
596+
assert_snapshot!(
597+
format!("{plan}"),
598+
@r#"SubqueryAlias: countdown
599+
RecursiveQuery: is_distinct=false
600+
Projection: test.col_int32 AS n
601+
Filter: test.col_int32 = Int32(5)
602+
TableScan: test projection=[col_int32]
603+
Projection: CAST(CAST(countdown.n AS Int64) - Int64(1) AS Int32)
604+
Filter: countdown.n > Int32(1)
605+
TableScan: countdown projection=[n]
606+
"#
607+
);
608+
Ok(())
609+
}
610+
481611
fn test_sql(sql: &str) -> Result<LogicalPlan> {
482612
// parse the SQL
483613
let dialect = GenericDialect {}; // or AnsiDialect, or your own dialect ...
@@ -587,6 +717,14 @@ impl ContextProvider for MyContextProvider {
587717
None
588718
}
589719

720+
fn create_cte_work_table(
721+
&self,
722+
_name: &str,
723+
schema: SchemaRef,
724+
) -> Result<Arc<dyn TableSource>> {
725+
Ok(Arc::new(MyTableSource { schema }))
726+
}
727+
590728
fn options(&self) -> &ConfigOptions {
591729
&self.options
592730
}

datafusion/sqllogictest/test_files/cte.slt

Lines changed: 8 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -110,7 +110,7 @@ logical_plan
110110
04)------EmptyRelation: rows=1
111111
05)----Projection: nodes.id + Int64(1) AS id
112112
06)------Filter: nodes.id < Int64(10)
113-
07)--------TableScan: nodes
113+
07)--------TableScan: nodes projection=[id]
114114
physical_plan
115115
01)RecursiveQueryExec: name=nodes, is_distinct=false
116116
02)--ProjectionExec: expr=[1 as id]
@@ -152,11 +152,10 @@ logical_plan
152152
01)Sort: balances.time ASC NULLS LAST, balances.name ASC NULLS LAST, balances.account_balance ASC NULLS LAST
153153
02)--SubqueryAlias: balances
154154
03)----RecursiveQuery: is_distinct=false
155-
04)------Projection: balance.time, balance.name, balance.account_balance
156-
05)--------TableScan: balance
157-
06)------Projection: balances.time + Int64(1) AS time, balances.name, balances.account_balance + Int64(10) AS account_balance
158-
07)--------Filter: balances.time < Int64(10)
159-
08)----------TableScan: balances
155+
04)------TableScan: balance projection=[time, name, account_balance]
156+
05)------Projection: balances.time + Int64(1) AS time, balances.name, balances.account_balance + Int64(10) AS account_balance
157+
06)--------Filter: balances.time < Int64(10)
158+
07)----------TableScan: balances projection=[time, name, account_balance]
160159
physical_plan
161160
01)SortExec: expr=[time@0 ASC NULLS LAST, name@1 ASC NULLS LAST, account_balance@2 ASC NULLS LAST], preserve_partitioning=[false]
162161
02)--RecursiveQueryExec: name=balances, is_distinct=false
@@ -958,7 +957,7 @@ logical_plan
958957
04)------EmptyRelation: rows=1
959958
05)----Projection: numbers.n + Int64(1)
960959
06)------Filter: numbers.n < Int64(10)
961-
07)--------TableScan: numbers
960+
07)--------TableScan: numbers projection=[n]
962961
physical_plan
963962
01)RecursiveQueryExec: name=numbers, is_distinct=false
964963
02)--ProjectionExec: expr=[1 as n]
@@ -984,7 +983,7 @@ logical_plan
984983
04)------EmptyRelation: rows=1
985984
05)----Projection: numbers.n + Int64(1)
986985
06)------Filter: numbers.n < Int64(10)
987-
07)--------TableScan: numbers
986+
07)--------TableScan: numbers projection=[n]
988987
physical_plan
989988
01)RecursiveQueryExec: name=numbers, is_distinct=false
990989
02)--ProjectionExec: expr=[1 as n]
@@ -1041,8 +1040,7 @@ logical_plan
10411040
04)------Projection: Int64(0) AS k, Int64(0) AS v
10421041
05)--------EmptyRelation: rows=1
10431042
06)------Sort: r.v ASC NULLS LAST, fetch=1
1044-
07)--------Projection: r.k, r.v
1045-
08)----------TableScan: r
1043+
07)--------TableScan: r projection=[k, v]
10461044
physical_plan
10471045
01)GlobalLimitExec: skip=0, fetch=5
10481046
02)--RecursiveQueryExec: name=r, is_distinct=false

0 commit comments

Comments
 (0)