Skip to content

Commit 061bb01

Browse files
committed
[SPARK-25248][CORE] Audit barrier Scala APIs for 2.4
## What changes were proposed in this pull request? I made one pass over barrier APIs added to Spark 2.4 and updates some scopes and docs. I will update Python docs once Scala doc was reviewed. One major issue is that `BarrierTaskContext` implements `TaskContextImpl` that exposes some public methods. And internally there were several direct references to `TaskContextImpl` methods instead of `TaskContext`. This PR moved some methods from `TaskContextImpl` to `TaskContext`, remaining package private, and used delegate methods to avoid inheriting `TaskContextImp` and exposing unnecessary APIs. TODOs: - [x] scala doc - [x] python doc (apache#22261 ). Closes apache#22240 from mengxr/SPARK-25248. Authored-by: Xiangrui Meng <[email protected]> Signed-off-by: Xiangrui Meng <[email protected]>
1 parent 3aa6028 commit 061bb01

File tree

11 files changed

+163
-64
lines changed

11 files changed

+163
-64
lines changed

core/src/main/scala/org/apache/spark/BarrierTaskContext.scala

Lines changed: 93 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -24,25 +24,22 @@ import scala.language.postfixOps
2424

2525
import org.apache.spark.annotation.{Experimental, Since}
2626
import org.apache.spark.executor.TaskMetrics
27+
import org.apache.spark.internal.Logging
2728
import org.apache.spark.memory.TaskMemoryManager
28-
import org.apache.spark.metrics.MetricsSystem
29+
import org.apache.spark.metrics.source.Source
2930
import org.apache.spark.rpc.{RpcEndpointRef, RpcTimeout}
30-
import org.apache.spark.util.{RpcUtils, Utils}
31-
32-
/** A [[TaskContext]] with extra info and tooling for a barrier stage. */
33-
class BarrierTaskContext(
34-
override val stageId: Int,
35-
override val stageAttemptNumber: Int,
36-
override val partitionId: Int,
37-
override val taskAttemptId: Long,
38-
override val attemptNumber: Int,
39-
override val taskMemoryManager: TaskMemoryManager,
40-
localProperties: Properties,
41-
@transient private val metricsSystem: MetricsSystem,
42-
// The default value is only used in tests.
43-
override val taskMetrics: TaskMetrics = TaskMetrics.empty)
44-
extends TaskContextImpl(stageId, stageAttemptNumber, partitionId, taskAttemptId, attemptNumber,
45-
taskMemoryManager, localProperties, metricsSystem, taskMetrics) {
31+
import org.apache.spark.shuffle.FetchFailedException
32+
import org.apache.spark.util._
33+
34+
/**
35+
* :: Experimental ::
36+
* A [[TaskContext]] with extra contextual info and tooling for tasks in a barrier stage.
37+
* Use [[BarrierTaskContext#get]] to obtain the barrier context for a running barrier task.
38+
*/
39+
@Experimental
40+
@Since("2.4.0")
41+
class BarrierTaskContext private[spark] (
42+
taskContext: TaskContext) extends TaskContext with Logging {
4643

4744
// Find the driver side RPCEndpointRef of the coordinator that handles all the barrier() calls.
4845
private val barrierCoordinator: RpcEndpointRef = {
@@ -68,7 +65,7 @@ class BarrierTaskContext(
6865
*
6966
* CAUTION! In a barrier stage, each task must have the same number of barrier() calls, in all
7067
* possible code branches. Otherwise, you may get the job hanging or a SparkException after
71-
* timeout. Some examples of misuses listed below:
68+
* timeout. Some examples of '''misuses''' are listed below:
7269
* 1. Only call barrier() function on a subset of all the tasks in the same barrier stage, it
7370
* shall lead to timeout of the function call.
7471
* {{{
@@ -146,20 +143,95 @@ class BarrierTaskContext(
146143

147144
/**
148145
* :: Experimental ::
149-
* Returns the all task infos in this barrier stage, the task infos are ordered by partitionId.
146+
* Returns [[BarrierTaskInfo]] for all tasks in this barrier stage, ordered by partition ID.
150147
*/
151148
@Experimental
152149
@Since("2.4.0")
153150
def getTaskInfos(): Array[BarrierTaskInfo] = {
154-
val addressesStr = localProperties.getProperty("addresses", "")
151+
val addressesStr = Option(taskContext.getLocalProperty("addresses")).getOrElse("")
155152
addressesStr.split(",").map(_.trim()).map(new BarrierTaskInfo(_))
156153
}
154+
155+
// delegate methods
156+
157+
override def isCompleted(): Boolean = taskContext.isCompleted()
158+
159+
override def isInterrupted(): Boolean = taskContext.isInterrupted()
160+
161+
override def isRunningLocally(): Boolean = taskContext.isRunningLocally()
162+
163+
override def addTaskCompletionListener(listener: TaskCompletionListener): this.type = {
164+
taskContext.addTaskCompletionListener(listener)
165+
this
166+
}
167+
168+
override def addTaskFailureListener(listener: TaskFailureListener): this.type = {
169+
taskContext.addTaskFailureListener(listener)
170+
this
171+
}
172+
173+
override def stageId(): Int = taskContext.stageId()
174+
175+
override def stageAttemptNumber(): Int = taskContext.stageAttemptNumber()
176+
177+
override def partitionId(): Int = taskContext.partitionId()
178+
179+
override def attemptNumber(): Int = taskContext.attemptNumber()
180+
181+
override def taskAttemptId(): Long = taskContext.taskAttemptId()
182+
183+
override def getLocalProperty(key: String): String = taskContext.getLocalProperty(key)
184+
185+
override def taskMetrics(): TaskMetrics = taskContext.taskMetrics()
186+
187+
override def getMetricsSources(sourceName: String): Seq[Source] = {
188+
taskContext.getMetricsSources(sourceName)
189+
}
190+
191+
override private[spark] def killTaskIfInterrupted(): Unit = taskContext.killTaskIfInterrupted()
192+
193+
override private[spark] def getKillReason(): Option[String] = taskContext.getKillReason()
194+
195+
override private[spark] def taskMemoryManager(): TaskMemoryManager = {
196+
taskContext.taskMemoryManager()
197+
}
198+
199+
override private[spark] def registerAccumulator(a: AccumulatorV2[_, _]): Unit = {
200+
taskContext.registerAccumulator(a)
201+
}
202+
203+
override private[spark] def setFetchFailed(fetchFailed: FetchFailedException): Unit = {
204+
taskContext.setFetchFailed(fetchFailed)
205+
}
206+
207+
override private[spark] def markInterrupted(reason: String): Unit = {
208+
taskContext.markInterrupted(reason)
209+
}
210+
211+
override private[spark] def markTaskFailed(error: Throwable): Unit = {
212+
taskContext.markTaskFailed(error)
213+
}
214+
215+
override private[spark] def markTaskCompleted(error: Option[Throwable]): Unit = {
216+
taskContext.markTaskCompleted(error)
217+
}
218+
219+
override private[spark] def fetchFailed: Option[FetchFailedException] = {
220+
taskContext.fetchFailed
221+
}
222+
223+
override private[spark] def getLocalProperties: Properties = taskContext.getLocalProperties
157224
}
158225

226+
@Experimental
227+
@Since("2.4.0")
159228
object BarrierTaskContext {
160229
/**
161-
* Return the currently active BarrierTaskContext. This can be called inside of user functions to
230+
* :: Experimental ::
231+
* Returns the currently active BarrierTaskContext. This can be called inside of user functions to
162232
* access contextual information about running barrier tasks.
163233
*/
234+
@Experimental
235+
@Since("2.4.0")
164236
def get(): BarrierTaskContext = TaskContext.get().asInstanceOf[BarrierTaskContext]
165237
}

core/src/main/scala/org/apache/spark/BarrierTaskInfo.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,4 +28,4 @@ import org.apache.spark.annotation.{Experimental, Since}
2828
*/
2929
@Experimental
3030
@Since("2.4.0")
31-
class BarrierTaskInfo(val address: String)
31+
class BarrierTaskInfo private[spark] (val address: String)

core/src/main/scala/org/apache/spark/TaskContext.scala

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -221,4 +221,18 @@ abstract class TaskContext extends Serializable {
221221
*/
222222
private[spark] def setFetchFailed(fetchFailed: FetchFailedException): Unit
223223

224+
/** Marks the task for interruption, i.e. cancellation. */
225+
private[spark] def markInterrupted(reason: String): Unit
226+
227+
/** Marks the task as failed and triggers the failure listeners. */
228+
private[spark] def markTaskFailed(error: Throwable): Unit
229+
230+
/** Marks the task as completed and triggers the completion listeners. */
231+
private[spark] def markTaskCompleted(error: Option[Throwable]): Unit
232+
233+
/** Optionally returns the stored fetch failure in the task. */
234+
private[spark] def fetchFailed: Option[FetchFailedException]
235+
236+
/** Gets local properties set upstream in the driver. */
237+
private[spark] def getLocalProperties: Properties
224238
}

core/src/main/scala/org/apache/spark/TaskContextImpl.scala

Lines changed: 6 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ import org.apache.spark.metrics.source.Source
3030
import org.apache.spark.shuffle.FetchFailedException
3131
import org.apache.spark.util._
3232

33+
3334
/**
3435
* A [[TaskContext]] implementation.
3536
*
@@ -98,9 +99,8 @@ private[spark] class TaskContextImpl(
9899
this
99100
}
100101

101-
/** Marks the task as failed and triggers the failure listeners. */
102102
@GuardedBy("this")
103-
private[spark] def markTaskFailed(error: Throwable): Unit = synchronized {
103+
private[spark] override def markTaskFailed(error: Throwable): Unit = synchronized {
104104
if (failed) return
105105
failed = true
106106
failure = error
@@ -109,9 +109,8 @@ private[spark] class TaskContextImpl(
109109
}
110110
}
111111

112-
/** Marks the task as completed and triggers the completion listeners. */
113112
@GuardedBy("this")
114-
private[spark] def markTaskCompleted(error: Option[Throwable]): Unit = synchronized {
113+
private[spark] override def markTaskCompleted(error: Option[Throwable]): Unit = synchronized {
115114
if (completed) return
116115
completed = true
117116
invokeListeners(onCompleteCallbacks, "TaskCompletionListener", error) {
@@ -140,8 +139,7 @@ private[spark] class TaskContextImpl(
140139
}
141140
}
142141

143-
/** Marks the task for interruption, i.e. cancellation. */
144-
private[spark] def markInterrupted(reason: String): Unit = {
142+
private[spark] override def markInterrupted(reason: String): Unit = {
145143
reasonIfKilled = Some(reason)
146144
}
147145

@@ -176,8 +174,7 @@ private[spark] class TaskContextImpl(
176174
this._fetchFailedException = Option(fetchFailed)
177175
}
178176

179-
private[spark] def fetchFailed: Option[FetchFailedException] = _fetchFailedException
177+
private[spark] override def fetchFailed: Option[FetchFailedException] = _fetchFailedException
180178

181-
// TODO: shall we publish it and define it in `TaskContext`?
182-
private[spark] def getLocalProperties(): Properties = localProperties
179+
private[spark] override def getLocalProperties(): Properties = localProperties
183180
}

core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -270,7 +270,7 @@ private[spark] abstract class BasePythonRunner[IN, OUT](
270270
dataOut.writeInt(context.partitionId())
271271
dataOut.writeInt(context.attemptNumber())
272272
dataOut.writeLong(context.taskAttemptId())
273-
val localProps = context.asInstanceOf[TaskContextImpl].getLocalProperties.asScala
273+
val localProps = context.getLocalProperties.asScala
274274
dataOut.writeInt(localProps.size)
275275
localProps.foreach { case (k, v) =>
276276
PythonRDD.writeUTF(k, dataOut)

core/src/main/scala/org/apache/spark/rdd/RDD.scala

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1649,7 +1649,15 @@ abstract class RDD[T: ClassTag](
16491649

16501650
/**
16511651
* :: Experimental ::
1652-
* Indicates that Spark must launch the tasks together for the current stage.
1652+
* Marks the current stage as a barrier stage, where Spark must launch all tasks together.
1653+
* In case of a task failure, instead of only restarting the failed task, Spark will abort the
1654+
* entire stage and re-launch all tasks for this stage.
1655+
* The barrier execution mode feature is experimental and it only handles limited scenarios.
1656+
* Please read the linked SPIP and design docs to understand the limitations and future plans.
1657+
* @return an [[RDDBarrier]] instance that provides actions within a barrier stage
1658+
* @see [[org.apache.spark.BarrierTaskContext]]
1659+
* @see <a href="https://jira.apache.org/jira/browse/SPARK-24374">SPIP: Barrier Execution Mode</a>
1660+
* @see <a href="https://jira.apache.org/jira/browse/SPARK-24582">Design Doc</a>
16531661
*/
16541662
@Experimental
16551663
@Since("2.4.0")

core/src/main/scala/org/apache/spark/rdd/RDDBarrier.scala

Lines changed: 15 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -22,15 +22,23 @@ import scala.reflect.ClassTag
2222
import org.apache.spark.TaskContext
2323
import org.apache.spark.annotation.{Experimental, Since}
2424

25-
/** Represents an RDD barrier, which forces Spark to launch tasks of this stage together. */
26-
class RDDBarrier[T: ClassTag](rdd: RDD[T]) {
25+
/**
26+
* :: Experimental ::
27+
* Wraps an RDD in a barrier stage, which forces Spark to launch tasks of this stage together.
28+
* [[org.apache.spark.rdd.RDDBarrier]] instances are created by
29+
* [[org.apache.spark.rdd.RDD#barrier]].
30+
*/
31+
@Experimental
32+
@Since("2.4.0")
33+
class RDDBarrier[T: ClassTag] private[spark] (rdd: RDD[T]) {
2734

2835
/**
2936
* :: Experimental ::
30-
* Generate a new barrier RDD by applying a function to each partitions of the prev RDD.
31-
*
32-
* `preservesPartitioning` indicates whether the input function preserves the partitioner, which
33-
* should be `false` unless `rdd` is a pair RDD and the input function doesn't modify the keys.
37+
* Returns a new RDD by applying a function to each partition of the wrapped RDD,
38+
* where tasks are launched together in a barrier stage.
39+
* The interface is the same as [[org.apache.spark.rdd.RDD#mapPartitions]].
40+
* Please see the API doc there.
41+
* @see [[org.apache.spark.BarrierTaskContext]]
3442
*/
3543
@Experimental
3644
@Since("2.4.0")
@@ -46,5 +54,5 @@ class RDDBarrier[T: ClassTag](rdd: RDD[T]) {
4654
)
4755
}
4856

49-
/** TODO extra conf(e.g. timeout) */
57+
// TODO: [SPARK-25247] add extra conf to RDDBarrier, e.g., timeout.
5058
}

core/src/main/scala/org/apache/spark/scheduler/Task.scala

Lines changed: 14 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -82,28 +82,21 @@ private[spark] abstract class Task[T](
8282
SparkEnv.get.blockManager.registerTask(taskAttemptId)
8383
// TODO SPARK-24874 Allow create BarrierTaskContext based on partitions, instead of whether
8484
// the stage is barrier.
85+
val taskContext = new TaskContextImpl(
86+
stageId,
87+
stageAttemptId, // stageAttemptId and stageAttemptNumber are semantically equal
88+
partitionId,
89+
taskAttemptId,
90+
attemptNumber,
91+
taskMemoryManager,
92+
localProperties,
93+
metricsSystem,
94+
metrics)
95+
8596
context = if (isBarrier) {
86-
new BarrierTaskContext(
87-
stageId,
88-
stageAttemptId, // stageAttemptId and stageAttemptNumber are semantically equal
89-
partitionId,
90-
taskAttemptId,
91-
attemptNumber,
92-
taskMemoryManager,
93-
localProperties,
94-
metricsSystem,
95-
metrics)
97+
new BarrierTaskContext(taskContext)
9698
} else {
97-
new TaskContextImpl(
98-
stageId,
99-
stageAttemptId, // stageAttemptId and stageAttemptNumber are semantically equal
100-
partitionId,
101-
taskAttemptId,
102-
attemptNumber,
103-
taskMemoryManager,
104-
localProperties,
105-
metricsSystem,
106-
metrics)
99+
taskContext
107100
}
108101

109102
TaskContext.setTaskContext(context)
@@ -180,7 +173,7 @@ private[spark] abstract class Task[T](
180173
var epoch: Long = -1
181174

182175
// Task context, to be initialized in run().
183-
@transient var context: TaskContextImpl = _
176+
@transient var context: TaskContext = _
184177

185178
// The actual Thread on which the task is running, if any. Initialized in run().
186179
@volatile @transient private var taskThread: Thread = _

core/src/main/scala/org/apache/spark/util/Utils.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1387,7 +1387,7 @@ private[spark] object Utils extends Logging {
13871387
originalThrowable = cause
13881388
try {
13891389
logError("Aborting task", originalThrowable)
1390-
TaskContext.get().asInstanceOf[TaskContextImpl].markTaskFailed(originalThrowable)
1390+
TaskContext.get().markTaskFailed(originalThrowable)
13911391
catchBlock
13921392
} catch {
13931393
case t: Throwable =>

project/MimaExcludes.scala

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,13 @@ object MimaExcludes {
3636

3737
// Exclude rules for 2.4.x
3838
lazy val v24excludes = v23excludes ++ Seq(
39+
// [SPARK-25248] add package private methods to TaskContext
40+
ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.TaskContext.markTaskFailed"),
41+
ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.TaskContext.markInterrupted"),
42+
ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.TaskContext.fetchFailed"),
43+
ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.TaskContext.markTaskCompleted"),
44+
ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.TaskContext.getLocalProperties"),
45+
3946
// [SPARK-10697][ML] Add lift to Association rules
4047
ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.fpm.FPGrowthModel.this"),
4148
ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.mllib.fpm.AssociationRules#Rule.this"),

0 commit comments

Comments
 (0)