Skip to content

Commit 922adad

Browse files
ivosoncloud-fan
authored andcommitted
[SPARK-53575][CORE] Retry entire consumer stages when checksum mismatch detected for a retried shuffle map task
### What changes were proposed in this pull request? This PR proposes to retry all tasks of the consumer stages, when checksum mismatches are detected on their producer stages. In the case that we can't rollback and retry all tasks of a consumer stage, we will have to abort the stage (thus the job). How do we detect and handle nondeterministic before: - Stages are labeled as indeterminate at planning time, prior to query execution - When a task completes and `FetchFailed` is detected, we will abort all unrollbackable succeeding stages of the map stage, and resubmit failed stages. - In `submitMissingTasks()`, if a stage itself is isIndeterminate, we will call `unregisterAllMapAndMergeOutput()` and retry all tasks for stage. How do we detect and handle nondeterministic now: - During query execution, we keep track on the checksums produced by each map task. - When a task completes and checksum mismatch is detected, we will abort unrollbackable succeeding stages of the stage with checksum mismatches. The failed stages resubmission still happen in the same places as before. - In `submitMissingTasks()`, if the parent of a stage has checksum mismatches, we will call `unregisterAllMapAndMergeOutput()` and retry all tasks for stage. Note that (1) if a stage `isReliablyCheckpointed`, the consumer stages don't need to have whole stage retry, and (2) when mismatches are detected for a stage in a chain (e.g., the first stage in stage_i -> stage_i+1 -> stage_i+2 -> ...), the direct consumer (e.g., stage_i+1) of the stage will have a whole stage retry, and an indirect consumer (e.g., stage_i+2) will have a whole stage retry when its parent detects checksum mismatches. ### Why are the changes needed? Handle nondeterministic issues caused by the retry of shuffle map task. ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? UTs added. ### Was this patch authored or co-authored using generative AI tooling? No Closes #52336 from ivoson/SPARK-53575. Authored-by: Tengfei Huang <[email protected]> Signed-off-by: Wenchen Fan <[email protected]>
1 parent 573c7da commit 922adad

File tree

9 files changed

+407
-65
lines changed

9 files changed

+407
-65
lines changed

core/src/main/scala/org/apache/spark/Dependency.scala

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,8 @@ class ShuffleDependency[K: ClassTag, V: ClassTag, C: ClassTag](
8989
val aggregator: Option[Aggregator[K, V, C]] = None,
9090
val mapSideCombine: Boolean = false,
9191
val shuffleWriterProcessor: ShuffleWriteProcessor = new ShuffleWriteProcessor,
92-
val rowBasedChecksums: Array[RowBasedChecksum] = ShuffleDependency.EMPTY_ROW_BASED_CHECKSUMS)
92+
val rowBasedChecksums: Array[RowBasedChecksum] = ShuffleDependency.EMPTY_ROW_BASED_CHECKSUMS,
93+
val checksumMismatchFullRetryEnabled: Boolean = false)
9394
extends Dependency[Product2[K, V]] with Logging {
9495

9596
def this(

core/src/main/scala/org/apache/spark/MapOutputTracker.scala

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -165,9 +165,11 @@ private class ShuffleStatus(
165165

166166
/**
167167
* Register a map output. If there is already a registered location for the map output then it
168-
* will be replaced by the new location.
168+
* will be replaced by the new location. Returns true if the checksum in the new MapStatus is
169+
* different from a previous registered MapStatus. Otherwise, returns false.
169170
*/
170-
def addMapOutput(mapIndex: Int, status: MapStatus): Unit = withWriteLock {
171+
def addMapOutput(mapIndex: Int, status: MapStatus): Boolean = withWriteLock {
172+
var isChecksumMismatch: Boolean = false
171173
val currentMapStatus = mapStatuses(mapIndex)
172174
if (currentMapStatus == null) {
173175
_numAvailableMapOutputs += 1
@@ -183,9 +185,11 @@ private class ShuffleStatus(
183185
logInfo(s"Checksum of map output changes from ${preStatus.checksumValue} to " +
184186
s"${status.checksumValue} for task ${status.mapId}.")
185187
checksumMismatchIndices.add(mapIndex)
188+
isChecksumMismatch = true
186189
}
187190
mapStatuses(mapIndex) = status
188191
mapIdToMapIndex(status.mapId) = mapIndex
192+
isChecksumMismatch
189193
}
190194

191195
/**
@@ -853,7 +857,7 @@ private[spark] class MapOutputTrackerMaster(
853857
}
854858
}
855859

856-
def registerMapOutput(shuffleId: Int, mapIndex: Int, status: MapStatus): Unit = {
860+
def registerMapOutput(shuffleId: Int, mapIndex: Int, status: MapStatus): Boolean = {
857861
shuffleStatuses(shuffleId).addMapOutput(mapIndex, status)
858862
}
859863

core/src/main/scala/org/apache/spark/rdd/RDD.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1773,7 +1773,7 @@ abstract class RDD[T: ClassTag](
17731773
/**
17741774
* Return whether this RDD is reliably checkpointed and materialized.
17751775
*/
1776-
private[rdd] def isReliablyCheckpointed: Boolean = {
1776+
private[spark] def isReliablyCheckpointed: Boolean = {
17771777
checkpointData match {
17781778
case Some(reliable: ReliableRDDCheckpointData[_]) if reliable.isCheckpointed => true
17791779
case _ => false

core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala

Lines changed: 73 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -1551,29 +1551,46 @@ private[spark] class DAGScheduler(
15511551
// The operation here can make sure for the partially completed intermediate stage,
15521552
// `findMissingPartitions()` returns all partitions every time.
15531553
stage match {
1554-
case sms: ShuffleMapStage if stage.isIndeterminate && !sms.isAvailable =>
1555-
// already executed at least once
1556-
if (sms.getNextAttemptId > 0) {
1557-
// While we previously validated possible rollbacks during the handling of a FetchFailure,
1558-
// where we were fetching from an indeterminate source map stages, this later check
1559-
// covers additional cases like recalculating an indeterminate stage after an executor
1560-
// loss. Moreover, because this check occurs later in the process, if a result stage task
1561-
// has successfully completed, we can detect this and abort the job, as rolling back a
1562-
// result stage is not possible.
1563-
val stagesToRollback = collectSucceedingStages(sms)
1564-
abortStageWithInvalidRollBack(stagesToRollback)
1565-
// stages which cannot be rolled back were aborted which leads to removing the
1566-
// the dependant job(s) from the active jobs set
1567-
val numActiveJobsWithStageAfterRollback =
1568-
activeJobs.count(job => stagesToRollback.contains(job.finalStage))
1569-
if (numActiveJobsWithStageAfterRollback == 0) {
1570-
logInfo(log"All jobs depending on the indeterminate stage " +
1571-
log"(${MDC(STAGE_ID, stage.id)}) were aborted so this stage is not needed anymore.")
1572-
return
1554+
case sms: ShuffleMapStage if !sms.isAvailable =>
1555+
val needFullStageRetry = if (sms.shuffleDep.checksumMismatchFullRetryEnabled) {
1556+
// When the parents of this stage are indeterminate (e.g., some parents are not
1557+
// checkpointed and checksum mismatches are detected), the output data of the parents
1558+
// may have changed due to task retries. For correctness reason, we need to
1559+
// retry all tasks of the current stage. The legacy way of using current stage's
1560+
// deterministic level to trigger full stage retry is not accurate.
1561+
stage.isParentIndeterminate
1562+
} else {
1563+
if (stage.isIndeterminate) {
1564+
// already executed at least once
1565+
if (sms.getNextAttemptId > 0) {
1566+
// While we previously validated possible rollbacks during the handling of a FetchFailure,
1567+
// where we were fetching from an indeterminate source map stages, this later check
1568+
// covers additional cases like recalculating an indeterminate stage after an executor
1569+
// loss. Moreover, because this check occurs later in the process, if a result stage task
1570+
// has successfully completed, we can detect this and abort the job, as rolling back a
1571+
// result stage is not possible.
1572+
val stagesToRollback = collectSucceedingStages(sms)
1573+
abortStageWithInvalidRollBack(stagesToRollback)
1574+
// stages which cannot be rolled back were aborted which leads to removing the
1575+
// the dependant job(s) from the active jobs set
1576+
val numActiveJobsWithStageAfterRollback =
1577+
activeJobs.count(job => stagesToRollback.contains(job.finalStage))
1578+
if (numActiveJobsWithStageAfterRollback == 0) {
1579+
logInfo(log"All jobs depending on the indeterminate stage " +
1580+
log"(${MDC(STAGE_ID, stage.id)}) were aborted so this stage is not needed anymore.")
1581+
return
1582+
}
1583+
}
1584+
true
1585+
} else {
1586+
false
15731587
}
15741588
}
1575-
mapOutputTracker.unregisterAllMapAndMergeOutput(sms.shuffleDep.shuffleId)
1576-
sms.shuffleDep.newShuffleMergeState()
1589+
1590+
if (needFullStageRetry) {
1591+
mapOutputTracker.unregisterAllMapAndMergeOutput(sms.shuffleDep.shuffleId)
1592+
sms.shuffleDep.newShuffleMergeState()
1593+
}
15771594
case _ =>
15781595
}
15791596

@@ -1886,6 +1903,20 @@ private[spark] class DAGScheduler(
18861903
}
18871904
}
18881905

1906+
/**
1907+
* If a map stage is non-deterministic, the map tasks of the stage may return different result
1908+
* when re-try. To make sure data correctness, we need to re-try all the tasks of its succeeding
1909+
* stages, as the input data may be changed after the map tasks are re-tried. For stages where
1910+
* rollback and retry all tasks are not possible, we will need to abort the stages.
1911+
*/
1912+
private[scheduler] def abortUnrollbackableStages(mapStage: ShuffleMapStage): Unit = {
1913+
val stagesToRollback = collectSucceedingStages(mapStage)
1914+
val rollingBackStages = abortStageWithInvalidRollBack(stagesToRollback)
1915+
logInfo(log"The shuffle map stage ${MDC(SHUFFLE_ID, mapStage)} with indeterminate output " +
1916+
log"was failed, we will roll back and rerun below stages which include itself and all its " +
1917+
log"indeterminate child stages: ${MDC(STAGES, rollingBackStages)}")
1918+
}
1919+
18891920
/**
18901921
* Responds to a task finishing. This is called inside the event loop so it assumes that it can
18911922
* modify the scheduler's internal state. Use taskEnded() to post a task end event from outside.
@@ -2022,8 +2053,26 @@ private[spark] class DAGScheduler(
20222053
// The epoch of the task is acceptable (i.e., the task was launched after the most
20232054
// recent failure we're aware of for the executor), so mark the task's output as
20242055
// available.
2025-
mapOutputTracker.registerMapOutput(
2056+
val isChecksumMismatched = mapOutputTracker.registerMapOutput(
20262057
shuffleStage.shuffleDep.shuffleId, smt.partitionId, status)
2058+
if (isChecksumMismatched) {
2059+
shuffleStage.isChecksumMismatched = isChecksumMismatched
2060+
// There could be multiple checksum mismatches detected for a single stage attempt.
2061+
// We check for stage abortion once and only once when we first detect checksum
2062+
// mismatch for each stage attempt. For example, assume that we have
2063+
// stage1 -> stage2, and we encounter checksum mismatch during the retry of stage1.
2064+
// In this case, we need to call abortUnrollbackableStages() for the succeeding
2065+
// stages. Assume that when stage2 is retried, some tasks finish and some tasks
2066+
// failed again with FetchFailed. In case that we encounter checksum mismatch again
2067+
// during the retry of stage1, we need to call abortUnrollbackableStages() again.
2068+
if (shuffleStage.maxChecksumMismatchedId < smt.stageAttemptId) {
2069+
shuffleStage.maxChecksumMismatchedId = smt.stageAttemptId
2070+
if (shuffleStage.shuffleDep.checksumMismatchFullRetryEnabled
2071+
&& shuffleStage.isStageIndeterminate) {
2072+
abortUnrollbackableStages(shuffleStage)
2073+
}
2074+
}
2075+
}
20272076
}
20282077
} else {
20292078
logInfo(log"Ignoring ${MDC(TASK_NAME, smt)} completion from an older attempt of indeterminate stage")
@@ -2148,12 +2197,8 @@ private[spark] class DAGScheduler(
21482197
// Note that, if map stage is UNORDERED, we are fine. The shuffle partitioner is
21492198
// guaranteed to be determinate, so the input data of the reducers will not change
21502199
// even if the map tasks are re-tried.
2151-
if (mapStage.isIndeterminate) {
2152-
val stagesToRollback = collectSucceedingStages(mapStage)
2153-
val rollingBackStages = abortStageWithInvalidRollBack(stagesToRollback)
2154-
logInfo(log"The shuffle map stage ${MDC(SHUFFLE_ID, mapStage)} with indeterminate output was failed, " +
2155-
log"we will roll back and rerun below stages which include itself and all its " +
2156-
log"indeterminate child stages: ${MDC(STAGES, rollingBackStages)}")
2200+
if (mapStage.isIndeterminate && !mapStage.shuffleDep.checksumMismatchFullRetryEnabled) {
2201+
abortUnrollbackableStages(mapStage)
21572202
}
21582203

21592204
// We expect one executor failure to trigger many FetchFailures in rapid succession,

core/src/main/scala/org/apache/spark/scheduler/Stage.scala

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,18 @@ private[scheduler] abstract class Stage(
7272
private var nextAttemptId: Int = 0
7373
private[scheduler] def getNextAttemptId: Int = nextAttemptId
7474

75+
/**
76+
* Whether checksum mismatches have been detected across different attempt of the stage, where
77+
* checksum mismatches typically indicates that different stage attempts have produced different
78+
* data.
79+
*/
80+
private[scheduler] var isChecksumMismatched: Boolean = false
81+
82+
/**
83+
* The maximum of task attempt id where checksum mismatches are detected.
84+
*/
85+
private[scheduler] var maxChecksumMismatchedId: Int = nextAttemptId
86+
7587
val name: String = callSite.shortForm
7688
val details: String = callSite.longForm
7789

@@ -131,4 +143,14 @@ private[scheduler] abstract class Stage(
131143
def isIndeterminate: Boolean = {
132144
rdd.outputDeterministicLevel == DeterministicLevel.INDETERMINATE
133145
}
146+
147+
// Returns true if any parents of this stage are indeterminate.
148+
def isParentIndeterminate: Boolean = {
149+
parents.exists(_.isStageIndeterminate)
150+
}
151+
152+
// Returns true if the stage itself is indeterminate.
153+
def isStageIndeterminate: Boolean = {
154+
!rdd.isReliablyCheckpointed && isChecksumMismatched
155+
}
134156
}

0 commit comments

Comments
 (0)