diff --git a/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecution.scala index 096ad11dd0657..0e02ca7adf293 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecution.scala @@ -51,7 +51,7 @@ class SqlScriptingExecution( val executionPlan = interpreter.buildExecutionPlan(sqlScript, args, ctx) // Add frame which represents SQL Script to the context. ctx.frames.append( - new SqlScriptingExecutionFrame(executionPlan, SqlScriptingFrameType.SQL_SCRIPT)) + new SqlScriptingSqlScriptExecutionFrame(executionPlan)) // Enter the scope of the top level compound. // We exit this scope explicitly in the getNextStatement method when there are no more // statements to execute. @@ -71,42 +71,6 @@ class SqlScriptingExecution( contextManagerHandle.runWith(f) } - /** - * Helper method to inject leave statement into the execution plan. - * @param executionPlan Execution plan to inject leave statement into. - * @param label Label of the leave statement. - */ - private def injectLeaveStatement(executionPlan: NonLeafStatementExec, label: String): Unit = { - // Go as deep as possible, to find a leaf node. Instead of a statement that - // should be executed next, inject LEAVE statement in its place. - var currExecPlan = executionPlan - while (currExecPlan.curr.exists(_.isInstanceOf[NonLeafStatementExec])) { - currExecPlan = currExecPlan.curr.get.asInstanceOf[NonLeafStatementExec] - } - currExecPlan.curr = Some(new LeaveStatementExec(label)) - } - - /** - * Helper method to execute interrupts to ConditionalStatements. - * This method should only interrupt when the statement that throws is a conditional statement. - * @param executionPlan Execution plan. - */ - private def interruptConditionalStatements(executionPlan: NonLeafStatementExec): Unit = { - // Go as deep as possible into the execution plan children nodes, to find a leaf node. - // That leaf node is the next statement that is to be executed. If the parent node of that - // leaf node is a conditional statement, skip the conditional statement entirely. - var currExecPlan = executionPlan - while (currExecPlan.curr.exists(_.isInstanceOf[NonLeafStatementExec])) { - currExecPlan = currExecPlan.curr.get.asInstanceOf[NonLeafStatementExec] - } - - currExecPlan match { - case exec: ConditionalStatementExec => - exec.interrupted = true - case _ => - } - } - /** Helper method to iterate get next statements from the first available frame. */ private def getNextStatement: Option[CompoundStatementExec] = { // Remove frames that are already executed. @@ -122,30 +86,8 @@ class SqlScriptingExecution( context.frames.remove(context.frames.size - 1) - // If the last frame is a handler, set leave statement to be the next one in the - // innermost scope that should be exited. - if (lastFrame.frameType == SqlScriptingFrameType.EXIT_HANDLER - && context.frames.nonEmpty) { - // Remove the scope if handler is executed. - if (context.firstHandlerScopeLabel.isDefined - && lastFrame.scopeLabel.get == context.firstHandlerScopeLabel.get) { - context.firstHandlerScopeLabel = None - } - - // Inject leave statement into the execution plan of the last frame. - injectLeaveStatement(context.frames.last.executionPlan, lastFrame.scopeLabel.get) - } - - if (lastFrame.frameType == SqlScriptingFrameType.CONTINUE_HANDLER - && context.frames.nonEmpty) { - // Remove the scope if handler is executed. - if (context.firstHandlerScopeLabel.isDefined - && lastFrame.scopeLabel.get == context.firstHandlerScopeLabel.get) { - context.firstHandlerScopeLabel = None - } - - // Interrupt conditional statements - interruptConditionalStatements(context.frames.last.executionPlan) + if (context.frames.nonEmpty) { + lastFrame.exitExecutionFrame(context) } } // If there are still frames available, get the next statement. @@ -202,15 +144,17 @@ class SqlScriptingExecution( private def handleException(e: SparkThrowable): Unit = { context.findHandler(e.getCondition, e.getSqlState) match { case Some(handler) => - val handlerFrame = new SqlScriptingExecutionFrame( - handler.body, - if (handler.handlerType == ExceptionHandlerType.CONTINUE) { - SqlScriptingFrameType.CONTINUE_HANDLER - } else { - SqlScriptingFrameType.EXIT_HANDLER - }, - handler.scopeLabel - ) + val handlerFrame = handler.handlerType match { + case ExceptionHandlerType.EXIT => new SqlScriptingExitHandlerExecutionFrame( + handler.body, + handler.scopeLabel + ) + case ExceptionHandlerType.CONTINUE => new SqlScriptingContinueHandlerExecutionFrame( + handler.body, + handler.scopeLabel + ) + } + context.frames.append( handlerFrame ) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionContext.scala index 08ba54e6e4e4d..1cc7c5bd222f5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionContext.scala @@ -24,7 +24,6 @@ import scala.collection.mutable.ListBuffer import org.apache.spark.SparkException import org.apache.spark.sql.catalyst.catalog.{SqlScriptingExecutionContextExtension, VariableDefinition} -import org.apache.spark.sql.scripting.SqlScriptingFrameType.SqlScriptingFrameType /** * SQL scripting execution context - keeps track of the current execution state. @@ -59,8 +58,7 @@ class SqlScriptingExecutionContext extends SqlScriptingExecutionContextExtension } // If the last frame is a handler, try to find a handler in its body first. - if (frames.last.frameType == SqlScriptingFrameType.EXIT_HANDLER - || frames.last.frameType == SqlScriptingFrameType.CONTINUE_HANDLER) { + if (frames.last.isInstanceOf[SqlScriptingHandlerExecutionFrame]) { val handler = frames.last.findHandler(condition, sqlState, firstHandlerScopeLabel) if (handler.isDefined) { return handler @@ -74,7 +72,7 @@ class SqlScriptingExecutionContext extends SqlScriptingExecutionContextExtension val scriptFrame = frames.head val handler = scriptFrame.findHandler(condition, sqlState, firstHandlerScopeLabel) if (handler.isDefined) { - firstHandlerScopeLabel = handler.get.scopeLabel + firstHandlerScopeLabel = Some(handler.get.scopeLabel) return handler } @@ -82,24 +80,14 @@ class SqlScriptingExecutionContext extends SqlScriptingExecutionContextExtension } } -object SqlScriptingFrameType extends Enumeration { - type SqlScriptingFrameType = Value - val SQL_SCRIPT, EXIT_HANDLER, CONTINUE_HANDLER = Value -} - /** - * SQL scripting executor - executes script and returns result statements. + * SQL scripting executor - base class for executing a CompoundBody returning result statements. * This supports returning multiple result statements from a single script. * * @param executionPlan CompoundBody which need to be executed. - * @param frameType Type of the frame. - * @param scopeLabel Label of the scope where handler is defined. - * Available only for frameType = HANDLER. */ -class SqlScriptingExecutionFrame( - val executionPlan: CompoundBodyExec, - val frameType: SqlScriptingFrameType, - val scopeLabel: Option[String] = None) extends Iterator[CompoundStatementExec] { +abstract class SqlScriptingExecutionFrame( + val executionPlan: CompoundBodyExec) extends Iterator[CompoundStatementExec] { // List of scopes that are currently active. private[scripting] val scopes: ListBuffer[SqlScriptingExecutionScope] = ListBuffer.empty @@ -135,16 +123,13 @@ class SqlScriptingExecutionFrame( def currentScope: SqlScriptingExecutionScope = scopes.last - // TODO: Introduce a separate class for different frame types (Script, Stored Procedure, - // Error Handler) implementing SqlScriptingExecutionFrame interface. def findHandler( condition: String, sqlState: String, firstHandlerScopeLabel: Option[String]): Option[ExceptionHandlerExec] = { val searchScopes = - if (frameType == SqlScriptingFrameType.EXIT_HANDLER - || frameType == SqlScriptingFrameType.CONTINUE_HANDLER) { + if (this.isInstanceOf[SqlScriptingHandlerExecutionFrame]) { // If the frame is a handler, search for the handler in its body. Don't skip any scopes. scopes.reverseIterator } else if (firstHandlerScopeLabel.isEmpty) { @@ -167,6 +152,124 @@ class SqlScriptingExecutionFrame( None } + + def exitExecutionFrame(context: SqlScriptingExecutionContext): Unit +} + +object SqlScriptingExecutionFrame { + /** + * Helper method to inject leave statement into the execution plan. + * @param executionPlan Execution plan to inject leave statement into. + * @param label Label of the leave statement. + */ + private[sql] def injectLeaveStatement( + executionPlan: NonLeafStatementExec, label: String): Unit = { + // Go as deep as possible, to find a leaf node. Instead of a statement that + // should be executed next, inject LEAVE statement in its place. + var currExecPlan = executionPlan + while (currExecPlan.curr.exists(_.isInstanceOf[NonLeafStatementExec])) { + currExecPlan = currExecPlan.curr.get.asInstanceOf[NonLeafStatementExec] + } + currExecPlan.curr = Some(new LeaveStatementExec(label)) + } + + /** + * Helper method to execute interrupts to ConditionalStatements. + * This method should only interrupt when the statement that throws is a conditional statement. + * @param executionPlan Execution plan. + */ + private[sql] def interruptConditionalStatements(executionPlan: NonLeafStatementExec): Unit = { + // Go as deep as possible into the execution plan children nodes, to find a leaf node. + // That leaf node is the next statement that is to be executed. If the parent node of that + // leaf node is a conditional statement, skip the conditional statement entirely. + var currExecPlan = executionPlan + while (currExecPlan.curr.exists(_.isInstanceOf[NonLeafStatementExec])) { + currExecPlan = currExecPlan.curr.get.asInstanceOf[NonLeafStatementExec] + } + + currExecPlan match { + case exec: ConditionalStatementExec => + exec.interrupted = true + case _ => + } + } +} + +/** + * SQL scripting script frame executor - executes script and returns result statements. + * This supports returning multiple result statements from a single script. + * @param executionPlan CompoundBody which need to be executed. + */ +class SqlScriptingSqlScriptExecutionFrame( + override val executionPlan: CompoundBodyExec) + extends SqlScriptingExecutionFrame(executionPlan) { + override def exitExecutionFrame(context: SqlScriptingExecutionContext): Unit = {} +} + +/** + * SQL scripting executor - base class for executing handlers. + * @param executionPlan CompoundBody which need to be executed. + * @param scopeLabel Label of the scope where handler is defined. + */ +abstract class SqlScriptingHandlerExecutionFrame( + override val executionPlan: CompoundBodyExec, + val scopeLabel: String) extends SqlScriptingExecutionFrame(executionPlan) + +/** + * SQL scripting executor - executes an exit handler and returns result statements. + * This supports returning multiple result statements from a single script. + * @param executionPlan CompoundBody which need to be executed. + * @param scopeLabel Label of the scope where handler is defined. + */ +class SqlScriptingExitHandlerExecutionFrame( + override val executionPlan: CompoundBodyExec, + override val scopeLabel: String) + extends SqlScriptingHandlerExecutionFrame(executionPlan, scopeLabel) { + + /** + * Set leave statement to be the next one in the innermost scope that should be exited. + * @param context Execution context after the current frame was removed from the + * frame stack. + */ + override def exitExecutionFrame(context: SqlScriptingExecutionContext): Unit = { + // Remove the scope if handler is executed. + if (context.firstHandlerScopeLabel.isDefined + && scopeLabel == context.firstHandlerScopeLabel.get) { + context.firstHandlerScopeLabel = None + } + + // Inject leave statement into the execution plan of the last frame. + SqlScriptingExecutionFrame.injectLeaveStatement(context.frames.last.executionPlan, scopeLabel) + } +} + +/** + * SQL scripting executor - executes a continue handler and returns result statements. + * This supports returning multiple result statements from a single script. + * @param executionPlan CompoundBody which need to be executed. + * @param scopeLabel Label of the scope where handler is defined. + */ +class SqlScriptingContinueHandlerExecutionFrame( + override val executionPlan: CompoundBodyExec, + override val scopeLabel: String) + extends SqlScriptingHandlerExecutionFrame(executionPlan, scopeLabel) { + + /** + * If the last frame is a handler, set leave statement to be the next one in the + * innermost scope that should be exited. + * @param context Execution context after the current frame was removed from the + * frame stack. + */ + override def exitExecutionFrame(context: SqlScriptingExecutionContext): Unit = { + // Remove the scope if handler is executed. + if (context.firstHandlerScopeLabel.isDefined + && scopeLabel == context.firstHandlerScopeLabel.get) { + context.firstHandlerScopeLabel = None + } + + // Interrupt conditional statements. + SqlScriptingExecutionFrame.interruptConditionalStatements(context.frames.last.executionPlan) + } } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNode.scala b/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNode.scala index 598e379c73ac9..6081c132ec254 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNode.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNode.scala @@ -1225,7 +1225,7 @@ class ForStatementExec( class ExceptionHandlerExec( val body: CompoundBodyExec, val handlerType: ExceptionHandlerType, - val scopeLabel: Option[String]) extends NonLeafStatementExec { + val scopeLabel: String) extends NonLeafStatementExec { protected[scripting] var curr: Option[CompoundStatementExec] = body.curr diff --git a/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreter.scala b/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreter.scala index eebdb681f62c1..59f27cd666151 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreter.scala @@ -91,7 +91,7 @@ case class SqlScriptingInterpreter(session: SparkSession) { val handlerExec = new ExceptionHandlerExec( handlerBodyExec, handler.handlerType, - Some(compoundBody.label.get)) + compoundBody.label.get) // For each condition handler is defined for, add corresponding key value pair // to the conditionHandlerMap. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingLocalVariableManager.scala b/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingLocalVariableManager.scala index c9a243db4f08b..4c7affb06ea51 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingLocalVariableManager.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingLocalVariableManager.scala @@ -87,7 +87,10 @@ class SqlScriptingLocalVariableManager(context: SqlScriptingExecutionContext) // including the scope where the previously checked exception handler frame is defined. // Exception handler frames should not have access to variables from scopes // which are nested below the scope where the handler is defined. - var previousFrameDefinitionLabel = context.currentFrame.scopeLabel + var previousFrameDefinitionLabel = context.currentFrame match { + case frame: SqlScriptingHandlerExecutionFrame => Some(frame.scopeLabel) + case _ => None + } // dropRight(1) removes the current frame, which we already checked above. context.frames.dropRight(1).reverseIterator.foreach(frame => { @@ -105,7 +108,10 @@ class SqlScriptingLocalVariableManager(context: SqlScriptingExecutionContext) // in this frame. If we still have not found the variable, we now have to find the definition // of this new frame, so we reassign the frame definition label to search for. if (candidateScopes.nonEmpty) { - previousFrameDefinitionLabel = frame.scopeLabel + previousFrameDefinitionLabel = frame match { + case frame: SqlScriptingHandlerExecutionFrame => Some(frame.scopeLabel) + case _ => None + } } }) None diff --git a/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingTestUtils.scala b/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingTestUtils.scala index e508de6547cd8..19ac3a661d3e5 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingTestUtils.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingTestUtils.scala @@ -53,7 +53,7 @@ trait SqlScriptingTestUtils { val context = new SqlScriptingExecutionContext() val executionPlan = interpreter.buildExecutionPlan(compoundBody, args, context) context.frames.append( - new SqlScriptingExecutionFrame(executionPlan, SqlScriptingFrameType.SQL_SCRIPT) + new SqlScriptingSqlScriptExecutionFrame(executionPlan) ) executionPlan.enterScope()