Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 25 additions & 11 deletions repl/src/main/scala/org/apache/livy/repl/Session.scala
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ package org.apache.livy.repl

import java.util.{LinkedHashMap => JLinkedHashMap}
import java.util.Map.Entry
import java.util.concurrent.Executors
import java.util.concurrent.{ConcurrentHashMap, Executors}
import java.util.concurrent.atomic.AtomicInteger

import scala.collection.JavaConverters._
Expand Down Expand Up @@ -63,6 +63,8 @@ class Session(
private val cancelExecutor = ExecutionContext.fromExecutorService(
Executors.newSingleThreadExecutor())

private val statementThreads = new ConcurrentHashMap[Int, Thread]()

private implicit val formats = DefaultFormats

private var _state: SessionState = SessionState.NotStarted
Expand Down Expand Up @@ -161,18 +163,29 @@ class Session(
_statements.synchronized { _statements(statementId) = statement }

Future {
setJobGroup(tpe, statementId)
statement.compareAndTransit(StatementState.Waiting, StatementState.Running)
val currentThread = Thread.currentThread()
statementThreads.put(statementId, currentThread)
try {
setJobGroup(tpe, statementId)
statement.compareAndTransit(StatementState.Waiting, StatementState.Running)

if (statement.state.get() == StatementState.Running) {
statement.started = System.currentTimeMillis()
statement.output = executeCode(interpreter(tpe), statementId, code)
}
if (statement.state.get() == StatementState.Running) {
statement.started = System.currentTimeMillis()
statement.output = executeCode(interpreter(tpe), statementId, code)
}

statement.compareAndTransit(StatementState.Running, StatementState.Available)
statement.compareAndTransit(StatementState.Cancelling, StatementState.Cancelled)
statement.updateProgress(1.0)
statement.completed = System.currentTimeMillis()
statement.compareAndTransit(StatementState.Running, StatementState.Available)
statement.compareAndTransit(StatementState.Cancelling, StatementState.Cancelled)
statement.updateProgress(1.0)
statement.completed = System.currentTimeMillis()
} finally {
statementThreads.remove(statementId, currentThread)
// Clear the interrupt flag, but log if the thread was interrupted.
if (Thread.interrupted()) {
warn(s"Thread was interrupted during execution of statement $statementId; " +
"interrupt flag cleared.")
}
}
}(interpreterExecutor)

statementId
Expand Down Expand Up @@ -212,6 +225,7 @@ class Session(
info(s"Failed to cancel statement $statementId.")
statement.compareAndTransit(StatementState.Cancelling, StatementState.Cancelled)
} else {
Option(statementThreads.get(statementId)).foreach(_.interrupt())
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ArnavBalyan - Can you please verify that this change is only intended to interrupt interruptible tasks such as sleep, object.wait, etc? As @gyogal has mentioned this will not interrupt actual long running threads.

I am also curious about the fact that this is being called within the while loop. Are we waiting for the state to be successfully set to Cancelled instead of still Cancelling? The upon failure, we set it to Cancelled after timing out therefore it will not try to be interrupted again?

sc.cancelJobGroup(statementId.toString)
if (statement.state.get() == StatementState.Cancelling) {
Thread.sleep(livyConf.getTimeAsMs(RSCConf.Entry.JOB_CANCEL_TRIGGER_INTERVAL))
Expand Down
27 changes: 27 additions & 0 deletions repl/src/test/scala/org/apache/livy/repl/SparkSessionSpec.scala
Original file line number Diff line number Diff line change
Expand Up @@ -236,6 +236,33 @@ class SparkSessionSpec extends BaseSessionSpec(Spark) {
}
}

it should "cancel driver code without spark jobs" in withSession { session =>
val stmtId = session.execute(
"""
|Thread.sleep(5000)
|val r = 1 + 1
|r
""".stripMargin)

eventually(timeout(30 seconds), interval(100 millis)) {
assert(session.statements(stmtId).state.get() == StatementState.Running)
}

session.cancel(stmtId)

eventually(timeout(30 seconds), interval(100 millis)) {
val statement = session.statements(stmtId)
assert(statement.state.get() == StatementState.Cancelled)
val resultJson = parse(statement.output)
(resultJson \ "status").extract[String] should equal ("error")
statement.output should not include ("r: Int = 2")
}

val followUp = execute(session)("r")
val followUpResult = parse(followUp.output)
(followUpResult \ "status").extract[String] should equal ("error")
}

it should "correctly calculate progress" in withSession { session =>
val executeCode =
"""
Expand Down