Skip to content

Commit 4021d91

Browse files
milanisvetcloud-fan
authored andcommitted
[SPARK-50838][SQL] Performs additional checks inside recursive CTEs to throw an error if forbidden case is encountered
### What changes were proposed in this pull request? Performs additional checks inside recursive CTEs to throw an error if forbidden case is encountered: 1. Recursive term can contain one recursive reference only. 2. Recursive reference can't be used in some kinds of joins and aggregations. 3. Recursive references are not allowed in subqueries In addition, the name of `recursive` function inside `CTERelationDef` is rewritten to `hasRecursiveCTERelationRef` and adds `hasItsOwnUnionLoopRef` function as it is also needed to check if cteDef is recursive after substitution. A small bug in `CTESubstitution` is fixed which now enables substitution of self-references within subqueries as well (but not its resolution, as they are not allowed). ### Why are the changes needed? Support for the recursive CTE. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? will be tested in #49571 ### Was this patch authored or co-authored using generative AI tooling? No. Closes #49518 from milanisvet/checkRecursion. Authored-by: Milan Cupac <[email protected]> Signed-off-by: Wenchen Fan <[email protected]>
1 parent 5c1f7c2 commit 4021d91

File tree

6 files changed

+112
-11
lines changed

6 files changed

+112
-11
lines changed

common/utils/src/main/resources/error/error-conditions.json

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3135,6 +3135,24 @@
31353135
],
31363136
"sqlState" : "42836"
31373137
},
3138+
"INVALID_RECURSIVE_REFERENCE" : {
3139+
"message" : [
3140+
"Invalid recursive reference found inside WITH RECURSIVE clause."
3141+
],
3142+
"subClass" : {
3143+
"NUMBER" : {
3144+
"message" : [
3145+
"Multiple self-references to one recursive CTE are not allowed."
3146+
]
3147+
},
3148+
"PLACE" : {
3149+
"message" : [
3150+
"Recursive references cannot be used on the right side of left outer/semi/anti joins, on the left side of right outer joins, in full outer joins, in aggregates, and in subquery expressions."
3151+
]
3152+
}
3153+
},
3154+
"sqlState" : "42836"
3155+
},
31383156
"INVALID_REGEXP_REPLACE" : {
31393157
"message" : [
31403158
"Could not perform regexp_replace for source = \"<source>\", pattern = \"<pattern>\", replacement = \"<replacement>\" and position = <position>."

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CTESubstitution.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -402,7 +402,7 @@ object CTESubstitution extends Rule[LogicalPlan] {
402402
other.transformExpressionsWithPruning(_.containsPattern(PLAN_EXPRESSION)) {
403403
case e: SubqueryExpression =>
404404
e.withNewPlan(
405-
apply(substituteCTE(e.plan, alwaysInline, cteRelations, None)))
405+
apply(substituteCTE(e.plan, alwaysInline, cteRelations, recursiveCTERelation)))
406406
}
407407
}
408408
}

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ import org.apache.spark.{SparkException, SparkThrowable}
2222
import org.apache.spark.internal.{Logging, LogKeys, MDC}
2323
import org.apache.spark.sql.AnalysisException
2424
import org.apache.spark.sql.catalyst.ExtendedAnalysisException
25+
import org.apache.spark.sql.catalyst.analysis.ResolveWithCTE.{checkForSelfReferenceInSubquery, checkIfSelfReferenceIsPlacedCorrectly}
2526
import org.apache.spark.sql.catalyst.expressions._
2627
import org.apache.spark.sql.catalyst.expressions.SubExprUtils._
2728
import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, AggregateFunction, ListAgg, Median, PercentileCont, PercentileDisc}
@@ -274,10 +275,17 @@ trait CheckAnalysis extends PredicateHelper with LookupCatalog with QueryErrorsB
274275
checkTrailingCommaInSelect(proj)
275276
case agg: Aggregate =>
276277
checkTrailingCommaInSelect(agg)
278+
case unionLoop: UnionLoop =>
279+
// Recursive CTEs have already substituted Union to UnionLoop at this stage.
280+
// Here we perform additional checks for them.
281+
checkIfSelfReferenceIsPlacedCorrectly(unionLoop, unionLoop.id)
277282

278283
case _ =>
279284
}
280285

286+
// Check if there is any self-reference within subqueries
287+
checkForSelfReferenceInSubquery(plan)
288+
281289
// We transform up and order the rules so as to catch the first possible failure instead
282290
// of the result of cascading resolution failures.
283291
plan.foreachUp {

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveWithCTE.scala

Lines changed: 76 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ import scala.collection.mutable
2121

2222
import org.apache.spark.sql.AnalysisException
2323
import org.apache.spark.sql.catalyst.expressions.SubqueryExpression
24+
import org.apache.spark.sql.catalyst.plans.{Inner, LeftAnti, LeftOuter, LeftSemi, RightOuter}
2425
import org.apache.spark.sql.catalyst.plans.logical._
2526
import org.apache.spark.sql.catalyst.rules.Rule
2627
import org.apache.spark.sql.catalyst.trees.TreePattern.{CTE, PLAN_EXPRESSION}
@@ -49,17 +50,18 @@ object ResolveWithCTE extends Rule[LogicalPlan] {
4950
plan.resolveOperatorsDownWithPruning(_.containsAllPatterns(CTE)) {
5051
case withCTE @ WithCTE(_, cteDefs) =>
5152
val newCTEDefs = cteDefs.map {
52-
// `cteDef.recursive` means "presence of a recursive CTERelationRef under cteDef". The
53-
// side effect of node substitution below is that after CTERelationRef substitution
54-
// its cteDef is no more considered `recursive`. This code path is common for `cteDef`
55-
// that were non-recursive from the get go, as well as those that are no more recursive
56-
// due to node substitution.
57-
case cteDef if !cteDef.recursive =>
53+
// cteDef in the first case is either recursive and all the recursive CTERelationRefs
54+
// are already substituted to UnionLoopRef in the previous pass, or it is not recursive
55+
// at all. In both cases we need to put it in the map in case it is resolved.
56+
// Second case is performing the substitution of recursive CTERelationRefs.
57+
case cteDef if !cteDef.hasSelfReferenceAsCTERef =>
5858
if (cteDef.resolved) {
5959
cteDefMap.put(cteDef.id, cteDef)
6060
}
6161
cteDef
6262
case cteDef =>
63+
// Multiple self-references are not allowed within one cteDef.
64+
checkNumberOfSelfReferences(cteDef)
6365
cteDef.child match {
6466
// If it is a supported recursive CTE query pattern (4 so far), extract the anchor and
6567
// recursive plans from the Union and rewrite Union with UnionLoop. The recursive CTE
@@ -183,4 +185,72 @@ object ResolveWithCTE extends Rule[LogicalPlan] {
183185
columnNames.map(UnresolvedSubqueryColumnAliases(_, ref)).getOrElse(ref)
184186
}
185187
}
188+
189+
/**
190+
* Checks if there is any self-reference within subqueries and throws an error
191+
* if that is the case.
192+
*/
193+
def checkForSelfReferenceInSubquery(plan: LogicalPlan): Unit = {
194+
plan.subqueriesAll.foreach { subquery =>
195+
subquery.foreach {
196+
case r: CTERelationRef if r.recursive =>
197+
throw new AnalysisException(
198+
errorClass = "INVALID_RECURSIVE_REFERENCE.PLACE",
199+
messageParameters = Map.empty)
200+
case _ =>
201+
}
202+
}
203+
}
204+
205+
/**
206+
* Counts number of self-references in a recursive CTE definition and throws an error
207+
* if that number is bigger than 1.
208+
*/
209+
private def checkNumberOfSelfReferences(cteDef: CTERelationDef): Unit = {
210+
val numOfSelfRef = cteDef.collectWithSubqueries {
211+
case ref: CTERelationRef if ref.cteId == cteDef.id => ref
212+
}.length
213+
if (numOfSelfRef > 1) {
214+
cteDef.failAnalysis(
215+
errorClass = "INVALID_RECURSIVE_REFERENCE.NUMBER",
216+
messageParameters = Map.empty)
217+
}
218+
}
219+
220+
/**
221+
* Throws error if self-reference is placed in places which are not allowed:
222+
* right side of left outer/semi/anti joins, left side of right outer joins,
223+
* in full outer joins and in aggregates
224+
*/
225+
def checkIfSelfReferenceIsPlacedCorrectly(
226+
plan: LogicalPlan,
227+
cteId: Long,
228+
allowRecursiveRef: Boolean = true): Unit = plan match {
229+
case Join(left, right, Inner, _, _) =>
230+
checkIfSelfReferenceIsPlacedCorrectly(left, cteId, allowRecursiveRef)
231+
checkIfSelfReferenceIsPlacedCorrectly(right, cteId, allowRecursiveRef)
232+
case Join(left, right, LeftOuter, _, _) =>
233+
checkIfSelfReferenceIsPlacedCorrectly(left, cteId, allowRecursiveRef)
234+
checkIfSelfReferenceIsPlacedCorrectly(right, cteId, allowRecursiveRef = false)
235+
case Join(left, right, RightOuter, _, _) =>
236+
checkIfSelfReferenceIsPlacedCorrectly(left, cteId, allowRecursiveRef = false)
237+
checkIfSelfReferenceIsPlacedCorrectly(right, cteId, allowRecursiveRef)
238+
case Join(left, right, LeftSemi, _, _) =>
239+
checkIfSelfReferenceIsPlacedCorrectly(left, cteId, allowRecursiveRef)
240+
checkIfSelfReferenceIsPlacedCorrectly(right, cteId, allowRecursiveRef = false)
241+
case Join(left, right, LeftAnti, _, _) =>
242+
checkIfSelfReferenceIsPlacedCorrectly(left, cteId, allowRecursiveRef)
243+
checkIfSelfReferenceIsPlacedCorrectly(right, cteId, allowRecursiveRef = false)
244+
case Join(left, right, _, _, _) =>
245+
checkIfSelfReferenceIsPlacedCorrectly(left, cteId, allowRecursiveRef = false)
246+
checkIfSelfReferenceIsPlacedCorrectly(right, cteId, allowRecursiveRef = false)
247+
case Aggregate(_, _, child, _) =>
248+
checkIfSelfReferenceIsPlacedCorrectly(child, cteId, allowRecursiveRef = false)
249+
case r: UnionLoopRef if !allowRecursiveRef && r.loopId == cteId =>
250+
throw new AnalysisException(
251+
errorClass = "INVALID_RECURSIVE_REFERENCE.PLACE",
252+
messageParameters = Map.empty)
253+
case other =>
254+
other.children.foreach(checkIfSelfReferenceIsPlacedCorrectly(_, cteId, allowRecursiveRef))
255+
}
186256
}

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/InlineCTE.scala

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,10 @@ case class InlineCTE(
6161
// 1) It is fine to inline a CTE if it references another CTE that is non-deterministic;
6262
// 2) Any `CTERelationRef` that contains `OuterReference` would have been inlined first.
6363
refCount == 1 ||
64-
cteDef.deterministic ||
64+
// Don't inline recursive CTEs if not necessary as recursion is very costly.
65+
// The check if cteDef is recursive is performed by checking if it contains
66+
// a UnionLoopRef with the same ID.
67+
(cteDef.deterministic && !cteDef.hasSelfReferenceAsUnionLoopRef) ||
6568
cteDef.child.exists(_.expressions.exists(_.isInstanceOf[OuterReference]))
6669
}
6770

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/cteOperators.scala

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -100,12 +100,14 @@ case class CTERelationDef(
100100

101101
override def output: Seq[Attribute] = if (resolved) child.output else Nil
102102

103-
lazy val recursive: Boolean = child.exists{
104-
// If the reference is found inside the child, referencing to this CTE definition,
105-
// and already marked as recursive, then this CTE definition is recursive.
103+
lazy val hasSelfReferenceAsCTERef: Boolean = child.exists{
106104
case CTERelationRef(this.id, _, _, _, _, true) => true
107105
case _ => false
108106
}
107+
lazy val hasSelfReferenceAsUnionLoopRef: Boolean = child.exists{
108+
case UnionLoopRef(this.id, _, _) => true
109+
case _ => false
110+
}
109111
}
110112

111113
object CTERelationDef {

0 commit comments

Comments
 (0)