Skip to content

Commit d5729f0

Browse files
TeodorDjeliccloud-fan
authored andcommitted
[SPARK-53621][CORE] Adding Support for Executing CONTINUE HANDLER
### What changes were proposed in this pull request? - Added support for executing CONTINUE exception handlers in SQL scripting - Extended existing exception handling framework to support both EXIT and CONTINUE handler types - Added interrupt capability for conditional statements to support CONTINUE handlers - Enhanced frame type system to distinguish between EXIT_HANDLER and CONTINUE_HANDLER - Updated test coverage with comprehensive CONTINUE handler scenarios Feature is under a new feature switch spark.sql.scripting.continueHandlerEnabled inside `SQLConfig.scala`. ### Why are the changes needed? This is a part of PRs focused on an effort to add support for `CONTINUE HANDLER`s. Follow-up PR will contain more tests. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? - Added extensive unit tests in `SqlScriptingExecutionSuite.scala` covering various CONTINUE handler scenarios - Added E2E test in `SqlScriptingE2eSuite.scala` demonstrating CONTINUE handler functionality - Tests cover duplicate handler detection for both EXIT and CONTINUE types - Tests verify proper execution flow continuation after exception handling ### Was this patch authored or co-authored using generative AI tooling? No. Closes #52371 from TeodorDjelic/executing-continue-handlers. Authored-by: Teodor Djelic <[email protected]> Signed-off-by: Wenchen Fan <[email protected]>
1 parent 46ac78e commit d5729f0

File tree

6 files changed

+618
-44
lines changed

6 files changed

+618
-44
lines changed

sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecution.scala

Lines changed: 43 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ import org.apache.spark.SparkThrowable
2121
import org.apache.spark.sql.Row
2222
import org.apache.spark.sql.catalyst.SqlScriptingContextManager
2323
import org.apache.spark.sql.catalyst.expressions.Expression
24-
import org.apache.spark.sql.catalyst.plans.logical.{CommandResult, CompoundBody, LocalRelation, LogicalPlan}
24+
import org.apache.spark.sql.catalyst.plans.logical.{CommandResult, CompoundBody, ExceptionHandlerType, LocalRelation, LogicalPlan}
2525
import org.apache.spark.sql.catalyst.types.DataTypeUtils
2626
import org.apache.spark.sql.classic.{DataFrame, SparkSession}
2727
import org.apache.spark.sql.types.StructType
@@ -78,14 +78,35 @@ class SqlScriptingExecution(
7878
*/
7979
private def injectLeaveStatement(executionPlan: NonLeafStatementExec, label: String): Unit = {
8080
// Go as deep as possible, to find a leaf node. Instead of a statement that
81-
// should be executed next, inject LEAVE statement in its place.
81+
// should be executed next, inject LEAVE statement in its place.
8282
var currExecPlan = executionPlan
8383
while (currExecPlan.curr.exists(_.isInstanceOf[NonLeafStatementExec])) {
8484
currExecPlan = currExecPlan.curr.get.asInstanceOf[NonLeafStatementExec]
8585
}
8686
currExecPlan.curr = Some(new LeaveStatementExec(label))
8787
}
8888

89+
/**
90+
* Helper method to execute interrupts to ConditionalStatements.
91+
* This method should only interrupt when the statement that throws is a conditional statement.
92+
* @param executionPlan Execution plan.
93+
*/
94+
private def interruptConditionalStatements(executionPlan: NonLeafStatementExec): Unit = {
95+
// Go as deep as possible into the execution plan children nodes, to find a leaf node.
96+
// That leaf node is the next statement that is to be executed. If the parent node of that
97+
// leaf node is a conditional statement, skip the conditional statement entirely.
98+
var currExecPlan = executionPlan
99+
while (currExecPlan.curr.exists(_.isInstanceOf[NonLeafStatementExec])) {
100+
currExecPlan = currExecPlan.curr.get.asInstanceOf[NonLeafStatementExec]
101+
}
102+
103+
currExecPlan match {
104+
case exec: ConditionalStatementExec =>
105+
exec.interrupted = true
106+
case _ =>
107+
}
108+
}
109+
89110
/** Helper method to iterate get next statements from the first available frame. */
90111
private def getNextStatement: Option[CompoundStatementExec] = {
91112
// Remove frames that are already executed.
@@ -103,15 +124,29 @@ class SqlScriptingExecution(
103124

104125
// If the last frame is a handler, set leave statement to be the next one in the
105126
// innermost scope that should be exited.
106-
if (lastFrame.frameType == SqlScriptingFrameType.HANDLER && context.frames.nonEmpty) {
127+
if (lastFrame.frameType == SqlScriptingFrameType.EXIT_HANDLER
128+
&& context.frames.nonEmpty) {
107129
// Remove the scope if handler is executed.
108130
if (context.firstHandlerScopeLabel.isDefined
109131
&& lastFrame.scopeLabel.get == context.firstHandlerScopeLabel.get) {
110132
context.firstHandlerScopeLabel = None
111133
}
134+
112135
// Inject leave statement into the execution plan of the last frame.
113136
injectLeaveStatement(context.frames.last.executionPlan, lastFrame.scopeLabel.get)
114137
}
138+
139+
if (lastFrame.frameType == SqlScriptingFrameType.CONTINUE_HANDLER
140+
&& context.frames.nonEmpty) {
141+
// Remove the scope if handler is executed.
142+
if (context.firstHandlerScopeLabel.isDefined
143+
&& lastFrame.scopeLabel.get == context.firstHandlerScopeLabel.get) {
144+
context.firstHandlerScopeLabel = None
145+
}
146+
147+
// Interrupt conditional statements
148+
interruptConditionalStatements(context.frames.last.executionPlan)
149+
}
115150
}
116151
// If there are still frames available, get the next statement.
117152
if (context.frames.nonEmpty) {
@@ -169,7 +204,11 @@ class SqlScriptingExecution(
169204
case Some(handler) =>
170205
val handlerFrame = new SqlScriptingExecutionFrame(
171206
handler.body,
172-
SqlScriptingFrameType.HANDLER,
207+
if (handler.handlerType == ExceptionHandlerType.CONTINUE) {
208+
SqlScriptingFrameType.CONTINUE_HANDLER
209+
} else {
210+
SqlScriptingFrameType.EXIT_HANDLER
211+
},
173212
handler.scopeLabel
174213
)
175214
context.frames.append(

sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionContext.scala

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,8 @@ class SqlScriptingExecutionContext extends SqlScriptingExecutionContextExtension
5959
}
6060

6161
// If the last frame is a handler, try to find a handler in its body first.
62-
if (frames.last.frameType == SqlScriptingFrameType.HANDLER) {
62+
if (frames.last.frameType == SqlScriptingFrameType.EXIT_HANDLER
63+
|| frames.last.frameType == SqlScriptingFrameType.CONTINUE_HANDLER) {
6364
val handler = frames.last.findHandler(condition, sqlState, firstHandlerScopeLabel)
6465
if (handler.isDefined) {
6566
return handler
@@ -83,7 +84,7 @@ class SqlScriptingExecutionContext extends SqlScriptingExecutionContextExtension
8384

8485
object SqlScriptingFrameType extends Enumeration {
8586
type SqlScriptingFrameType = Value
86-
val SQL_SCRIPT, HANDLER = Value
87+
val SQL_SCRIPT, EXIT_HANDLER, CONTINUE_HANDLER = Value
8788
}
8889

8990
/**
@@ -141,7 +142,9 @@ class SqlScriptingExecutionFrame(
141142
sqlState: String,
142143
firstHandlerScopeLabel: Option[String]): Option[ExceptionHandlerExec] = {
143144

144-
val searchScopes = if (frameType == SqlScriptingFrameType.HANDLER) {
145+
val searchScopes =
146+
if (frameType == SqlScriptingFrameType.EXIT_HANDLER
147+
|| frameType == SqlScriptingFrameType.CONTINUE_HANDLER) {
145148
// If the frame is a handler, search for the handler in its body. Don't skip any scopes.
146149
scopes.reverseIterator
147150
} else if (firstHandlerScopeLabel.isEmpty) {

sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNode.scala

Lines changed: 25 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,13 @@ trait NonLeafStatementExec extends CompoundStatementExec {
106106
}
107107
}
108108

109+
/**
110+
* Conditional node in the execution tree. It is a conditional non-leaf node.
111+
*/
112+
trait ConditionalStatementExec extends NonLeafStatementExec {
113+
protected[scripting] var interrupted: Boolean = false
114+
}
115+
109116
/**
110117
* Executable node for SingleStatement.
111118
* @param parsedPlan
@@ -401,7 +408,7 @@ class IfElseStatementExec(
401408
conditions: Seq[SingleStatementExec],
402409
conditionalBodies: Seq[CompoundBodyExec],
403410
elseBody: Option[CompoundBodyExec],
404-
session: SparkSession) extends NonLeafStatementExec {
411+
session: SparkSession) extends ConditionalStatementExec {
405412
private object IfElseState extends Enumeration {
406413
val Condition, Body = Value
407414
}
@@ -415,7 +422,7 @@ class IfElseStatementExec(
415422

416423
private lazy val treeIterator: Iterator[CompoundStatementExec] =
417424
new Iterator[CompoundStatementExec] {
418-
override def hasNext: Boolean = curr.nonEmpty
425+
override def hasNext: Boolean = !interrupted && curr.nonEmpty
419426

420427
override def next(): CompoundStatementExec = {
421428
if (curr.exists(_.isInstanceOf[LeaveStatementExec])) {
@@ -467,6 +474,7 @@ class IfElseStatementExec(
467474
state = IfElseState.Condition
468475
curr = Some(conditions.head)
469476
clauseIdx = 0
477+
interrupted = false
470478
conditions.foreach(c => c.reset())
471479
conditionalBodies.foreach(b => b.reset())
472480
elseBody.foreach(b => b.reset())
@@ -484,7 +492,7 @@ class WhileStatementExec(
484492
condition: SingleStatementExec,
485493
body: CompoundBodyExec,
486494
label: Option[String],
487-
session: SparkSession) extends NonLeafStatementExec {
495+
session: SparkSession) extends ConditionalStatementExec {
488496

489497
private object WhileState extends Enumeration {
490498
val Condition, Body = Value
@@ -495,7 +503,7 @@ class WhileStatementExec(
495503

496504
private lazy val treeIterator: Iterator[CompoundStatementExec] =
497505
new Iterator[CompoundStatementExec] {
498-
override def hasNext: Boolean = curr.nonEmpty
506+
override def hasNext: Boolean = !interrupted && curr.nonEmpty
499507

500508
override def next(): CompoundStatementExec = state match {
501509
case WhileState.Condition =>
@@ -551,6 +559,7 @@ class WhileStatementExec(
551559
override def reset(): Unit = {
552560
state = WhileState.Condition
553561
curr = Some(condition)
562+
interrupted = false
554563
condition.reset()
555564
body.reset()
556565
}
@@ -575,7 +584,7 @@ class SearchedCaseStatementExec(
575584
conditions: Seq[SingleStatementExec],
576585
conditionalBodies: Seq[CompoundBodyExec],
577586
elseBody: Option[CompoundBodyExec],
578-
session: SparkSession) extends NonLeafStatementExec {
587+
session: SparkSession) extends ConditionalStatementExec {
579588
private object CaseState extends Enumeration {
580589
val Condition, Body = Value
581590
}
@@ -588,7 +597,7 @@ class SearchedCaseStatementExec(
588597

589598
private lazy val treeIterator: Iterator[CompoundStatementExec] =
590599
new Iterator[CompoundStatementExec] {
591-
override def hasNext: Boolean = curr.nonEmpty
600+
override def hasNext: Boolean = !interrupted && curr.nonEmpty
592601

593602
override def next(): CompoundStatementExec = {
594603
if (curr.exists(_.isInstanceOf[LeaveStatementExec])) {
@@ -640,6 +649,7 @@ class SearchedCaseStatementExec(
640649
state = CaseState.Condition
641650
curr = Some(conditions.head)
642651
clauseIdx = 0
652+
interrupted = false
643653
conditions.foreach(c => c.reset())
644654
conditionalBodies.foreach(b => b.reset())
645655
elseBody.foreach(b => b.reset())
@@ -662,7 +672,7 @@ class SimpleCaseStatementExec(
662672
conditionalBodies: Seq[CompoundBodyExec],
663673
elseBody: Option[CompoundBodyExec],
664674
session: SparkSession,
665-
context: SqlScriptingExecutionContext) extends NonLeafStatementExec {
675+
context: SqlScriptingExecutionContext) extends ConditionalStatementExec {
666676
private object CaseState extends Enumeration {
667677
val Condition, Body = Value
668678
}
@@ -699,7 +709,7 @@ class SimpleCaseStatementExec(
699709

700710
private lazy val treeIterator: Iterator[CompoundStatementExec] =
701711
new Iterator[CompoundStatementExec] {
702-
override def hasNext: Boolean = state match {
712+
override def hasNext: Boolean = !interrupted && (state match {
703713
case CaseState.Condition =>
704714
// Equivalent to the "iteration hasn't started yet" - to avoid computing cache
705715
// before the first actual iteration.
@@ -710,7 +720,7 @@ class SimpleCaseStatementExec(
710720
cachedConditionBodyIterator.hasNext ||
711721
elseBody.isDefined
712722
case CaseState.Body => bodyExec.exists(_.getTreeIterator.hasNext)
713-
}
723+
})
714724

715725
override def next(): CompoundStatementExec = state match {
716726
case CaseState.Condition =>
@@ -779,6 +789,7 @@ class SimpleCaseStatementExec(
779789
bodyExec = None
780790
curr = None
781791
isCacheValid = false
792+
interrupted = false
782793
caseVariableExec.reset()
783794
conditionalBodies.foreach(b => b.reset())
784795
elseBody.foreach(b => b.reset())
@@ -797,7 +808,7 @@ class RepeatStatementExec(
797808
condition: SingleStatementExec,
798809
body: CompoundBodyExec,
799810
label: Option[String],
800-
session: SparkSession) extends NonLeafStatementExec {
811+
session: SparkSession) extends ConditionalStatementExec {
801812

802813
private object RepeatState extends Enumeration {
803814
val Condition, Body = Value
@@ -808,7 +819,7 @@ class RepeatStatementExec(
808819

809820
private lazy val treeIterator: Iterator[CompoundStatementExec] =
810821
new Iterator[CompoundStatementExec] {
811-
override def hasNext: Boolean = curr.nonEmpty
822+
override def hasNext: Boolean = !interrupted && curr.nonEmpty
812823

813824
override def next(): CompoundStatementExec = state match {
814825
case RepeatState.Condition =>
@@ -863,6 +874,7 @@ class RepeatStatementExec(
863874
override def reset(): Unit = {
864875
state = RepeatState.Body
865876
curr = Some(body)
877+
interrupted = false
866878
body.reset()
867879
condition.reset()
868880
}
@@ -989,7 +1001,7 @@ class ForStatementExec(
9891001
statements: Seq[CompoundStatementExec],
9901002
val label: Option[String],
9911003
session: SparkSession,
992-
context: SqlScriptingExecutionContext) extends NonLeafStatementExec {
1004+
context: SqlScriptingExecutionContext) extends ConditionalStatementExec {
9931005

9941006
private object ForState extends Enumeration {
9951007
val VariableAssignment, Body = Value
@@ -1015,11 +1027,6 @@ class ForStatementExec(
10151027

10161028
private var bodyWithVariables: Option[CompoundBodyExec] = None
10171029

1018-
/**
1019-
* For can be interrupted by LeaveStatementExec
1020-
*/
1021-
private var interrupted: Boolean = false
1022-
10231030
/**
10241031
* Whether this iteration of the FOR loop is the first one.
10251032
*/
@@ -1028,6 +1035,7 @@ class ForStatementExec(
10281035
private lazy val treeIterator: Iterator[CompoundStatementExec] =
10291036
new Iterator[CompoundStatementExec] {
10301037

1038+
// Variable interrupted is being used by both EXIT and CONTINUE handlers
10311039
override def hasNext: Boolean = !interrupted && (state match {
10321040
// `firstIteration` NEEDS to be the first condition! This is to handle edge-cases when
10331041
// query fails with an exception. If the `cachedQueryResult().hasNext` is first, this

sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreter.scala

Lines changed: 3 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ import scala.collection.mutable.HashMap
2121

2222
import org.apache.spark.SparkException
2323
import org.apache.spark.sql.catalyst.expressions.{Alias, Expression}
24-
import org.apache.spark.sql.catalyst.plans.logical.{CompoundBody, CompoundPlanStatement, ExceptionHandlerType, ForStatement, IfElseStatement, IterateStatement, LeaveStatement, LoopStatement, OneRowRelation, Project, RepeatStatement, SearchedCaseStatement, SimpleCaseStatement, SingleStatement, WhileStatement}
24+
import org.apache.spark.sql.catalyst.plans.logical.{CompoundBody, CompoundPlanStatement, ForStatement, IfElseStatement, IterateStatement, LeaveStatement, LoopStatement, OneRowRelation, Project, RepeatStatement, SearchedCaseStatement, SimpleCaseStatement, SingleStatement, WhileStatement}
2525
import org.apache.spark.sql.catalyst.trees.CurrentOrigin
2626
import org.apache.spark.sql.classic.SparkSession
2727
import org.apache.spark.sql.errors.SqlScriptingErrors
@@ -87,17 +87,11 @@ case class SqlScriptingInterpreter(session: SparkSession) {
8787
args,
8888
context)
8989

90-
// Execution node of handler.
91-
val handlerScopeLabel = if (handler.handlerType == ExceptionHandlerType.EXIT) {
92-
Some(compoundBody.label.get)
93-
} else {
94-
None
95-
}
96-
90+
// Scope label should be Some(compoundBody.label.get) for both handler types
9791
val handlerExec = new ExceptionHandlerExec(
9892
handlerBodyExec,
9993
handler.handlerType,
100-
handlerScopeLabel)
94+
Some(compoundBody.label.get))
10195

10296
// For each condition handler is defined for, add corresponding key value pair
10397
// to the conditionHandlerMap.

sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingE2eSuite.scala

Lines changed: 42 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,17 @@ import org.apache.spark.sql.types.{IntegerType, StructField, StructType}
3535
* For full functionality tests, see SqlScriptingParserSuite and SqlScriptingInterpreterSuite.
3636
*/
3737
class SqlScriptingE2eSuite extends QueryTest with SharedSparkSession {
38+
39+
protected override def beforeAll(): Unit = {
40+
super.beforeAll()
41+
conf.setConf(SQLConf.SQL_SCRIPTING_CONTINUE_HANDLER_ENABLED, true)
42+
}
43+
44+
protected override def afterAll(): Unit = {
45+
conf.unsetConf(SQLConf.SQL_SCRIPTING_CONTINUE_HANDLER_ENABLED.key)
46+
super.afterAll()
47+
}
48+
3849
// Helpers
3950
private def verifySqlScriptResult(
4051
sqlText: String,
@@ -77,7 +88,7 @@ class SqlScriptingE2eSuite extends QueryTest with SharedSparkSession {
7788
}
7889
}
7990

80-
test("Scripting with exception handlers") {
91+
test("Scripting with exit exception handlers") {
8192
val sqlScript =
8293
"""
8394
|BEGIN
@@ -104,6 +115,36 @@ class SqlScriptingE2eSuite extends QueryTest with SharedSparkSession {
104115
verifySqlScriptResult(sqlScript, Seq(Row(2)))
105116
}
106117

118+
test("Scripting with continue exception handlers") {
119+
val sqlScript =
120+
"""
121+
|BEGIN
122+
| DECLARE flag1 INT = -1;
123+
| DECLARE flag2 INT = -1;
124+
| DECLARE CONTINUE HANDLER FOR DIVIDE_BY_ZERO
125+
| BEGIN
126+
| SELECT flag1;
127+
| SET flag1 = 1;
128+
| END;
129+
| BEGIN
130+
| DECLARE CONTINUE HANDLER FOR SQLSTATE '22012'
131+
| BEGIN
132+
| SELECT flag1;
133+
| SET flag1 = 2;
134+
| END;
135+
| SELECT 5;
136+
| SET flag2 = 1;
137+
| SELECT 1/0;
138+
| SELECT 6;
139+
| SET flag2 = 2;
140+
| END;
141+
| SELECT 7;
142+
| SELECT flag1, flag2;
143+
|END
144+
|""".stripMargin
145+
verifySqlScriptResult(sqlScript, Seq(Row(2, 2)))
146+
}
147+
107148
test("single select") {
108149
val sqlText = "SELECT 1;"
109150
verifySqlScriptResult(sqlText, Seq(Row(1)))

0 commit comments

Comments
 (0)