Skip to content

Commit e3486e1

Browse files
jiangxb1987gatorsmile
authored andcommitted
[SPARK-24795][CORE] Implement barrier execution mode
## What changes were proposed in this pull request? Propose new APIs and modify job/task scheduling to support barrier execution mode, which requires all tasks in a same barrier stage start at the same time, and retry all tasks in case some tasks fail in the middle. The barrier execution mode is useful for some ML/DL workloads. The proposed API changes include: - `RDDBarrier` that marks an RDD as barrier (Spark must launch all the tasks together for the current stage). - `BarrierTaskContext` that support global sync of all tasks in a barrier stage, and provide extra `BarrierTaskInfo`s. In DAGScheduler, we retry all tasks of a barrier stage in case some tasks fail in the middle, this is achieved by unregistering map outputs for a shuffleId (for ShuffleMapStage) or clear the finished partitions in an active job (for ResultStage). ## How was this patch tested? Add `RDDBarrierSuite` to ensure we convert RDDs correctly; Add new test cases in `DAGSchedulerSuite` to ensure we do task scheduling correctly; Add new test cases in `SparkContextSuite` to ensure the barrier execution mode actually works (both under local mode and local cluster mode). Add new test cases in `TaskSchedulerImplSuite` to ensure we schedule tasks for barrier taskSet together. Author: Xingbo Jiang <[email protected]> Closes apache#21758 from jiangxb1987/barrier-execution-mode.
1 parent 5ed7660 commit e3486e1

28 files changed

+673
-64
lines changed
Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark
19+
20+
import org.apache.spark.annotation.{Experimental, Since}
21+
22+
/** A [[TaskContext]] with extra info and tooling for a barrier stage. */
23+
trait BarrierTaskContext extends TaskContext {
24+
25+
/**
26+
* :: Experimental ::
27+
* Sets a global barrier and waits until all tasks in this stage hit this barrier. Similar to
28+
* MPI_Barrier function in MPI, the barrier() function call blocks until all tasks in the same
29+
* stage have reached this routine.
30+
*/
31+
@Experimental
32+
@Since("2.4.0")
33+
def barrier(): Unit
34+
35+
/**
36+
* :: Experimental ::
37+
* Returns the all task infos in this barrier stage, the task infos are ordered by partitionId.
38+
*/
39+
@Experimental
40+
@Since("2.4.0")
41+
def getTaskInfos(): Array[BarrierTaskInfo]
42+
}
Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark
19+
20+
import java.util.Properties
21+
22+
import org.apache.spark.executor.TaskMetrics
23+
import org.apache.spark.memory.TaskMemoryManager
24+
import org.apache.spark.metrics.MetricsSystem
25+
26+
/** A [[BarrierTaskContext]] implementation. */
27+
private[spark] class BarrierTaskContextImpl(
28+
override val stageId: Int,
29+
override val stageAttemptNumber: Int,
30+
override val partitionId: Int,
31+
override val taskAttemptId: Long,
32+
override val attemptNumber: Int,
33+
override val taskMemoryManager: TaskMemoryManager,
34+
localProperties: Properties,
35+
@transient private val metricsSystem: MetricsSystem,
36+
// The default value is only used in tests.
37+
override val taskMetrics: TaskMetrics = TaskMetrics.empty)
38+
extends TaskContextImpl(stageId, stageAttemptNumber, partitionId, taskAttemptId, attemptNumber,
39+
taskMemoryManager, localProperties, metricsSystem, taskMetrics)
40+
with BarrierTaskContext {
41+
42+
// TODO SPARK-24817 implement global barrier.
43+
override def barrier(): Unit = {}
44+
45+
override def getTaskInfos(): Array[BarrierTaskInfo] = {
46+
val addressesStr = localProperties.getProperty("addresses", "")
47+
addressesStr.split(",").map(_.trim()).map(new BarrierTaskInfo(_))
48+
}
49+
}
Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark
19+
20+
import org.apache.spark.annotation.{Experimental, Since}
21+
22+
23+
/**
24+
* :: Experimental ::
25+
* Carries all task infos of a barrier task.
26+
*
27+
* @param address the IPv4 address(host:port) of the executor that a barrier task is running on
28+
*/
29+
@Experimental
30+
@Since("2.4.0")
31+
class BarrierTaskInfo(val address: String)

core/src/main/scala/org/apache/spark/MapOutputTracker.scala

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -434,6 +434,18 @@ private[spark] class MapOutputTrackerMaster(
434434
}
435435
}
436436

437+
/** Unregister all map output information of the given shuffle. */
438+
def unregisterAllMapOutput(shuffleId: Int) {
439+
shuffleStatuses.get(shuffleId) match {
440+
case Some(shuffleStatus) =>
441+
shuffleStatus.removeOutputsByFilter(x => true)
442+
incrementEpoch()
443+
case None =>
444+
throw new SparkException(
445+
s"unregisterAllMapOutput called for nonexistent shuffle ID $shuffleId.")
446+
}
447+
}
448+
437449
/** Unregister shuffle data */
438450
def unregisterShuffle(shuffleId: Int) {
439451
shuffleStatuses.remove(shuffleId).foreach { shuffleStatus =>

core/src/main/scala/org/apache/spark/rdd/MapPartitionsRDD.scala

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,11 +23,21 @@ import org.apache.spark.{Partition, TaskContext}
2323

2424
/**
2525
* An RDD that applies the provided function to every partition of the parent RDD.
26+
*
27+
* @param prev the parent RDD.
28+
* @param f The function used to map a tuple of (TaskContext, partition index, input iterator) to
29+
* an output iterator.
30+
* @param preservesPartitioning Whether the input function preserves the partitioner, which should
31+
* be `false` unless `prev` is a pair RDD and the input function
32+
* doesn't modify the keys.
33+
* @param isFromBarrier Indicates whether this RDD is transformed from an RDDBarrier, a stage
34+
* containing at least one RDDBarrier shall be turned into a barrier stage.
2635
*/
2736
private[spark] class MapPartitionsRDD[U: ClassTag, T: ClassTag](
2837
var prev: RDD[T],
2938
f: (TaskContext, Int, Iterator[T]) => Iterator[U], // (TaskContext, partition index, iterator)
30-
preservesPartitioning: Boolean = false)
39+
preservesPartitioning: Boolean = false,
40+
isFromBarrier: Boolean = false)
3141
extends RDD[U](prev) {
3242

3343
override val partitioner = if (preservesPartitioning) firstParent[T].partitioner else None
@@ -41,4 +51,7 @@ private[spark] class MapPartitionsRDD[U: ClassTag, T: ClassTag](
4151
super.clearDependencies()
4252
prev = null
4353
}
54+
55+
@transient protected lazy override val isBarrier_ : Boolean =
56+
isFromBarrier || dependencies.exists(_.rdd.isBarrier())
4457
}

core/src/main/scala/org/apache/spark/rdd/RDD.scala

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ import org.apache.hadoop.mapred.TextOutputFormat
3333

3434
import org.apache.spark._
3535
import org.apache.spark.Partitioner._
36-
import org.apache.spark.annotation.{DeveloperApi, Since}
36+
import org.apache.spark.annotation.{DeveloperApi, Experimental, Since}
3737
import org.apache.spark.api.java.JavaRDD
3838
import org.apache.spark.internal.Logging
3939
import org.apache.spark.partial.BoundedDouble
@@ -1647,6 +1647,14 @@ abstract class RDD[T: ClassTag](
16471647
}
16481648
}
16491649

1650+
/**
1651+
* :: Experimental ::
1652+
* Indicates that Spark must launch the tasks together for the current stage.
1653+
*/
1654+
@Experimental
1655+
@Since("2.4.0")
1656+
def barrier(): RDDBarrier[T] = withScope(new RDDBarrier[T](this))
1657+
16501658
// =======================================================================
16511659
// Other internal methods and fields
16521660
// =======================================================================
@@ -1839,6 +1847,23 @@ abstract class RDD[T: ClassTag](
18391847
def toJavaRDD() : JavaRDD[T] = {
18401848
new JavaRDD(this)(elementClassTag)
18411849
}
1850+
1851+
/**
1852+
* Whether the RDD is in a barrier stage. Spark must launch all the tasks at the same time for a
1853+
* barrier stage.
1854+
*
1855+
* An RDD is in a barrier stage, if at least one of its parent RDD(s), or itself, are mapped from
1856+
* an [[RDDBarrier]]. This function always returns false for a [[ShuffledRDD]], since a
1857+
* [[ShuffledRDD]] indicates start of a new stage.
1858+
*
1859+
* A [[MapPartitionsRDD]] can be transformed from an [[RDDBarrier]], under that case the
1860+
* [[MapPartitionsRDD]] shall be marked as barrier.
1861+
*/
1862+
private[spark] def isBarrier(): Boolean = isBarrier_
1863+
1864+
// From performance concern, cache the value to avoid repeatedly compute `isBarrier()` on a long
1865+
// RDD chain.
1866+
@transient protected lazy val isBarrier_ : Boolean = dependencies.exists(_.rdd.isBarrier())
18421867
}
18431868

18441869

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.rdd
19+
20+
import scala.reflect.ClassTag
21+
22+
import org.apache.spark.BarrierTaskContext
23+
import org.apache.spark.TaskContext
24+
import org.apache.spark.annotation.{Experimental, Since}
25+
26+
/** Represents an RDD barrier, which forces Spark to launch tasks of this stage together. */
27+
class RDDBarrier[T: ClassTag](rdd: RDD[T]) {
28+
29+
/**
30+
* :: Experimental ::
31+
* Maps partitions together with a provided BarrierTaskContext.
32+
*
33+
* `preservesPartitioning` indicates whether the input function preserves the partitioner, which
34+
* should be `false` unless `rdd` is a pair RDD and the input function doesn't modify the keys.
35+
*/
36+
@Experimental
37+
@Since("2.4.0")
38+
def mapPartitions[S: ClassTag](
39+
f: (Iterator[T], BarrierTaskContext) => Iterator[S],
40+
preservesPartitioning: Boolean = false): RDD[S] = rdd.withScope {
41+
val cleanedF = rdd.sparkContext.clean(f)
42+
new MapPartitionsRDD(
43+
rdd,
44+
(context: TaskContext, index: Int, iter: Iterator[T]) =>
45+
cleanedF(iter, context.asInstanceOf[BarrierTaskContext]),
46+
preservesPartitioning,
47+
isFromBarrier = true
48+
)
49+
}
50+
51+
/** TODO extra conf(e.g. timeout) */
52+
}

core/src/main/scala/org/apache/spark/rdd/ShuffledRDD.scala

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -110,4 +110,6 @@ class ShuffledRDD[K: ClassTag, V: ClassTag, C: ClassTag](
110110
super.clearDependencies()
111111
prev = null
112112
}
113+
114+
private[spark] override def isBarrier(): Boolean = false
113115
}

core/src/main/scala/org/apache/spark/scheduler/ActiveJob.scala

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,4 +60,10 @@ private[spark] class ActiveJob(
6060
val finished = Array.fill[Boolean](numPartitions)(false)
6161

6262
var numFinished = 0
63+
64+
/** Resets the status of all partitions in this stage so they are marked as not finished. */
65+
def resetAllPartitions(): Unit = {
66+
(0 until numPartitions).foreach(finished.update(_, false))
67+
numFinished = 0
68+
}
6369
}

0 commit comments

Comments
 (0)