Skip to content

Commit d65ee81

Browse files
[SPARK-46741][SQL] Cache Table with CTE should work when CTE in plan expression subquery
### What changes were proposed in this pull request? Follow comment #53333 (comment) ### Why are the changes needed? Support all case ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? UT ### Was this patch authored or co-authored using generative AI tooling? No Closes #53526 from AngersZhuuuu/SPARK-46741-FOLLOWUP. Lead-authored-by: Angerszhuuuu <[email protected]> Co-authored-by: Wenchen Fan <[email protected]> Signed-off-by: Wenchen Fan <[email protected]>
1 parent 1a1ab19 commit d65ee81

File tree

2 files changed

+48
-8
lines changed

2 files changed

+48
-8
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/normalizer/NormalizeCTEIds.scala

Lines changed: 21 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -17,29 +17,42 @@
1717

1818
package org.apache.spark.sql.catalyst.normalizer
1919

20+
import java.util.concurrent.atomic.AtomicLong
21+
22+
import scala.collection.mutable
23+
2024
import org.apache.spark.sql.catalyst.plans.logical.{CacheTableAsSelect, CTERelationRef, LogicalPlan, UnionLoop, UnionLoopRef, WithCTE}
2125
import org.apache.spark.sql.catalyst.rules.Rule
2226

23-
object NormalizeCTEIds extends Rule[LogicalPlan]{
27+
object NormalizeCTEIds extends Rule[LogicalPlan] {
2428
override def apply(plan: LogicalPlan): LogicalPlan = {
2529
val curId = new java.util.concurrent.atomic.AtomicLong()
26-
plan transformDown {
30+
val cteIdToNewId = mutable.Map.empty[Long, Long]
31+
applyInternal(plan, curId, cteIdToNewId)
32+
}
2733

34+
private def applyInternal(
35+
plan: LogicalPlan,
36+
curId: AtomicLong,
37+
cteIdToNewId: mutable.Map[Long, Long]): LogicalPlan = {
38+
plan transformDownWithSubqueries {
2839
case ctas @ CacheTableAsSelect(_, plan, _, _, _, _, _) =>
29-
ctas.copy(plan = apply(plan))
40+
ctas.copy(plan = applyInternal(plan, curId, cteIdToNewId))
3041

3142
case withCTE @ WithCTE(plan, cteDefs) =>
32-
val defIdToNewId = withCTE.cteDefs.map(_.id).map((_, curId.getAndIncrement())).toMap
33-
val normalizedPlan = canonicalizeCTE(plan, defIdToNewId)
3443
val newCteDefs = cteDefs.map { cteDef =>
35-
val normalizedCteDef = canonicalizeCTE(cteDef.child, defIdToNewId)
36-
cteDef.copy(child = normalizedCteDef, id = defIdToNewId(cteDef.id))
44+
cteIdToNewId.getOrElseUpdate(cteDef.id, curId.getAndIncrement())
45+
val normalizedCteDef = canonicalizeCTE(cteDef.child, cteIdToNewId)
46+
cteDef.copy(child = normalizedCteDef, id = cteIdToNewId(cteDef.id))
3747
}
48+
val normalizedPlan = canonicalizeCTE(plan, cteIdToNewId)
3849
withCTE.copy(plan = normalizedPlan, cteDefs = newCteDefs)
3950
}
4051
}
4152

42-
def canonicalizeCTE(plan: LogicalPlan, defIdToNewId: Map[Long, Long]): LogicalPlan = {
53+
private def canonicalizeCTE(
54+
plan: LogicalPlan,
55+
defIdToNewId: mutable.Map[Long, Long]): LogicalPlan = {
4356
plan.transformDownWithSubqueries {
4457
// For nested WithCTE, if defIndex didn't contain the cteId,
4558
// means it's not current WithCTE's ref.

sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2633,6 +2633,33 @@ class CachedTableSuite extends QueryTest with SQLTestUtils
26332633
}
26342634
assert(inMemoryTableScan.size == 1)
26352635
checkAnswer(df, Row(5) :: Nil)
2636+
2637+
sql(
2638+
"""
2639+
|CACHE TABLE cache_subquery_cte_table
2640+
|WITH v AS (
2641+
| SELECT c1 * c2 c3 from t1
2642+
|)
2643+
|SELECT *
2644+
|FROM v
2645+
|WHERE EXISTS (
2646+
| WITH cte AS (SELECT 1 AS id)
2647+
| SELECT 1
2648+
| FROM cte
2649+
| WHERE cte.id = v.c3
2650+
|)
2651+
|""".stripMargin)
2652+
2653+
val cteInSubquery = sql(
2654+
"""
2655+
|SELECT * FROM cache_subquery_cte_table
2656+
|""".stripMargin)
2657+
2658+
val subqueryInMemoryTableScan = collect(cteInSubquery.queryExecution.executedPlan) {
2659+
case i: InMemoryTableScanExec => i
2660+
}
2661+
assert(subqueryInMemoryTableScan.size == 1)
2662+
checkAnswer(cteInSubquery, Row(1) :: Nil)
26362663
}
26372664
}
26382665

0 commit comments

Comments
 (0)