diff --git a/client-http/pom.xml b/client-http/pom.xml index 875f8cf66..2a72f8722 100644 --- a/client-http/pom.xml +++ b/client-http/pom.xml @@ -112,6 +112,7 @@ org.apache.maven.plugins maven-shade-plugin + 3.4.1 shade diff --git a/docs/rest-api.md b/docs/rest-api.md index fd6142483..46ed3a381 100644 --- a/docs/rest-api.md +++ b/docs/rest-api.md @@ -377,6 +377,96 @@ Returns code completion candidates for the specified code in the session. +### POST /sessions/{sessionId}/tasks + +Submits a pre-compiled Spark job (task) to run in an interactive session. This endpoint allows you to execute compiled Java/Scala Spark jobs within the context of an existing interactive session, providing an alternative to submitting code snippets via statements. + +Unlike statements which execute code strings, tasks run pre-compiled Job implementations that have been serialized and sent to the session. This is useful for running complex, pre-compiled Spark applications while maintaining the interactive session context. + +#### Request Body + + + + + + + + + + + + + +
NameDescriptionType
jobSerialized job data (base64 encoded byte array representing a compiled Job implementation)byte array (required)
jobTypeThe type of job being submitted (e.g., "spark" for Scala/Java jobs)string
+ +#### Response Body + +The task object. + +### GET /sessions/{sessionId}/tasks + +Returns all tasks submitted to this session. + +#### Request Parameters + + + + + + + + + + + + + + + + + + +
NameDescriptionType
fromThe start index to fetch tasksint
sizeNumber of tasks to fetchint
orderProvide value as "desc" to get tasks in descending order (by default, tasks are in ascending order)string
+ +#### Response Body + + + + + + + + + + + + + +
NameDescriptionType
total_tasksTotal number of tasks in this sessionint
tasksTask listlist
+ +### GET /sessions/{sessionId}/tasks/{taskId} + +Returns the status and result of a specific submitted task. + +#### Response Body + +The task object. + +### POST /sessions/{sessionId}/tasks/{taskId}/cancel + +Cancels the specified task in this session. If the task is currently running, Livy will attempt to cancel the associated Spark job group. If the task is waiting, it will be cancelled immediately. + +#### Response Body + + + + + + + + +
NameDescriptionType
msgis always "canceled"string
+ ### GET /batches Returns all the active batch sessions. @@ -893,6 +983,99 @@ A statement represents the result of an execution statement. +### Task + +A task represents a pre-compiled job submitted to an interactive session. Tasks provide a way to execute compiled Spark applications (implementing the `org.apache.livy.Job` interface) within an interactive session context, combining the benefits of pre-compiled code with the flexibility of interactive sessions. + +**Key differences between Tasks and Statements:** +- **Statements** execute code strings (Scala, Python, R, or SQL) interactively +- **Tasks** execute pre-compiled, serialized Job implementations + +Tasks are useful when you have complex Spark logic that has been compiled and tested, but you want to run it in the context of an existing interactive session without creating a separate batch job. + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
NameDescriptionType
idThe task id (unique within the session)integer
stateThe current execution state of the tasktask state
outputThe serialized task result as a byte array (if completed successfully)byte array
errorThe error message (if the task failed)string
serializedExceptionThe serialized exception object (if the task failed)byte array
progressThe execution progress (0.0 to 1.0)double
submittedTimestamp when the task was submitted (milliseconds since epoch)long
completedTimestamp when the task completed (milliseconds since epoch)long
+ +#### Task State + + + + + + + + + + + + + + + + + + + + + + + + + + + +
ValueDescription
waitingTask has been submitted and is waiting to start execution
runningTask is currently executing in the Spark context
availableTask completed successfully and results are available
failedTask execution failed with an error
cancellingTask cancellation has been requested and is in progress
cancelledTask was successfully cancelled before completion
+ +**Valid State Transitions:** +- `waiting` → `running` (task starts execution) +- `waiting` → `cancelled` (task cancelled before starting) +- `running` → `available` (task completes successfully) +- `running` → `failed` (task encounters an error) +- `running` → `cancelling` (cancellation requested) +- `cancelling` → `cancelled` (cancellation completes) +- `cancelling` → `failed` (task fails during cancellation) + ### Batch diff --git a/pom.xml b/pom.xml index a88e449fa..9f86f44b5 100644 --- a/pom.xml +++ b/pom.xml @@ -85,7 +85,7 @@ 2.4.52.4.5${spark.scala-2.11.version} - 5.6.0 + 6.8.13.0.01.153.17.0 @@ -1169,6 +1169,13 @@ + + hadoop3 + + 3 + 3.4.0 + + hadoop2 @@ -1222,7 +1229,7 @@ 1.8 0.10.9.7 3.7.0-M11 - 4.1.96.Final + 4.1.108.Final 2.15.2 2.15.2 spark-${spark.version}-bin-hadoop${hadoop.major-minor.version} diff --git a/repl/pom.xml b/repl/pom.xml index 6a37cb515..6123c94c0 100644 --- a/repl/pom.xml +++ b/repl/pom.xml @@ -176,6 +176,7 @@ org.apache.maven.plugins maven-shade-plugin + 3.4.1 shade diff --git a/repl/scala-2.12/src/main/scala/org/apache/livy/repl/SparkInterpreter.scala b/repl/scala-2.12/src/main/scala/org/apache/livy/repl/SparkInterpreter.scala index c3756a505..fadd62f50 100644 --- a/repl/scala-2.12/src/main/scala/org/apache/livy/repl/SparkInterpreter.scala +++ b/repl/scala-2.12/src/main/scala/org/apache/livy/repl/SparkInterpreter.scala @@ -107,14 +107,14 @@ class SparkInterpreter(protected override val conf: SparkConf) extends AbstractS sparkILoop.interpret(code) } - override protected def completeCandidates(code: String, cursor: Int) : Array[String] = { - val completer : Completion = { + override protected def completeCandidates(code: String, cursor: Int): Array[String] = { + val completer: Completion = { try { val cls = Class.forName("scala.tools.nsc.interpreter.PresentationCompilerCompleter") cls.getDeclaredConstructor(classOf[IMain]).newInstance(sparkILoop.intp) .asInstanceOf[Completion] } catch { - case e : ClassNotFoundException => NoCompletion + case e: ClassNotFoundException => NoCompletion } } completer.complete(code, cursor).candidates.toArray @@ -126,9 +126,9 @@ class SparkInterpreter(protected override val conf: SparkConf) extends AbstractS } override protected def bind(name: String, - tpe: String, - value: Object, - modifier: List[String]): Unit = { + tpe: String, + value: Object, + modifier: List[String]): Unit = { sparkILoop.beQuietDuring { sparkILoop.bind(name, tpe, value, modifier) } diff --git a/repl/src/main/scala/org/apache/livy/repl/ReplDriver.scala b/repl/src/main/scala/org/apache/livy/repl/ReplDriver.scala index 9ac82e82a..13791e61c 100644 --- a/repl/src/main/scala/org/apache/livy/repl/ReplDriver.scala +++ b/repl/src/main/scala/org/apache/livy/repl/ReplDriver.scala @@ -25,8 +25,8 @@ import org.apache.spark.SparkConf import org.apache.livy.{EOLUtils, Logging} import org.apache.livy.client.common.ClientConf -import org.apache.livy.rsc.{BaseProtocol, ReplJobResults, RSCConf} -import org.apache.livy.rsc.BaseProtocol.ReplState +import org.apache.livy.rsc.{BaseProtocol, ReplJobResults, RSCConf, TaskResults} +import org.apache.livy.rsc.BaseProtocol.{CancelTaskRequest, GetTaskResults, ReplState} import org.apache.livy.rsc.driver._ import org.apache.livy.rsc.rpc.Rpc import org.apache.livy.sessions._ @@ -81,7 +81,7 @@ class ReplDriver(conf: SparkConf, livyConf: RSCConf) session.statements.get(msg.from).toArray } else { val until = msg.from + msg.size - session.statements.filterKeys(id => id >= msg.from && id < until).values.toArray + session.statements.filterKeys(id => id >= msg.from).take(until).values.toArray } } @@ -93,6 +93,42 @@ class ReplDriver(conf: SparkConf, livyConf: RSCConf) new ReplJobResults(statements.sortBy(_.id)) } + def handle(ctx: ChannelHandlerContext, msg: BaseProtocol.TaskJobRequest): Int = { + val context = jobContext() + val serializer = this.serializer() + val job = new BypassJob(serializer, msg.serializedJob) + session.submitTask(job, context) + } + + /** + * Return a specific task by ID from the session's task registry + */ + def handle(ctx: ChannelHandlerContext, msg: GetTaskResults): TaskResults = { + val tasks = if (msg.allResults) { + session.tasks.values.toArray + } else { + assert(msg.from != null) + assert(msg.size != null) + if (msg.size == 1) { + session.tasks.get(msg.from).toArray + } else { + val until = msg.from + msg.size + session.tasks.filterKeys(id => id >= msg.from).take(until).values.toArray + } + } + + // Update progress of statements when queried + tasks.foreach { s => + s.updateProgress(session.progressOfTask(s.id)) + } + + new TaskResults(tasks.sortBy(_.id)) + } + + def handle(ctx: ChannelHandlerContext, msg: CancelTaskRequest): Unit = { + session.cancelTask(msg.taskId) + } + override protected def createWrapper(msg: BaseProtocol.BypassJobRequest): BypassJobWrapper = { Kind(msg.jobType) match { case PySpark if session.interpreter(PySpark).isDefined => diff --git a/repl/src/main/scala/org/apache/livy/repl/Session.scala b/repl/src/main/scala/org/apache/livy/repl/Session.scala index 262c811c7..366d7ecb6 100644 --- a/repl/src/main/scala/org/apache/livy/repl/Session.scala +++ b/repl/src/main/scala/org/apache/livy/repl/Session.scala @@ -33,9 +33,10 @@ import org.json4s.jackson.JsonMethods.{compact, render} import org.json4s.DefaultFormats import org.json4s.JsonDSL._ -import org.apache.livy.Logging +import org.apache.livy.{Job, JobContext, Logging} +import org.apache.livy.client.common.Serializer import org.apache.livy.rsc.RSCConf -import org.apache.livy.rsc.driver.{SparkEntries, Statement, StatementState} +import org.apache.livy.rsc.driver.{SparkEntries, Statement, StatementState, Task, TaskState} import org.apache.livy.sessions._ object Session { @@ -70,13 +71,23 @@ class Session( // Number of statements kept in driver's memory private val numRetainedStatements = livyConf.getInt(RSCConf.Entry.RETAINED_STATEMENTS) + // Number of tasks kept in driver's memory + private val numRetainedTasks = livyConf.getInt(RSCConf.Entry.RETAINED_TASKS) + private val _statements = new JLinkedHashMap[Int, Statement] { protected override def removeEldestEntry(eldest: Entry[Int, Statement]): Boolean = { size() > numRetainedStatements } }.asScala - private val newStatementId = new AtomicInteger(0) + // Number of tasks kept in driver's memory (use same config as statements) + private val _tasks = new JLinkedHashMap[Int, Task] { + protected override def removeEldestEntry(eldest: Entry[Int, Task]): Boolean = { + size() > numRetainedTasks + } + }.asScala + + private val newId = new AtomicInteger(0) private val defaultInterpKind = Kind(livyConf.get(RSCConf.Entry.SESSION_KIND)) @@ -86,7 +97,7 @@ class Session( stateChangedCallback(_state) - private def sc: SparkContext = { + private[repl] def sc: SparkContext = { require(entries != null) entries.sc().sc } @@ -156,12 +167,12 @@ class Session( throw new IllegalArgumentException(s"Code type should be specified if session kind is shared") } - val statementId = newStatementId.getAndIncrement() + val statementId = newId.getAndIncrement() val statement = new Statement(statementId, code, StatementState.Waiting, null) _statements.synchronized { _statements(statementId) = statement } Future { - setJobGroup(tpe, statementId) + setJobGroup(tpe, statementIdToJobGroup(statementId)) statement.compareAndTransit(StatementState.Waiting, StatementState.Running) if (statement.state.get() == StatementState.Running) { @@ -232,6 +243,87 @@ class Session( interpGroup.values.foreach(_.close()) } + def tasks: collection.Map[Int, Task] = _tasks.synchronized { + _tasks.toMap + } + + def submitTask(job: Job[Array[Byte]], jc: JobContext): Int = { + val taskId = newId.getAndIncrement() + val task = new Task(taskId, TaskState.Waiting, null, null) + task.submitted = System.currentTimeMillis() + _tasks.synchronized { _tasks(taskId) = task } + + Future { + task.compareAndTransit(TaskState.Waiting, TaskState.Running) + + try { + if (task.state.get() == TaskState.Running) { + task.submitted = System.currentTimeMillis() + jc.sc().setJobGroup(task.id.toString, s"Job group for task ${task.id}") + task.output = job.call(jc) + task.compareAndTransit(TaskState.Running, TaskState.Available) + } + } catch { + case e: Throwable => + task.error = e.toString + task.serializedException = new Serializer().serialize(e).array() + task.compareAndTransit(TaskState.Running, TaskState.Failed) + } + + task.compareAndTransit(TaskState.Cancelling, TaskState.Cancelled) + task.updateProgress(1.0) + task.completed = System.currentTimeMillis() + }(interpreterExecutor) + taskId + } + + def getTask(taskId: Integer): Option[Task] = _tasks.synchronized { + _tasks.get(taskId) + } + + def cancelTask(taskId: Int): Unit = { + val taskOpt = _tasks.synchronized { _tasks.get(taskId) } + if (taskOpt.isEmpty) { + return + } + + val task = taskOpt.get + if (task.state.get().isOneOf( + TaskState.Available, TaskState.Cancelled, TaskState.Cancelling)) { + return + } else { + // statement 1 is running and statement 2 is waiting. User cancels + // statement 2 then cancels statement 1. The 2nd cancel call will loop and block the 1st + // cancel call since cancelExecutor is single threaded. To avoid this, set the statement + // state to cancelled when cancelling a waiting statement. + task.compareAndTransit(TaskState.Waiting, TaskState.Cancelled) + task.compareAndTransit(TaskState.Running, TaskState.Cancelling) + } + + info(s"Cancelling task $taskId...") + + Future { + val deadline = livyConf.getTimeAsMs(RSCConf.Entry.JOB_CANCEL_TIMEOUT).millis.fromNow + + while (task.state.get() == TaskState.Cancelling) { + if (deadline.isOverdue()) { + info(s"Failed to cancel task $taskId.") + task.compareAndTransit(TaskState.Cancelling, TaskState.Cancelled) + } else { + sc.cancelJobGroup(taskId.toString) + if (task.state.get() == TaskState.Cancelling) { + Thread.sleep(livyConf.getTimeAsMs(RSCConf.Entry.JOB_CANCEL_TRIGGER_INTERVAL)) + } + } + } + + if (task.state.get() == TaskState.Cancelled) { + task.completed = System.currentTimeMillis() + info(s"Task $taskId cancelled.") + } + }(cancelExecutor) + } + /** * Get the current progress of given statement id. */ @@ -253,6 +345,24 @@ class Session( } } + def progressOfTask(taskId: Int): Double = { + val jobGroup = taskIdToJobGroup(taskId) + + val jobIds = sc.statusTracker.getJobIdsForGroup(jobGroup) + val jobs = jobIds.flatMap { id => sc.statusTracker.getJobInfo(id) } + val stages = jobs.flatMap { job => + job.stageIds().flatMap(sc.statusTracker.getStageInfo) + } + + val taskCount = stages.map(_.numTasks).sum + val completedTaskCount = stages.map(_.numCompletedTasks).sum + if (taskCount == 0) { + 0.0 + } else { + completedTaskCount.toDouble / taskCount + } + } + private def changeState(newState: SessionState): Unit = { synchronized { _state = newState @@ -266,7 +376,7 @@ class Session( changeState(SessionState.Busy) def transitToIdle() = { - val executingLastStatement = executionCount == newStatementId.intValue() - 1 + val executingLastStatement = executionCount == newId.intValue() - 1 if (_statements.isEmpty || executingLastStatement) { changeState(SessionState.Idle) } @@ -333,32 +443,36 @@ class Session( compact(render(resultInJson)) } - private def setJobGroup(codeType: Kind, statementId: Int): String = { - val jobGroup = statementIdToJobGroup(statementId) + private def setJobGroup(codeType: Kind, jobGroupId: String): String = { val (cmd, tpe) = codeType match { case Spark | SQL => // A dummy value to avoid automatic value binding in scala REPL. - (s"""val _livyJobGroup$jobGroup = sc.setJobGroup("$jobGroup",""" + - s""""Job group for statement $jobGroup")""", + (s"""val _livyJobGroup$jobGroupId = sc.setJobGroup("$jobGroupId",""" + + s""""Job group for statement $jobGroupId")""", Spark) case PySpark => - (s"""sc.setJobGroup("$jobGroup", "Job group for statement $jobGroup")""", PySpark) + (s"""sc.setJobGroup("$jobGroupId", "Job group for statement $jobGroupId")""", PySpark) case SparkR => sc.getConf.get("spark.livy.spark_major_version", "1") match { case "1" => - (s"""setJobGroup(sc, "$jobGroup", "Job group for statement $jobGroup", FALSE)""", + (s"""setJobGroup(sc, "$jobGroupId", "Job group for statement $jobGroupId", FALSE)""", SparkR) case "2" | "3" => - (s"""setJobGroup("$jobGroup", "Job group for statement $jobGroup", FALSE)""", SparkR) + (s"""setJobGroup("$jobGroupId", "Job group for statement $jobGroupId", FALSE)""", + SparkR) case v => throw new IllegalArgumentException(s"Unknown Spark major version [$v]") } } // Set the job group - executeCode(interpreter(tpe), statementId, cmd) + executeCode(interpreter(tpe), jobGroupId.toInt, cmd) } private def statementIdToJobGroup(statementId: Int): String = { statementId.toString } + + private def taskIdToJobGroup(taskId: Int): String = { + taskId.toString + } } diff --git a/repl/src/test/scala/org/apache/livy/repl/SQLInterpreterSpec.scala b/repl/src/test/scala/org/apache/livy/repl/SQLInterpreterSpec.scala index 5e839d4c8..1526922ab 100644 --- a/repl/src/test/scala/org/apache/livy/repl/SQLInterpreterSpec.scala +++ b/repl/src/test/scala/org/apache/livy/repl/SQLInterpreterSpec.scala @@ -30,6 +30,7 @@ import org.apache.livy.rsc.RSCConf import org.apache.livy.rsc.driver.SparkEntries case class People(name: String, age: Int) + case class Person(name: String, birthday: Date) class SQLInterpreterSpec extends BaseInterpreterSpec { @@ -72,7 +73,9 @@ class SQLInterpreterSpec extends BaseInterpreterSpec { ))) ) - val result = Try { resp1 should equal(expectedResult)} + val result = Try { + resp1 should equal(expectedResult) + } if (result.isFailure) { fail(s"$resp1 doesn't equal to expected result") } @@ -104,7 +107,9 @@ class SQLInterpreterSpec extends BaseInterpreterSpec { ))) ) - val result = Try { resp1 should equal(expectedResult)} + val result = Try { + resp1 should equal(expectedResult) + } if (result.isFailure) { fail(s"$resp1 doesn't equal to expected result") } @@ -139,7 +144,7 @@ class SQLInterpreterSpec extends BaseInterpreterSpec { ) // Test empty result - val resp2 = interpreter.execute( + val resp2 = interpreter.execute( """ |SELECT name FROM people WHERE age > 22 """.stripMargin) @@ -192,8 +197,10 @@ class SQLInterpreterSpec extends BaseInterpreterSpec { assert(resp.isInstanceOf[Interpreter.ExecuteError]) val error = resp.asInstanceOf[Interpreter.ExecuteError] - error.ename should be ("Error") - assert(error.evalue.contains("not found") || error.evalue.contains("cannot be found")) + error.ename should be("Error") + assert(error.evalue.contains("not found") + || error.evalue.contains("TABLE_OR_VIEW_NOT_FOUND") + || error.evalue.contains("cannot be found")) } it should "fail if submitting multiple queries" in withInterpreter { interpreter => @@ -204,6 +211,6 @@ class SQLInterpreterSpec extends BaseInterpreterSpec { """.stripMargin) assert(resp.isInstanceOf[Interpreter.ExecuteError]) - resp.asInstanceOf[Interpreter.ExecuteError].ename should be ("Error") + resp.asInstanceOf[Interpreter.ExecuteError].ename should be("Error") } } diff --git a/repl/src/test/scala/org/apache/livy/repl/SessionTaskSpec.scala b/repl/src/test/scala/org/apache/livy/repl/SessionTaskSpec.scala new file mode 100644 index 000000000..adf4b5490 --- /dev/null +++ b/repl/src/test/scala/org/apache/livy/repl/SessionTaskSpec.scala @@ -0,0 +1,168 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.livy.repl + +import java.nio.ByteBuffer + +import scala.concurrent.duration._ +import scala.language.postfixOps +import scala.math.random + +import org.mockito.Mockito.when +import org.scalatest.concurrent.Eventually.eventually +import org.scalatest.concurrent.Waiters.{interval, timeout} +import org.scalatestplus.mockito.MockitoSugar.mock + +import org.apache.livy._ +import org.apache.livy.client.common.Serializer +import org.apache.livy.rsc.driver.TaskState +import org.apache.livy.sessions.Spark + +class SessionTaskSpec extends BaseSessionSpec(Spark) { + + lazy val serializer = new Serializer() + + val computePi: Job[Double] = new Job[Double]() { + override def call(jc: JobContext): Double = { + val slices = 100 + val n = math.min(1000000L * slices, Int.MaxValue).toInt + val xs = 1 until n + val rdd = jc.sc.parallelize(xs, slices) + .setName("'Initial rdd'") + val sample = rdd.map { _ => + val x = random * 2 - 1 + val y = random * 2 - 1 + (x, y) + }.setName("'Random points sample'") + + val inside = sample.filter { case (x, y) => x * x + y * y < 1 } + .setName("'Random points inside circle'") + + val count = inside.count() + + 4.0 * count / n + } + } + + def wrapJob[T](job: Job[T]): Job[Array[Byte]] = new Job[Array[Byte]] { + override def call(jc: JobContext): Array[Byte] = serializer.serialize(job.call(jc)).array() + } + + def makeJob(fn: JobContext => Array[Byte]): Job[Array[Byte]] = new Job[Array[Byte]] { + override def call(jc: JobContext): Array[Byte] = fn(jc) + } + + val wrappedComputePi: Job[Array[Byte]] = wrapJob(computePi) + + + it should "should submit and retrieve a task result after completion" in withSession { session => + val context = mock[JobContext] + when(context.sc()).thenReturn(session.sc) + + val job = makeJob(_ => "result".getBytes) + + val taskId = session.submitTask(job, context) + + eventually(timeout(30 seconds), interval(100 millis)) { + val task = session.getTask(taskId).get + assert(task.state.get() == TaskState.Available) + } + + val task = session.getTask(taskId).get + task.output should be("result".getBytes) + } + + it should "should update task state and it's progression" in withSession { session => + val context = mock[JobContext] + when(context.sc()).thenReturn(session.sc) + + val taskId = session.submitTask(wrappedComputePi, context) + + // Wait for the task to start and check that progress is non-zero. + eventually(timeout(30 seconds), interval(100 millis)) { + session.tasks(taskId).updateProgress(session.progressOfTask(taskId)) + val task = session.getTask(taskId).get + assert(task.state.get() == TaskState.Running) + assert(task.progress > 0.0) + } + } + + it should "should be able to cancel a task" in withSession { session => + val context = mock[JobContext] + when(context.sc()).thenReturn(session.sc) + + val taskId = session.submitTask(wrappedComputePi, context) + + // Wait for the task to start and check that progress is non-zero. + eventually(timeout(30 seconds), interval(100 millis)) { + session.tasks(taskId).updateProgress(session.progressOfTask(taskId)) + val task = session.getTask(taskId).get + assert(task.state.get() == TaskState.Running) + assert(task.progress > 0.0) + } + + session.cancelTask(taskId) + + // Check that the task is cancelled. + eventually(timeout(30 seconds), interval(100 millis)) { + val task = session.getTask(taskId).get + assert(task.state.get() == TaskState.Cancelled) + } + } + + it should "should retrieve any checked exception thrown" in withSession { session => + val context = mock[JobContext] + when(context.sc()).thenReturn(session.sc) + + val job = makeJob(_ => throw new Exception("test")) + val taskId = session.submitTask(job, context) + + // Wait for the task to start and check that progress is non-zero. + eventually(timeout(30 seconds), interval(100 millis)) { + val task = session.getTask(taskId).get + assert(task.state.get() == TaskState.Failed) + } + + val task = session.getTask(taskId).get + val deserializedException = serializer.deserialize(ByteBuffer.wrap(task.serializedException)) + task.output should be(null) + task.error should be("java.lang.Exception: test") + deserializedException.toString should be(new Exception("test").toString) + } + + it should "should retrieve any unchecked exception thrown" in withSession { session => + val context = mock[JobContext] + when(context.sc()).thenReturn(session.sc) + + val job = makeJob(_ => throw new RuntimeException("test")) + val taskId = session.submitTask(job, context) + + // Wait for the task to start and check that progress is non-zero. + eventually(timeout(30 seconds), interval(100 millis)) { + val task = session.getTask(taskId).get + assert(task.state.get() == TaskState.Failed) + } + + val task = session.getTask(taskId).get + val deserializedException = serializer.deserialize(ByteBuffer.wrap(task.serializedException)) + task.output should be(null) + task.error should be("java.lang.RuntimeException: test") + deserializedException.toString should be(new RuntimeException("test").toString) + } +} + diff --git a/rsc/pom.xml b/rsc/pom.xml index df5e92355..74221d751 100644 --- a/rsc/pom.xml +++ b/rsc/pom.xml @@ -64,6 +64,7 @@ io.netty netty-all + ${netty.version} org.apache.spark @@ -114,6 +115,7 @@ org.apache.hadoop hadoop-common + ${hadoop.version} provided diff --git a/rsc/src/main/java/org/apache/livy/rsc/BaseProtocol.java b/rsc/src/main/java/org/apache/livy/rsc/BaseProtocol.java index 6b7bab1df..56894b9bc 100644 --- a/rsc/src/main/java/org/apache/livy/rsc/BaseProtocol.java +++ b/rsc/src/main/java/org/apache/livy/rsc/BaseProtocol.java @@ -243,6 +243,47 @@ public CancelReplJobRequest() { } } + public static class TaskJobRequest { + public final byte[] serializedJob; + + public TaskJobRequest(byte[] serializedJob) { + this.serializedJob = serializedJob; + } + + public TaskJobRequest() { + this(null); + } + } + + public static class GetTaskResults { + public boolean allResults; + public Integer from, size; + + public GetTaskResults(Integer from, Integer size) { + this.allResults = false; + this.from = from; + this.size = size; + } + + public GetTaskResults() { + this.allResults = true; + from = null; + size = null; + } + } + + public static class CancelTaskRequest { + public final int taskId; + + public CancelTaskRequest(int taskId) { + this.taskId = taskId; + } + + public CancelTaskRequest() { + this(-1); + } + } + public static class InitializationError { public final String stackTrace; diff --git a/rsc/src/main/java/org/apache/livy/rsc/RSCClient.java b/rsc/src/main/java/org/apache/livy/rsc/RSCClient.java index ee9c9012f..54e985aec 100644 --- a/rsc/src/main/java/org/apache/livy/rsc/RSCClient.java +++ b/rsc/src/main/java/org/apache/livy/rsc/RSCClient.java @@ -315,6 +315,25 @@ public Future completeReplCode(String code, String codeType, int curso String[].class); } + public Future submitTask(byte[] serializedJob) { + return deferredCall(new BaseProtocol.TaskJobRequest(serializedJob), Integer.class); + } + + /** + * Get task information from the driver's task registry + */ + public Future getTaskResults(Integer from, Integer size) throws Exception { + return deferredCall(new BaseProtocol.GetTaskResults(from, size), TaskResults.class); + } + + public Future getTaskResults() throws Exception { + return deferredCall(new BaseProtocol.GetTaskResults(), TaskResults.class); + } + + public void cancelTask(int taskId) throws Exception { + deferredCall(new BaseProtocol.CancelTaskRequest(taskId), Void.class); + } + /** * @return Return the repl state. If this's not connected to a repl session, it will return null. */ diff --git a/rsc/src/main/java/org/apache/livy/rsc/RSCConf.java b/rsc/src/main/java/org/apache/livy/rsc/RSCConf.java index 4c45956d7..978a34db9 100644 --- a/rsc/src/main/java/org/apache/livy/rsc/RSCConf.java +++ b/rsc/src/main/java/org/apache/livy/rsc/RSCConf.java @@ -79,6 +79,7 @@ public static enum Entry implements ConfEntry { JOB_CANCEL_TIMEOUT("job-cancel.timeout", "30s"), RETAINED_STATEMENTS("retained-statements", 100), + RETAINED_TASKS("retained-tasks", 100), RETAINED_SHARE_VARIABLES("retained.share-variables", 100), // Number of result rows to get for SQL Interpreters. diff --git a/rsc/src/main/java/org/apache/livy/rsc/TaskResults.java b/rsc/src/main/java/org/apache/livy/rsc/TaskResults.java new file mode 100644 index 000000000..c369808c5 --- /dev/null +++ b/rsc/src/main/java/org/apache/livy/rsc/TaskResults.java @@ -0,0 +1,32 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.livy.rsc; + +import org.apache.livy.rsc.driver.Task; + +public class TaskResults { + public final Task[] tasks; + + public TaskResults(Task[] tasks) { + this.tasks = tasks; + } + + TaskResults() { + this(null); + } +} diff --git a/rsc/src/main/java/org/apache/livy/rsc/driver/BypassJob.java b/rsc/src/main/java/org/apache/livy/rsc/driver/BypassJob.java index f0d14c667..2e7342418 100644 --- a/rsc/src/main/java/org/apache/livy/rsc/driver/BypassJob.java +++ b/rsc/src/main/java/org/apache/livy/rsc/driver/BypassJob.java @@ -24,12 +24,12 @@ import org.apache.livy.client.common.BufferUtils; import org.apache.livy.client.common.Serializer; -class BypassJob implements Job { +public class BypassJob implements Job { private final Serializer serializer; private final byte[] serializedJob; - BypassJob(Serializer serializer, byte[] serializedJob) { + public BypassJob(Serializer serializer, byte[] serializedJob) { this.serializer = serializer; this.serializedJob = serializedJob; } diff --git a/rsc/src/main/java/org/apache/livy/rsc/driver/RSCDriver.java b/rsc/src/main/java/org/apache/livy/rsc/driver/RSCDriver.java index b5b99f624..e1cc1a31d 100644 --- a/rsc/src/main/java/org/apache/livy/rsc/driver/RSCDriver.java +++ b/rsc/src/main/java/org/apache/livy/rsc/driver/RSCDriver.java @@ -52,6 +52,7 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import org.apache.livy.JobContext; import org.apache.livy.client.common.Serializer; import org.apache.livy.rsc.BaseProtocol; import org.apache.livy.rsc.BypassJobStatus; @@ -384,11 +385,11 @@ public void submit(JobWrapper job) { } } - JobContextImpl jobContext() { + protected JobContext jobContext() { return jc; } - Serializer serializer() { + protected Serializer serializer() { return serializer; } diff --git a/rsc/src/main/java/org/apache/livy/rsc/driver/Task.java b/rsc/src/main/java/org/apache/livy/rsc/driver/Task.java new file mode 100644 index 000000000..53eb2d6ed --- /dev/null +++ b/rsc/src/main/java/org/apache/livy/rsc/driver/Task.java @@ -0,0 +1,64 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.livy.rsc.driver; + +import java.util.concurrent.atomic.AtomicReference; + +import com.fasterxml.jackson.annotation.JsonRawValue; + +/** + * Represents a submitted task (job) in an interactive session. + */ +public class Task { + public final Integer id; + public final AtomicReference state; + @JsonRawValue + public volatile byte[] output; + public volatile byte[] serializedException; + public volatile String error; + public double progress; + public long submitted = 0; + public long completed = 0; + + public Task(Integer id, TaskState state, byte[] output, String error) { + this.id = id; + this.state = new AtomicReference<>(state); + this.output = output; + this.error = error; + } + + public Task() { + this(null, null, null, null); + } + + public boolean compareAndTransit(final TaskState from, final TaskState to) { + if (state.compareAndSet(from, to)) { + TaskState.validate(from, to); + return true; + } + return false; + } + + public void updateProgress(double p) { + if (this.state.get().isOneOf(TaskState.Cancelled, TaskState.Available)) { + this.progress = 1.0; + } else { + this.progress = p; + } + } +} diff --git a/rsc/src/main/java/org/apache/livy/rsc/driver/TaskState.java b/rsc/src/main/java/org/apache/livy/rsc/driver/TaskState.java new file mode 100644 index 000000000..c79190851 --- /dev/null +++ b/rsc/src/main/java/org/apache/livy/rsc/driver/TaskState.java @@ -0,0 +1,90 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.livy.rsc.driver; + +import java.util.*; + +import com.fasterxml.jackson.annotation.JsonValue; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +public enum TaskState { + Waiting("waiting"), + Started("started"), + Running("running"), + Available("available"), + Cancelling("cancelling"), + Cancelled("cancelled"), + Starting("starting"), + Failed("failed"); + private static final Logger LOG = LoggerFactory.getLogger(TaskState.class); + + private final String state; + + TaskState(final String text) { + this.state = text; + } + + @JsonValue + @Override + public String toString() { + return state; + } + + public boolean isOneOf(TaskState... states) { + for (TaskState s : states) { + if (s == this) { + return true; + } + } + return false; + } + + private static final Map> PREDECESSORS; + + static void put(TaskState key, + Map> map, + TaskState... values) { + map.put(key, Collections.unmodifiableList(Arrays.asList(values))); + } + + static { + final Map> predecessors = + new EnumMap<>(TaskState.class); + put(Waiting, predecessors); + put(Running, predecessors, Waiting); + put(Available, predecessors, Running); + put(Cancelling, predecessors, Running); + put(Failed, predecessors, Running, Cancelling); + put(Cancelled, predecessors, Waiting, Cancelling); + + PREDECESSORS = Collections.unmodifiableMap(predecessors); + } + + static boolean isValid(TaskState from, TaskState to) { + return PREDECESSORS.get(to).contains(from); + } + + static void validate(TaskState from, TaskState to) { + LOG.debug("{} -> {}", from, to); + if (!isValid(from, to)) { + throw new IllegalStateException("Illegal Transition: " + from + " -> " + to); + } + } + +} diff --git a/server/pom.xml b/server/pom.xml index 7e5bc0dc5..92f26ba1f 100644 --- a/server/pom.xml +++ b/server/pom.xml @@ -52,6 +52,12 @@ ${project.version} + + org.apache.commons + commons-lang3 + 3.12.0 + + org.apache.livy livy-test-lib diff --git a/server/src/main/scala/org/apache/livy/server/batch/BatchSession.scala b/server/src/main/scala/org/apache/livy/server/batch/BatchSession.scala index 8b64a0398..70144b685 100644 --- a/server/src/main/scala/org/apache/livy/server/batch/BatchSession.scala +++ b/server/src/main/scala/org/apache/livy/server/batch/BatchSession.scala @@ -28,19 +28,21 @@ import com.fasterxml.jackson.annotation.JsonIgnoreProperties import org.apache.livy.{LivyConf, Logging, Utils} import org.apache.livy.server.AccessManager import org.apache.livy.server.recovery.SessionStore -import org.apache.livy.sessions.{FinishedSessionState, Session, SessionState} +import org.apache.livy.sessions.{Session, SessionState} import org.apache.livy.sessions.Session._ import org.apache.livy.utils.{AppInfo, SparkApp, SparkAppListener, SparkProcessBuilder} + @JsonIgnoreProperties(ignoreUnknown = true) case class BatchRecoveryMetadata( - id: Int, - name: Option[String], - appId: Option[String], - appTag: String, - owner: String, - proxyUser: Option[String], - version: Int = 1) + id: Int, + name: Option[String], + appId: Option[String], + appTag: String, + owner: String, + proxyUser: Option[String], + namespace: String, + version: Int = 1) extends RecoveryMetadata object BatchSession extends Logging { @@ -53,17 +55,18 @@ object BatchSession extends Logging { } def create( - id: Int, - name: Option[String], - request: CreateBatchRequest, - livyConf: LivyConf, - accessManager: AccessManager, - owner: String, - proxyUser: Option[String], - sessionStore: SessionStore, - mockApp: Option[SparkApp] = None): BatchSession = { + id: Int, + name: Option[String], + request: CreateBatchRequest, + livyConf: LivyConf, + accessManager: AccessManager, + owner: String, + proxyUser: Option[String], + sessionStore: SessionStore, + mockApp: Option[SparkApp] = None): BatchSession = { val appTag = s"livy-batch-$id-${Random.alphanumeric.take(8).mkString}".toLowerCase() val impersonatedUser = accessManager.checkImpersonation(proxyUser, owner) + val namespace = SparkApp.getNamespace(request.conf, livyConf) def createSparkApp(s: BatchSession): SparkApp = { val conf = SparkApp.prepareSparkConf( @@ -106,7 +109,8 @@ object BatchSession extends Logging { childProcesses.decrementAndGet() } } - SparkApp.create(appTag, None, Option(sparkSubmit), livyConf, Option(s)) + val extrasMap: Map[String, String] = Map(SparkApp.SPARK_KUBERNETES_NAMESPACE_KEY -> namespace) + SparkApp.create(appTag, None, Option(sparkSubmit), livyConf, Option(s), extrasMap) } info(s"Creating batch session $id: [owner: $owner, request: $request]") @@ -120,14 +124,15 @@ object BatchSession extends Logging { owner, impersonatedUser, sessionStore, + namespace, mockApp.map { m => (_: BatchSession) => m }.getOrElse(createSparkApp)) } def recover( - m: BatchRecoveryMetadata, - livyConf: LivyConf, - sessionStore: SessionStore, - mockApp: Option[SparkApp] = None): BatchSession = { + m: BatchRecoveryMetadata, + livyConf: LivyConf, + sessionStore: SessionStore, + mockApp: Option[SparkApp] = None): BatchSession = { new BatchSession( m.id, m.name, @@ -137,23 +142,27 @@ object BatchSession extends Logging { m.owner, m.proxyUser, sessionStore, + m.namespace, mockApp.map { m => (_: BatchSession) => m }.getOrElse { s => - SparkApp.create(m.appTag, m.appId, None, livyConf, Option(s)) + SparkApp.create(m.appTag, m.appId, None, + livyConf, Option(s), Map(SparkApp.SPARK_KUBERNETES_NAMESPACE_KEY -> m.namespace)) }) } } class BatchSession( - id: Int, - name: Option[String], - appTag: String, - initialState: SessionState, - livyConf: LivyConf, - owner: String, - override val proxyUser: Option[String], - sessionStore: SessionStore, - sparkApp: BatchSession => SparkApp) + id: Int, + name: Option[String], + appTag: String, + initialState: SessionState, + livyConf: LivyConf, + owner: String, + override val proxyUser: Option[String], + sessionStore: SessionStore, + namespace: String, + sparkApp: BatchSession => SparkApp) extends Session(id, name, owner, livyConf) with SparkAppListener { + import BatchSession._ protected implicit def executor: ExecutionContextExecutor = ExecutionContext.global @@ -201,8 +210,10 @@ class BatchSession( } } - override def infoChanged(appInfo: AppInfo): Unit = { this.appInfo = appInfo } + override def infoChanged(appInfo: AppInfo): Unit = { + this.appInfo = appInfo + } override def recoveryMetadata: RecoveryMetadata = - BatchRecoveryMetadata(id, name, appId, appTag, owner, proxyUser) + BatchRecoveryMetadata(id, name, appId, appTag, owner, proxyUser, namespace) } diff --git a/server/src/main/scala/org/apache/livy/server/interactive/InteractiveSession.scala b/server/src/main/scala/org/apache/livy/server/interactive/InteractiveSession.scala index 4250794dc..00ca2e24f 100644 --- a/server/src/main/scala/org/apache/livy/server/interactive/InteractiveSession.scala +++ b/server/src/main/scala/org/apache/livy/server/interactive/InteractiveSession.scala @@ -26,9 +26,9 @@ import java.util.concurrent.atomic.AtomicLong import scala.collection.JavaConverters._ import scala.collection.mutable -import scala.concurrent.{Future, Promise} +import scala.concurrent.Future import scala.concurrent.duration.{Duration, FiniteDuration} -import scala.util.{Random, Try} +import scala.util.Random import com.fasterxml.jackson.annotation.JsonIgnoreProperties import org.apache.hadoop.fs.Path @@ -37,7 +37,7 @@ import org.apache.spark.launcher.SparkLauncher import org.apache.livy._ import org.apache.livy.client.common.HttpMessages._ import org.apache.livy.rsc.{PingJob, RSCClient, RSCConf} -import org.apache.livy.rsc.driver.Statement +import org.apache.livy.rsc.driver.{Statement, Task, TaskState} import org.apache.livy.server.AccessManager import org.apache.livy.server.recovery.SessionStore import org.apache.livy.sessions._ @@ -70,6 +70,7 @@ case class InteractiveRecoveryMetadata( // proxyUser is deprecated. It is available here only for backward compatibility proxyUser: Option[String], rscDriverUri: Option[URI], + namespace: String, version: Int = 1) extends RecoveryMetadata @@ -93,6 +94,7 @@ object InteractiveSession extends Logging { mockClient: Option[RSCClient] = None): InteractiveSession = { val appTag = s"livy-session-$id-${Random.alphanumeric.take(8).mkString}".toLowerCase() val impersonatedUser = accessManager.checkImpersonation(proxyUser, owner) + val namespace = SparkApp.getNamespace(request.conf, livyConf) val client = mockClient.orElse { val conf = SparkApp.prepareSparkConf(appTag, livyConf, prepareConf( @@ -153,6 +155,7 @@ object InteractiveSession extends Logging { request.numExecutors, request.pyFiles, request.queue, + namespace, mockApp) } @@ -193,6 +196,7 @@ object InteractiveSession extends Logging { metadata.numExecutors, metadata.pyFiles, metadata.queue, + metadata.namespace, mockApp) } @@ -433,6 +437,7 @@ class InteractiveSession( val numExecutors: Option[Int], val pyFiles: List[String], val queue: Option[String], + val namespace: String, mockApp: Option[SparkApp]) // For unit test. extends Session(id, name, owner, ttl, idleTimeout, livyConf) with SessionHeartbeat @@ -448,6 +453,7 @@ class InteractiveSession( } private val operations = mutable.Map[Long, String]() private val operationCounter = new AtomicLong(0) + private val taskRegistry = mutable.Map[Long, String]() // Track job types for tasks private var rscDriverUri: Option[URI] = None private var sessionLog: IndexedSeq[String] = IndexedSeq.empty private val sessionSaveLock = new Object() @@ -462,11 +468,15 @@ class InteractiveSession( app = mockApp.orElse { val driverProcess = client.flatMap { c => Option(c.getDriverProcess) } .map(new LineBufferedProcess(_, livyConf.getInt(LivyConf.SPARK_LOGS_SIZE))) - if (!livyConf.isRunningOnKubernetes()) { - driverProcess.map(_ => SparkApp.create(appTag, appId, driverProcess, livyConf, Some(this))) + val namespace = SparkApp.getNamespace(conf, livyConf) + val extrasMap: Map[String, String] = Map(SparkApp.SPARK_KUBERNETES_NAMESPACE_KEY -> namespace) + if (!livyConf.isRunningOnKubernetes()) { + driverProcess.map(_ => + SparkApp.create(appTag, appId, driverProcess, livyConf, Some(this), extrasMap) + ) } else { // Create SparkApp for Kubernetes anyway - Some(SparkApp.create(appTag, appId, driverProcess, livyConf, Some(this))) + Some(SparkApp.create(appTag, appId, driverProcess, livyConf, Some(this), extrasMap)) } } @@ -535,7 +545,7 @@ class InteractiveSession( heartbeatTimeout.toSeconds.toInt, owner, ttl, idleTimeout, driverMemory, driverCores, executorMemory, executorCores, conf, archives, files, jars, numExecutors, pyFiles, queue, - proxyUser, rscDriverUri) + proxyUser, rscDriverUri, namespace) override def state: SessionState = { if (serverSideState == SessionState.Running) { @@ -651,7 +661,29 @@ class InteractiveSession( operations.remove(id).foreach { client.get.cancel } } - private def transition(newState: SessionState) = synchronized { + def executeTask(job: Array[Byte]): Task = { + ensureRunning() + recordActivity() + val taskId = client.get.submitTask(job).get + new Task(taskId, TaskState.Waiting, null, null) + } + + def tasks: IndexedSeq[Task] = { + ensureActive() + client.get.getTaskResults().get().tasks.toIndexedSeq + } + + def getTask(taskId: Int): Option[Task] = { + ensureActive() + client.get.getTaskResults(taskId, 1).get().tasks.headOption + } + + def cancelTask(taskId: Int): Unit = { + ensureActive() + client.get.cancelTask(taskId) + } + + private def transition(newState: SessionState): Unit = synchronized { // When a statement returns an error, the session should transit to error state. // If the session crashed because of the error, the session should instead go to dead state. // Since these 2 transitions are triggered by different threads, there's a race condition. diff --git a/server/src/main/scala/org/apache/livy/server/interactive/InteractiveSessionServlet.scala b/server/src/main/scala/org/apache/livy/server/interactive/InteractiveSessionServlet.scala index 310504e75..fc6bb92b4 100644 --- a/server/src/main/scala/org/apache/livy/server/interactive/InteractiveSessionServlet.scala +++ b/server/src/main/scala/org/apache/livy/server/interactive/InteractiveSessionServlet.scala @@ -193,6 +193,54 @@ class InteractiveSessionServlet( Ok(Map("msg" -> "canceled")) } } + + get("/:id/tasks") { + withViewAccessSession { session => + val order = params.get("order") + val tasks = if (order.map(_.trim).exists(_.equalsIgnoreCase("desc"))) { + session.tasks.reverse + } else { + session.tasks + } + val from = params.get("from").map(_.toInt).getOrElse(0) + val size = params.get("size").map(_.toInt).getOrElse(tasks.length) + + Map( + "total_tasks" -> tasks.length, + "tasks" -> tasks.view(from, from + size) + ) + } + } + + // Task endpoints - allow users to submit jobs in interactive sessions + val getTask = get("/:id/tasks/:taskId") { + withViewAccessSession { session => + val taskId = params("taskId").toInt + + session.getTask(taskId).getOrElse(NotFound("Task not found")) + } + } + + jpost[SerializedJob]("/:id/tasks") { req => + withModifyAccessSession { session => + require(req.job != null && req.job.length > 0, "no job provided.") + val task = session.executeTask(req.job) + Created(task, + headers = Map( + "Location" -> url(getTask, + "id" -> session.id.toString, + "taskId" -> task.id.toString))) + } + } + + post("/:id/tasks/:taskId/cancel") { + withModifyAccessSession { session => + val taskId = params("taskId").toInt + session.cancelTask(taskId) + Ok(Map("msg" -> "canceled")) + } + } + // This endpoint is used by the client-http module to "connect" to an existing session and // update its last activity time. It performs authorization checks to make sure the caller // has access to the session, so even though it returns the same data, it behaves differently diff --git a/server/src/main/scala/org/apache/livy/utils/SparkApp.scala b/server/src/main/scala/org/apache/livy/utils/SparkApp.scala index e424f80fc..4450260b7 100644 --- a/server/src/main/scala/org/apache/livy/utils/SparkApp.scala +++ b/server/src/main/scala/org/apache/livy/utils/SparkApp.scala @@ -17,6 +17,9 @@ package org.apache.livy.utils +import java.io.{File, FileInputStream} +import java.util.Properties + import scala.collection.JavaConverters._ import org.apache.livy.LivyConf @@ -28,10 +31,12 @@ object AppInfo { } case class AppInfo( - var driverLogUrl: Option[String] = None, - var sparkUiUrl: Option[String] = None, - var executorLogUrls: Option[String] = None) { + var driverLogUrl: Option[String] = None, + var sparkUiUrl: Option[String] = None, + var executorLogUrls: Option[String] = None) { + import AppInfo._ + def asJavaMap: java.util.Map[String, String] = Map( DRIVER_LOG_URL_NAME -> driverLogUrl.orNull, @@ -56,12 +61,27 @@ trait SparkAppListener { */ object SparkApp { private val SPARK_YARN_TAG_KEY = "spark.yarn.tags" + val SPARK_KUBERNETES_NAMESPACE_KEY = "spark.kubernetes.namespace" object State extends Enumeration { val STARTING, RUNNING, FINISHED, FAILED, KILLED = Value } + type State = State.Value + def getNamespace(conf: Map[String, String], livyConf: LivyConf): String = { + var namespace: String = conf.getOrElse(SPARK_KUBERNETES_NAMESPACE_KEY, "") + if (namespace == "") { + val sparkHome = livyConf.sparkHome().get // SPARK_HOME is mandatory for Livy + val sparkDefaultsPath = + sparkHome + File.separator + "conf" + File.separator + "spark-defaults.conf" + val properties = new Properties() + properties.load(new FileInputStream(sparkDefaultsPath)) + namespace = properties.getProperty(SPARK_KUBERNETES_NAMESPACE_KEY, "default") + } + namespace + } + /** * Return cluster manager dependent SparkConf. * @@ -70,9 +90,9 @@ object SparkApp { * @param sparkConf */ def prepareSparkConf( - uniqueAppTag: String, - livyConf: LivyConf, - sparkConf: Map[String, String]): Map[String, String] = { + uniqueAppTag: String, + livyConf: LivyConf, + sparkConf: Map[String, String]): Map[String, String] = { if (livyConf.isRunningOnYarn()) { val userYarnTags = sparkConf.get(SPARK_YARN_TAG_KEY).map("," + _).getOrElse("") val mergedYarnTags = uniqueAppTag + userYarnTags @@ -98,15 +118,16 @@ object SparkApp { * @param uniqueAppTag A tag that can uniquely identify the application. */ def create( - uniqueAppTag: String, - appId: Option[String], - process: Option[LineBufferedProcess], - livyConf: LivyConf, - listener: Option[SparkAppListener]): SparkApp = { + uniqueAppTag: String, + appId: Option[String], + process: Option[LineBufferedProcess], + livyConf: LivyConf, + listener: Option[SparkAppListener], + extrasMap: Map[String, String]): SparkApp = { if (livyConf.isRunningOnYarn()) { new SparkYarnApp(uniqueAppTag, appId, process, listener, livyConf) } else if (livyConf.isRunningOnKubernetes()) { - new SparkKubernetesApp(uniqueAppTag, appId, process, listener, livyConf) + new SparkKubernetesApp(uniqueAppTag, appId, process, listener, livyConf, extrasMap) } else { require(process.isDefined, "process must not be None when Livy master is not YARN or" + "Kubernetes.") @@ -120,5 +141,6 @@ object SparkApp { */ abstract class SparkApp { def kill(): Unit + def log(): IndexedSeq[String] } diff --git a/server/src/main/scala/org/apache/livy/utils/SparkKubernetesApp.scala b/server/src/main/scala/org/apache/livy/utils/SparkKubernetesApp.scala index 0f466095a..25c3c35a5 100644 --- a/server/src/main/scala/org/apache/livy/utils/SparkKubernetesApp.scala +++ b/server/src/main/scala/org/apache/livy/utils/SparkKubernetesApp.scala @@ -21,6 +21,8 @@ import java.util.Collections import java.util.concurrent._ import scala.annotation.tailrec +import scala.collection.JavaConverters.asScalaSetConverter +import scala.collection.mutable import scala.collection.mutable.ArrayBuffer import scala.concurrent._ import scala.concurrent.duration._ @@ -52,14 +54,19 @@ object SparkKubernetesApp extends Logging { val iter = leakedAppTags.entrySet().iterator() var isRemoved = false val now = System.currentTimeMillis() - val apps = withRetry(kubernetesClient.getApplications()) + val apps = appNamespaces.flatMap(namespace => + withRetry(kubernetesClient.inNamespace(namespace).getApplications()) + ) while (iter.hasNext) { val entry = iter.next() apps.find(_.getApplicationTag.contains(entry.getKey)) .foreach({ app => info(s"Kill leaked app ${app.getApplicationId}") - withRetry(kubernetesClient.killApplication(app)) + withRetry(kubernetesClient + .inNamespace(app.getApplicationNamespace) + .killApplication(app) + ) iter.remove() isRemoved = true }) @@ -138,6 +145,8 @@ object SparkKubernetesApp extends Logging { private var sessionLeakageCheckInterval: Long = _ var kubernetesClient: DefaultKubernetesClient = _ + var appNamespaces: mutable.Set[String] = + java.util.concurrent.ConcurrentHashMap.newKeySet[String]().asScala private var appLookupThreadPoolSize: Long = _ private var appLookupMaxFailedTimes: Long = _ @@ -146,8 +155,7 @@ object SparkKubernetesApp extends Logging { this.livyConf = livyConf // KubernetesClient is thread safe. Create once, share it across threads. - kubernetesClient = - KubernetesClientFactory.createKubernetesClient(livyConf) + kubernetesClient = KubernetesClientFactory.createKubernetesClient(livyConf) cacheLogSize = livyConf.getInt(LivyConf.SPARK_LOGS_SIZE) appLookupTimeout = livyConf.getTimeAsMs(LivyConf.KUBERNETES_APP_LOOKUP_TIMEOUT).milliseconds @@ -181,7 +189,9 @@ object SparkKubernetesApp extends Logging { // until envoy comes back, which could take upto 30 seconds @tailrec private def withRetry[T](fn: => T, n: Int = 10, retryBackoff: Long = 3000): T = { - Try { fn } match { + Try { + fn + } match { case Success(x) => x case _ if n > 1 => Thread.sleep(Math.max(retryBackoff, 3000)) @@ -245,7 +255,9 @@ class SparkKubernetesApp private[utils] ( process: Option[LineBufferedProcess], listener: Option[SparkAppListener], livyConf: LivyConf, - kubernetesClient: => KubernetesClient = SparkKubernetesApp.kubernetesClient) // For unit test. + extrasMap: Map[String, String], + kubernetesClient: => DefaultKubernetesClient = + SparkKubernetesApp.kubernetesClient) // For unit test. extends SparkApp with Logging { @@ -262,6 +274,9 @@ class SparkKubernetesApp private[utils] ( private var kubernetesTagToAppIdFailedTimes: Int = _ private var kubernetesAppMonitorFailedTimes: Int = _ + private var namespace: String = extrasMap(SparkApp.SPARK_KUBERNETES_NAMESPACE_KEY) + appNamespaces.add(namespace) + private def failToMonitor(): Unit = { changeState(SparkApp.State.FAILED) process.foreach(_.destroy()) @@ -292,10 +307,11 @@ class SparkKubernetesApp private[utils] ( } // Get KubernetesApplication by appTag. val appOption: Option[KubernetesApplication] = try { - getAppFromTag(appTag, pollInterval, appLookupTimeout.fromNow) + getAppFromTag(appTag, pollInterval, appLookupTimeout.fromNow, namespace) } catch { case e: Exception => failToGetAppId() + error(s"Exception getting app from tag $appTag in namespace $namespace with message: ", e) appPromise.failure(e) return } @@ -311,7 +327,7 @@ class SparkKubernetesApp private[utils] ( listener.foreach(_.appIdKnown(appId)) if (livyConf.getBoolean(LivyConf.KUBERNETES_INGRESS_CREATE)) { - withRetry(kubernetesClient.createSparkUIIngress(app, livyConf)) + withRetry(kubernetesClient.inNamespace(namespace).createSparkUIIngress(app, livyConf)) } var appInfo = AppInfo() @@ -326,7 +342,7 @@ class SparkKubernetesApp private[utils] ( debug(s"getApplicationReport, applicationId: ${app.getApplicationId}, " + s"namespace: ${app.getApplicationNamespace} " + s"applicationTag: ${app.getApplicationTag}") - val report = kubernetesClient.getApplicationReport(livyConf, app, + val report = kubernetesClient.inNamespace(namespace).getApplicationReport(livyConf, app, cacheLogSize = cacheLogSize) report } @@ -392,14 +408,17 @@ class SparkKubernetesApp private[utils] ( process.foreach(_.destroy()) def applicationDetails: Option[Try[KubernetesApplication]] = appPromise.future.value + if (applicationDetails.isEmpty) { leakedAppTags.put(appTag, System.currentTimeMillis()) return } + def kubernetesApplication: KubernetesApplication = applicationDetails.get.get + if (kubernetesApplication != null && kubernetesApplication.getApplicationId != null) { try { - withRetry(kubernetesClient.killApplication( + withRetry(kubernetesClient.inNamespace(namespace).killApplication( Await.result(appPromise.future, appLookupTimeout))) } catch { // We cannot kill the Kubernetes app without the appTag. @@ -431,19 +450,22 @@ class SparkKubernetesApp private[utils] ( } /** - * Find the corresponding KubernetesApplication from an application tag. - * - * @param appTag The application tag tagged on the target application. - * If the tag is not unique, it returns the first application it found. - * @return Option[KubernetesApplication] or the failure. - */ + * Find the corresponding KubernetesApplication from an application tag. + * + * @param appTag The application tag tagged on the target application. + * If the tag is not unique, it returns the first application it found. + * @return Option[KubernetesApplication] or the failure. + */ private def getAppFromTag( appTag: String, pollInterval: duration.Duration, - deadline: Deadline): Option[KubernetesApplication] = { + deadline: Deadline, namespace: String): Option[KubernetesApplication] = { import KubernetesExtensions._ - - withRetry(kubernetesClient.getApplications().find(_.getApplicationTag.contains(appTag))) + withRetry(kubernetesClient + .inNamespace(namespace) + .getApplications() + .find(_.getApplicationTag.contains(appTag)) + ) match { case Some(app) => Some(app) case None => @@ -465,9 +487,9 @@ class SparkKubernetesApp private[utils] ( // Exposed for unit test. private[utils] def mapKubernetesState( - kubernetesAppState: String, - appTag: String - ): SparkApp.State.Value = { + kubernetesAppState: String, + appTag: String + ): SparkApp.State.Value = { import KubernetesApplicationState._ kubernetesAppState.toLowerCase match { case PENDING | CONTAINER_CREATING => @@ -566,7 +588,7 @@ private[utils] case class KubernetesAppReport(driver: Option[Pod], executors: Se import scala.collection.JavaConverters._ val sparkContainerName = livyConf.get(LivyConf.KUBERNETES_SPARK_CONTAINER_NAME) for (c <- podStatus.getContainerStatuses.asScala) { - if (c.getName == sparkContainerName && c.getState.getTerminated != null) { + if (c.getName == sparkContainerName && c.getState.getTerminated != null) { val exitCode = c.getState.getTerminated.getExitCode if (exitCode == 0) { return Some(KubernetesApplicationState.SUCCEEDED) @@ -668,9 +690,11 @@ private[utils] case class KubernetesAppReport(driver: Option[Pod], executors: Se } private[utils] object KubernetesExtensions { + import KubernetesConstants._ implicit class KubernetesClientExtensions(client: KubernetesClient) { + import scala.collection.JavaConverters._ private val NGINX_CONFIG_SNIPPET: String = @@ -682,12 +706,11 @@ private[utils] object KubernetesExtensions { """.stripMargin def getApplications( - labels: Map[String, String] = Map(SPARK_ROLE_LABEL -> SPARK_ROLE_DRIVER), - appTagLabel: String = SPARK_APP_TAG_LABEL, - appIdLabel: String = SPARK_APP_ID_LABEL - ): Seq[KubernetesApplication] = { - client.pods.inAnyNamespace - .withLabels(labels.asJava) + labels: Map[String, String] = Map(SPARK_ROLE_LABEL -> SPARK_ROLE_DRIVER), + appTagLabel: String = SPARK_APP_TAG_LABEL, + appIdLabel: String = SPARK_APP_ID_LABEL + ): Seq[KubernetesApplication] = { + client.pods.withLabels(labels.asJava) .withLabel(appTagLabel) .withLabel(appIdLabel) .list.getItems.asScala.map(new KubernetesApplication(_)) @@ -698,11 +721,11 @@ private[utils] object KubernetesExtensions { } def getApplicationReport( - livyConf: LivyConf, - app: KubernetesApplication, - cacheLogSize: Int, - appTagLabel: String = SPARK_APP_TAG_LABEL - ): KubernetesAppReport = { + livyConf: LivyConf, + app: KubernetesApplication, + cacheLogSize: Int, + appTagLabel: String = SPARK_APP_TAG_LABEL + ): KubernetesAppReport = { val pods = client.pods.inNamespace(app.getApplicationNamespace) .withLabels(Map(appTagLabel -> app.getApplicationTag).asJava) .list.getItems.asScala @@ -746,7 +769,8 @@ private[utils] object KubernetesExtensions { private[utils] def buildSparkUIIngress( app: KubernetesApplication, className: String, protocol: String, host: String, - tlsSecretName: String, additionalConfSnippet: String, additionalAnnotations: (String, String)* + tlsSecretName: String, additionalConfSnippet: String, + additionalAnnotations: (String, String)* ): Ingress = { val appTag = app.getApplicationTag val serviceHost = s"${getServiceName(app)}.${app.getApplicationNamespace}.svc.cluster.local" @@ -823,6 +847,7 @@ private[utils] object KubernetesExtensions { } private[utils] object KubernetesClientFactory { + import java.io.File import com.google.common.base.Charsets import com.google.common.io.Files @@ -831,6 +856,7 @@ private[utils] object KubernetesClientFactory { def toOption: Option[String] = if (string == null || string.isEmpty) None else Option(string) } + def createKubernetesClient(livyConf: LivyConf): DefaultKubernetesClient = { val masterUrl = sparkMasterToKubernetesApi(livyConf.sparkMaster()) @@ -878,7 +904,7 @@ private[utils] object KubernetesClientFactory { val configBuilder: ConfigBuilder) extends AnyVal { def withOption[T] (option: Option[T]) - (configurator: (T, ConfigBuilder) => ConfigBuilder): ConfigBuilder = { + (configurator: (T, ConfigBuilder) => ConfigBuilder): ConfigBuilder = { option.map { opt => configurator(opt, configBuilder) }.getOrElse(configBuilder) diff --git a/server/src/test/scala/org/apache/livy/server/batch/BatchSessionSpec.scala b/server/src/test/scala/org/apache/livy/server/batch/BatchSessionSpec.scala index 401a8beb1..481fcae3d 100644 --- a/server/src/test/scala/org/apache/livy/server/batch/BatchSessionSpec.scala +++ b/server/src/test/scala/org/apache/livy/server/batch/BatchSessionSpec.scala @@ -143,7 +143,7 @@ class BatchSessionSpec val req = new CreateBatchRequest() val name = Some("Test Batch Session") val mockApp = mock[SparkApp] - val m = BatchRecoveryMetadata(99, name, None, "appTag", null, None) + val m = BatchRecoveryMetadata(99, name, None, "appTag", null, None, "") val batch = BatchSession.recover(m, conf, sessionStore, Some(mockApp)) batch.state shouldBe (SessionState.Recovering) diff --git a/server/src/test/scala/org/apache/livy/server/interactive/InteractiveSessionSpec.scala b/server/src/test/scala/org/apache/livy/server/interactive/InteractiveSessionSpec.scala index e7d651f89..a7f30293a 100644 --- a/server/src/test/scala/org/apache/livy/server/interactive/InteractiveSessionSpec.scala +++ b/server/src/test/scala/org/apache/livy/server/interactive/InteractiveSessionSpec.scala @@ -18,10 +18,12 @@ package org.apache.livy.server.interactive import java.net.URI +import java.nio.ByteBuffer import scala.concurrent.Await import scala.concurrent.duration._ import scala.language.postfixOps +import scala.math.random import org.apache.spark.launcher.SparkLauncher import org.json4s.{DefaultFormats, Extraction, JValue} @@ -33,9 +35,10 @@ import org.scalatest.{BeforeAndAfterAll, FunSpec, Matchers} import org.scalatest.concurrent.Eventually._ import org.scalatestplus.mockito.MockitoSugar.mock -import org.apache.livy.{ExecuteRequest, JobHandle, LivyBaseUnitTestSuite, LivyConf} +import org.apache.livy._ +import org.apache.livy.client.common.Serializer import org.apache.livy.rsc.{PingJob, RSCClient, RSCConf} -import org.apache.livy.rsc.driver.StatementState +import org.apache.livy.rsc.driver.{StatementState, TaskState} import org.apache.livy.server.AccessManager import org.apache.livy.server.recovery.SessionStore import org.apache.livy.sessions.{PySpark, SessionState, Spark} @@ -262,6 +265,69 @@ class InteractiveSessionSpec extends FunSpec } } + withSession("should get task progress along with task result") { session => + val job = new Job[Int]() { + override def call(jc: JobContext): Int = 1234567 + } + val serializer = new Serializer() + val task = session.executeTask(serializer.serialize(job).array()) + task.progress should be (0.0) + + eventually(timeout(10 seconds), interval(100 millis)) { + val t = session.getTask(task.id).get + t.state.get() shouldBe TaskState.Available + t.progress should be (1.0) + } + + val result = serializer.deserialize(ByteBuffer.wrap(session.getTask(task.id).get.output)) + .asInstanceOf[Int] + result should be (1234567) + } + + withSession("should return None for non-existent tasks") { session => + session.getTask(9999999) should be (None) + } + + withSession("should cancel tasks") { session => + val computePi: Job[Double] = new Job[Double]() { + override def call(jc: JobContext): Double = { + val slices = 100 + val n = math.min(1000000L * slices, Int.MaxValue).toInt + val xs = 1 until n + val rdd = jc.sc.parallelize(xs, slices) + .setName("'Initial rdd'") + val sample = rdd.map { _ => + val x = random * 2 - 1 + val y = random * 2 - 1 + (x, y) + }.setName("'Random points sample'") + + val inside = sample.filter { case (x, y) => x * x + y * y < 1 } + .setName("'Random points inside circle'") + + val count = inside.count() + + 4.0 * count / n + } + } + val serializer = new Serializer() + val task = session.executeTask(serializer.serialize(computePi).array()) + + // Wait for the task to start. + eventually(timeout(10 seconds), interval(100 millis)) { + val t = session.getTask(task.id).get + t.state.get() shouldBe TaskState.Running + } + + session.cancelTask(task.id) + + // Wait for the task to be canceled. + eventually(timeout(30 seconds), interval(100 millis)) { + val t = session.getTask(task.id).get + t.state.get() shouldBe TaskState.Cancelled + } + } + withSession("should error out the session if the interpreter dies") { session => session.executeStatement(ExecuteRequest("import os; os._exit(666)", None)) eventually(timeout(30 seconds), interval(100 millis)) { @@ -279,7 +345,7 @@ class InteractiveSessionSpec extends FunSpec val m = InteractiveRecoveryMetadata( 78, Some("Test session"), None, "appTag", Spark, 0, null, None, None, None, None, None, None, Map.empty[String, String], List.empty[String], List.empty[String], - List.empty[String], None, List.empty[String], None, None, Some(URI.create(""))) + List.empty[String], None, List.empty[String], None, None, Some(URI.create("")), "") val s = InteractiveSession.recover(m, conf, sessionStore, None, Some(mockClient)) s.start() @@ -298,7 +364,7 @@ class InteractiveSessionSpec extends FunSpec val m = InteractiveRecoveryMetadata( 78, None, None, "appTag", Spark, 0, null, None, None, None, None, None, None, Map.empty[String, String], List.empty[String], List.empty[String], - List.empty[String], None, List.empty[String], None, None, Some(URI.create(""))) + List.empty[String], None, List.empty[String], None, None, Some(URI.create("")), "") val s = InteractiveSession.recover(m, conf, sessionStore, None, Some(mockClient)) s.start() @@ -315,11 +381,12 @@ class InteractiveSessionSpec extends FunSpec val m = InteractiveRecoveryMetadata( 78, None, Some("appId"), "appTag", Spark, 0, null, None, None, None, None, None, None, Map.empty[String, String], List.empty[String], List.empty[String], - List.empty[String], None, List.empty[String], None, None, None) + List.empty[String], None, List.empty[String], None, None, None, "") val s = InteractiveSession.recover(m, conf, sessionStore, None) s.start() s.state shouldBe a[SessionState.Dead] s.logLines().mkString should include("RSCDriver URI is unknown") } } + } diff --git a/server/src/test/scala/org/apache/livy/sessions/SessionManagerSpec.scala b/server/src/test/scala/org/apache/livy/sessions/SessionManagerSpec.scala index 363b01f89..62b35b6e0 100644 --- a/server/src/test/scala/org/apache/livy/sessions/SessionManagerSpec.scala +++ b/server/src/test/scala/org/apache/livy/sessions/SessionManagerSpec.scala @@ -215,7 +215,7 @@ class SessionManagerSpec extends FunSpec with Matchers with LivyBaseUnitTestSuit implicit def executor: ExecutionContext = ExecutionContext.global def makeMetadata(id: Int, appTag: String): BatchRecoveryMetadata = { - BatchRecoveryMetadata(id, Some(s"test-session-$id"), None, appTag, null, None) + BatchRecoveryMetadata(id, Some(s"test-session-$id"), None, appTag, null, None, "") } def mockSession(id: Int): BatchSession = {