From 4cd559f17643697c1e53f27d9c54c091cfd62418 Mon Sep 17 00:00:00 2001 From: Jiexing Li Date: Mon, 10 Mar 2025 14:12:38 -0700 Subject: [PATCH 01/29] compute checksum for shuffle --- .../shuffle/checksum/RowBasedChecksum.scala | 109 ++++++++++++++++ .../sort/BypassMergeSortShuffleWriter.java | 30 ++++- .../shuffle/sort/UnsafeShuffleWriter.java | 21 +++- .../scala/org/apache/spark/Dependency.scala | 4 +- .../org/apache/spark/MapOutputTracker.scala | 13 +- .../spark/internal/config/package.scala | 15 +++ .../apache/spark/scheduler/MapStatus.scala | 50 ++++++-- .../shuffle/sort/SortShuffleWriter.scala | 18 ++- .../util/collection/ExternalSorter.scala | 23 +++- .../checksum/RowBasedChecksumSuite.scala | 116 ++++++++++++++++++ .../sort/UnsafeShuffleWriterSuite.java | 48 +++++++- .../apache/spark/MapOutputTrackerSuite.scala | 6 +- .../spark/MapStatusesSerDeserBenchmark.scala | 2 +- .../spark/scheduler/DAGSchedulerSuite.scala | 50 +++++++- .../shuffle/ShuffleChecksumTestHelper.scala | 10 ++ .../BypassMergeSortShuffleWriterSuite.scala | 60 ++++++++- .../shuffle/sort/SortShuffleWriterSuite.scala | 74 ++++++++++- .../exchange/ShuffleExchangeExec.scala | 12 +- .../spark/sql/MapStatusEndToEndSuite.scala | 60 +++++++++ 19 files changed, 677 insertions(+), 44 deletions(-) create mode 100644 core/src/main/java/org/apache/spark/shuffle/checksum/RowBasedChecksum.scala create mode 100644 core/src/test/java/org/apache/spark/shuffle/checksum/RowBasedChecksumSuite.scala create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/MapStatusEndToEndSuite.scala diff --git a/core/src/main/java/org/apache/spark/shuffle/checksum/RowBasedChecksum.scala b/core/src/main/java/org/apache/spark/shuffle/checksum/RowBasedChecksum.scala new file mode 100644 index 0000000000000..8d32126605a89 --- /dev/null +++ b/core/src/main/java/org/apache/spark/shuffle/checksum/RowBasedChecksum.scala @@ -0,0 +1,109 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.shuffle.checksum + +import java.io.{ByteArrayOutputStream, ObjectOutputStream} +import java.util.zip.Checksum + +import org.apache.spark.network.shuffle.checksum.ShuffleChecksumHelper + +/** + * A class for computing checksum for input (key, value) pairs. The checksum is independent of + * the order of the input (key, value) pairs. It is done by computing a checksum for each row + * first, and then computing the XOR for all the row checksums. + */ +abstract class RowBasedChecksum() extends Serializable { + private var checksumValue: Long = 0 + /** Returns the checksum value computed */ + def getValue: Long = checksumValue + + /** Updates the row-based checksum with the given (key, value) pair */ + def update(key: Any, value: Any): Unit = { + val rowChecksumValue = calculateRowChecksum(key, value) + checksumValue = checksumValue ^ rowChecksumValue + } + + /** Computes and returns the checksum value for the given (key, value) pair */ + protected def calculateRowChecksum(key: Any, value: Any): Long +} + +/** + * A Concrete implementation of RowBasedChecksum. The checksum for each row is + * computed by first converting the (key, value) pair to byte array using OutputStreams, + * and then computing the checksum for the byte array. + * + * @param checksumAlgorithm the algorithm used for computing checksum. + */ +class OutputStreamRowBasedChecksum(checksumAlgorithm: String) + extends RowBasedChecksum() { + + /** Subclass of ByteArrayOutputStream that exposes `buf` directly. */ + final private class MyByteArrayOutputStream(size: Int) + extends ByteArrayOutputStream(size) { + def getBuf: Array[Byte] = buf + } + + private val DEFAULT_INITIAL_SER_BUFFER_SIZE = 32 * 1024 + + @transient private lazy val serBuffer = + new MyByteArrayOutputStream(DEFAULT_INITIAL_SER_BUFFER_SIZE) + @transient private lazy val objOut = new ObjectOutputStream(serBuffer) + + @transient + protected lazy val checksum: Checksum = + ShuffleChecksumHelper.getChecksumByAlgorithm(checksumAlgorithm) + + override protected def calculateRowChecksum(key: Any, value: Any): Long = { + assert(checksum != null, "Checksum is null") + + // Converts the (key, value) pair into byte array. + objOut.reset() + serBuffer.reset() + objOut.writeObject((key, value)) + objOut.flush() + serBuffer.flush() + + // Computes and returns the checksum for the byte array. + checksum.reset() + checksum.update(serBuffer.getBuf, 0, serBuffer.size()) + checksum.getValue + } +} + +object RowBasedChecksum { + def createPartitionRowBasedChecksums( + numPartitions: Int, + checksumAlgorithm: String): Array[RowBasedChecksum] = { + val rowBasedChecksums: Array[RowBasedChecksum] = new Array[RowBasedChecksum](numPartitions) + for (i <- 0 until numPartitions) { + rowBasedChecksums(i) = new OutputStreamRowBasedChecksum(checksumAlgorithm) + } + rowBasedChecksums + } + + def getAggregatedChecksumValue(rowBasedChecksums: Array[RowBasedChecksum]): Long = { + val numPartitions: Int = if (rowBasedChecksums != null) rowBasedChecksums.length else 0 + var aggregatedChecksum: Long = 0 + if (numPartitions > 0) { + for (i <- 0 until numPartitions) { + aggregatedChecksum = aggregatedChecksum * 31 + rowBasedChecksums(i).getValue + } + } + return aggregatedChecksum + } +} diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java b/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java index 86f7d5143eff5..170df7d10d453 100644 --- a/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java +++ b/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java @@ -53,6 +53,7 @@ import org.apache.spark.scheduler.MapStatus$; import org.apache.spark.serializer.Serializer; import org.apache.spark.serializer.SerializerInstance; +import org.apache.spark.shuffle.checksum.RowBasedChecksum; import org.apache.spark.shuffle.ShuffleWriteMetricsReporter; import org.apache.spark.shuffle.ShuffleWriter; import org.apache.spark.storage.*; @@ -104,6 +105,14 @@ final class BypassMergeSortShuffleWriter private long[] partitionLengths; /** Checksum calculator for each partition. Empty when shuffle checksum disabled. */ private final Checksum[] partitionChecksums; + /** + * Checksum calculator for each partition. Different from the above Checksum, + * RowBasedChecksum is independent of the input row order, which is used to + * detect whether different task attempts of the same partition produce different + * output data or not. + */ + private final RowBasedChecksum[] rowBasedChecksums; + private final SparkConf conf; /** * Are we in the process of stopping? Because map tasks can call stop() with success = true @@ -132,6 +141,8 @@ final class BypassMergeSortShuffleWriter this.serializer = dep.serializer(); this.shuffleExecutorComponents = shuffleExecutorComponents; this.partitionChecksums = createPartitionChecksums(numPartitions, conf); + this.rowBasedChecksums = dep.rowBasedChecksums(); + this.conf = conf; } @Override @@ -144,7 +155,7 @@ public void write(Iterator> records) throws IOException { partitionLengths = mapOutputWriter.commitAllPartitions( ShuffleChecksumHelper.EMPTY_CHECKSUM_VALUE).getPartitionLengths(); mapStatus = MapStatus$.MODULE$.apply( - blockManager.shuffleServerId(), partitionLengths, mapId); + blockManager.shuffleServerId(), partitionLengths, mapId, getAggregatedChecksumValue()); return; } final SerializerInstance serInstance = serializer.newInstance(); @@ -171,7 +182,11 @@ public void write(Iterator> records) throws IOException { while (records.hasNext()) { final Product2 record = records.next(); final K key = record._1(); - partitionWriters[partitioner.getPartition(key)].write(key, record._2()); + final int partitionId = partitioner.getPartition(key); + partitionWriters[partitionId].write(key, record._2()); + if (rowBasedChecksums.length > 0) { + rowBasedChecksums[partitionId].update(key, record._2()); + } } for (int i = 0; i < numPartitions; i++) { @@ -182,7 +197,7 @@ public void write(Iterator> records) throws IOException { partitionLengths = writePartitionedData(mapOutputWriter); mapStatus = MapStatus$.MODULE$.apply( - blockManager.shuffleServerId(), partitionLengths, mapId); + blockManager.shuffleServerId(), partitionLengths, mapId, getAggregatedChecksumValue()); } catch (Exception e) { try { mapOutputWriter.abort(e); @@ -199,6 +214,15 @@ public long[] getPartitionLengths() { return partitionLengths; } + public RowBasedChecksum[] getRowBasedChecksums() { + return rowBasedChecksums; + } + + public long getAggregatedChecksumValue() { + final long checksum = RowBasedChecksum.getAggregatedChecksumValue(rowBasedChecksums); + return checksum; + } + /** * Concatenate all of the per-partition files into a single combined file. * diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java b/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java index ac9d335d63591..93cd6004cc317 100644 --- a/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java +++ b/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java @@ -60,6 +60,7 @@ import org.apache.spark.shuffle.api.ShufflePartitionWriter; import org.apache.spark.shuffle.api.SingleSpillShuffleMapOutputWriter; import org.apache.spark.shuffle.api.WritableByteChannelWrapper; +import org.apache.spark.shuffle.checksum.RowBasedChecksum; import org.apache.spark.storage.BlockManager; import org.apache.spark.storage.TimeTrackingOutputStream; import org.apache.spark.unsafe.Platform; @@ -103,6 +104,13 @@ private static final class MyByteArrayOutputStream extends ByteArrayOutputStream private MyByteArrayOutputStream serBuffer; private SerializationStream serOutputStream; + /** + * RowBasedChecksum calculator for each partition. RowBasedChecksum is independent + * of the input row order, which is used to detect whether different task attempts + * of the same partition produce different output data or not. + */ + private final RowBasedChecksum[] rowBasedChecksums; + /** * Are we in the process of stopping? Because map tasks can call stop() with success = true * and then call stop() with success = false if they get an exception, we want to make sure @@ -142,6 +150,7 @@ public UnsafeShuffleWriter( (int) (long) sparkConf.get(package$.MODULE$.SHUFFLE_SORT_INIT_BUFFER_SIZE()); this.mergeBufferSizeInBytes = (int) (long) sparkConf.get(package$.MODULE$.SHUFFLE_FILE_MERGE_BUFFER_SIZE()) * 1024; + this.rowBasedChecksums = dep.rowBasedChecksums(); open(); } @@ -163,6 +172,13 @@ public long getPeakMemoryUsedBytes() { return peakMemoryUsedBytes; } + public RowBasedChecksum[] getRowBasedChecksums() { + return rowBasedChecksums; + } + public long getAggregatedChecksumValue() { + return RowBasedChecksum.getAggregatedChecksumValue(rowBasedChecksums); + } + /** * This convenience method should only be called in test code. */ @@ -234,7 +250,7 @@ void closeAndWriteOutput() throws IOException { } } mapStatus = MapStatus$.MODULE$.apply( - blockManager.shuffleServerId(), partitionLengths, mapId); + blockManager.shuffleServerId(), partitionLengths, mapId, getAggregatedChecksumValue()); } @VisibleForTesting @@ -252,6 +268,9 @@ void insertRecordIntoSorter(Product2 record) throws IOException { sorter.insertRecord( serBuffer.getBuf(), Platform.BYTE_ARRAY_OFFSET, serializedRecordSize, partitionId); + if (rowBasedChecksums.length > 0) { + rowBasedChecksums[partitionId].update(key, record._2()); + } } @VisibleForTesting diff --git a/core/src/main/scala/org/apache/spark/Dependency.scala b/core/src/main/scala/org/apache/spark/Dependency.scala index 573608c4327e0..de8ad991152cc 100644 --- a/core/src/main/scala/org/apache/spark/Dependency.scala +++ b/core/src/main/scala/org/apache/spark/Dependency.scala @@ -29,6 +29,7 @@ import org.apache.spark.internal.LogKeys._ import org.apache.spark.rdd.RDD import org.apache.spark.serializer.Serializer import org.apache.spark.shuffle.{ShuffleHandle, ShuffleWriteProcessor} +import org.apache.spark.shuffle.checksum.RowBasedChecksum import org.apache.spark.storage.BlockManagerId import org.apache.spark.util.Utils @@ -83,7 +84,8 @@ class ShuffleDependency[K: ClassTag, V: ClassTag, C: ClassTag]( val keyOrdering: Option[Ordering[K]] = None, val aggregator: Option[Aggregator[K, V, C]] = None, val mapSideCombine: Boolean = false, - val shuffleWriterProcessor: ShuffleWriteProcessor = new ShuffleWriteProcessor) + val shuffleWriterProcessor: ShuffleWriteProcessor = new ShuffleWriteProcessor, + val rowBasedChecksums: Array[RowBasedChecksum] = Array.empty) extends Dependency[Product2[K, V]] with Logging { if (mapSideCombine) { diff --git a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala index a660bccd2e68f..259115167706b 100644 --- a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala +++ b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala @@ -23,7 +23,7 @@ import java.util.concurrent.{ConcurrentHashMap, LinkedBlockingQueue, ThreadPoolE import java.util.concurrent.locks.ReentrantReadWriteLock import scala.collection -import scala.collection.mutable.{HashMap, ListBuffer, Map} +import scala.collection.mutable.{HashMap, ListBuffer, Map, Set} import scala.concurrent.{ExecutionContext, Future} import scala.concurrent.duration.Duration import scala.jdk.CollectionConverters._ @@ -99,6 +99,11 @@ private class ShuffleStatus( */ val mapStatusesDeleted = new Array[MapStatus](numPartitions) + /** + * Keep the indices of the Map tasks whose checksums are different across retries. + */ + private[this] val checksumMismatchIndices : Set[Int] = Set() + /** * MergeStatus for each shuffle partition when push-based shuffle is enabled. The index of the * array is the shuffle partition id (reduce id). Each value in the array is the MergeStatus for @@ -169,6 +174,12 @@ private class ShuffleStatus( } else { mapIdToMapIndex.remove(currentMapStatus.mapId) } + + val preStatus = + if (mapStatuses(mapIndex) != null) mapStatuses(mapIndex) else mapStatusesDeleted(mapIndex) + if (preStatus != null && preStatus.checksumValue != status.checksumValue) { + checksumMismatchIndices.add(mapIndex) + } mapStatuses(mapIndex) = status mapIdToMapIndex(status.mapId) = mapIndex } diff --git a/core/src/main/scala/org/apache/spark/internal/config/package.scala b/core/src/main/scala/org/apache/spark/internal/config/package.scala index 3ce374d0477d8..9b95dddc5a856 100644 --- a/core/src/main/scala/org/apache/spark/internal/config/package.scala +++ b/core/src/main/scala/org/apache/spark/internal/config/package.scala @@ -1633,6 +1633,21 @@ package object config { .booleanConf .createWithDefault(true) + private[spark] val SHUFFLE_ORDER_INDEPENDENT_CHECKSUM_ENABLED = + ConfigBuilder("spark.shuffle.orderIndependentChecksum.enabled") + .doc("Whether to calculate order independent checksum for the shuffle data or not. If " + + "enabled, Spark will calculate a checksum that is independent of the input row order for " + + "each mapper and returns the checksums from executors to driver. Different from the above" + + "checksum, the order independent remains the same even if the shuffle row order changes. " + + "While the above checksum is sensitive to shuffle data ordering to detect file " + + "corruption. This checksum is used to detect whether different task attempts of the same " + + "partition produce different output data or not (same set of keyValue pairs). In case " + + "the output data has changed across retries, Spark will need to retry all tasks of the " + + "consumer stages to avoid correctness issues.") + .version("3.4.0") + .booleanConf + .createWithDefault(false) + private[spark] val SHUFFLE_CHECKSUM_ALGORITHM = ConfigBuilder("spark.shuffle.checksum.algorithm") .doc("The algorithm is used to calculate the shuffle checksum. Currently, it only supports " + diff --git a/core/src/main/scala/org/apache/spark/scheduler/MapStatus.scala b/core/src/main/scala/org/apache/spark/scheduler/MapStatus.scala index 113521453ad7b..64e12555caae3 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/MapStatus.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/MapStatus.scala @@ -24,7 +24,7 @@ import scala.collection.mutable import org.roaringbitmap.RoaringBitmap import org.apache.spark.SparkEnv -import org.apache.spark.internal.config +import org.apache.spark.internal.{config, Logging} import org.apache.spark.storage.BlockManagerId import org.apache.spark.util.Utils @@ -39,7 +39,7 @@ private[spark] trait ShuffleOutputStatus * task has shuffle files stored on as well as the sizes of outputs for each reducer, for passing * on to the reduce tasks. */ -private[spark] sealed trait MapStatus extends ShuffleOutputStatus { +private[spark] sealed trait MapStatus extends ShuffleOutputStatus with Logging { /** Location where this task output is. */ def location: BlockManagerId @@ -58,6 +58,12 @@ private[spark] sealed trait MapStatus extends ShuffleOutputStatus { * partitionId of the task or taskContext.taskAttemptId is used. */ def mapId: Long + + /** + * The checksum value of this shuffle map task, which can be used to evaluate whether the + * output data have changed across different map task retries. + */ + def checksumValue: Long = 0 } @@ -74,11 +80,12 @@ private[spark] object MapStatus { def apply( loc: BlockManagerId, uncompressedSizes: Array[Long], - mapTaskId: Long): MapStatus = { + mapTaskId: Long, + checksumVal: Long = 0): MapStatus = { if (uncompressedSizes.length > minPartitionsToUseHighlyCompressMapStatus) { - HighlyCompressedMapStatus(loc, uncompressedSizes, mapTaskId) + HighlyCompressedMapStatus(loc, uncompressedSizes, mapTaskId, checksumVal) } else { - new CompressedMapStatus(loc, uncompressedSizes, mapTaskId) + new CompressedMapStatus(loc, uncompressedSizes, mapTaskId, checksumVal) } } @@ -119,18 +126,24 @@ private[spark] object MapStatus { * @param loc location where the task is being executed. * @param compressedSizes size of the blocks, indexed by reduce partition id. * @param _mapTaskId unique task id for the task + * @param _checksumVal the checksum value for the task */ private[spark] class CompressedMapStatus( private[this] var loc: BlockManagerId, private[this] var compressedSizes: Array[Byte], - private[this] var _mapTaskId: Long) + private[this] var _mapTaskId: Long, + private[this] var _checksumVal: Long = 0) extends MapStatus with Externalizable { // For deserialization only - protected def this() = this(null, null.asInstanceOf[Array[Byte]], -1) + protected def this() = this(null, null.asInstanceOf[Array[Byte]], -1, 0) - def this(loc: BlockManagerId, uncompressedSizes: Array[Long], mapTaskId: Long) = { - this(loc, uncompressedSizes.map(MapStatus.compressSize), mapTaskId) + def this( + loc: BlockManagerId, + uncompressedSizes: Array[Long], + mapTaskId: Long, + checksumVal: Long) = { + this(loc, uncompressedSizes.map(MapStatus.compressSize), mapTaskId, checksumVal) } override def location: BlockManagerId = loc @@ -145,11 +158,14 @@ private[spark] class CompressedMapStatus( override def mapId: Long = _mapTaskId + override def checksumValue: Long = _checksumVal + override def writeExternal(out: ObjectOutput): Unit = Utils.tryOrIOException { loc.writeExternal(out) out.writeInt(compressedSizes.length) out.write(compressedSizes) out.writeLong(_mapTaskId) + out.writeLong(_checksumVal) } override def readExternal(in: ObjectInput): Unit = Utils.tryOrIOException { @@ -158,6 +174,7 @@ private[spark] class CompressedMapStatus( compressedSizes = new Array[Byte](len) in.readFully(compressedSizes) _mapTaskId = in.readLong() + _checksumVal = in.readLong() } } @@ -172,6 +189,7 @@ private[spark] class CompressedMapStatus( * @param avgSize average size of the non-empty and non-huge blocks * @param hugeBlockSizes sizes of huge blocks by their reduceId. * @param _mapTaskId unique task id for the task + * @param _checksumVal checksum value for the task */ private[spark] class HighlyCompressedMapStatus private ( private[this] var loc: BlockManagerId, @@ -179,7 +197,8 @@ private[spark] class HighlyCompressedMapStatus private ( private[this] var emptyBlocks: RoaringBitmap, private[this] var avgSize: Long, private[this] var hugeBlockSizes: scala.collection.Map[Int, Byte], - private[this] var _mapTaskId: Long) + private[this] var _mapTaskId: Long, + private[this] var _checksumVal: Long = 0) extends MapStatus with Externalizable { // loc could be null when the default constructor is called during deserialization @@ -187,7 +206,7 @@ private[spark] class HighlyCompressedMapStatus private ( || numNonEmptyBlocks == 0 || _mapTaskId > 0, "Average size can only be zero for map stages that produced no output") - protected def this() = this(null, -1, null, -1, null, -1) // For deserialization only + protected def this() = this(null, -1, null, -1, null, -1, 0) // For deserialization only override def location: BlockManagerId = loc @@ -209,6 +228,8 @@ private[spark] class HighlyCompressedMapStatus private ( override def mapId: Long = _mapTaskId + override def checksumValue: Long = _checksumVal + override def writeExternal(out: ObjectOutput): Unit = Utils.tryOrIOException { loc.writeExternal(out) emptyBlocks.serialize(out) @@ -219,6 +240,7 @@ private[spark] class HighlyCompressedMapStatus private ( out.writeByte(kv._2) } out.writeLong(_mapTaskId) + out.writeLong(_checksumVal) } override def readExternal(in: ObjectInput): Unit = Utils.tryOrIOException { @@ -236,6 +258,7 @@ private[spark] class HighlyCompressedMapStatus private ( } hugeBlockSizes = hugeBlockSizesImpl _mapTaskId = in.readLong() + _checksumVal = in.readLong() } } @@ -243,7 +266,8 @@ private[spark] object HighlyCompressedMapStatus { def apply( loc: BlockManagerId, uncompressedSizes: Array[Long], - mapTaskId: Long): HighlyCompressedMapStatus = { + mapTaskId: Long, + checksumVal: Long = 0): HighlyCompressedMapStatus = { // We must keep track of which blocks are empty so that we don't report a zero-sized // block as being non-empty (or vice-versa) when using the average block size. var i = 0 @@ -310,6 +334,6 @@ private[spark] object HighlyCompressedMapStatus { emptyBlocks.trim() emptyBlocks.runOptimize() new HighlyCompressedMapStatus(loc, numNonEmptyBlocks, emptyBlocks, avgSize, - hugeBlockSizes, mapTaskId) + hugeBlockSizes, mapTaskId, checksumVal) } } diff --git a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala index 3be7d24f7e4ec..388ef1e82fa7e 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala @@ -23,6 +23,7 @@ import org.apache.spark.scheduler.MapStatus import org.apache.spark.shuffle.{BaseShuffleHandle, ShuffleWriter} import org.apache.spark.shuffle.ShuffleWriteMetricsReporter import org.apache.spark.shuffle.api.ShuffleExecutorComponents +import org.apache.spark.shuffle.checksum.RowBasedChecksum import org.apache.spark.util.collection.ExternalSorter private[spark] class SortShuffleWriter[K, V, C]( @@ -48,17 +49,27 @@ private[spark] class SortShuffleWriter[K, V, C]( private var partitionLengths: Array[Long] = _ + def getRowBasedChecksums: Array[RowBasedChecksum] = { + if (sorter != null) sorter.getRowBasedChecksums else new Array[RowBasedChecksum](0) + } + + def getAggregatedChecksumValue: Long = { + if (sorter != null) sorter.getAggregatedChecksumValue else 0 + } + /** Write a bunch of records to this task's output */ override def write(records: Iterator[Product2[K, V]]): Unit = { sorter = if (dep.mapSideCombine) { new ExternalSorter[K, V, C]( - context, dep.aggregator, Some(dep.partitioner), dep.keyOrdering, dep.serializer) + context, dep.aggregator, Some(dep.partitioner), dep.keyOrdering, + dep.serializer, dep.rowBasedChecksums) } else { // In this case we pass neither an aggregator nor an ordering to the sorter, because we don't // care whether the keys get sorted in each partition; that will be done on the reduce side // if the operation being run is sortByKey. new ExternalSorter[K, V, V]( - context, aggregator = None, Some(dep.partitioner), ordering = None, dep.serializer) + context, aggregator = None, Some(dep.partitioner), ordering = None, + dep.serializer, dep.rowBasedChecksums) } sorter.insertAll(records) @@ -69,7 +80,8 @@ private[spark] class SortShuffleWriter[K, V, C]( dep.shuffleId, mapId, dep.partitioner.numPartitions) sorter.writePartitionedMapOutput(dep.shuffleId, mapId, mapOutputWriter, writeMetrics) partitionLengths = mapOutputWriter.commitAllPartitions(sorter.getChecksums).getPartitionLengths - mapStatus = MapStatus(blockManager.shuffleServerId, partitionLengths, mapId) + mapStatus = + MapStatus(blockManager.shuffleServerId, partitionLengths, mapId, getAggregatedChecksumValue) } /** Close this writer, passing along whether the map completed */ diff --git a/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala b/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala index 393cdbbef0a5a..c7b49b24b9442 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala @@ -33,7 +33,7 @@ import org.apache.spark.internal.LogKeys.{NUM_BYTES, TASK_ATTEMPT_ID} import org.apache.spark.serializer._ import org.apache.spark.shuffle.{ShufflePartitionPairsWriter, ShuffleWriteMetricsReporter} import org.apache.spark.shuffle.api.{ShuffleMapOutputWriter, ShufflePartitionWriter} -import org.apache.spark.shuffle.checksum.ShuffleChecksumSupport +import org.apache.spark.shuffle.checksum.{RowBasedChecksum, ShuffleChecksumSupport} import org.apache.spark.storage.{BlockId, DiskBlockObjectWriter, ShuffleBlockId} import org.apache.spark.util.{CompletionIterator, Utils => TryUtils} @@ -97,7 +97,8 @@ private[spark] class ExternalSorter[K, V, C]( aggregator: Option[Aggregator[K, V, C]] = None, partitioner: Option[Partitioner] = None, ordering: Option[Ordering[K]] = None, - serializer: Serializer = SparkEnv.get.serializer) + serializer: Serializer = SparkEnv.get.serializer, + rowBasedChecksums: Array[RowBasedChecksum] = Array.empty) extends Spillable[WritablePartitionedPairCollection[K, C]](context.taskMemoryManager()) with Logging with ShuffleChecksumSupport { @@ -142,10 +143,16 @@ private[spark] class ExternalSorter[K, V, C]( private val forceSpillFiles = new ArrayBuffer[SpilledFile] @volatile private var readingIterator: SpillableIterator = null + /** Checksum calculator for each partition. Empty when shuffle checksum disabled. */ private val partitionChecksums = createPartitionChecksums(numPartitions, conf) def getChecksums: Array[Long] = getChecksumValues(partitionChecksums) + def getRowBasedChecksums: Array[RowBasedChecksum] = rowBasedChecksums + + def getAggregatedChecksumValue: Long = + RowBasedChecksum.getAggregatedChecksumValue(rowBasedChecksums) + // A comparator for keys K that orders them within a partition to allow aggregation or sorting. // Can be a partial ordering by hash code if a total ordering is not provided through by the // user. (A partial ordering means that equal keys have comparator.compare(k, k) = 0, but some @@ -197,16 +204,24 @@ private[spark] class ExternalSorter[K, V, C]( while (records.hasNext) { addElementsRead() kv = records.next() - map.changeValue((actualPartitioner.getPartition(kv._1), kv._1), update) + val partitionId = actualPartitioner.getPartition(kv._1) + map.changeValue((partitionId, kv._1), update) maybeSpillCollection(usingMap = true) + if (!rowBasedChecksums.isEmpty) { + rowBasedChecksums(partitionId).update(kv._1, kv._2) + } } } else { // Stick values into our buffer while (records.hasNext) { addElementsRead() val kv = records.next() - buffer.insert(actualPartitioner.getPartition(kv._1), kv._1, kv._2.asInstanceOf[C]) + val partitionId = actualPartitioner.getPartition(kv._1) + buffer.insert(partitionId, kv._1, kv._2.asInstanceOf[C]) maybeSpillCollection(usingMap = false) + if (!rowBasedChecksums.isEmpty) { + rowBasedChecksums(partitionId).update(kv._1, kv._2) + } } } } diff --git a/core/src/test/java/org/apache/spark/shuffle/checksum/RowBasedChecksumSuite.scala b/core/src/test/java/org/apache/spark/shuffle/checksum/RowBasedChecksumSuite.scala new file mode 100644 index 0000000000000..933205e28eb3c --- /dev/null +++ b/core/src/test/java/org/apache/spark/shuffle/checksum/RowBasedChecksumSuite.scala @@ -0,0 +1,116 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.shuffle.checksum + +import org.apache.spark.SparkFunSuite + +class RowBasedChecksumSuite extends SparkFunSuite { + +// test("Invalid checksum algorithm should fail") { +// val rowBasedChecksum = new OutputStreamRowBasedChecksum("invalid") +// rowBasedChecksum.update(Long.box(10), Long.box(20)) +// // We fail to compute the checksum, and getValue returns 0. +// assert(rowBasedChecksum.getValue == 0) +// } + + test("Two identical rows should have a checksum of zero with XOR - ADLER32") { + val rowBasedChecksum = new OutputStreamRowBasedChecksum("ADLER32") + assert(rowBasedChecksum.getValue == 0) + + // Updates the checksum with one row. + rowBasedChecksum.update(Long.box(10), Long.box(20)) + assert(rowBasedChecksum.getValue != 0) + // specific value: + // - Scala 2.12: 3569823494L + // - Scala 2.13: 3329306339L + + // Updates the checksum with the same row again, and the row-based checksum should become 0. + rowBasedChecksum.update(Long.box(10), Long.box(20)) + assert(rowBasedChecksum.getValue == 0) + } + + test("Two identical rows should have a checksum of zero with XOR - CRC32") { + val rowBasedChecksum = new OutputStreamRowBasedChecksum("CRC32") + assert(rowBasedChecksum.getValue == 0) + + // Updates the checksum with one row. + rowBasedChecksum.update(Long.box(10), Long.box(20)) + assert(rowBasedChecksum.getValue != 0) + // specific value: + // - Scala 2.12: 846153942L + // - Scala 2.13: 2004131951L + + // Updates the checksum with the same row again, and the row-based checksum should become 0. + rowBasedChecksum.update(Long.box(10), Long.box(20)) + assert(rowBasedChecksum.getValue == 0) + } + + test("The checksum is independent of row order - two rows") { + val algorithms = Array("ADLER32", "CRC32") + for(algorithm <- algorithms) { + val rowBasedChecksum1 = new OutputStreamRowBasedChecksum(algorithm) + val rowBasedChecksum2 = new OutputStreamRowBasedChecksum(algorithm) + assert(rowBasedChecksum1.getValue == 0) + assert(rowBasedChecksum2.getValue == 0) + + rowBasedChecksum1.update(Long.box(10), Long.box(20)) + rowBasedChecksum2.update(Long.box(30), Long.box(40)) + assert(rowBasedChecksum1.getValue != rowBasedChecksum2.getValue) + + rowBasedChecksum1.update(Long.box(30), Long.box(40)) + rowBasedChecksum2.update(Long.box(10), Long.box(20)) + assert(rowBasedChecksum1.getValue == rowBasedChecksum2.getValue) + + assert(rowBasedChecksum1.getValue != 0) + assert(rowBasedChecksum2.getValue != 0) + } + } + + test("The checksum is independent of row order - multiple rows") { + val algorithms = Array("ADLER32", "CRC32") + for(algorithm <- algorithms) { + val rowBasedChecksum1 = new OutputStreamRowBasedChecksum(algorithm) + val rowBasedChecksum2 = new OutputStreamRowBasedChecksum(algorithm) + assert(rowBasedChecksum1.getValue == 0) + assert(rowBasedChecksum2.getValue == 0) + + rowBasedChecksum1.update(Long.box(10), Long.box(20)) + rowBasedChecksum1.update(Long.box(90), Long.box(100)) + assert(rowBasedChecksum1.getValue != rowBasedChecksum2.getValue) + + rowBasedChecksum1.update(Long.box(30), Long.box(40)) + rowBasedChecksum1.update(Long.box(70), Long.box(80)) + assert(rowBasedChecksum1.getValue != rowBasedChecksum2.getValue) + + rowBasedChecksum1.update(Long.box(50), Long.box(60)) + rowBasedChecksum1.update(Long.box(50), Long.box(60)) + assert(rowBasedChecksum1.getValue != rowBasedChecksum2.getValue) + + rowBasedChecksum1.update(Long.box(70), Long.box(80)) + rowBasedChecksum2.update(Long.box(30), Long.box(40)) + assert(rowBasedChecksum1.getValue != rowBasedChecksum2.getValue) + + rowBasedChecksum1.update(Long.box(90), Long.box(100)) + rowBasedChecksum2.update(Long.box(10), Long.box(20)) + assert(rowBasedChecksum1.getValue == rowBasedChecksum2.getValue) + + assert(rowBasedChecksum1.getValue != 0) + assert(rowBasedChecksum2.getValue != 0) + } + } +} diff --git a/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java b/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java index c55254e04f401..a8f70586e8b9c 100644 --- a/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java +++ b/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java @@ -49,6 +49,7 @@ import org.apache.spark.scheduler.MapStatus; import org.apache.spark.security.CryptoStreamUtils; import org.apache.spark.serializer.*; +import org.apache.spark.shuffle.checksum.RowBasedChecksum; import org.apache.spark.shuffle.IndexShuffleBlockResolver; import org.apache.spark.shuffle.sort.io.LocalDiskShuffleExecutorComponents; import org.apache.spark.storage.*; @@ -174,11 +175,20 @@ public void setUp() throws Exception { File file = (File) invocationOnMock.getArguments()[0]; return Utils.tempFileWith(file); }); - + resetDependency(); when(taskContext.taskMetrics()).thenReturn(taskMetrics); + when(taskContext.taskMemoryManager()).thenReturn(taskMemoryManager); + } + + private void resetDependency() { when(shuffleDep.serializer()).thenReturn(serializer); when(shuffleDep.partitioner()).thenReturn(hashPartitioner); - when(taskContext.taskMemoryManager()).thenReturn(taskMemoryManager); + final int checksumSize = + (boolean) conf.get(package$.MODULE$.SHUFFLE_ORDER_INDEPENDENT_CHECKSUM_ENABLED()) ? NUM_PARTITIONS : 0; + final String checksumAlgorithm = conf.get(package$.MODULE$.SHUFFLE_CHECKSUM_ALGORITHM()); + final RowBasedChecksum[] rowBasedChecksums = + RowBasedChecksum.createPartitionRowBasedChecksums(checksumSize, checksumAlgorithm); + when(shuffleDep.rowBasedChecksums()).thenReturn(rowBasedChecksums); } private UnsafeShuffleWriter createWriter(boolean transferToEnabled) @@ -613,6 +623,40 @@ public void spillFilesAreDeletedWhenStoppingAfterError() throws IOException, Spa assertSpillFilesWereCleanedUp(); } + @Test + public void testRowBasedChecksum() throws IOException, SparkException { + conf.set(package$.MODULE$.SHUFFLE_ORDER_INDEPENDENT_CHECKSUM_ENABLED(), true); + final ArrayList> dataToWrite = new ArrayList<>(); + for (int i = 0; i < NUM_PARTITIONS; i++) { + for (int j = 0; j < 5; j++) { + dataToWrite.add(new Tuple2<>(i, i + j)); + } + } + + long[] checksumValues = new long[0]; + long aggregatedChecksumValue = 0; + for (int i = 0; i < 100; i++) { + resetDependency(); + final UnsafeShuffleWriter writer = createWriter(false); + Collections.shuffle(dataToWrite); + writer.write(dataToWrite.iterator()); + writer.stop(true); + + if (i == 0) { + checksumValues = getRowBasedChecksumValues(writer.getRowBasedChecksums()); + assertEquals(checksumValues.length, NUM_PARTITIONS); + Arrays.stream(checksumValues).allMatch(v -> v > 0); + + aggregatedChecksumValue = writer.getAggregatedChecksumValue(); + assert(aggregatedChecksumValue != 0); + } else { + assertArrayEquals(checksumValues, + getRowBasedChecksumValues(writer.getRowBasedChecksums())); + assertEquals(aggregatedChecksumValue, writer.getAggregatedChecksumValue()); + } + } + } + @Test public void testPeakMemoryUsed() throws Exception { final long recordLengthBytes = 8; diff --git a/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala b/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala index 26dc218c30c74..4453010b6470e 100644 --- a/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala +++ b/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala @@ -271,7 +271,7 @@ class MapOutputTrackerSuite extends SparkFunSuite with LocalSparkContext { masterTracker.registerShuffle(20, 100, MergeStatus.SHUFFLE_PUSH_DUMMY_NUM_REDUCES) (0 until 100).foreach { i => masterTracker.registerMapOutput(20, i, new CompressedMapStatus( - BlockManagerId("999", "mps", 1000), Array.fill[Long](4000000)(0), 5)) + BlockManagerId("999", "mps", 1000), Array.fill[Long](4000000)(0), 5, 100)) } val senderAddress = RpcAddress("localhost", 12345) val rpcCallContext = mock(classOf[RpcCallContext]) @@ -578,7 +578,7 @@ class MapOutputTrackerSuite extends SparkFunSuite with LocalSparkContext { masterTracker.registerShuffle(20, 100, MergeStatus.SHUFFLE_PUSH_DUMMY_NUM_REDUCES) (0 until 100).foreach { i => masterTracker.registerMapOutput(20, i, new CompressedMapStatus( - BlockManagerId("999", "mps", 1000), Array.fill[Long](4000000)(0), 5)) + BlockManagerId("999", "mps", 1000), Array.fill[Long](4000000)(0), 5, 100)) } val mapWorkerRpcEnv = createRpcEnv("spark-worker", "localhost", 0, new SecurityManager(conf)) @@ -625,7 +625,7 @@ class MapOutputTrackerSuite extends SparkFunSuite with LocalSparkContext { masterTracker.registerShuffle(20, 100, MergeStatus.SHUFFLE_PUSH_DUMMY_NUM_REDUCES) (0 until 100).foreach { i => masterTracker.registerMapOutput(20, i, new CompressedMapStatus( - BlockManagerId("999", "mps", 1000), Array.fill[Long](4000000)(0), 5)) + BlockManagerId("999", "mps", 1000), Array.fill[Long](4000000)(0), 5, 100)) } masterTracker.registerMergeResult(20, 0, MergeStatus(BlockManagerId("999", "mps", 1000), 0, bitmap1, 1000L)) diff --git a/core/src/test/scala/org/apache/spark/MapStatusesSerDeserBenchmark.scala b/core/src/test/scala/org/apache/spark/MapStatusesSerDeserBenchmark.scala index 75f952d063d33..bd8766bd260e9 100644 --- a/core/src/test/scala/org/apache/spark/MapStatusesSerDeserBenchmark.scala +++ b/core/src/test/scala/org/apache/spark/MapStatusesSerDeserBenchmark.scala @@ -59,7 +59,7 @@ object MapStatusesSerDeserBenchmark extends BenchmarkBase { Array.fill(blockSize) { // Creating block size ranging from 0byte to 1GB (r.nextDouble() * 1024 * 1024 * 1024).toLong - }, i)) + }, i, i * 100)) } val shuffleStatus = tracker.shuffleStatuses.get(shuffleId).head diff --git a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala index 3e507df706ba5..841fafd15ff89 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala @@ -1160,7 +1160,8 @@ class DAGSchedulerSuite extends SparkFunSuite with TempLocalSparkContext with Ti stageId: Int, attemptIdx: Int, numShufflePartitions: Int, - hostNames: Seq[String] = Seq.empty[String]): Unit = { + hostNames: Seq[String] = Seq.empty[String], + checksumVal: Long = 0): Unit = { def compareStageAttempt(taskSet: TaskSet): Boolean = { taskSet.stageId == stageId && taskSet.stageAttemptId == attemptIdx } @@ -1175,7 +1176,7 @@ class DAGSchedulerSuite extends SparkFunSuite with TempLocalSparkContext with Ti } else { s"host${('A' + idx).toChar}" } - (Success, makeMapStatus(hostName, numShufflePartitions)) + (Success, makeMapStatus(hostName, numShufflePartitions, checksumVal = checksumVal)) }.toSeq) } @@ -4757,6 +4758,42 @@ class DAGSchedulerSuite extends SparkFunSuite with TempLocalSparkContext with Ti assert(mapStatuses.count(s => s != null && s.location.executorId == "hostB-exec") === 1) } + /** + * In this test, we simulate a job where some tasks in a stage fail, and it triggers the retry + * of the task in its previous stage. The two attempts of the same task in the previous stage + * produce different shuffle checksums. + */ + test("Output usage log for tasks that produce different checksum across retries") { + setupStageAbortTest(sc) + + val parts = 8 + val shuffleMapRdd = new MyRDD(sc, parts, Nil) + val shuffleDep = new ShuffleDependency(shuffleMapRdd, new HashPartitioner(parts)) + val reduceRdd = new MyRDD(sc, parts, List(shuffleDep), tracker = mapOutputTracker) + submit(reduceRdd, (0 until parts).toArray) + + // Complete stage 0 and then fail stage 1, and tasks in stage 0 produce a checksum of 100. + completeShuffleMapStageSuccessfully(0, 0, numShufflePartitions = parts, checksumVal = 100) + completeNextStageWithFetchFailure(1, 0, shuffleDep) + + // Resubmit and confirm that now all is well. + scheduler.resubmitFailedStages() + assert(scheduler.runningStages.nonEmpty) + assert(!ended) + + // Complete stage 0 and then stage 1, and the retried task in stage 0 produces a different + // checksum of 200. + completeShuffleMapStageSuccessfully(0, 1, numShufflePartitions = parts, checksumVal = 200) + completeNextResultStageWithSuccess(1, 1) + + // Confirm job finished successfully. + sc.listenerBus.waitUntilEmpty() + assert(ended) + assert(results == (0 until parts).map { idx => idx -> 42 }.toMap) + assertDataStructuresEmpty() + mapOutputTracker.unregisterShuffle(shuffleDep.shuffleId) + } + Seq(true, false).foreach { registerMergeResults => test("SPARK-40096: Send finalize events even if shuffle merger blocks indefinitely " + s"with registerMergeResults is ${registerMergeResults}") { @@ -5144,8 +5181,13 @@ class DAGSchedulerAbortStageOffSuite extends DAGSchedulerSuite { object DAGSchedulerSuite { val mergerLocs = ArrayBuffer[BlockManagerId]() - def makeMapStatus(host: String, reduces: Int, sizes: Byte = 2, mapTaskId: Long = -1): MapStatus = - MapStatus(makeBlockManagerId(host), Array.fill[Long](reduces)(sizes), mapTaskId) + def makeMapStatus( + host: String, + reduces: Int, + sizes: Byte = 2, + mapTaskId: Long = -1, + checksumVal: Long = 0): MapStatus = + MapStatus(makeBlockManagerId(host), Array.fill[Long](reduces)(sizes), mapTaskId, checksumVal) def makeBlockManagerId(host: String, execId: Option[String] = None): BlockManagerId = { BlockManagerId(execId.getOrElse(host + "-exec"), host, 12345) diff --git a/core/src/test/scala/org/apache/spark/shuffle/ShuffleChecksumTestHelper.scala b/core/src/test/scala/org/apache/spark/shuffle/ShuffleChecksumTestHelper.scala index 8be103b7be860..c5b6c9faaa1a2 100644 --- a/core/src/test/scala/org/apache/spark/shuffle/ShuffleChecksumTestHelper.scala +++ b/core/src/test/scala/org/apache/spark/shuffle/ShuffleChecksumTestHelper.scala @@ -19,6 +19,8 @@ package org.apache.spark.shuffle import java.io.File +import org.apache.spark.shuffle.checksum.RowBasedChecksum + trait ShuffleChecksumTestHelper { /** @@ -37,4 +39,12 @@ trait ShuffleChecksumTestHelper { assert(ShuffleChecksumUtils.compareChecksums(numPartition, algorithm, checksum, data, index), "checksum must be consistent at both write and read sides") } + + def getRowBasedChecksumValues(rowBasedChecksums: Array[RowBasedChecksum]): Array[Long] = { + if (rowBasedChecksums == null) { + Array.empty + } else { + rowBasedChecksums.map(_.getValue) + } + } } diff --git a/core/src/test/scala/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriterSuite.scala b/core/src/test/scala/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriterSuite.scala index ce2aefa74229a..e9f23e0da7eb4 100644 --- a/core/src/test/scala/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriterSuite.scala +++ b/core/src/test/scala/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriterSuite.scala @@ -22,6 +22,7 @@ import java.util.UUID import scala.collection.mutable import scala.collection.mutable.ArrayBuffer +import scala.util.Random import org.mockito.{Mock, MockitoAnnotations} import org.mockito.Answers.RETURNS_SMART_NULLS @@ -37,6 +38,7 @@ import org.apache.spark.network.shuffle.checksum.ShuffleChecksumHelper import org.apache.spark.serializer.{JavaSerializer, SerializerInstance, SerializerManager} import org.apache.spark.shuffle.{IndexShuffleBlockResolver, ShuffleChecksumTestHelper} import org.apache.spark.shuffle.api.ShuffleExecutorComponents +import org.apache.spark.shuffle.checksum.RowBasedChecksum import org.apache.spark.shuffle.sort.io.LocalDiskShuffleExecutorComponents import org.apache.spark.storage._ import org.apache.spark.util.Utils @@ -74,8 +76,7 @@ class BypassMergeSortShuffleWriterSuite ) val memoryManager = new TestMemoryManager(conf) val taskMemoryManager = new TaskMemoryManager(memoryManager, 0) - when(dependency.partitioner).thenReturn(new HashPartitioner(7)) - when(dependency.serializer).thenReturn(new JavaSerializer(conf)) + resetDependency(conf) when(taskContext.taskMetrics()).thenReturn(taskMetrics) when(blockResolver.getDataFile(0, 0)).thenReturn(outputFile) when(blockManager.diskBlockManager).thenReturn(diskBlockManager) @@ -145,6 +146,19 @@ class BypassMergeSortShuffleWriterSuite } } + private def resetDependency(sc: SparkConf): Unit = { + reset(dependency) + val numPartitions = 7 + when(dependency.partitioner).thenReturn(new HashPartitioner(numPartitions)) + when(dependency.serializer).thenReturn(new JavaSerializer(sc)) + val checksumSize = + if (sc.get(config.SHUFFLE_ORDER_INDEPENDENT_CHECKSUM_ENABLED)) numPartitions else 0 + val checksumAlgorithm = conf.get(config.SHUFFLE_CHECKSUM_ALGORITHM) + val rowBasedChecksums = + RowBasedChecksum.createPartitionRowBasedChecksums(checksumSize, checksumAlgorithm) + when(dependency.rowBasedChecksums).thenReturn(rowBasedChecksums) + } + test("write empty iterator") { val writer = new BypassMergeSortShuffleWriter[Int, Int]( blockManager, @@ -294,4 +308,46 @@ class BypassMergeSortShuffleWriterSuite assert(checksumFile.length() === 8 * numPartition) compareChecksums(numPartition, checksumAlgorithm, checksumFile, dataFile, indexFile) } + + test("Row-based checksums are independent of input row order") { + val transferConf = + conf.clone.set(config.SHUFFLE_ORDER_INDEPENDENT_CHECKSUM_ENABLED.key, true.toString) + val records: List[(Int, Int)] = List( + (1, 1), (1, 2), (1, 3), (1, 4), (1, 5), + (2, 2), (2, 3), (2, 4), (2, 5), (2, 6), + (3, 3), (3, 4), (3, 5), (3, 6), (3, 7), + (4, 4), (4, 5), (4, 6), (4, 7), (4, 8), + (5, 5), (5, 6), (5, 7), (5, 8), (5, 9), + (6, 6), (6, 7), (6, 8), (6, 9), (6, 10), + (7, 7), (7, 8), (7, 9), (7, 10), (7, 11)) + + var checksumValues : Array[Long] = Array[Long]() + var aggregatedChecksumValue = 0L + for (i <- 1 to 100) { + resetDependency(transferConf); + val writer = new BypassMergeSortShuffleWriter[Int, Int]( + blockManager, + shuffleHandle, + 0L, // MapId + transferConf, + taskContext.taskMetrics().shuffleWriteMetrics, + shuffleExecutorComponents) + + writer.write(Random.shuffle(records).iterator) + writer.stop(/* success = */ true) + + if(i == 1) { + checksumValues = getRowBasedChecksumValues(writer.getRowBasedChecksums) + assert(checksumValues.length > 0) + assert(checksumValues.forall(_ > 0)) + + aggregatedChecksumValue = writer.getAggregatedChecksumValue() + assert(aggregatedChecksumValue != 0) + } else { + assert(checksumValues.sameElements( + getRowBasedChecksumValues(writer.getRowBasedChecksums))) + assert(aggregatedChecksumValue == writer.getAggregatedChecksumValue()) + } + } + } } diff --git a/core/src/test/scala/org/apache/spark/shuffle/sort/SortShuffleWriterSuite.scala b/core/src/test/scala/org/apache/spark/shuffle/sort/SortShuffleWriterSuite.scala index 99402abb16cac..cefb0d2aabde5 100644 --- a/core/src/test/scala/org/apache/spark/shuffle/sort/SortShuffleWriterSuite.scala +++ b/core/src/test/scala/org/apache/spark/shuffle/sort/SortShuffleWriterSuite.scala @@ -17,6 +17,8 @@ package org.apache.spark.shuffle.sort +import scala.util.Random + import org.mockito.{Mock, MockitoAnnotations} import org.mockito.Answers.RETURNS_SMART_NULLS import org.mockito.Mockito._ @@ -29,6 +31,7 @@ import org.apache.spark.memory.MemoryTestingUtils import org.apache.spark.serializer.JavaSerializer import org.apache.spark.shuffle.{BaseShuffleHandle, IndexShuffleBlockResolver, ShuffleChecksumTestHelper} import org.apache.spark.shuffle.api.ShuffleExecutorComponents +import org.apache.spark.shuffle.checksum.RowBasedChecksum import org.apache.spark.shuffle.sort.io.LocalDiskShuffleExecutorComponents import org.apache.spark.storage.BlockManager import org.apache.spark.util.Utils @@ -50,6 +53,7 @@ class SortShuffleWriterSuite private val shuffleBlockResolver = new IndexShuffleBlockResolver(conf) private val serializer = new JavaSerializer(conf) private var shuffleExecutorComponents: ShuffleExecutorComponents = _ + @Mock(answer = RETURNS_SMART_NULLS) private var dependency: ShuffleDependency[Int, Int, Int] = _ private val partitioner = new Partitioner() { def numPartitions = numMaps @@ -60,13 +64,9 @@ class SortShuffleWriterSuite super.beforeEach() MockitoAnnotations.openMocks(this).close() shuffleHandle = { - val dependency = mock(classOf[ShuffleDependency[Int, Int, Int]]) - when(dependency.partitioner).thenReturn(partitioner) - when(dependency.serializer).thenReturn(serializer) - when(dependency.aggregator).thenReturn(None) - when(dependency.keyOrdering).thenReturn(None) new BaseShuffleHandle(shuffleId, dependency) } + resetDependency() shuffleExecutorComponents = new LocalDiskShuffleExecutorComponents( conf, blockManager, shuffleBlockResolver) } @@ -79,6 +79,20 @@ class SortShuffleWriterSuite } } + private def resetDependency(): Unit = { + reset(dependency); + when(dependency.partitioner).thenReturn(partitioner) + when(dependency.serializer).thenReturn(serializer) + when(dependency.aggregator).thenReturn(None) + when(dependency.keyOrdering).thenReturn(None) + val checksumSize = + if (conf.get(config.SHUFFLE_ORDER_INDEPENDENT_CHECKSUM_ENABLED)) numMaps else 0 + val checksumAlgorithm = conf.get(config.SHUFFLE_CHECKSUM_ALGORITHM) + val rowBasedChecksums = RowBasedChecksum.createPartitionRowBasedChecksums( + checksumSize, checksumAlgorithm) + when(dependency.rowBasedChecksums).thenReturn(rowBasedChecksums) + } + test("write empty iterator") { val context = MemoryTestingUtils.fakeTaskContext(sc.env) val writer = new SortShuffleWriter[Int, Int, Int]( @@ -114,6 +128,50 @@ class SortShuffleWriterSuite assert(records.size === writeMetrics.recordsWritten) } + test("Row-based checksums are independent of input row order") { + conf.set("spark.shuffle.orderIndependentChecksum.enabled", true.toString) + // FIXME: this can affect other tests (if any) after this set of tests + // since `sc` is global. + sc.stop() + val localSC = new SparkContext("local[4]", "test", conf) + val shuffleBlockResolver = new IndexShuffleBlockResolver(conf) + val context = MemoryTestingUtils.fakeTaskContext(localSC.env) + val records: List[(Int, Int)] = List( + (1, 1), (1, 2), (1, 3), (1, 4), (1, 5), + (2, 2), (2, 3), (2, 4), (2, 5), (2, 6), + (3, 3), (3, 4), (3, 5), (3, 6), (3, 7), + (4, 4), (4, 5), (4, 6), (4, 7), (4, 8), + (5, 5), (5, 6), (5, 7), (5, 8), (5, 9)) + + var checksumValues : Array[Long] = Array[Long]() + var aggregatedChecksumValue = 0L + for (i <- 1 to 100) { + resetDependency() + val writer = new SortShuffleWriter[Int, Int, Int]( + shuffleHandle, + mapId = 2, + context, + context.taskMetrics().shuffleWriteMetrics, + new LocalDiskShuffleExecutorComponents( + conf, shuffleBlockResolver._blockManager, shuffleBlockResolver)) + writer.write(Random.shuffle(records).iterator) + if(i == 1) { + checksumValues = getRowBasedChecksumValues(writer.getRowBasedChecksums) + assert(checksumValues.length > 0) + assert(checksumValues.forall(_ > 0)) + + aggregatedChecksumValue = writer.getAggregatedChecksumValue + assert(aggregatedChecksumValue != 0) + } else { + assert(checksumValues.sameElements( + getRowBasedChecksumValues(writer.getRowBasedChecksums))) + assert(aggregatedChecksumValue == writer.getAggregatedChecksumValue) + } + writer.stop(success = true) + } + localSC.stop() + } + Seq((true, false, false), (true, true, false), (true, false, true), @@ -141,6 +199,12 @@ class SortShuffleWriterSuite when(dependency.serializer).thenReturn(serializer) when(dependency.aggregator).thenReturn(aggregator) when(dependency.keyOrdering).thenReturn(order) + val checksumSize = + if (conf.get(config.SHUFFLE_ORDER_INDEPENDENT_CHECKSUM_ENABLED)) numMaps else 0 + val checksumAlgorithm = conf.get(config.SHUFFLE_CHECKSUM_ALGORITHM) + val rowBasedChecksums: Array[RowBasedChecksum] = + RowBasedChecksum.createPartitionRowBasedChecksums(checksumSize, checksumAlgorithm) + when(dependency.rowBasedChecksums).thenReturn(rowBasedChecksums) new BaseShuffleHandle[Int, Int, Int](shuffleId, dependency) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchangeExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchangeExec.scala index 31a3f53eb7191..57d57640a1d8c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchangeExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchangeExec.scala @@ -28,6 +28,7 @@ import org.apache.spark.internal.config import org.apache.spark.rdd.{RDD, RDDOperationScope} import org.apache.spark.serializer.Serializer import org.apache.spark.shuffle.{ShuffleWriteMetricsReporter, ShuffleWriteProcessor} +import org.apache.spark.shuffle.checksum.RowBasedChecksum import org.apache.spark.shuffle.sort.SortShuffleManager import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.{Attribute, BoundReference, UnsafeProjection, UnsafeRow} @@ -471,12 +472,21 @@ object ShuffleExchangeExec { // Now, we manually create a ShuffleDependency. Because pairs in rddWithPartitionIds // are in the form of (partitionId, row) and every partitionId is in the expected range // [0, part.numPartitions - 1]. The partitioner of this is a PartitionIdPassthrough. + val checksumSize = + if (SparkEnv.get.conf.get(config.SHUFFLE_ORDER_INDEPENDENT_CHECKSUM_ENABLED)) { + part.numPartitions + } else { + 0 + } + val checksumAlgorithm = SparkEnv.get.conf.get(config.SHUFFLE_CHECKSUM_ALGORITHM) val dependency = new ShuffleDependency[Int, InternalRow, InternalRow]( rddWithPartitionIds, new PartitionIdPassthrough(part.numPartitions), serializer, - shuffleWriterProcessor = createShuffleWriteProcessor(writeMetrics)) + shuffleWriterProcessor = createShuffleWriteProcessor(writeMetrics), + rowBasedChecksums = RowBasedChecksum.createPartitionRowBasedChecksums( + checksumSize, checksumAlgorithm)) dependency } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/MapStatusEndToEndSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/MapStatusEndToEndSuite.scala new file mode 100644 index 0000000000000..c9e47b3791b6e --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/MapStatusEndToEndSuite.scala @@ -0,0 +1,60 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import org.apache.spark.{MapOutputTrackerMaster, SparkFunSuite} +import org.apache.spark.internal.config.SHUFFLE_ORDER_INDEPENDENT_CHECKSUM_ENABLED +import org.apache.spark.sql.classic.SparkSession +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.test.SQLTestUtils + +class MapStatusEndToEndSuite extends SparkFunSuite with SQLTestUtils { + override def spark: SparkSession = SparkSession.builder() + .master("local") + .config(SHUFFLE_ORDER_INDEPENDENT_CHECKSUM_ENABLED.key, value = true) + .config(SQLConf.LEAF_NODE_DEFAULT_PARALLELISM.key, value = 5) + .getOrCreate() + + override def afterEach(): Unit = { + // This suite should not interfere with the other test suites. + SparkSession.getActiveSession.foreach(_.stop()) + SparkSession.clearActiveSession() + SparkSession.getDefaultSession.foreach(_.stop()) + SparkSession.clearDefaultSession() + } + + test("Propagate checksum from executor to driver") { + assert(spark.sparkContext.conf.get("spark.shuffle.orderIndependentChecksum.enabled") == "true") + assert(spark.conf.get("spark.shuffle.orderIndependentChecksum.enabled") == "true") + assert(spark.sparkContext.conf.get("spark.sql.leafNodeDefaultParallelism") == "5") + assert(spark.conf.get("spark.sql.leafNodeDefaultParallelism") == "5") + + withTable("t") { + spark.range(1000).repartition(10).write.mode("overwrite"). + saveAsTable("t") + } + + val shuffleStatuses = spark.sparkContext.env.mapOutputTracker. + asInstanceOf[MapOutputTrackerMaster].shuffleStatuses + assert(shuffleStatuses.size == 1) + + val mapStatuses = shuffleStatuses(0).mapStatuses + assert(mapStatuses.length == 5) + assert(mapStatuses.forall(_.checksumValue != 0)) + } +} From 53d11af7a7b14ea500d947da1c9003cf23bdd223 Mon Sep 17 00:00:00 2001 From: Jiexing Li Date: Wed, 12 Mar 2025 11:50:07 -0700 Subject: [PATCH 02/29] faster checksum --- .../shuffle/checksum/RowBasedChecksum.scala | 26 ++- .../spark/internal/config/package.scala | 15 -- .../apache/spark/scheduler/MapStatus.scala | 4 +- .../checksum/RowBasedChecksumSuite.scala | 116 ---------- .../sort/UnsafeShuffleWriterSuite.java | 10 +- .../BypassMergeSortShuffleWriterSuite.scala | 15 +- .../shuffle/sort/SortShuffleWriterSuite.scala | 12 +- .../expressions/UnsafeRowChecksum.scala | 57 +++++ .../apache/spark/sql/internal/SQLConf.scala | 15 ++ .../exchange/ShuffleExchangeExec.scala | 9 +- .../spark/sql/MapStatusEndToEndSuite.scala | 3 +- .../spark/sql/UnsafeRowChecksumSuite.scala | 208 ++++++++++++++++++ 12 files changed, 321 insertions(+), 169 deletions(-) delete mode 100644 core/src/test/java/org/apache/spark/shuffle/checksum/RowBasedChecksumSuite.scala create mode 100644 sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRowChecksum.scala create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/UnsafeRowChecksumSuite.scala diff --git a/core/src/main/java/org/apache/spark/shuffle/checksum/RowBasedChecksum.scala b/core/src/main/java/org/apache/spark/shuffle/checksum/RowBasedChecksum.scala index 8d32126605a89..c1f50a93be03c 100644 --- a/core/src/main/java/org/apache/spark/shuffle/checksum/RowBasedChecksum.scala +++ b/core/src/main/java/org/apache/spark/shuffle/checksum/RowBasedChecksum.scala @@ -20,6 +20,9 @@ package org.apache.spark.shuffle.checksum import java.io.{ByteArrayOutputStream, ObjectOutputStream} import java.util.zip.Checksum +import scala.util.control.NonFatal + +import org.apache.spark.internal.Logging import org.apache.spark.network.shuffle.checksum.ShuffleChecksumHelper /** @@ -27,15 +30,28 @@ import org.apache.spark.network.shuffle.checksum.ShuffleChecksumHelper * the order of the input (key, value) pairs. It is done by computing a checksum for each row * first, and then computing the XOR for all the row checksums. */ -abstract class RowBasedChecksum() extends Serializable { +abstract class RowBasedChecksum() extends Serializable with Logging { + private var hasError: Boolean = false private var checksumValue: Long = 0 - /** Returns the checksum value computed */ - def getValue: Long = checksumValue + /** Returns the checksum value computed. Tt returns the default checksum value (0) if there + * are any errors encountered during the checksum computation. + */ + def getValue: Long = { + if (!hasError) checksumValue else 0 + } /** Updates the row-based checksum with the given (key, value) pair */ def update(key: Any, value: Any): Unit = { - val rowChecksumValue = calculateRowChecksum(key, value) - checksumValue = checksumValue ^ rowChecksumValue + if (!hasError) { + try { + val rowChecksumValue = calculateRowChecksum(key, value) + checksumValue = checksumValue ^ rowChecksumValue + } catch { + case NonFatal(e) => + logInfo("Checksum computation encountered error: ", e) + hasError = true + } + } } /** Computes and returns the checksum value for the given (key, value) pair */ diff --git a/core/src/main/scala/org/apache/spark/internal/config/package.scala b/core/src/main/scala/org/apache/spark/internal/config/package.scala index 9b95dddc5a856..3ce374d0477d8 100644 --- a/core/src/main/scala/org/apache/spark/internal/config/package.scala +++ b/core/src/main/scala/org/apache/spark/internal/config/package.scala @@ -1633,21 +1633,6 @@ package object config { .booleanConf .createWithDefault(true) - private[spark] val SHUFFLE_ORDER_INDEPENDENT_CHECKSUM_ENABLED = - ConfigBuilder("spark.shuffle.orderIndependentChecksum.enabled") - .doc("Whether to calculate order independent checksum for the shuffle data or not. If " + - "enabled, Spark will calculate a checksum that is independent of the input row order for " + - "each mapper and returns the checksums from executors to driver. Different from the above" + - "checksum, the order independent remains the same even if the shuffle row order changes. " + - "While the above checksum is sensitive to shuffle data ordering to detect file " + - "corruption. This checksum is used to detect whether different task attempts of the same " + - "partition produce different output data or not (same set of keyValue pairs). In case " + - "the output data has changed across retries, Spark will need to retry all tasks of the " + - "consumer stages to avoid correctness issues.") - .version("3.4.0") - .booleanConf - .createWithDefault(false) - private[spark] val SHUFFLE_CHECKSUM_ALGORITHM = ConfigBuilder("spark.shuffle.checksum.algorithm") .doc("The algorithm is used to calculate the shuffle checksum. Currently, it only supports " + diff --git a/core/src/main/scala/org/apache/spark/scheduler/MapStatus.scala b/core/src/main/scala/org/apache/spark/scheduler/MapStatus.scala index 64e12555caae3..f7490d0182883 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/MapStatus.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/MapStatus.scala @@ -24,7 +24,7 @@ import scala.collection.mutable import org.roaringbitmap.RoaringBitmap import org.apache.spark.SparkEnv -import org.apache.spark.internal.{config, Logging} +import org.apache.spark.internal.config import org.apache.spark.storage.BlockManagerId import org.apache.spark.util.Utils @@ -39,7 +39,7 @@ private[spark] trait ShuffleOutputStatus * task has shuffle files stored on as well as the sizes of outputs for each reducer, for passing * on to the reduce tasks. */ -private[spark] sealed trait MapStatus extends ShuffleOutputStatus with Logging { +private[spark] sealed trait MapStatus extends ShuffleOutputStatus { /** Location where this task output is. */ def location: BlockManagerId diff --git a/core/src/test/java/org/apache/spark/shuffle/checksum/RowBasedChecksumSuite.scala b/core/src/test/java/org/apache/spark/shuffle/checksum/RowBasedChecksumSuite.scala deleted file mode 100644 index 933205e28eb3c..0000000000000 --- a/core/src/test/java/org/apache/spark/shuffle/checksum/RowBasedChecksumSuite.scala +++ /dev/null @@ -1,116 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.shuffle.checksum - -import org.apache.spark.SparkFunSuite - -class RowBasedChecksumSuite extends SparkFunSuite { - -// test("Invalid checksum algorithm should fail") { -// val rowBasedChecksum = new OutputStreamRowBasedChecksum("invalid") -// rowBasedChecksum.update(Long.box(10), Long.box(20)) -// // We fail to compute the checksum, and getValue returns 0. -// assert(rowBasedChecksum.getValue == 0) -// } - - test("Two identical rows should have a checksum of zero with XOR - ADLER32") { - val rowBasedChecksum = new OutputStreamRowBasedChecksum("ADLER32") - assert(rowBasedChecksum.getValue == 0) - - // Updates the checksum with one row. - rowBasedChecksum.update(Long.box(10), Long.box(20)) - assert(rowBasedChecksum.getValue != 0) - // specific value: - // - Scala 2.12: 3569823494L - // - Scala 2.13: 3329306339L - - // Updates the checksum with the same row again, and the row-based checksum should become 0. - rowBasedChecksum.update(Long.box(10), Long.box(20)) - assert(rowBasedChecksum.getValue == 0) - } - - test("Two identical rows should have a checksum of zero with XOR - CRC32") { - val rowBasedChecksum = new OutputStreamRowBasedChecksum("CRC32") - assert(rowBasedChecksum.getValue == 0) - - // Updates the checksum with one row. - rowBasedChecksum.update(Long.box(10), Long.box(20)) - assert(rowBasedChecksum.getValue != 0) - // specific value: - // - Scala 2.12: 846153942L - // - Scala 2.13: 2004131951L - - // Updates the checksum with the same row again, and the row-based checksum should become 0. - rowBasedChecksum.update(Long.box(10), Long.box(20)) - assert(rowBasedChecksum.getValue == 0) - } - - test("The checksum is independent of row order - two rows") { - val algorithms = Array("ADLER32", "CRC32") - for(algorithm <- algorithms) { - val rowBasedChecksum1 = new OutputStreamRowBasedChecksum(algorithm) - val rowBasedChecksum2 = new OutputStreamRowBasedChecksum(algorithm) - assert(rowBasedChecksum1.getValue == 0) - assert(rowBasedChecksum2.getValue == 0) - - rowBasedChecksum1.update(Long.box(10), Long.box(20)) - rowBasedChecksum2.update(Long.box(30), Long.box(40)) - assert(rowBasedChecksum1.getValue != rowBasedChecksum2.getValue) - - rowBasedChecksum1.update(Long.box(30), Long.box(40)) - rowBasedChecksum2.update(Long.box(10), Long.box(20)) - assert(rowBasedChecksum1.getValue == rowBasedChecksum2.getValue) - - assert(rowBasedChecksum1.getValue != 0) - assert(rowBasedChecksum2.getValue != 0) - } - } - - test("The checksum is independent of row order - multiple rows") { - val algorithms = Array("ADLER32", "CRC32") - for(algorithm <- algorithms) { - val rowBasedChecksum1 = new OutputStreamRowBasedChecksum(algorithm) - val rowBasedChecksum2 = new OutputStreamRowBasedChecksum(algorithm) - assert(rowBasedChecksum1.getValue == 0) - assert(rowBasedChecksum2.getValue == 0) - - rowBasedChecksum1.update(Long.box(10), Long.box(20)) - rowBasedChecksum1.update(Long.box(90), Long.box(100)) - assert(rowBasedChecksum1.getValue != rowBasedChecksum2.getValue) - - rowBasedChecksum1.update(Long.box(30), Long.box(40)) - rowBasedChecksum1.update(Long.box(70), Long.box(80)) - assert(rowBasedChecksum1.getValue != rowBasedChecksum2.getValue) - - rowBasedChecksum1.update(Long.box(50), Long.box(60)) - rowBasedChecksum1.update(Long.box(50), Long.box(60)) - assert(rowBasedChecksum1.getValue != rowBasedChecksum2.getValue) - - rowBasedChecksum1.update(Long.box(70), Long.box(80)) - rowBasedChecksum2.update(Long.box(30), Long.box(40)) - assert(rowBasedChecksum1.getValue != rowBasedChecksum2.getValue) - - rowBasedChecksum1.update(Long.box(90), Long.box(100)) - rowBasedChecksum2.update(Long.box(10), Long.box(20)) - assert(rowBasedChecksum1.getValue == rowBasedChecksum2.getValue) - - assert(rowBasedChecksum1.getValue != 0) - assert(rowBasedChecksum2.getValue != 0) - } - } -} diff --git a/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java b/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java index a8f70586e8b9c..930e5721ce4bb 100644 --- a/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java +++ b/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java @@ -175,16 +175,15 @@ public void setUp() throws Exception { File file = (File) invocationOnMock.getArguments()[0]; return Utils.tempFileWith(file); }); - resetDependency(); + resetDependency(false); when(taskContext.taskMetrics()).thenReturn(taskMetrics); when(taskContext.taskMemoryManager()).thenReturn(taskMemoryManager); } - private void resetDependency() { + private void resetDependency(boolean rowbasedChecksumEnabled) { when(shuffleDep.serializer()).thenReturn(serializer); when(shuffleDep.partitioner()).thenReturn(hashPartitioner); - final int checksumSize = - (boolean) conf.get(package$.MODULE$.SHUFFLE_ORDER_INDEPENDENT_CHECKSUM_ENABLED()) ? NUM_PARTITIONS : 0; + final int checksumSize = rowbasedChecksumEnabled ? NUM_PARTITIONS : 0; final String checksumAlgorithm = conf.get(package$.MODULE$.SHUFFLE_CHECKSUM_ALGORITHM()); final RowBasedChecksum[] rowBasedChecksums = RowBasedChecksum.createPartitionRowBasedChecksums(checksumSize, checksumAlgorithm); @@ -625,7 +624,6 @@ public void spillFilesAreDeletedWhenStoppingAfterError() throws IOException, Spa @Test public void testRowBasedChecksum() throws IOException, SparkException { - conf.set(package$.MODULE$.SHUFFLE_ORDER_INDEPENDENT_CHECKSUM_ENABLED(), true); final ArrayList> dataToWrite = new ArrayList<>(); for (int i = 0; i < NUM_PARTITIONS; i++) { for (int j = 0; j < 5; j++) { @@ -636,7 +634,7 @@ public void testRowBasedChecksum() throws IOException, SparkException { long[] checksumValues = new long[0]; long aggregatedChecksumValue = 0; for (int i = 0; i < 100; i++) { - resetDependency(); + resetDependency(true); final UnsafeShuffleWriter writer = createWriter(false); Collections.shuffle(dataToWrite); writer.write(dataToWrite.iterator()); diff --git a/core/src/test/scala/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriterSuite.scala b/core/src/test/scala/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriterSuite.scala index e9f23e0da7eb4..408e60e5e90d8 100644 --- a/core/src/test/scala/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriterSuite.scala +++ b/core/src/test/scala/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriterSuite.scala @@ -76,7 +76,7 @@ class BypassMergeSortShuffleWriterSuite ) val memoryManager = new TestMemoryManager(conf) val taskMemoryManager = new TaskMemoryManager(memoryManager, 0) - resetDependency(conf) + resetDependency() when(taskContext.taskMetrics()).thenReturn(taskMetrics) when(blockResolver.getDataFile(0, 0)).thenReturn(outputFile) when(blockManager.diskBlockManager).thenReturn(diskBlockManager) @@ -146,13 +146,12 @@ class BypassMergeSortShuffleWriterSuite } } - private def resetDependency(sc: SparkConf): Unit = { + private def resetDependency(rowbasedChecksumEnabled : Boolean = false): Unit = { reset(dependency) val numPartitions = 7 when(dependency.partitioner).thenReturn(new HashPartitioner(numPartitions)) - when(dependency.serializer).thenReturn(new JavaSerializer(sc)) - val checksumSize = - if (sc.get(config.SHUFFLE_ORDER_INDEPENDENT_CHECKSUM_ENABLED)) numPartitions else 0 + when(dependency.serializer).thenReturn(new JavaSerializer(conf)) + val checksumSize = if (rowbasedChecksumEnabled) numPartitions else 0 val checksumAlgorithm = conf.get(config.SHUFFLE_CHECKSUM_ALGORITHM) val rowBasedChecksums = RowBasedChecksum.createPartitionRowBasedChecksums(checksumSize, checksumAlgorithm) @@ -310,8 +309,6 @@ class BypassMergeSortShuffleWriterSuite } test("Row-based checksums are independent of input row order") { - val transferConf = - conf.clone.set(config.SHUFFLE_ORDER_INDEPENDENT_CHECKSUM_ENABLED.key, true.toString) val records: List[(Int, Int)] = List( (1, 1), (1, 2), (1, 3), (1, 4), (1, 5), (2, 2), (2, 3), (2, 4), (2, 5), (2, 6), @@ -324,12 +321,12 @@ class BypassMergeSortShuffleWriterSuite var checksumValues : Array[Long] = Array[Long]() var aggregatedChecksumValue = 0L for (i <- 1 to 100) { - resetDependency(transferConf); + resetDependency(true); val writer = new BypassMergeSortShuffleWriter[Int, Int]( blockManager, shuffleHandle, 0L, // MapId - transferConf, + conf, taskContext.taskMetrics().shuffleWriteMetrics, shuffleExecutorComponents) diff --git a/core/src/test/scala/org/apache/spark/shuffle/sort/SortShuffleWriterSuite.scala b/core/src/test/scala/org/apache/spark/shuffle/sort/SortShuffleWriterSuite.scala index cefb0d2aabde5..c64a437568d1b 100644 --- a/core/src/test/scala/org/apache/spark/shuffle/sort/SortShuffleWriterSuite.scala +++ b/core/src/test/scala/org/apache/spark/shuffle/sort/SortShuffleWriterSuite.scala @@ -79,14 +79,13 @@ class SortShuffleWriterSuite } } - private def resetDependency(): Unit = { + private def resetDependency(rowbasedChecksumEnabled : Boolean = false): Unit = { reset(dependency); when(dependency.partitioner).thenReturn(partitioner) when(dependency.serializer).thenReturn(serializer) when(dependency.aggregator).thenReturn(None) when(dependency.keyOrdering).thenReturn(None) - val checksumSize = - if (conf.get(config.SHUFFLE_ORDER_INDEPENDENT_CHECKSUM_ENABLED)) numMaps else 0 + val checksumSize = if (rowbasedChecksumEnabled) numMaps else 0 val checksumAlgorithm = conf.get(config.SHUFFLE_CHECKSUM_ALGORITHM) val rowBasedChecksums = RowBasedChecksum.createPartitionRowBasedChecksums( checksumSize, checksumAlgorithm) @@ -129,7 +128,6 @@ class SortShuffleWriterSuite } test("Row-based checksums are independent of input row order") { - conf.set("spark.shuffle.orderIndependentChecksum.enabled", true.toString) // FIXME: this can affect other tests (if any) after this set of tests // since `sc` is global. sc.stop() @@ -146,7 +144,7 @@ class SortShuffleWriterSuite var checksumValues : Array[Long] = Array[Long]() var aggregatedChecksumValue = 0L for (i <- 1 to 100) { - resetDependency() + resetDependency(true) val writer = new SortShuffleWriter[Int, Int, Int]( shuffleHandle, mapId = 2, @@ -199,11 +197,9 @@ class SortShuffleWriterSuite when(dependency.serializer).thenReturn(serializer) when(dependency.aggregator).thenReturn(aggregator) when(dependency.keyOrdering).thenReturn(order) - val checksumSize = - if (conf.get(config.SHUFFLE_ORDER_INDEPENDENT_CHECKSUM_ENABLED)) numMaps else 0 val checksumAlgorithm = conf.get(config.SHUFFLE_CHECKSUM_ALGORITHM) val rowBasedChecksums: Array[RowBasedChecksum] = - RowBasedChecksum.createPartitionRowBasedChecksums(checksumSize, checksumAlgorithm) + RowBasedChecksum.createPartitionRowBasedChecksums(0, checksumAlgorithm) when(dependency.rowBasedChecksums).thenReturn(rowBasedChecksums) new BaseShuffleHandle[Int, Int, Int](shuffleId, dependency) } diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRowChecksum.scala b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRowChecksum.scala new file mode 100644 index 0000000000000..10c808ebe403f --- /dev/null +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRowChecksum.scala @@ -0,0 +1,57 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions + +import org.apache.spark.shuffle.checksum.RowBasedChecksum + +/** + * A concrete implementation of RowBasedChecksum for computing checksum for UnsafeRow. + * The checksum for each row is computed by first casting or converting the baseObject + * in the UnsafeRow to a byte array, and then computing the checksum for the byte array. + * + * Note that the input key is ignored in the checksum computation. As the Spark shuffle + * currently uses a PartitionIdPassthrough partitioner, the keys are already the partition + * IDs for sending the data, and they are the same for all rows in the same partition. + */ +class UnsafeRowChecksum extends RowBasedChecksum() { + + override protected def calculateRowChecksum(key: Any, value: Any): Long = { + assert( + value.isInstanceOf[UnsafeRow], + "Expecting UnsafeRow but got " + value.getClass.getName) + + // Casts or converts the baseObject in UnsafeRow to a byte array. + val unsafeRow = value.asInstanceOf[UnsafeRow] + XXH64.hashUnsafeBytes( + unsafeRow.getBaseObject, + unsafeRow.getBaseOffset, + unsafeRow.getSizeInBytes, + 0 + ) + } +} + +object UnsafeRowChecksum { + def createUnsafeRowChecksums(numPartitions: Int): Array[RowBasedChecksum] = { + val rowBasedChecksums: Array[RowBasedChecksum] = new Array[RowBasedChecksum](numPartitions) + for (i <- 0 until numPartitions) { + rowBasedChecksums(i) = new UnsafeRowChecksum() + } + rowBasedChecksums + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 93acb39944fae..89a4756dca572 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -5724,6 +5724,21 @@ object SQLConf { .booleanConf .createWithDefault(true) + val SHUFFLE_ORDER_INDEPENDENT_CHECKSUM_ENABLED = + buildConf("spark.shuffle.orderIndependentChecksum.enabled") + .doc("Whether to calculate order independent checksum for the shuffle data or not. If " + + "enabled, Spark will calculate a checksum that is independent of the input row order for " + + "each mapper and returns the checksums from executors to driver. Different from the above" + + "checksum, the order independent remains the same even if the shuffle row order changes. " + + "While the above checksum is sensitive to shuffle data ordering to detect file " + + "corruption. This checksum is used to detect whether different task attempts of the same " + + "partition produce different output data or not (same set of keyValue pairs). In case " + + "the output data has changed across retries, Spark will need to retry all tasks of the " + + "consumer stages to avoid correctness issues.") + .version("3.4.0") + .booleanConf + .createWithDefault(true) + /** * Holds information about keys that have been deprecated. * diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchangeExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchangeExec.scala index 57d57640a1d8c..0bffeef96589f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchangeExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchangeExec.scala @@ -28,10 +28,9 @@ import org.apache.spark.internal.config import org.apache.spark.rdd.{RDD, RDDOperationScope} import org.apache.spark.serializer.Serializer import org.apache.spark.shuffle.{ShuffleWriteMetricsReporter, ShuffleWriteProcessor} -import org.apache.spark.shuffle.checksum.RowBasedChecksum import org.apache.spark.shuffle.sort.SortShuffleManager import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.{Attribute, BoundReference, UnsafeProjection, UnsafeRow} +import org.apache.spark.sql.catalyst.expressions.{Attribute, BoundReference, UnsafeProjection, UnsafeRow, UnsafeRowChecksum} import org.apache.spark.sql.catalyst.expressions.BindReferences.bindReferences import org.apache.spark.sql.catalyst.expressions.codegen.LazilyGeneratedOrdering import org.apache.spark.sql.catalyst.plans.logical.Statistics @@ -473,20 +472,18 @@ object ShuffleExchangeExec { // are in the form of (partitionId, row) and every partitionId is in the expected range // [0, part.numPartitions - 1]. The partitioner of this is a PartitionIdPassthrough. val checksumSize = - if (SparkEnv.get.conf.get(config.SHUFFLE_ORDER_INDEPENDENT_CHECKSUM_ENABLED)) { + if (SparkEnv.get.conf.get(SQLConf.SHUFFLE_ORDER_INDEPENDENT_CHECKSUM_ENABLED)) { part.numPartitions } else { 0 } - val checksumAlgorithm = SparkEnv.get.conf.get(config.SHUFFLE_CHECKSUM_ALGORITHM) val dependency = new ShuffleDependency[Int, InternalRow, InternalRow]( rddWithPartitionIds, new PartitionIdPassthrough(part.numPartitions), serializer, shuffleWriterProcessor = createShuffleWriteProcessor(writeMetrics), - rowBasedChecksums = RowBasedChecksum.createPartitionRowBasedChecksums( - checksumSize, checksumAlgorithm)) + rowBasedChecksums = UnsafeRowChecksum.createUnsafeRowChecksums(checksumSize)) dependency } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/MapStatusEndToEndSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/MapStatusEndToEndSuite.scala index c9e47b3791b6e..33aac82a1df40 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/MapStatusEndToEndSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/MapStatusEndToEndSuite.scala @@ -18,7 +18,6 @@ package org.apache.spark.sql import org.apache.spark.{MapOutputTrackerMaster, SparkFunSuite} -import org.apache.spark.internal.config.SHUFFLE_ORDER_INDEPENDENT_CHECKSUM_ENABLED import org.apache.spark.sql.classic.SparkSession import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SQLTestUtils @@ -26,7 +25,7 @@ import org.apache.spark.sql.test.SQLTestUtils class MapStatusEndToEndSuite extends SparkFunSuite with SQLTestUtils { override def spark: SparkSession = SparkSession.builder() .master("local") - .config(SHUFFLE_ORDER_INDEPENDENT_CHECKSUM_ENABLED.key, value = true) + .config(SQLConf.SHUFFLE_ORDER_INDEPENDENT_CHECKSUM_ENABLED.key, value = true) .config(SQLConf.LEAF_NODE_DEFAULT_PARALLELISM.key, value = 5) .getOrCreate() diff --git a/sql/core/src/test/scala/org/apache/spark/sql/UnsafeRowChecksumSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/UnsafeRowChecksumSuite.scala new file mode 100644 index 0000000000000..8b40e7e424d2d --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/UnsafeRowChecksumSuite.scala @@ -0,0 +1,208 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import java.nio.ByteBuffer + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder +import org.apache.spark.sql.catalyst.expressions.{UnsafeRow, UnsafeRowChecksum} +import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.Platform + +class UnsafeRowChecksumSuite extends SparkFunSuite { + private val schema = new StructType().add("value", IntegerType) + private val toUnsafeRow = ExpressionEncoder(schema).createSerializer() + + private val schemaComplex = new StructType() + .add("stringCol", StringType) + .add("doubleCol", DoubleType) + .add("longCol", LongType) + .add("int32Col", IntegerType) + .add("int16Col", ShortType) + .add("int8Col", ByteType) + .add("boolCol", BooleanType) + private val toUnsafeRowComplex = ExpressionEncoder(schemaComplex).createSerializer() + + private def setUnsafeRowValue( + stringCol: String, + doubleCol: Double, + longCol: Long, + int32Col: Int, + int16Col: Short, + int8Col: Byte, + boolCol: Boolean, + unsafeRowOffheap: UnsafeRow): Unit = { + unsafeRowOffheap.writeFieldTo(0, ByteBuffer.wrap(stringCol.getBytes)) + unsafeRowOffheap.setDouble(1, doubleCol) + unsafeRowOffheap.setLong(2, longCol) + unsafeRowOffheap.setInt(3, int32Col) + unsafeRowOffheap.setShort(4, int16Col) + unsafeRowOffheap.setByte(5, int8Col) + unsafeRowOffheap.setBoolean(6, boolCol) + } + + test("Non-UnsafeRow value should fail") { + val rowBasedChecksum = new UnsafeRowChecksum() + rowBasedChecksum.update(1, Long.box(20)) + // We fail to compute the checksum, and getValue returns 0. + assert(rowBasedChecksum.getValue == 0) + } + + test("Two identical rows should have a checksum of zero with XOR") { + val rowBasedChecksum = new UnsafeRowChecksum() + assert(rowBasedChecksum.getValue == 0) + + // Updates the checksum with one row. + rowBasedChecksum.update(1, toUnsafeRow(Row(20))) + assert(rowBasedChecksum.getValue == 8551541565481898028L) + + // Updates the checksum with the same row again, and the row-based checksum should become 0. + rowBasedChecksum.update(1, toUnsafeRow(Row(20))) + assert(rowBasedChecksum.getValue == 0) + } + + test("The checksum is independent of row order - two rows") { + val rowBasedChecksum1 = new UnsafeRowChecksum() + val rowBasedChecksum2 = new UnsafeRowChecksum() + assert(rowBasedChecksum1.getValue == 0) + assert(rowBasedChecksum2.getValue == 0) + + rowBasedChecksum1.update(1, toUnsafeRow(Row(20))) + rowBasedChecksum2.update(1, toUnsafeRow(Row(40))) + assert(rowBasedChecksum1.getValue != rowBasedChecksum2.getValue) + + rowBasedChecksum1.update(2, toUnsafeRow(Row(40))) + rowBasedChecksum2.update(2, toUnsafeRow(Row(20))) + assert(rowBasedChecksum1.getValue == rowBasedChecksum2.getValue) + + assert(rowBasedChecksum1.getValue != 0) + assert(rowBasedChecksum2.getValue != 0) + } + + test("The checksum is independent of row order - multiple rows") { + val rowBasedChecksum1 = new UnsafeRowChecksum() + val rowBasedChecksum2 = new UnsafeRowChecksum() + assert(rowBasedChecksum1.getValue == 0) + assert(rowBasedChecksum2.getValue == 0) + + rowBasedChecksum1.update(1, toUnsafeRow(Row(20))) + rowBasedChecksum2.update(1, toUnsafeRow(Row(100))) + assert(rowBasedChecksum1.getValue != rowBasedChecksum2.getValue) + + rowBasedChecksum1.update(2, toUnsafeRow(Row(40))) + rowBasedChecksum2.update(2, toUnsafeRow(Row(80))) + assert(rowBasedChecksum1.getValue != rowBasedChecksum2.getValue) + + rowBasedChecksum1.update(3, toUnsafeRow(Row(60))) + rowBasedChecksum2.update(3, toUnsafeRow(Row(60))) + assert(rowBasedChecksum1.getValue != rowBasedChecksum2.getValue) + + rowBasedChecksum1.update(4, toUnsafeRow(Row(80))) + rowBasedChecksum2.update(4, toUnsafeRow(Row(40))) + assert(rowBasedChecksum1.getValue != rowBasedChecksum2.getValue) + + rowBasedChecksum1.update(5, toUnsafeRow(Row(100))) + rowBasedChecksum2.update(5, toUnsafeRow(Row(20))) + assert(rowBasedChecksum1.getValue == rowBasedChecksum2.getValue) + + assert(rowBasedChecksum1.getValue != 0) + assert(rowBasedChecksum2.getValue != 0) + } + + test("The checksum is the same for byte array and non-byte array object") { + val buffer = Platform.allocateMemory(16) + val unsafeRowOffheap = new UnsafeRow(1) + unsafeRowOffheap.pointTo(null, buffer, 16) + assert(!unsafeRowOffheap.getBaseObject.isInstanceOf[Array[Byte]]) + + val rowBasedChecksum1 = new UnsafeRowChecksum() + val rowBasedChecksum2 = new UnsafeRowChecksum() + assert(rowBasedChecksum1.getValue == 0) + assert(rowBasedChecksum2.getValue == 0) + + rowBasedChecksum1.update(1, toUnsafeRow(Row(20))) + unsafeRowOffheap.setInt(0, 20) + rowBasedChecksum2.update(1, unsafeRowOffheap) + assert(rowBasedChecksum1.getValue == rowBasedChecksum2.getValue) + + rowBasedChecksum1.update(2, toUnsafeRow(Row(40))) + unsafeRowOffheap.setInt(0, 40) + rowBasedChecksum2.update(2, unsafeRowOffheap) + assert(rowBasedChecksum1.getValue == rowBasedChecksum2.getValue) + + assert(rowBasedChecksum1.getValue != 0) + assert(rowBasedChecksum2.getValue != 0) + + Platform.freeMemory(buffer) + } + + test("The checksum is independent of row order - complex rows") { + val rowBasedChecksum1 = new UnsafeRowChecksum() + val rowBasedChecksum2 = new UnsafeRowChecksum() + assert(rowBasedChecksum1.getValue == 0) + assert(rowBasedChecksum2.getValue == 0) + + rowBasedChecksum1.update(1, toUnsafeRowComplex(Row( + "Some string", 0.99, 10000L, 1000, 100.toShort, 10.toByte, true))) + rowBasedChecksum2.update(1, toUnsafeRowComplex(Row( + "Some other string", 10.88, 20000L, 2000, 200.toShort, 20.toByte, false))) + assert(rowBasedChecksum1.getValue != rowBasedChecksum2.getValue) + + rowBasedChecksum1.update(2, toUnsafeRowComplex(Row( + "Some other string", 10.88, 20000L, 2000, 200.toShort, 20.toByte, false))) + rowBasedChecksum2.update(2, toUnsafeRowComplex(Row( + "Some string", 0.99, 10000L, 1000, 100.toShort, 10.toByte, true))) + assert(rowBasedChecksum1.getValue == rowBasedChecksum2.getValue) + + assert(rowBasedChecksum1.getValue != 0) + assert(rowBasedChecksum2.getValue != 0) + } + + test("The checksum is independent of row order - complex offheap rows") { + val buffer1 = Platform.allocateMemory(128) + val unsafeRowOffheap1 = new UnsafeRow(7) + unsafeRowOffheap1.pointTo(null, buffer1, 128) + setUnsafeRowValue( + "Some string", 0.99, 10000L, 1000, 100.toShort, 10.toByte, true, unsafeRowOffheap1) + assert(!unsafeRowOffheap1.getBaseObject.isInstanceOf[Array[Byte]]) + + val buffer2 = Platform.allocateMemory(128) + val unsafeRowOffheap2 = new UnsafeRow(7) + unsafeRowOffheap2.pointTo(null, buffer2, 128) + setUnsafeRowValue( + "Some other string", 10.88, 20000L, 2000, 200.toShort, 20.toByte, false, unsafeRowOffheap2) + assert(!unsafeRowOffheap2.getBaseObject.isInstanceOf[Array[Byte]]) + + val rowBasedChecksum1 = new UnsafeRowChecksum() + val rowBasedChecksum2 = new UnsafeRowChecksum() + assert(rowBasedChecksum1.getValue == 0) + assert(rowBasedChecksum2.getValue == 0) + + rowBasedChecksum1.update(1, unsafeRowOffheap1) + rowBasedChecksum2.update(1, unsafeRowOffheap2) + assert(rowBasedChecksum1.getValue != rowBasedChecksum2.getValue) + + rowBasedChecksum1.update(2, unsafeRowOffheap2) + rowBasedChecksum2.update(2, unsafeRowOffheap1) + assert(rowBasedChecksum1.getValue == rowBasedChecksum2.getValue) + + assert(rowBasedChecksum1.getValue != 0) + assert(rowBasedChecksum2.getValue != 0) + } +} From c7675b1d3e1d7affd4301d5c6283fb3d23f03722 Mon Sep 17 00:00:00 2001 From: Jiexing Li Date: Thu, 13 Mar 2025 19:26:57 -0700 Subject: [PATCH 03/29] default on --- .../apache/spark/sql/internal/SQLConf.scala | 2 +- .../spark/sql/UnsafeRowChecksumSuite.scala | 59 ------------------- 2 files changed, 1 insertion(+), 60 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 89a4756dca572..2a7e6b7527f5f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -5735,7 +5735,7 @@ object SQLConf { "partition produce different output data or not (same set of keyValue pairs). In case " + "the output data has changed across retries, Spark will need to retry all tasks of the " + "consumer stages to avoid correctness issues.") - .version("3.4.0") + .version("4.1.0") .booleanConf .createWithDefault(true) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/UnsafeRowChecksumSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/UnsafeRowChecksumSuite.scala index 8b40e7e424d2d..bb523f9e3381a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/UnsafeRowChecksumSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/UnsafeRowChecksumSuite.scala @@ -125,33 +125,6 @@ class UnsafeRowChecksumSuite extends SparkFunSuite { assert(rowBasedChecksum2.getValue != 0) } - test("The checksum is the same for byte array and non-byte array object") { - val buffer = Platform.allocateMemory(16) - val unsafeRowOffheap = new UnsafeRow(1) - unsafeRowOffheap.pointTo(null, buffer, 16) - assert(!unsafeRowOffheap.getBaseObject.isInstanceOf[Array[Byte]]) - - val rowBasedChecksum1 = new UnsafeRowChecksum() - val rowBasedChecksum2 = new UnsafeRowChecksum() - assert(rowBasedChecksum1.getValue == 0) - assert(rowBasedChecksum2.getValue == 0) - - rowBasedChecksum1.update(1, toUnsafeRow(Row(20))) - unsafeRowOffheap.setInt(0, 20) - rowBasedChecksum2.update(1, unsafeRowOffheap) - assert(rowBasedChecksum1.getValue == rowBasedChecksum2.getValue) - - rowBasedChecksum1.update(2, toUnsafeRow(Row(40))) - unsafeRowOffheap.setInt(0, 40) - rowBasedChecksum2.update(2, unsafeRowOffheap) - assert(rowBasedChecksum1.getValue == rowBasedChecksum2.getValue) - - assert(rowBasedChecksum1.getValue != 0) - assert(rowBasedChecksum2.getValue != 0) - - Platform.freeMemory(buffer) - } - test("The checksum is independent of row order - complex rows") { val rowBasedChecksum1 = new UnsafeRowChecksum() val rowBasedChecksum2 = new UnsafeRowChecksum() @@ -173,36 +146,4 @@ class UnsafeRowChecksumSuite extends SparkFunSuite { assert(rowBasedChecksum1.getValue != 0) assert(rowBasedChecksum2.getValue != 0) } - - test("The checksum is independent of row order - complex offheap rows") { - val buffer1 = Platform.allocateMemory(128) - val unsafeRowOffheap1 = new UnsafeRow(7) - unsafeRowOffheap1.pointTo(null, buffer1, 128) - setUnsafeRowValue( - "Some string", 0.99, 10000L, 1000, 100.toShort, 10.toByte, true, unsafeRowOffheap1) - assert(!unsafeRowOffheap1.getBaseObject.isInstanceOf[Array[Byte]]) - - val buffer2 = Platform.allocateMemory(128) - val unsafeRowOffheap2 = new UnsafeRow(7) - unsafeRowOffheap2.pointTo(null, buffer2, 128) - setUnsafeRowValue( - "Some other string", 10.88, 20000L, 2000, 200.toShort, 20.toByte, false, unsafeRowOffheap2) - assert(!unsafeRowOffheap2.getBaseObject.isInstanceOf[Array[Byte]]) - - val rowBasedChecksum1 = new UnsafeRowChecksum() - val rowBasedChecksum2 = new UnsafeRowChecksum() - assert(rowBasedChecksum1.getValue == 0) - assert(rowBasedChecksum2.getValue == 0) - - rowBasedChecksum1.update(1, unsafeRowOffheap1) - rowBasedChecksum2.update(1, unsafeRowOffheap2) - assert(rowBasedChecksum1.getValue != rowBasedChecksum2.getValue) - - rowBasedChecksum1.update(2, unsafeRowOffheap2) - rowBasedChecksum2.update(2, unsafeRowOffheap1) - assert(rowBasedChecksum1.getValue == rowBasedChecksum2.getValue) - - assert(rowBasedChecksum1.getValue != 0) - assert(rowBasedChecksum2.getValue != 0) - } } From 7b89c448a5aa031a2edc1c7f188585bb6e087008 Mon Sep 17 00:00:00 2001 From: Jiexing Li Date: Fri, 14 Mar 2025 09:15:11 -0700 Subject: [PATCH 04/29] fix compile error --- .../test/scala/org/apache/spark/sql/UnsafeRowChecksumSuite.scala | 1 - 1 file changed, 1 deletion(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/UnsafeRowChecksumSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/UnsafeRowChecksumSuite.scala index bb523f9e3381a..ce892ba76e8de 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/UnsafeRowChecksumSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/UnsafeRowChecksumSuite.scala @@ -23,7 +23,6 @@ import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder import org.apache.spark.sql.catalyst.expressions.{UnsafeRow, UnsafeRowChecksum} import org.apache.spark.sql.types._ -import org.apache.spark.unsafe.Platform class UnsafeRowChecksumSuite extends SparkFunSuite { private val schema = new StructType().add("value", IntegerType) From 64dd36bbe749dd7fd8cc8bcee28bd322a672aada Mon Sep 17 00:00:00 2001 From: Jiexing Li Date: Fri, 14 Mar 2025 13:43:23 -0700 Subject: [PATCH 05/29] add contructor --- .../scala/org/apache/spark/Dependency.scala | 21 +++++++++++++++++++ 1 file changed, 21 insertions(+) diff --git a/core/src/main/scala/org/apache/spark/Dependency.scala b/core/src/main/scala/org/apache/spark/Dependency.scala index de8ad991152cc..80bc4b441e1db 100644 --- a/core/src/main/scala/org/apache/spark/Dependency.scala +++ b/core/src/main/scala/org/apache/spark/Dependency.scala @@ -75,6 +75,7 @@ abstract class NarrowDependency[T](_rdd: RDD[T]) extends Dependency[T] { * @param aggregator map/reduce-side aggregator for RDD's shuffle * @param mapSideCombine whether to perform partial aggregation (also known as map-side combine) * @param shuffleWriterProcessor the processor to control the write behavior in ShuffleMapTask + * @param rowBasedChecksums the row-based checksums for each shuffle partition */ @DeveloperApi class ShuffleDependency[K: ClassTag, V: ClassTag, C: ClassTag]( @@ -88,6 +89,26 @@ class ShuffleDependency[K: ClassTag, V: ClassTag, C: ClassTag]( val rowBasedChecksums: Array[RowBasedChecksum] = Array.empty) extends Dependency[Product2[K, V]] with Logging { + def this( + rdd: RDD[_ <: Product2[K, V]], + partitioner: Partitioner, + serializer: Serializer, + keyOrdering: Option[Ordering[K]], + aggregator: Option[Aggregator[K, V, C]], + mapSideCombine: Boolean, + shuffleWriterProcessor: ShuffleWriteProcessor) = { + this( + rdd, + partitioner, + serializer, + keyOrdering, + aggregator, + mapSideCombine, + shuffleWriterProcessor, + Array.empty + ) + } + if (mapSideCombine) { require(aggregator.isDefined, "Map-side combine without Aggregator specified!") } From 422e3704bb1864379b2749d192e39dcde804e3ba Mon Sep 17 00:00:00 2001 From: Jiexing Li Date: Mon, 21 Apr 2025 10:14:02 -0700 Subject: [PATCH 06/29] address comments --- .../shuffle/checksum/RowBasedChecksum.scala | 22 +++++++------------ .../sort/BypassMergeSortShuffleWriter.java | 3 +-- .../scala/org/apache/spark/Dependency.scala | 5 ++++- .../shuffle/ShuffleChecksumTestHelper.scala | 5 +++-- .../apache/spark/sql/internal/SQLConf.scala | 2 +- 5 files changed, 17 insertions(+), 20 deletions(-) diff --git a/core/src/main/java/org/apache/spark/shuffle/checksum/RowBasedChecksum.scala b/core/src/main/java/org/apache/spark/shuffle/checksum/RowBasedChecksum.scala index c1f50a93be03c..b4ed7c57114f1 100644 --- a/core/src/main/java/org/apache/spark/shuffle/checksum/RowBasedChecksum.scala +++ b/core/src/main/java/org/apache/spark/shuffle/checksum/RowBasedChecksum.scala @@ -48,7 +48,7 @@ abstract class RowBasedChecksum() extends Serializable with Logging { checksumValue = checksumValue ^ rowChecksumValue } catch { case NonFatal(e) => - logInfo("Checksum computation encountered error: ", e) + logError("Checksum computation encountered error: ", e) hasError = true } } @@ -62,6 +62,9 @@ abstract class RowBasedChecksum() extends Serializable with Logging { * A Concrete implementation of RowBasedChecksum. The checksum for each row is * computed by first converting the (key, value) pair to byte array using OutputStreams, * and then computing the checksum for the byte array. + * Note that this checksum computation is very expensive, and it is used only in tests + * in the core component. A much cheaper implementation of RowBasedChecksum is in + * UnsafeRowChecksum. * * @param checksumAlgorithm the algorithm used for computing checksum. */ @@ -105,21 +108,12 @@ object RowBasedChecksum { def createPartitionRowBasedChecksums( numPartitions: Int, checksumAlgorithm: String): Array[RowBasedChecksum] = { - val rowBasedChecksums: Array[RowBasedChecksum] = new Array[RowBasedChecksum](numPartitions) - for (i <- 0 until numPartitions) { - rowBasedChecksums(i) = new OutputStreamRowBasedChecksum(checksumAlgorithm) - } - rowBasedChecksums + Array.tabulate(numPartitions)(_ => new OutputStreamRowBasedChecksum(checksumAlgorithm)) } def getAggregatedChecksumValue(rowBasedChecksums: Array[RowBasedChecksum]): Long = { - val numPartitions: Int = if (rowBasedChecksums != null) rowBasedChecksums.length else 0 - var aggregatedChecksum: Long = 0 - if (numPartitions > 0) { - for (i <- 0 until numPartitions) { - aggregatedChecksum = aggregatedChecksum * 31 + rowBasedChecksums(i).getValue - } - } - return aggregatedChecksum + Option(rowBasedChecksums) + .map(_.foldLeft(0L)((acc, c) => acc * 31L + c.getValue)) + .getOrElse(0L) } } diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java b/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java index 170df7d10d453..8011ec77589da 100644 --- a/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java +++ b/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java @@ -219,8 +219,7 @@ public RowBasedChecksum[] getRowBasedChecksums() { } public long getAggregatedChecksumValue() { - final long checksum = RowBasedChecksum.getAggregatedChecksumValue(rowBasedChecksums); - return checksum; + return RowBasedChecksum.getAggregatedChecksumValue(rowBasedChecksums); } /** diff --git a/core/src/main/scala/org/apache/spark/Dependency.scala b/core/src/main/scala/org/apache/spark/Dependency.scala index 80bc4b441e1db..1c9834fcf4a3d 100644 --- a/core/src/main/scala/org/apache/spark/Dependency.scala +++ b/core/src/main/scala/org/apache/spark/Dependency.scala @@ -60,6 +60,9 @@ abstract class NarrowDependency[T](_rdd: RDD[T]) extends Dependency[T] { override def rdd: RDD[T] = _rdd } +object ShuffleDependency { + private val EmptyRowBasedChecksums: Array[RowBasedChecksum] = Array.empty +} /** * :: DeveloperApi :: @@ -86,7 +89,7 @@ class ShuffleDependency[K: ClassTag, V: ClassTag, C: ClassTag]( val aggregator: Option[Aggregator[K, V, C]] = None, val mapSideCombine: Boolean = false, val shuffleWriterProcessor: ShuffleWriteProcessor = new ShuffleWriteProcessor, - val rowBasedChecksums: Array[RowBasedChecksum] = Array.empty) + val rowBasedChecksums: Array[RowBasedChecksum] = ShuffleDependency.EmptyRowBasedChecksums) extends Dependency[Product2[K, V]] with Logging { def this( diff --git a/core/src/test/scala/org/apache/spark/shuffle/ShuffleChecksumTestHelper.scala b/core/src/test/scala/org/apache/spark/shuffle/ShuffleChecksumTestHelper.scala index c5b6c9faaa1a2..af878fb2de67b 100644 --- a/core/src/test/scala/org/apache/spark/shuffle/ShuffleChecksumTestHelper.scala +++ b/core/src/test/scala/org/apache/spark/shuffle/ShuffleChecksumTestHelper.scala @@ -40,8 +40,9 @@ trait ShuffleChecksumTestHelper { "checksum must be consistent at both write and read sides") } - def getRowBasedChecksumValues(rowBasedChecksums: Array[RowBasedChecksum]): Array[Long] = { - if (rowBasedChecksums == null) { + def getRowBasedChecksumValues( + rowBasedChecksums: Array[RowBasedChecksum] = Array.empty): Array[Long] = { + if (rowBasedChecksums.nonEmpty) { Array.empty } else { rowBasedChecksums.map(_.getValue) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 2a7e6b7527f5f..fab912d356c2a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -5737,7 +5737,7 @@ object SQLConf { "consumer stages to avoid correctness issues.") .version("4.1.0") .booleanConf - .createWithDefault(true) + .createWithDefault(false) /** * Holds information about keys that have been deprecated. From 89901ca9f5a5c5c929f134b78591a0e7788d70de Mon Sep 17 00:00:00 2001 From: Jiexing Li Date: Mon, 21 Apr 2025 12:19:35 -0700 Subject: [PATCH 07/29] move config --- .../apache/spark/internal/config/package.scala | 15 +++++++++++++++ .../shuffle/sort/UnsafeShuffleWriterSuite.java | 12 +++++++----- .../spark/shuffle/ShuffleChecksumTestHelper.scala | 5 ++--- .../sort/BypassMergeSortShuffleWriterSuite.scala | 15 +++++++++------ .../shuffle/sort/SortShuffleWriterSuite.scala | 12 ++++++++---- .../org/apache/spark/sql/internal/SQLConf.scala | 15 --------------- .../execution/exchange/ShuffleExchangeExec.scala | 2 +- .../apache/spark/sql/MapStatusEndToEndSuite.scala | 3 ++- 8 files changed, 44 insertions(+), 35 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/internal/config/package.scala b/core/src/main/scala/org/apache/spark/internal/config/package.scala index 3ce374d0477d8..99a086be8ac62 100644 --- a/core/src/main/scala/org/apache/spark/internal/config/package.scala +++ b/core/src/main/scala/org/apache/spark/internal/config/package.scala @@ -1622,6 +1622,21 @@ package object config { s"The buffer size must be greater than 0 and less than or equal to ${Int.MaxValue}.") .createWithDefault(4096) + private[spark] val SHUFFLE_ORDER_INDEPENDENT_CHECKSUM_ENABLED = + ConfigBuilder("spark.shuffle.orderIndependentChecksum.enabled") + .doc("Whether to calculate order independent checksum for the shuffle data or not. If " + + "enabled, Spark will calculate a checksum that is independent of the input row order for " + + "each mapper and returns the checksums from executors to driver. Different from the above" + + "checksum, the order independent remains the same even if the shuffle row order changes. " + + "While the above checksum is sensitive to shuffle data ordering to detect file " + + "corruption. This checksum is used to detect whether different task attempts of the same " + + "partition produce different output data or not (same set of keyValue pairs). In case " + + "the output data has changed across retries, Spark will need to retry all tasks of the " + + "consumer stages to avoid correctness issues.") + .version("4.1.0") + .booleanConf + .createWithDefault(false) + private[spark] val SHUFFLE_CHECKSUM_ENABLED = ConfigBuilder("spark.shuffle.checksum.enabled") .doc("Whether to calculate the checksum of shuffle data. If enabled, Spark will calculate " + diff --git a/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java b/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java index 930e5721ce4bb..adcbba78a0a82 100644 --- a/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java +++ b/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java @@ -175,18 +175,19 @@ public void setUp() throws Exception { File file = (File) invocationOnMock.getArguments()[0]; return Utils.tempFileWith(file); }); - resetDependency(false); + resetDependency(); when(taskContext.taskMetrics()).thenReturn(taskMetrics); when(taskContext.taskMemoryManager()).thenReturn(taskMemoryManager); } - private void resetDependency(boolean rowbasedChecksumEnabled) { + private void resetDependency() { when(shuffleDep.serializer()).thenReturn(serializer); when(shuffleDep.partitioner()).thenReturn(hashPartitioner); - final int checksumSize = rowbasedChecksumEnabled ? NUM_PARTITIONS : 0; + final int checksumSize = + (boolean) conf.get(package$.MODULE$.SHUFFLE_ORDER_INDEPENDENT_CHECKSUM_ENABLED()) ? NUM_PARTITIONS : 0; final String checksumAlgorithm = conf.get(package$.MODULE$.SHUFFLE_CHECKSUM_ALGORITHM()); final RowBasedChecksum[] rowBasedChecksums = - RowBasedChecksum.createPartitionRowBasedChecksums(checksumSize, checksumAlgorithm); + RowBasedChecksum.createPartitionRowBasedChecksums(checksumSize, checksumAlgorithm); when(shuffleDep.rowBasedChecksums()).thenReturn(rowBasedChecksums); } @@ -624,6 +625,7 @@ public void spillFilesAreDeletedWhenStoppingAfterError() throws IOException, Spa @Test public void testRowBasedChecksum() throws IOException, SparkException { + conf.set(package$.MODULE$.SHUFFLE_ORDER_INDEPENDENT_CHECKSUM_ENABLED(), true); final ArrayList> dataToWrite = new ArrayList<>(); for (int i = 0; i < NUM_PARTITIONS; i++) { for (int j = 0; j < 5; j++) { @@ -634,7 +636,7 @@ public void testRowBasedChecksum() throws IOException, SparkException { long[] checksumValues = new long[0]; long aggregatedChecksumValue = 0; for (int i = 0; i < 100; i++) { - resetDependency(true); + resetDependency(); final UnsafeShuffleWriter writer = createWriter(false); Collections.shuffle(dataToWrite); writer.write(dataToWrite.iterator()); diff --git a/core/src/test/scala/org/apache/spark/shuffle/ShuffleChecksumTestHelper.scala b/core/src/test/scala/org/apache/spark/shuffle/ShuffleChecksumTestHelper.scala index af878fb2de67b..7ac62e910e1c5 100644 --- a/core/src/test/scala/org/apache/spark/shuffle/ShuffleChecksumTestHelper.scala +++ b/core/src/test/scala/org/apache/spark/shuffle/ShuffleChecksumTestHelper.scala @@ -40,9 +40,8 @@ trait ShuffleChecksumTestHelper { "checksum must be consistent at both write and read sides") } - def getRowBasedChecksumValues( - rowBasedChecksums: Array[RowBasedChecksum] = Array.empty): Array[Long] = { - if (rowBasedChecksums.nonEmpty) { + def getRowBasedChecksumValues(rowBasedChecksums: Array[RowBasedChecksum]): Array[Long] = { + if (rowBasedChecksums.isEmpty) { Array.empty } else { rowBasedChecksums.map(_.getValue) diff --git a/core/src/test/scala/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriterSuite.scala b/core/src/test/scala/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriterSuite.scala index 408e60e5e90d8..e9f23e0da7eb4 100644 --- a/core/src/test/scala/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriterSuite.scala +++ b/core/src/test/scala/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriterSuite.scala @@ -76,7 +76,7 @@ class BypassMergeSortShuffleWriterSuite ) val memoryManager = new TestMemoryManager(conf) val taskMemoryManager = new TaskMemoryManager(memoryManager, 0) - resetDependency() + resetDependency(conf) when(taskContext.taskMetrics()).thenReturn(taskMetrics) when(blockResolver.getDataFile(0, 0)).thenReturn(outputFile) when(blockManager.diskBlockManager).thenReturn(diskBlockManager) @@ -146,12 +146,13 @@ class BypassMergeSortShuffleWriterSuite } } - private def resetDependency(rowbasedChecksumEnabled : Boolean = false): Unit = { + private def resetDependency(sc: SparkConf): Unit = { reset(dependency) val numPartitions = 7 when(dependency.partitioner).thenReturn(new HashPartitioner(numPartitions)) - when(dependency.serializer).thenReturn(new JavaSerializer(conf)) - val checksumSize = if (rowbasedChecksumEnabled) numPartitions else 0 + when(dependency.serializer).thenReturn(new JavaSerializer(sc)) + val checksumSize = + if (sc.get(config.SHUFFLE_ORDER_INDEPENDENT_CHECKSUM_ENABLED)) numPartitions else 0 val checksumAlgorithm = conf.get(config.SHUFFLE_CHECKSUM_ALGORITHM) val rowBasedChecksums = RowBasedChecksum.createPartitionRowBasedChecksums(checksumSize, checksumAlgorithm) @@ -309,6 +310,8 @@ class BypassMergeSortShuffleWriterSuite } test("Row-based checksums are independent of input row order") { + val transferConf = + conf.clone.set(config.SHUFFLE_ORDER_INDEPENDENT_CHECKSUM_ENABLED.key, true.toString) val records: List[(Int, Int)] = List( (1, 1), (1, 2), (1, 3), (1, 4), (1, 5), (2, 2), (2, 3), (2, 4), (2, 5), (2, 6), @@ -321,12 +324,12 @@ class BypassMergeSortShuffleWriterSuite var checksumValues : Array[Long] = Array[Long]() var aggregatedChecksumValue = 0L for (i <- 1 to 100) { - resetDependency(true); + resetDependency(transferConf); val writer = new BypassMergeSortShuffleWriter[Int, Int]( blockManager, shuffleHandle, 0L, // MapId - conf, + transferConf, taskContext.taskMetrics().shuffleWriteMetrics, shuffleExecutorComponents) diff --git a/core/src/test/scala/org/apache/spark/shuffle/sort/SortShuffleWriterSuite.scala b/core/src/test/scala/org/apache/spark/shuffle/sort/SortShuffleWriterSuite.scala index c64a437568d1b..cefb0d2aabde5 100644 --- a/core/src/test/scala/org/apache/spark/shuffle/sort/SortShuffleWriterSuite.scala +++ b/core/src/test/scala/org/apache/spark/shuffle/sort/SortShuffleWriterSuite.scala @@ -79,13 +79,14 @@ class SortShuffleWriterSuite } } - private def resetDependency(rowbasedChecksumEnabled : Boolean = false): Unit = { + private def resetDependency(): Unit = { reset(dependency); when(dependency.partitioner).thenReturn(partitioner) when(dependency.serializer).thenReturn(serializer) when(dependency.aggregator).thenReturn(None) when(dependency.keyOrdering).thenReturn(None) - val checksumSize = if (rowbasedChecksumEnabled) numMaps else 0 + val checksumSize = + if (conf.get(config.SHUFFLE_ORDER_INDEPENDENT_CHECKSUM_ENABLED)) numMaps else 0 val checksumAlgorithm = conf.get(config.SHUFFLE_CHECKSUM_ALGORITHM) val rowBasedChecksums = RowBasedChecksum.createPartitionRowBasedChecksums( checksumSize, checksumAlgorithm) @@ -128,6 +129,7 @@ class SortShuffleWriterSuite } test("Row-based checksums are independent of input row order") { + conf.set("spark.shuffle.orderIndependentChecksum.enabled", true.toString) // FIXME: this can affect other tests (if any) after this set of tests // since `sc` is global. sc.stop() @@ -144,7 +146,7 @@ class SortShuffleWriterSuite var checksumValues : Array[Long] = Array[Long]() var aggregatedChecksumValue = 0L for (i <- 1 to 100) { - resetDependency(true) + resetDependency() val writer = new SortShuffleWriter[Int, Int, Int]( shuffleHandle, mapId = 2, @@ -197,9 +199,11 @@ class SortShuffleWriterSuite when(dependency.serializer).thenReturn(serializer) when(dependency.aggregator).thenReturn(aggregator) when(dependency.keyOrdering).thenReturn(order) + val checksumSize = + if (conf.get(config.SHUFFLE_ORDER_INDEPENDENT_CHECKSUM_ENABLED)) numMaps else 0 val checksumAlgorithm = conf.get(config.SHUFFLE_CHECKSUM_ALGORITHM) val rowBasedChecksums: Array[RowBasedChecksum] = - RowBasedChecksum.createPartitionRowBasedChecksums(0, checksumAlgorithm) + RowBasedChecksum.createPartitionRowBasedChecksums(checksumSize, checksumAlgorithm) when(dependency.rowBasedChecksums).thenReturn(rowBasedChecksums) new BaseShuffleHandle[Int, Int, Int](shuffleId, dependency) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index fab912d356c2a..93acb39944fae 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -5724,21 +5724,6 @@ object SQLConf { .booleanConf .createWithDefault(true) - val SHUFFLE_ORDER_INDEPENDENT_CHECKSUM_ENABLED = - buildConf("spark.shuffle.orderIndependentChecksum.enabled") - .doc("Whether to calculate order independent checksum for the shuffle data or not. If " + - "enabled, Spark will calculate a checksum that is independent of the input row order for " + - "each mapper and returns the checksums from executors to driver. Different from the above" + - "checksum, the order independent remains the same even if the shuffle row order changes. " + - "While the above checksum is sensitive to shuffle data ordering to detect file " + - "corruption. This checksum is used to detect whether different task attempts of the same " + - "partition produce different output data or not (same set of keyValue pairs). In case " + - "the output data has changed across retries, Spark will need to retry all tasks of the " + - "consumer stages to avoid correctness issues.") - .version("4.1.0") - .booleanConf - .createWithDefault(false) - /** * Holds information about keys that have been deprecated. * diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchangeExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchangeExec.scala index 0bffeef96589f..13f1f117fddc8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchangeExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchangeExec.scala @@ -472,7 +472,7 @@ object ShuffleExchangeExec { // are in the form of (partitionId, row) and every partitionId is in the expected range // [0, part.numPartitions - 1]. The partitioner of this is a PartitionIdPassthrough. val checksumSize = - if (SparkEnv.get.conf.get(SQLConf.SHUFFLE_ORDER_INDEPENDENT_CHECKSUM_ENABLED)) { + if (SparkEnv.get.conf.get(config.SHUFFLE_ORDER_INDEPENDENT_CHECKSUM_ENABLED)) { part.numPartitions } else { 0 diff --git a/sql/core/src/test/scala/org/apache/spark/sql/MapStatusEndToEndSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/MapStatusEndToEndSuite.scala index 33aac82a1df40..c9e47b3791b6e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/MapStatusEndToEndSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/MapStatusEndToEndSuite.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql import org.apache.spark.{MapOutputTrackerMaster, SparkFunSuite} +import org.apache.spark.internal.config.SHUFFLE_ORDER_INDEPENDENT_CHECKSUM_ENABLED import org.apache.spark.sql.classic.SparkSession import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SQLTestUtils @@ -25,7 +26,7 @@ import org.apache.spark.sql.test.SQLTestUtils class MapStatusEndToEndSuite extends SparkFunSuite with SQLTestUtils { override def spark: SparkSession = SparkSession.builder() .master("local") - .config(SQLConf.SHUFFLE_ORDER_INDEPENDENT_CHECKSUM_ENABLED.key, value = true) + .config(SHUFFLE_ORDER_INDEPENDENT_CHECKSUM_ENABLED.key, value = true) .config(SQLConf.LEAF_NODE_DEFAULT_PARALLELISM.key, value = 5) .getOrCreate() From a1c50fad478a5bea268651f7bc0dbbb57fd087ce Mon Sep 17 00:00:00 2001 From: Jiexing Li Date: Wed, 23 Apr 2025 10:08:37 -0700 Subject: [PATCH 08/29] address comments --- .../spark/shuffle/checksum/RowBasedChecksum.scala | 9 ++------- .../apache/spark/shuffle/sort/UnsafeShuffleWriter.java | 10 +++------- .../org/apache/spark/util/MyByteArrayOutputStream.java | 9 +++++++++ 3 files changed, 14 insertions(+), 14 deletions(-) create mode 100644 core/src/main/java/org/apache/spark/util/MyByteArrayOutputStream.java diff --git a/core/src/main/java/org/apache/spark/shuffle/checksum/RowBasedChecksum.scala b/core/src/main/java/org/apache/spark/shuffle/checksum/RowBasedChecksum.scala index b4ed7c57114f1..67cf08aa3cbed 100644 --- a/core/src/main/java/org/apache/spark/shuffle/checksum/RowBasedChecksum.scala +++ b/core/src/main/java/org/apache/spark/shuffle/checksum/RowBasedChecksum.scala @@ -17,13 +17,14 @@ package org.apache.spark.shuffle.checksum -import java.io.{ByteArrayOutputStream, ObjectOutputStream} +import java.io.ObjectOutputStream import java.util.zip.Checksum import scala.util.control.NonFatal import org.apache.spark.internal.Logging import org.apache.spark.network.shuffle.checksum.ShuffleChecksumHelper +import org.apache.spark.util.MyByteArrayOutputStream /** * A class for computing checksum for input (key, value) pairs. The checksum is independent of @@ -71,12 +72,6 @@ abstract class RowBasedChecksum() extends Serializable with Logging { class OutputStreamRowBasedChecksum(checksumAlgorithm: String) extends RowBasedChecksum() { - /** Subclass of ByteArrayOutputStream that exposes `buf` directly. */ - final private class MyByteArrayOutputStream(size: Int) - extends ByteArrayOutputStream(size) { - def getBuf: Array[Byte] = buf - } - private val DEFAULT_INITIAL_SER_BUFFER_SIZE = 32 * 1024 @transient private lazy val serBuffer = diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java b/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java index 93cd6004cc317..e6d01d4110c44 100644 --- a/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java +++ b/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java @@ -64,6 +64,7 @@ import org.apache.spark.storage.BlockManager; import org.apache.spark.storage.TimeTrackingOutputStream; import org.apache.spark.unsafe.Platform; +import org.apache.spark.util.MyByteArrayOutputStream; import org.apache.spark.util.Utils; @Private @@ -95,12 +96,6 @@ public class UnsafeShuffleWriter extends ShuffleWriter { @Nullable private long[] partitionLengths; private long peakMemoryUsedBytes = 0; - /** Subclass of ByteArrayOutputStream that exposes `buf` directly. */ - private static final class MyByteArrayOutputStream extends ByteArrayOutputStream { - MyByteArrayOutputStream(int size) { super(size); } - public byte[] getBuf() { return buf; } - } - private MyByteArrayOutputStream serBuffer; private SerializationStream serOutputStream; @@ -349,7 +344,8 @@ private long[] mergeSpillsUsingStandardWriter(SpillInfo[] spills) throws IOExcep logger.debug("Using slow merge"); mergeSpillsWithFileStream(spills, mapWriter, compressionCodec); } - partitionLengths = mapWriter.commitAllPartitions(sorter.getChecksums()).getPartitionLengths(); + partitionLengths = + mapWriter.commitAllPartitions(sorter.getChecksums()).getPartitionLengths(); } catch (Exception e) { try { mapWriter.abort(e); diff --git a/core/src/main/java/org/apache/spark/util/MyByteArrayOutputStream.java b/core/src/main/java/org/apache/spark/util/MyByteArrayOutputStream.java new file mode 100644 index 0000000000000..71c99b2015912 --- /dev/null +++ b/core/src/main/java/org/apache/spark/util/MyByteArrayOutputStream.java @@ -0,0 +1,9 @@ +package org.apache.spark.util; + +import java.io.ByteArrayOutputStream; + +/** Subclass of ByteArrayOutputStream that exposes `buf` directly. */ +public final class MyByteArrayOutputStream extends ByteArrayOutputStream { + public MyByteArrayOutputStream(int size) { super(size); } + public byte[] getBuf() { return buf; } +} From db59634edaf8e1dc587077edd5a2bc7955b3f357 Mon Sep 17 00:00:00 2001 From: Jiexing Li Date: Wed, 23 Apr 2025 21:42:23 -0700 Subject: [PATCH 09/29] add license headers --- .../spark/util/MyByteArrayOutputStream.java | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/core/src/main/java/org/apache/spark/util/MyByteArrayOutputStream.java b/core/src/main/java/org/apache/spark/util/MyByteArrayOutputStream.java index 71c99b2015912..6d17a6c699c2e 100644 --- a/core/src/main/java/org/apache/spark/util/MyByteArrayOutputStream.java +++ b/core/src/main/java/org/apache/spark/util/MyByteArrayOutputStream.java @@ -1,3 +1,20 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + package org.apache.spark.util; import java.io.ByteArrayOutputStream; From 04e08ebdecc680abbde40f2872be0b9c1f4861a4 Mon Sep 17 00:00:00 2001 From: Tengfei Huang Date: Tue, 26 Aug 2025 14:04:20 +0000 Subject: [PATCH 10/29] address comments --- .../shuffle/checksum/RowBasedChecksum.scala | 10 ++-------- .../sort/BypassMergeSortShuffleWriter.java | 9 +++++---- .../spark/shuffle/sort/UnsafeShuffleWriter.java | 16 +++++++++------- ...a => ExposedBufferByteArrayOutputStream.java} | 4 ++-- .../main/scala/org/apache/spark/Dependency.scala | 4 ++-- .../shuffle/sort/UnsafeShuffleWriterSuite.java | 2 +- .../shuffle/ShuffleChecksumTestHelper.scala | 8 +++++++- .../sort/BypassMergeSortShuffleWriterSuite.scala | 4 +--- .../shuffle/sort/SortShuffleWriterSuite.scala | 5 ++--- .../catalyst/expressions/UnsafeRowChecksum.scala | 6 +----- 10 files changed, 32 insertions(+), 36 deletions(-) rename core/src/main/java/org/apache/spark/util/{MyByteArrayOutputStream.java => ExposedBufferByteArrayOutputStream.java} (86%) diff --git a/core/src/main/java/org/apache/spark/shuffle/checksum/RowBasedChecksum.scala b/core/src/main/java/org/apache/spark/shuffle/checksum/RowBasedChecksum.scala index 67cf08aa3cbed..b88118ecc0a37 100644 --- a/core/src/main/java/org/apache/spark/shuffle/checksum/RowBasedChecksum.scala +++ b/core/src/main/java/org/apache/spark/shuffle/checksum/RowBasedChecksum.scala @@ -24,7 +24,7 @@ import scala.util.control.NonFatal import org.apache.spark.internal.Logging import org.apache.spark.network.shuffle.checksum.ShuffleChecksumHelper -import org.apache.spark.util.MyByteArrayOutputStream +import org.apache.spark.util.ExposedBufferByteArrayOutputStream /** * A class for computing checksum for input (key, value) pairs. The checksum is independent of @@ -75,7 +75,7 @@ class OutputStreamRowBasedChecksum(checksumAlgorithm: String) private val DEFAULT_INITIAL_SER_BUFFER_SIZE = 32 * 1024 @transient private lazy val serBuffer = - new MyByteArrayOutputStream(DEFAULT_INITIAL_SER_BUFFER_SIZE) + new ExposedBufferByteArrayOutputStream(DEFAULT_INITIAL_SER_BUFFER_SIZE) @transient private lazy val objOut = new ObjectOutputStream(serBuffer) @transient @@ -100,12 +100,6 @@ class OutputStreamRowBasedChecksum(checksumAlgorithm: String) } object RowBasedChecksum { - def createPartitionRowBasedChecksums( - numPartitions: Int, - checksumAlgorithm: String): Array[RowBasedChecksum] = { - Array.tabulate(numPartitions)(_ => new OutputStreamRowBasedChecksum(checksumAlgorithm)) - } - def getAggregatedChecksumValue(rowBasedChecksums: Array[RowBasedChecksum]): Long = { Option(rowBasedChecksums) .map(_.foldLeft(0L)((acc, c) => acc * 31L + c.getValue)) diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java b/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java index 8011ec77589da..73cabbe36cd30 100644 --- a/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java +++ b/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java @@ -32,6 +32,7 @@ import scala.Tuple2; import scala.collection.Iterator; +import com.google.common.annotations.VisibleForTesting; import com.google.common.io.Closeables; import org.apache.spark.internal.SparkLogger; @@ -112,7 +113,6 @@ final class BypassMergeSortShuffleWriter * output data or not. */ private final RowBasedChecksum[] rowBasedChecksums; - private final SparkConf conf; /** * Are we in the process of stopping? Because map tasks can call stop() with success = true @@ -142,7 +142,6 @@ final class BypassMergeSortShuffleWriter this.shuffleExecutorComponents = shuffleExecutorComponents; this.partitionChecksums = createPartitionChecksums(numPartitions, conf); this.rowBasedChecksums = dep.rowBasedChecksums(); - this.conf = conf; } @Override @@ -214,11 +213,13 @@ public long[] getPartitionLengths() { return partitionLengths; } - public RowBasedChecksum[] getRowBasedChecksums() { + @VisibleForTesting + RowBasedChecksum[] getRowBasedChecksums() { return rowBasedChecksums; } - public long getAggregatedChecksumValue() { + @VisibleForTesting + long getAggregatedChecksumValue() { return RowBasedChecksum.getAggregatedChecksumValue(rowBasedChecksums); } diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java b/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java index e6d01d4110c44..c891f33a3a6cf 100644 --- a/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java +++ b/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java @@ -64,7 +64,7 @@ import org.apache.spark.storage.BlockManager; import org.apache.spark.storage.TimeTrackingOutputStream; import org.apache.spark.unsafe.Platform; -import org.apache.spark.util.MyByteArrayOutputStream; +import org.apache.spark.util.ExposedBufferByteArrayOutputStream; import org.apache.spark.util.Utils; @Private @@ -96,7 +96,7 @@ public class UnsafeShuffleWriter extends ShuffleWriter { @Nullable private long[] partitionLengths; private long peakMemoryUsedBytes = 0; - private MyByteArrayOutputStream serBuffer; + private ExposedBufferByteArrayOutputStream serBuffer; private SerializationStream serOutputStream; /** @@ -167,10 +167,13 @@ public long getPeakMemoryUsedBytes() { return peakMemoryUsedBytes; } - public RowBasedChecksum[] getRowBasedChecksums() { + @VisibleForTesting + RowBasedChecksum[] getRowBasedChecksums() { return rowBasedChecksums; } - public long getAggregatedChecksumValue() { + + @VisibleForTesting + long getAggregatedChecksumValue() { return RowBasedChecksum.getAggregatedChecksumValue(rowBasedChecksums); } @@ -222,7 +225,7 @@ private void open() throws SparkException { partitioner.numPartitions(), sparkConf, writeMetrics); - serBuffer = new MyByteArrayOutputStream(DEFAULT_INITIAL_SER_BUFFER_SIZE); + serBuffer = new ExposedBufferByteArrayOutputStream(DEFAULT_INITIAL_SER_BUFFER_SIZE); serOutputStream = serializer.serializeStream(serBuffer); } @@ -344,8 +347,7 @@ private long[] mergeSpillsUsingStandardWriter(SpillInfo[] spills) throws IOExcep logger.debug("Using slow merge"); mergeSpillsWithFileStream(spills, mapWriter, compressionCodec); } - partitionLengths = - mapWriter.commitAllPartitions(sorter.getChecksums()).getPartitionLengths(); + partitionLengths = mapWriter.commitAllPartitions(sorter.getChecksums()).getPartitionLengths(); } catch (Exception e) { try { mapWriter.abort(e); diff --git a/core/src/main/java/org/apache/spark/util/MyByteArrayOutputStream.java b/core/src/main/java/org/apache/spark/util/ExposedBufferByteArrayOutputStream.java similarity index 86% rename from core/src/main/java/org/apache/spark/util/MyByteArrayOutputStream.java rename to core/src/main/java/org/apache/spark/util/ExposedBufferByteArrayOutputStream.java index 6d17a6c699c2e..bd59bd176fb93 100644 --- a/core/src/main/java/org/apache/spark/util/MyByteArrayOutputStream.java +++ b/core/src/main/java/org/apache/spark/util/ExposedBufferByteArrayOutputStream.java @@ -20,7 +20,7 @@ import java.io.ByteArrayOutputStream; /** Subclass of ByteArrayOutputStream that exposes `buf` directly. */ -public final class MyByteArrayOutputStream extends ByteArrayOutputStream { - public MyByteArrayOutputStream(int size) { super(size); } +public final class ExposedBufferByteArrayOutputStream extends ByteArrayOutputStream { + public ExposedBufferByteArrayOutputStream(int size) { super(size); } public byte[] getBuf() { return buf; } } diff --git a/core/src/main/scala/org/apache/spark/Dependency.scala b/core/src/main/scala/org/apache/spark/Dependency.scala index 1c9834fcf4a3d..34c21161e3627 100644 --- a/core/src/main/scala/org/apache/spark/Dependency.scala +++ b/core/src/main/scala/org/apache/spark/Dependency.scala @@ -61,7 +61,7 @@ abstract class NarrowDependency[T](_rdd: RDD[T]) extends Dependency[T] { } object ShuffleDependency { - private val EmptyRowBasedChecksums: Array[RowBasedChecksum] = Array.empty + private val EMPTY_ROW_BASED_CHECKSUMS: Array[RowBasedChecksum] = Array.empty } /** @@ -89,7 +89,7 @@ class ShuffleDependency[K: ClassTag, V: ClassTag, C: ClassTag]( val aggregator: Option[Aggregator[K, V, C]] = None, val mapSideCombine: Boolean = false, val shuffleWriterProcessor: ShuffleWriteProcessor = new ShuffleWriteProcessor, - val rowBasedChecksums: Array[RowBasedChecksum] = ShuffleDependency.EmptyRowBasedChecksums) + val rowBasedChecksums: Array[RowBasedChecksum] = ShuffleDependency.EMPTY_ROW_BASED_CHECKSUMS) extends Dependency[Product2[K, V]] with Logging { def this( diff --git a/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java b/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java index adcbba78a0a82..ac07a5c2caf0d 100644 --- a/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java +++ b/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java @@ -187,7 +187,7 @@ private void resetDependency() { (boolean) conf.get(package$.MODULE$.SHUFFLE_ORDER_INDEPENDENT_CHECKSUM_ENABLED()) ? NUM_PARTITIONS : 0; final String checksumAlgorithm = conf.get(package$.MODULE$.SHUFFLE_CHECKSUM_ALGORITHM()); final RowBasedChecksum[] rowBasedChecksums = - RowBasedChecksum.createPartitionRowBasedChecksums(checksumSize, checksumAlgorithm); + createPartitionRowBasedChecksums(checksumSize, checksumAlgorithm); when(shuffleDep.rowBasedChecksums()).thenReturn(rowBasedChecksums); } diff --git a/core/src/test/scala/org/apache/spark/shuffle/ShuffleChecksumTestHelper.scala b/core/src/test/scala/org/apache/spark/shuffle/ShuffleChecksumTestHelper.scala index 7ac62e910e1c5..c8fe2d5c70d97 100644 --- a/core/src/test/scala/org/apache/spark/shuffle/ShuffleChecksumTestHelper.scala +++ b/core/src/test/scala/org/apache/spark/shuffle/ShuffleChecksumTestHelper.scala @@ -19,7 +19,7 @@ package org.apache.spark.shuffle import java.io.File -import org.apache.spark.shuffle.checksum.RowBasedChecksum +import org.apache.spark.shuffle.checksum.{OutputStreamRowBasedChecksum, RowBasedChecksum} trait ShuffleChecksumTestHelper { @@ -47,4 +47,10 @@ trait ShuffleChecksumTestHelper { rowBasedChecksums.map(_.getValue) } } + + def createPartitionRowBasedChecksums( + numPartitions: Int, + checksumAlgorithm: String): Array[RowBasedChecksum] = { + Array.tabulate(numPartitions)(_ => new OutputStreamRowBasedChecksum(checksumAlgorithm)) + } } diff --git a/core/src/test/scala/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriterSuite.scala b/core/src/test/scala/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriterSuite.scala index e9f23e0da7eb4..bc92cf09c224c 100644 --- a/core/src/test/scala/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriterSuite.scala +++ b/core/src/test/scala/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriterSuite.scala @@ -38,7 +38,6 @@ import org.apache.spark.network.shuffle.checksum.ShuffleChecksumHelper import org.apache.spark.serializer.{JavaSerializer, SerializerInstance, SerializerManager} import org.apache.spark.shuffle.{IndexShuffleBlockResolver, ShuffleChecksumTestHelper} import org.apache.spark.shuffle.api.ShuffleExecutorComponents -import org.apache.spark.shuffle.checksum.RowBasedChecksum import org.apache.spark.shuffle.sort.io.LocalDiskShuffleExecutorComponents import org.apache.spark.storage._ import org.apache.spark.util.Utils @@ -154,8 +153,7 @@ class BypassMergeSortShuffleWriterSuite val checksumSize = if (sc.get(config.SHUFFLE_ORDER_INDEPENDENT_CHECKSUM_ENABLED)) numPartitions else 0 val checksumAlgorithm = conf.get(config.SHUFFLE_CHECKSUM_ALGORITHM) - val rowBasedChecksums = - RowBasedChecksum.createPartitionRowBasedChecksums(checksumSize, checksumAlgorithm) + val rowBasedChecksums = createPartitionRowBasedChecksums(checksumSize, checksumAlgorithm) when(dependency.rowBasedChecksums).thenReturn(rowBasedChecksums) } diff --git a/core/src/test/scala/org/apache/spark/shuffle/sort/SortShuffleWriterSuite.scala b/core/src/test/scala/org/apache/spark/shuffle/sort/SortShuffleWriterSuite.scala index cefb0d2aabde5..635c775bb907b 100644 --- a/core/src/test/scala/org/apache/spark/shuffle/sort/SortShuffleWriterSuite.scala +++ b/core/src/test/scala/org/apache/spark/shuffle/sort/SortShuffleWriterSuite.scala @@ -88,8 +88,7 @@ class SortShuffleWriterSuite val checksumSize = if (conf.get(config.SHUFFLE_ORDER_INDEPENDENT_CHECKSUM_ENABLED)) numMaps else 0 val checksumAlgorithm = conf.get(config.SHUFFLE_CHECKSUM_ALGORITHM) - val rowBasedChecksums = RowBasedChecksum.createPartitionRowBasedChecksums( - checksumSize, checksumAlgorithm) + val rowBasedChecksums = createPartitionRowBasedChecksums(checksumSize, checksumAlgorithm) when(dependency.rowBasedChecksums).thenReturn(rowBasedChecksums) } @@ -203,7 +202,7 @@ class SortShuffleWriterSuite if (conf.get(config.SHUFFLE_ORDER_INDEPENDENT_CHECKSUM_ENABLED)) numMaps else 0 val checksumAlgorithm = conf.get(config.SHUFFLE_CHECKSUM_ALGORITHM) val rowBasedChecksums: Array[RowBasedChecksum] = - RowBasedChecksum.createPartitionRowBasedChecksums(checksumSize, checksumAlgorithm) + createPartitionRowBasedChecksums(checksumSize, checksumAlgorithm) when(dependency.rowBasedChecksums).thenReturn(rowBasedChecksums) new BaseShuffleHandle[Int, Int, Int](shuffleId, dependency) } diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRowChecksum.scala b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRowChecksum.scala index 10c808ebe403f..2be675070eb10 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRowChecksum.scala +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRowChecksum.scala @@ -48,10 +48,6 @@ class UnsafeRowChecksum extends RowBasedChecksum() { object UnsafeRowChecksum { def createUnsafeRowChecksums(numPartitions: Int): Array[RowBasedChecksum] = { - val rowBasedChecksums: Array[RowBasedChecksum] = new Array[RowBasedChecksum](numPartitions) - for (i <- 0 until numPartitions) { - rowBasedChecksums(i) = new UnsafeRowChecksum() - } - rowBasedChecksums + Array.tabulate(numPartitions)(_ => new UnsafeRowChecksum()) } } From c9c28e6042581c76f898c61657a04e9293aee94c Mon Sep 17 00:00:00 2001 From: Tengfei Huang Date: Tue, 26 Aug 2025 14:25:32 +0000 Subject: [PATCH 11/29] address comments --- .../apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java | 1 + .../java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java | 1 + 2 files changed, 2 insertions(+) diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java b/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java index 73cabbe36cd30..fadc1037ea4d6 100644 --- a/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java +++ b/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java @@ -213,6 +213,7 @@ public long[] getPartitionLengths() { return partitionLengths; } + // For test only. @VisibleForTesting RowBasedChecksum[] getRowBasedChecksums() { return rowBasedChecksums; diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java b/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java index c891f33a3a6cf..612338b2046c0 100644 --- a/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java +++ b/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java @@ -167,6 +167,7 @@ public long getPeakMemoryUsedBytes() { return peakMemoryUsedBytes; } + // For test only. @VisibleForTesting RowBasedChecksum[] getRowBasedChecksums() { return rowBasedChecksums; From 2575d52bc648d7b36148d0a7a85980d133eb6f42 Mon Sep 17 00:00:00 2001 From: Tengfei Huang Date: Mon, 1 Sep 2025 09:07:01 +0000 Subject: [PATCH 12/29] address comments --- .../shuffle/checksum/RowBasedChecksum.scala | 63 +++++------------- .../org/apache/spark/MapOutputTracker.scala | 3 + .../spark/internal/config/package.scala | 14 ++-- .../apache/spark/scheduler/MapStatus.scala | 2 +- .../OutputStreamRowBasedChecksum.scala | 64 +++++++++++++++++++ .../spark/sql/UnsafeRowChecksumSuite.scala | 9 +-- 6 files changed, 96 insertions(+), 59 deletions(-) create mode 100644 core/src/test/scala/org/apache/spark/shuffle/checksum/OutputStreamRowBasedChecksum.scala diff --git a/core/src/main/java/org/apache/spark/shuffle/checksum/RowBasedChecksum.scala b/core/src/main/java/org/apache/spark/shuffle/checksum/RowBasedChecksum.scala index b88118ecc0a37..825aed2efa2f0 100644 --- a/core/src/main/java/org/apache/spark/shuffle/checksum/RowBasedChecksum.scala +++ b/core/src/main/java/org/apache/spark/shuffle/checksum/RowBasedChecksum.scala @@ -17,14 +17,9 @@ package org.apache.spark.shuffle.checksum -import java.io.ObjectOutputStream -import java.util.zip.Checksum - import scala.util.control.NonFatal import org.apache.spark.internal.Logging -import org.apache.spark.network.shuffle.checksum.ShuffleChecksumHelper -import org.apache.spark.util.ExposedBufferByteArrayOutputStream /** * A class for computing checksum for input (key, value) pairs. The checksum is independent of @@ -32,13 +27,21 @@ import org.apache.spark.util.ExposedBufferByteArrayOutputStream * first, and then computing the XOR for all the row checksums. */ abstract class RowBasedChecksum() extends Serializable with Logging { + private val ROTATE_POSITIONS = 27 private var hasError: Boolean = false - private var checksumValue: Long = 0 - /** Returns the checksum value computed. Tt returns the default checksum value (0) if there + private var checksumXor: Long = 0 + private var checksumSum: Long = 0 + + /** Returns the checksum value. It returns the default checksum value (0) if there * are any errors encountered during the checksum computation. */ def getValue: Long = { - if (!hasError) checksumValue else 0 + if (!hasError) { + val res = checksumXor ^ rotateLeft(checksumSum) + res + } else { + 0 + } } /** Updates the row-based checksum with the given (key, value) pair */ @@ -46,7 +49,8 @@ abstract class RowBasedChecksum() extends Serializable with Logging { if (!hasError) { try { val rowChecksumValue = calculateRowChecksum(key, value) - checksumValue = checksumValue ^ rowChecksumValue + checksumXor = checksumXor ^ rowChecksumValue + checksumSum += rowChecksumValue } catch { case NonFatal(e) => logError("Checksum computation encountered error: ", e) @@ -57,45 +61,10 @@ abstract class RowBasedChecksum() extends Serializable with Logging { /** Computes and returns the checksum value for the given (key, value) pair */ protected def calculateRowChecksum(key: Any, value: Any): Long -} - -/** - * A Concrete implementation of RowBasedChecksum. The checksum for each row is - * computed by first converting the (key, value) pair to byte array using OutputStreams, - * and then computing the checksum for the byte array. - * Note that this checksum computation is very expensive, and it is used only in tests - * in the core component. A much cheaper implementation of RowBasedChecksum is in - * UnsafeRowChecksum. - * - * @param checksumAlgorithm the algorithm used for computing checksum. - */ -class OutputStreamRowBasedChecksum(checksumAlgorithm: String) - extends RowBasedChecksum() { - - private val DEFAULT_INITIAL_SER_BUFFER_SIZE = 32 * 1024 - - @transient private lazy val serBuffer = - new ExposedBufferByteArrayOutputStream(DEFAULT_INITIAL_SER_BUFFER_SIZE) - @transient private lazy val objOut = new ObjectOutputStream(serBuffer) - - @transient - protected lazy val checksum: Checksum = - ShuffleChecksumHelper.getChecksumByAlgorithm(checksumAlgorithm) - - override protected def calculateRowChecksum(key: Any, value: Any): Long = { - assert(checksum != null, "Checksum is null") - - // Converts the (key, value) pair into byte array. - objOut.reset() - serBuffer.reset() - objOut.writeObject((key, value)) - objOut.flush() - serBuffer.flush() - // Computes and returns the checksum for the byte array. - checksum.reset() - checksum.update(serBuffer.getBuf, 0, serBuffer.size()) - checksum.getValue + // Rotate the value by shifting the bits by `ROTATE_POSITIONS` positions to the left. + private def rotateLeft(value: Long): Long = { + (value << ROTATE_POSITIONS) | (value >>> (64 - ROTATE_POSITIONS)) } } diff --git a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala index 075771e4e2ee3..91fabc1590c1f 100644 --- a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala +++ b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala @@ -174,10 +174,13 @@ private class ShuffleStatus( } else { mapIdToMapIndex.remove(currentMapStatus.mapId) } + logDebug(s"Checksum of map output for task ${status.mapId} is ${status.checksumValue}") val preStatus = if (mapStatuses(mapIndex) != null) mapStatuses(mapIndex) else mapStatusesDeleted(mapIndex) if (preStatus != null && preStatus.checksumValue != status.checksumValue) { + logInfo(s"Checksum of map output changes from ${preStatus.checksumValue} to " + + s"${status.checksumValue} for task ${status.mapId}.") checksumMismatchIndices.add(mapIndex) } mapStatuses(mapIndex) = status diff --git a/core/src/main/scala/org/apache/spark/internal/config/package.scala b/core/src/main/scala/org/apache/spark/internal/config/package.scala index 488703186e43d..23a2d02a21b11 100644 --- a/core/src/main/scala/org/apache/spark/internal/config/package.scala +++ b/core/src/main/scala/org/apache/spark/internal/config/package.scala @@ -1654,13 +1654,13 @@ package object config { ConfigBuilder("spark.shuffle.orderIndependentChecksum.enabled") .doc("Whether to calculate order independent checksum for the shuffle data or not. If " + "enabled, Spark will calculate a checksum that is independent of the input row order for " + - "each mapper and returns the checksums from executors to driver. Different from the above" + - "checksum, the order independent remains the same even if the shuffle row order changes. " + - "While the above checksum is sensitive to shuffle data ordering to detect file " + - "corruption. This checksum is used to detect whether different task attempts of the same " + - "partition produce different output data or not (same set of keyValue pairs). In case " + - "the output data has changed across retries, Spark will need to retry all tasks of the " + - "consumer stages to avoid correctness issues.") + "each mapper and returns the checksums from executors to driver. This is different from " + + "the checksum computed when spark.shuffle.checksum.enabled is enabled which is sensitive " + + "to shuffle data ordering to detect file corruption. While this checksum will be the " + + "same even if the shuffle row order changes and it is used to detect whether different " + + "task attempts of the same partition produce different output data or not (same set of " + + "keyValue pairs). In case the output data has changed across retries, Spark will need to " + + "retry all tasks of the consumer stages to avoid correctness issues.") .version("4.1.0") .booleanConf .createWithDefault(false) diff --git a/core/src/main/scala/org/apache/spark/scheduler/MapStatus.scala b/core/src/main/scala/org/apache/spark/scheduler/MapStatus.scala index f7490d0182883..e348b6a5f1493 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/MapStatus.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/MapStatus.scala @@ -61,7 +61,7 @@ private[spark] sealed trait MapStatus extends ShuffleOutputStatus { /** * The checksum value of this shuffle map task, which can be used to evaluate whether the - * output data have changed across different map task retries. + * output data has changed across different map task retries. */ def checksumValue: Long = 0 } diff --git a/core/src/test/scala/org/apache/spark/shuffle/checksum/OutputStreamRowBasedChecksum.scala b/core/src/test/scala/org/apache/spark/shuffle/checksum/OutputStreamRowBasedChecksum.scala new file mode 100644 index 0000000000000..3abec5f4bd656 --- /dev/null +++ b/core/src/test/scala/org/apache/spark/shuffle/checksum/OutputStreamRowBasedChecksum.scala @@ -0,0 +1,64 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.shuffle.checksum + +import java.io.ObjectOutputStream +import java.util.zip.Checksum + +import org.apache.spark.network.shuffle.checksum.ShuffleChecksumHelper +import org.apache.spark.util.ExposedBufferByteArrayOutputStream + +/** + * A Concrete implementation of RowBasedChecksum. The checksum for each row is + * computed by first converting the (key, value) pair to byte array using OutputStreams, + * and then computing the checksum for the byte array. + * Note that this checksum computation is very expensive, and it is used only in tests + * in the core component. A much cheaper implementation of RowBasedChecksum is in + * UnsafeRowChecksum. + * + * @param checksumAlgorithm the algorithm used for computing checksum. + */ +class OutputStreamRowBasedChecksum(checksumAlgorithm: String) + extends RowBasedChecksum() { + + private val DEFAULT_INITIAL_SER_BUFFER_SIZE = 32 * 1024 + + @transient private lazy val serBuffer = + new ExposedBufferByteArrayOutputStream(DEFAULT_INITIAL_SER_BUFFER_SIZE) + @transient private lazy val objOut = new ObjectOutputStream(serBuffer) + + @transient + protected lazy val checksum: Checksum = + ShuffleChecksumHelper.getChecksumByAlgorithm(checksumAlgorithm) + + override protected def calculateRowChecksum(key: Any, value: Any): Long = { + assert(checksum != null, "Checksum is null") + + // Converts the (key, value) pair into byte array. + objOut.reset() + serBuffer.reset() + objOut.writeObject((key, value)) + objOut.flush() + serBuffer.flush() + + // Computes and returns the checksum for the byte array. + checksum.reset() + checksum.update(serBuffer.getBuf, 0, serBuffer.size()) + checksum.getValue + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/UnsafeRowChecksumSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/UnsafeRowChecksumSuite.scala index ce892ba76e8de..07941ad626332 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/UnsafeRowChecksumSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/UnsafeRowChecksumSuite.scala @@ -63,17 +63,18 @@ class UnsafeRowChecksumSuite extends SparkFunSuite { assert(rowBasedChecksum.getValue == 0) } - test("Two identical rows should have a checksum of zero with XOR") { + test("Two identical rows should not have a checksum of zero") { val rowBasedChecksum = new UnsafeRowChecksum() assert(rowBasedChecksum.getValue == 0) // Updates the checksum with one row. rowBasedChecksum.update(1, toUnsafeRow(Row(20))) - assert(rowBasedChecksum.getValue == 8551541565481898028L) + assert(rowBasedChecksum.getValue == -9094624449814316735L) - // Updates the checksum with the same row again, and the row-based checksum should become 0. + // Updates the checksum with the same row again, since we mix the final xor and sum + // of the row-based checksum, the result would not be 0. rowBasedChecksum.update(1, toUnsafeRow(Row(20))) - assert(rowBasedChecksum.getValue == 0) + assert(rowBasedChecksum.getValue == -1240577858172431653L) } test("The checksum is independent of row order - two rows") { From 3b99edb38cf5ecbe19f934dd2592bd7554a4fb30 Mon Sep 17 00:00:00 2001 From: Tengfei Huang Date: Tue, 2 Sep 2025 14:46:28 +0000 Subject: [PATCH 13/29] fix code stype issue --- .../apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java b/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java index ac07a5c2caf0d..ee06a4818fbdb 100644 --- a/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java +++ b/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java @@ -184,7 +184,8 @@ private void resetDependency() { when(shuffleDep.serializer()).thenReturn(serializer); when(shuffleDep.partitioner()).thenReturn(hashPartitioner); final int checksumSize = - (boolean) conf.get(package$.MODULE$.SHUFFLE_ORDER_INDEPENDENT_CHECKSUM_ENABLED()) ? NUM_PARTITIONS : 0; + (boolean) conf.get(package$.MODULE$.SHUFFLE_ORDER_INDEPENDENT_CHECKSUM_ENABLED()) + ? NUM_PARTITIONS : 0; final String checksumAlgorithm = conf.get(package$.MODULE$.SHUFFLE_CHECKSUM_ALGORITHM()); final RowBasedChecksum[] rowBasedChecksums = createPartitionRowBasedChecksums(checksumSize, checksumAlgorithm); From 74266a5aa22f583facf73de4356863ea55893963 Mon Sep 17 00:00:00 2001 From: Tengfei Huang Date: Wed, 3 Sep 2025 02:25:16 +0000 Subject: [PATCH 14/29] debug flaky ut --- .../org/apache/spark/ContextCleaner.scala | 3 ++ .../org/apache/spark/MapOutputTracker.scala | 8 ++++ .../apache/spark/scheduler/DAGScheduler.scala | 4 ++ .../storage/BlockManagerStorageEndpoint.scala | 3 ++ .../spark/sql/MapStatusEndToEndSuite.scala | 47 +++++++++++++++---- 5 files changed, 55 insertions(+), 10 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/ContextCleaner.scala b/core/src/main/scala/org/apache/spark/ContextCleaner.scala index 54ea8c94daac1..a90069f8ab1ab 100644 --- a/core/src/main/scala/org/apache/spark/ContextCleaner.scala +++ b/core/src/main/scala/org/apache/spark/ContextCleaner.scala @@ -235,6 +235,9 @@ private[spark] class ContextCleaner( def doCleanupShuffle(shuffleId: Int, blocking: Boolean): Unit = { try { if (mapOutputTrackerMaster.containsShuffle(shuffleId)) { + // scalastyle:off println + System.out.println(s"Cleaning shuffle ${shuffleId}") + // scalastyle:on println logDebug("Cleaning shuffle " + shuffleId) // Shuffle must be removed before it's unregistered from the output tracker // to find blocks served by the shuffle service on deallocated executors diff --git a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala index 91fabc1590c1f..9c23b9a81a31a 100644 --- a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala +++ b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala @@ -838,6 +838,10 @@ private[spark] class MapOutputTrackerMaster( throw new IllegalArgumentException("Shuffle ID " + shuffleId + " registered twice") } } + // scalastyle:off println + System.out.println(s"${this}: shuffle $shuffleId registered and shuffle statuses size" + + s" ${shuffleStatuses.size}, ${shuffleStatuses}") + // scalastyle:on println } def updateMapOutput(shuffleId: Int, mapId: Long, bmAddress: BlockManagerId): Unit = { @@ -943,6 +947,10 @@ private[spark] class MapOutputTrackerMaster( shuffleStatus.invalidateSerializedMapOutputStatusCache() shuffleStatus.invalidateSerializedMergeOutputStatusCache() } + // scalastyle:off println + System.out.println(s"unregister shuffle $shuffleId and shuffle statuses size" + + s" ${shuffleStatuses.size}, ${shuffleStatuses}") + // scalastyle:on println } /** diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala index 30eb49b0c0798..3bc620e76866e 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala @@ -2022,6 +2022,10 @@ private[spark] class DAGScheduler( // The epoch of the task is acceptable (i.e., the task was launched after the most // recent failure we're aware of for the executor), so mark the task's output as // available. + + // scalastyle:off println + System.out.println(s"registered map output for ${shuffleStage.shuffleDep.shuffleId}") + // scalastyle:on println mapOutputTracker.registerMapOutput( shuffleStage.shuffleDep.shuffleId, smt.partitionId, status) } diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerStorageEndpoint.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerStorageEndpoint.scala index 54329c5b1e514..64c4698a56ce5 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerStorageEndpoint.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerStorageEndpoint.scala @@ -58,6 +58,9 @@ class BlockManagerStorageEndpoint( case RemoveShuffle(shuffleId) => doAsync[Boolean](log"removing shuffle ${MDC(SHUFFLE_ID, shuffleId)}", context) { if (mapOutputTracker != null) { + // scalastyle:off println + System.out.println(s"remove shuffle $shuffleId for ${mapOutputTracker}") + // scalastyle:on println mapOutputTracker.unregisterShuffle(shuffleId) } val shuffleManager = SparkEnv.get.shuffleManager diff --git a/sql/core/src/test/scala/org/apache/spark/sql/MapStatusEndToEndSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/MapStatusEndToEndSuite.scala index c9e47b3791b6e..a6a73cac9abc7 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/MapStatusEndToEndSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/MapStatusEndToEndSuite.scala @@ -17,11 +17,15 @@ package org.apache.spark.sql +import scala.collection.mutable.ArrayBuffer + import org.apache.spark.{MapOutputTrackerMaster, SparkFunSuite} import org.apache.spark.internal.config.SHUFFLE_ORDER_INDEPENDENT_CHECKSUM_ENABLED import org.apache.spark.sql.classic.SparkSession +import org.apache.spark.sql.execution.QueryExecution import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SQLTestUtils +import org.apache.spark.sql.util.QueryExecutionListener class MapStatusEndToEndSuite extends SparkFunSuite with SQLTestUtils { override def spark: SparkSession = SparkSession.builder() @@ -30,7 +34,7 @@ class MapStatusEndToEndSuite extends SparkFunSuite with SQLTestUtils { .config(SQLConf.LEAF_NODE_DEFAULT_PARALLELISM.key, value = 5) .getOrCreate() - override def afterEach(): Unit = { + override def afterAll(): Unit = { // This suite should not interfere with the other test suites. SparkSession.getActiveSession.foreach(_.stop()) SparkSession.clearActiveSession() @@ -44,17 +48,40 @@ class MapStatusEndToEndSuite extends SparkFunSuite with SQLTestUtils { assert(spark.sparkContext.conf.get("spark.sql.leafNodeDefaultParallelism") == "5") assert(spark.conf.get("spark.sql.leafNodeDefaultParallelism") == "5") - withTable("t") { - spark.range(1000).repartition(10).write.mode("overwrite"). - saveAsTable("t") + val queryExecutions = new ArrayBuffer[QueryExecution]() + val listener = new QueryExecutionListener() { + override def onSuccess(funcName: String, qe: QueryExecution, durationNs: Long): Unit = { + queryExecutions.append(qe) + } + + override def onFailure(funcName: String, qe: QueryExecution, exception: Exception): Unit = {} } + // Register the listener to keep a reference of shuffle dependency in `QueryExecution`, + // avoid shuffle statuses removed due to clean up. + spark.listenerManager.register(listener) + + try { + // scalastyle:off println + System.out.println("start testing") - val shuffleStatuses = spark.sparkContext.env.mapOutputTracker. - asInstanceOf[MapOutputTrackerMaster].shuffleStatuses - assert(shuffleStatuses.size == 1) + withTable("t") { + spark.range(1000).repartition(10).write.mode("overwrite"). + saveAsTable("t") + } - val mapStatuses = shuffleStatuses(0).mapStatuses - assert(mapStatuses.length == 5) - assert(mapStatuses.forall(_.checksumValue != 0)) + System.out.println(s"MapOutputTracker is ${spark.sparkContext.env.mapOutputTracker}") + val shuffleStatuses = spark.sparkContext.env.mapOutputTracker. + asInstanceOf[MapOutputTrackerMaster].shuffleStatuses + assert(shuffleStatuses.size == 1) + + queryExecutions.foreach(System.out.println) + // scalastyle:on println + + val mapStatuses = shuffleStatuses(0).mapStatuses + assert(mapStatuses.length == 5) + assert(mapStatuses.forall(_.checksumValue != 0)) + } finally { + spark.listenerManager.unregister(listener) + } } } From 22c79c87c3e678caa9f723b2f0c438d452a974d1 Mon Sep 17 00:00:00 2001 From: Tengfei Huang Date: Wed, 3 Sep 2025 10:09:46 +0000 Subject: [PATCH 15/29] Revert "debug flaky ut" This reverts commit 74266a5aa22f583facf73de4356863ea55893963. --- .../org/apache/spark/ContextCleaner.scala | 3 -- .../org/apache/spark/MapOutputTracker.scala | 8 ---- .../apache/spark/scheduler/DAGScheduler.scala | 4 -- .../storage/BlockManagerStorageEndpoint.scala | 3 -- .../spark/sql/MapStatusEndToEndSuite.scala | 47 ++++--------------- 5 files changed, 10 insertions(+), 55 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/ContextCleaner.scala b/core/src/main/scala/org/apache/spark/ContextCleaner.scala index a90069f8ab1ab..54ea8c94daac1 100644 --- a/core/src/main/scala/org/apache/spark/ContextCleaner.scala +++ b/core/src/main/scala/org/apache/spark/ContextCleaner.scala @@ -235,9 +235,6 @@ private[spark] class ContextCleaner( def doCleanupShuffle(shuffleId: Int, blocking: Boolean): Unit = { try { if (mapOutputTrackerMaster.containsShuffle(shuffleId)) { - // scalastyle:off println - System.out.println(s"Cleaning shuffle ${shuffleId}") - // scalastyle:on println logDebug("Cleaning shuffle " + shuffleId) // Shuffle must be removed before it's unregistered from the output tracker // to find blocks served by the shuffle service on deallocated executors diff --git a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala index 9c23b9a81a31a..91fabc1590c1f 100644 --- a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala +++ b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala @@ -838,10 +838,6 @@ private[spark] class MapOutputTrackerMaster( throw new IllegalArgumentException("Shuffle ID " + shuffleId + " registered twice") } } - // scalastyle:off println - System.out.println(s"${this}: shuffle $shuffleId registered and shuffle statuses size" + - s" ${shuffleStatuses.size}, ${shuffleStatuses}") - // scalastyle:on println } def updateMapOutput(shuffleId: Int, mapId: Long, bmAddress: BlockManagerId): Unit = { @@ -947,10 +943,6 @@ private[spark] class MapOutputTrackerMaster( shuffleStatus.invalidateSerializedMapOutputStatusCache() shuffleStatus.invalidateSerializedMergeOutputStatusCache() } - // scalastyle:off println - System.out.println(s"unregister shuffle $shuffleId and shuffle statuses size" + - s" ${shuffleStatuses.size}, ${shuffleStatuses}") - // scalastyle:on println } /** diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala index 3bc620e76866e..30eb49b0c0798 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala @@ -2022,10 +2022,6 @@ private[spark] class DAGScheduler( // The epoch of the task is acceptable (i.e., the task was launched after the most // recent failure we're aware of for the executor), so mark the task's output as // available. - - // scalastyle:off println - System.out.println(s"registered map output for ${shuffleStage.shuffleDep.shuffleId}") - // scalastyle:on println mapOutputTracker.registerMapOutput( shuffleStage.shuffleDep.shuffleId, smt.partitionId, status) } diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerStorageEndpoint.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerStorageEndpoint.scala index 64c4698a56ce5..54329c5b1e514 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerStorageEndpoint.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerStorageEndpoint.scala @@ -58,9 +58,6 @@ class BlockManagerStorageEndpoint( case RemoveShuffle(shuffleId) => doAsync[Boolean](log"removing shuffle ${MDC(SHUFFLE_ID, shuffleId)}", context) { if (mapOutputTracker != null) { - // scalastyle:off println - System.out.println(s"remove shuffle $shuffleId for ${mapOutputTracker}") - // scalastyle:on println mapOutputTracker.unregisterShuffle(shuffleId) } val shuffleManager = SparkEnv.get.shuffleManager diff --git a/sql/core/src/test/scala/org/apache/spark/sql/MapStatusEndToEndSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/MapStatusEndToEndSuite.scala index a6a73cac9abc7..c9e47b3791b6e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/MapStatusEndToEndSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/MapStatusEndToEndSuite.scala @@ -17,15 +17,11 @@ package org.apache.spark.sql -import scala.collection.mutable.ArrayBuffer - import org.apache.spark.{MapOutputTrackerMaster, SparkFunSuite} import org.apache.spark.internal.config.SHUFFLE_ORDER_INDEPENDENT_CHECKSUM_ENABLED import org.apache.spark.sql.classic.SparkSession -import org.apache.spark.sql.execution.QueryExecution import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SQLTestUtils -import org.apache.spark.sql.util.QueryExecutionListener class MapStatusEndToEndSuite extends SparkFunSuite with SQLTestUtils { override def spark: SparkSession = SparkSession.builder() @@ -34,7 +30,7 @@ class MapStatusEndToEndSuite extends SparkFunSuite with SQLTestUtils { .config(SQLConf.LEAF_NODE_DEFAULT_PARALLELISM.key, value = 5) .getOrCreate() - override def afterAll(): Unit = { + override def afterEach(): Unit = { // This suite should not interfere with the other test suites. SparkSession.getActiveSession.foreach(_.stop()) SparkSession.clearActiveSession() @@ -48,40 +44,17 @@ class MapStatusEndToEndSuite extends SparkFunSuite with SQLTestUtils { assert(spark.sparkContext.conf.get("spark.sql.leafNodeDefaultParallelism") == "5") assert(spark.conf.get("spark.sql.leafNodeDefaultParallelism") == "5") - val queryExecutions = new ArrayBuffer[QueryExecution]() - val listener = new QueryExecutionListener() { - override def onSuccess(funcName: String, qe: QueryExecution, durationNs: Long): Unit = { - queryExecutions.append(qe) - } - - override def onFailure(funcName: String, qe: QueryExecution, exception: Exception): Unit = {} + withTable("t") { + spark.range(1000).repartition(10).write.mode("overwrite"). + saveAsTable("t") } - // Register the listener to keep a reference of shuffle dependency in `QueryExecution`, - // avoid shuffle statuses removed due to clean up. - spark.listenerManager.register(listener) - - try { - // scalastyle:off println - System.out.println("start testing") - withTable("t") { - spark.range(1000).repartition(10).write.mode("overwrite"). - saveAsTable("t") - } + val shuffleStatuses = spark.sparkContext.env.mapOutputTracker. + asInstanceOf[MapOutputTrackerMaster].shuffleStatuses + assert(shuffleStatuses.size == 1) - System.out.println(s"MapOutputTracker is ${spark.sparkContext.env.mapOutputTracker}") - val shuffleStatuses = spark.sparkContext.env.mapOutputTracker. - asInstanceOf[MapOutputTrackerMaster].shuffleStatuses - assert(shuffleStatuses.size == 1) - - queryExecutions.foreach(System.out.println) - // scalastyle:on println - - val mapStatuses = shuffleStatuses(0).mapStatuses - assert(mapStatuses.length == 5) - assert(mapStatuses.forall(_.checksumValue != 0)) - } finally { - spark.listenerManager.unregister(listener) - } + val mapStatuses = shuffleStatuses(0).mapStatuses + assert(mapStatuses.length == 5) + assert(mapStatuses.forall(_.checksumValue != 0)) } } From df48158d9cfe3a989132869e2cbd9f8982b2ab9e Mon Sep 17 00:00:00 2001 From: Tengfei Huang Date: Wed, 3 Sep 2025 10:20:23 +0000 Subject: [PATCH 16/29] to resolve conclits --- .../scala/org/apache/spark/sql/MapStatusEndToEndSuite.scala | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/MapStatusEndToEndSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/MapStatusEndToEndSuite.scala index c9e47b3791b6e..d2865d278b9ae 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/MapStatusEndToEndSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/MapStatusEndToEndSuite.scala @@ -30,7 +30,7 @@ class MapStatusEndToEndSuite extends SparkFunSuite with SQLTestUtils { .config(SQLConf.LEAF_NODE_DEFAULT_PARALLELISM.key, value = 5) .getOrCreate() - override def afterEach(): Unit = { + override def afterAll(): Unit = { // This suite should not interfere with the other test suites. SparkSession.getActiveSession.foreach(_.stop()) SparkSession.clearActiveSession() @@ -38,7 +38,7 @@ class MapStatusEndToEndSuite extends SparkFunSuite with SQLTestUtils { SparkSession.clearDefaultSession() } - test("Propagate checksum from executor to driver") { + ignore("Propagate checksum from executor to driver") { assert(spark.sparkContext.conf.get("spark.shuffle.orderIndependentChecksum.enabled") == "true") assert(spark.conf.get("spark.shuffle.orderIndependentChecksum.enabled") == "true") assert(spark.sparkContext.conf.get("spark.sql.leafNodeDefaultParallelism") == "5") From 4cfaac839582b26693810041a474e5f25f2e25d0 Mon Sep 17 00:00:00 2001 From: Tengfei Huang Date: Wed, 3 Sep 2025 10:43:54 +0000 Subject: [PATCH 17/29] fix ut --- .../scala/org/apache/spark/sql/MapStatusEndToEndSuite.scala | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/MapStatusEndToEndSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/MapStatusEndToEndSuite.scala index d2865d278b9ae..bb62d7ea74e63 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/MapStatusEndToEndSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/MapStatusEndToEndSuite.scala @@ -28,6 +28,7 @@ class MapStatusEndToEndSuite extends SparkFunSuite with SQLTestUtils { .master("local") .config(SHUFFLE_ORDER_INDEPENDENT_CHECKSUM_ENABLED.key, value = true) .config(SQLConf.LEAF_NODE_DEFAULT_PARALLELISM.key, value = 5) + .config(SQLConf.CLASSIC_SHUFFLE_DEPENDENCY_FILE_CLEANUP_ENABLED.key, value = false) .getOrCreate() override def afterAll(): Unit = { @@ -38,11 +39,14 @@ class MapStatusEndToEndSuite extends SparkFunSuite with SQLTestUtils { SparkSession.clearDefaultSession() } - ignore("Propagate checksum from executor to driver") { + test("Propagate checksum from executor to driver") { assert(spark.sparkContext.conf.get("spark.shuffle.orderIndependentChecksum.enabled") == "true") assert(spark.conf.get("spark.shuffle.orderIndependentChecksum.enabled") == "true") assert(spark.sparkContext.conf.get("spark.sql.leafNodeDefaultParallelism") == "5") assert(spark.conf.get("spark.sql.leafNodeDefaultParallelism") == "5") + assert(spark.sparkContext.conf.get("spark.sql.classic.shuffleDependency.fileCleanup.enabled") + == "false") + assert(spark.conf.get("spark.sql.classic.shuffleDependency.fileCleanup.enabled") == "false") withTable("t") { spark.range(1000).repartition(10).write.mode("overwrite"). From 786fdd3a3ae067c627e977c368fde61af91a2119 Mon Sep 17 00:00:00 2001 From: Tengfei Huang Date: Thu, 4 Sep 2025 14:56:02 +0000 Subject: [PATCH 18/29] address comments --- .../shuffle/checksum/RowBasedChecksum.scala | 13 +++-- .../scala/org/apache/spark/Dependency.scala | 2 +- .../org/apache/spark/MapOutputTracker.scala | 3 +- .../sort/UnsafeShuffleWriterSuite.java | 49 ++++++++++--------- .../spark/scheduler/DAGSchedulerSuite.scala | 4 +- .../BypassMergeSortShuffleWriterSuite.scala | 19 +++---- 6 files changed, 49 insertions(+), 41 deletions(-) diff --git a/core/src/main/java/org/apache/spark/shuffle/checksum/RowBasedChecksum.scala b/core/src/main/java/org/apache/spark/shuffle/checksum/RowBasedChecksum.scala index 825aed2efa2f0..886296dc8a828 100644 --- a/core/src/main/java/org/apache/spark/shuffle/checksum/RowBasedChecksum.scala +++ b/core/src/main/java/org/apache/spark/shuffle/checksum/RowBasedChecksum.scala @@ -24,7 +24,8 @@ import org.apache.spark.internal.Logging /** * A class for computing checksum for input (key, value) pairs. The checksum is independent of * the order of the input (key, value) pairs. It is done by computing a checksum for each row - * first, and then computing the XOR for all the row checksums. + * first, then computing the XOR and SUM for all the row checksums and mixing these two values + * as the final checksum. */ abstract class RowBasedChecksum() extends Serializable with Logging { private val ROTATE_POSITIONS = 27 @@ -32,19 +33,21 @@ abstract class RowBasedChecksum() extends Serializable with Logging { private var checksumXor: Long = 0 private var checksumSum: Long = 0 - /** Returns the checksum value. It returns the default checksum value (0) if there + /** + * Returns the checksum value. It returns the default checksum value (0) if there * are any errors encountered during the checksum computation. */ def getValue: Long = { if (!hasError) { - val res = checksumXor ^ rotateLeft(checksumSum) - res + // Here we rotate the `checksumSum` to transforms these two values into a single, strong + // composite checksum by ensuring their bit patterns are thoroughly mixed. + checksumXor ^ rotateLeft(checksumSum) } else { 0 } } - /** Updates the row-based checksum with the given (key, value) pair */ + /** Updates the row-based checksum with the given (key, value) pair. Not thread safe. */ def update(key: Any, value: Any): Unit = { if (!hasError) { try { diff --git a/core/src/main/scala/org/apache/spark/Dependency.scala b/core/src/main/scala/org/apache/spark/Dependency.scala index eca4b5a44cb31..5b83556b5144f 100644 --- a/core/src/main/scala/org/apache/spark/Dependency.scala +++ b/core/src/main/scala/org/apache/spark/Dependency.scala @@ -108,7 +108,7 @@ class ShuffleDependency[K: ClassTag, V: ClassTag, C: ClassTag]( aggregator, mapSideCombine, shuffleWriterProcessor, - Array.empty + ShuffleDependency.EMPTY_ROW_BASED_CHECKSUMS ) } diff --git a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala index 91fabc1590c1f..5780b740fe5ab 100644 --- a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala +++ b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala @@ -101,8 +101,9 @@ private class ShuffleStatus( /** * Keep the indices of the Map tasks whose checksums are different across retries. + * Exposed for testing. */ - private[this] val checksumMismatchIndices : Set[Int] = Set() + private[spark] val checksumMismatchIndices : Set[Int] = Set() /** * MergeStatus for each shuffle partition when push-based shuffle is enabled. The index of the diff --git a/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java b/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java index ee06a4818fbdb..295663702889b 100644 --- a/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java +++ b/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java @@ -175,17 +175,15 @@ public void setUp() throws Exception { File file = (File) invocationOnMock.getArguments()[0]; return Utils.tempFileWith(file); }); - resetDependency(); + resetDependency(false); when(taskContext.taskMetrics()).thenReturn(taskMetrics); when(taskContext.taskMemoryManager()).thenReturn(taskMemoryManager); } - private void resetDependency() { + private void resetDependency(boolean rowBasedCheckSumEnabled) { when(shuffleDep.serializer()).thenReturn(serializer); when(shuffleDep.partitioner()).thenReturn(hashPartitioner); - final int checksumSize = - (boolean) conf.get(package$.MODULE$.SHUFFLE_ORDER_INDEPENDENT_CHECKSUM_ENABLED()) - ? NUM_PARTITIONS : 0; + final int checksumSize = rowBasedCheckSumEnabled ? NUM_PARTITIONS : 0; final String checksumAlgorithm = conf.get(package$.MODULE$.SHUFFLE_CHECKSUM_ALGORITHM()); final RowBasedChecksum[] rowBasedChecksums = createPartitionRowBasedChecksums(checksumSize, checksumAlgorithm); @@ -626,7 +624,6 @@ public void spillFilesAreDeletedWhenStoppingAfterError() throws IOException, Spa @Test public void testRowBasedChecksum() throws IOException, SparkException { - conf.set(package$.MODULE$.SHUFFLE_ORDER_INDEPENDENT_CHECKSUM_ENABLED(), true); final ArrayList> dataToWrite = new ArrayList<>(); for (int i = 0; i < NUM_PARTITIONS; i++) { for (int j = 0; j < 5; j++) { @@ -636,25 +633,29 @@ public void testRowBasedChecksum() throws IOException, SparkException { long[] checksumValues = new long[0]; long aggregatedChecksumValue = 0; - for (int i = 0; i < 100; i++) { - resetDependency(); - final UnsafeShuffleWriter writer = createWriter(false); - Collections.shuffle(dataToWrite); - writer.write(dataToWrite.iterator()); - writer.stop(true); - - if (i == 0) { - checksumValues = getRowBasedChecksumValues(writer.getRowBasedChecksums()); - assertEquals(checksumValues.length, NUM_PARTITIONS); - Arrays.stream(checksumValues).allMatch(v -> v > 0); - - aggregatedChecksumValue = writer.getAggregatedChecksumValue(); - assert(aggregatedChecksumValue != 0); - } else { - assertArrayEquals(checksumValues, - getRowBasedChecksumValues(writer.getRowBasedChecksums())); - assertEquals(aggregatedChecksumValue, writer.getAggregatedChecksumValue()); + try { + for (int i = 0; i < 100; i++) { + resetDependency(true); + final UnsafeShuffleWriter writer = createWriter(false); + Collections.shuffle(dataToWrite); + writer.write(dataToWrite.iterator()); + writer.stop(true); + + if (i == 0) { + checksumValues = getRowBasedChecksumValues(writer.getRowBasedChecksums()); + assertEquals(checksumValues.length, NUM_PARTITIONS); + Arrays.stream(checksumValues).allMatch(v -> v > 0); + + aggregatedChecksumValue = writer.getAggregatedChecksumValue(); + assert(aggregatedChecksumValue != 0); + } else { + assertArrayEquals(checksumValues, + getRowBasedChecksumValues(writer.getRowBasedChecksums())); + assertEquals(aggregatedChecksumValue, writer.getAggregatedChecksumValue()); + } } + } finally { + resetDependency(false); } } diff --git a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala index 58905e6d39a76..3b7e88f439270 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala @@ -4858,7 +4858,7 @@ class DAGSchedulerSuite extends SparkFunSuite with TempLocalSparkContext with Ti * of the task in its previous stage. The two attempts of the same task in the previous stage * produce different shuffle checksums. */ - test("Output usage log for tasks that produce different checksum across retries") { + test("Tasks that produce different checksum across retries") { setupStageAbortTest(sc) val parts = 8 @@ -4885,6 +4885,8 @@ class DAGSchedulerSuite extends SparkFunSuite with TempLocalSparkContext with Ti sc.listenerBus.waitUntilEmpty() assert(ended) assert(results == (0 until parts).map { idx => idx -> 42 }.toMap) + assert( + mapOutputTracker.shuffleStatuses(shuffleDep.shuffleId).checksumMismatchIndices.size == 1) assertDataStructuresEmpty() mapOutputTracker.unregisterShuffle(shuffleDep.shuffleId) } diff --git a/core/src/test/scala/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriterSuite.scala b/core/src/test/scala/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriterSuite.scala index bc92cf09c224c..61f4a36c8fda2 100644 --- a/core/src/test/scala/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriterSuite.scala +++ b/core/src/test/scala/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriterSuite.scala @@ -75,7 +75,7 @@ class BypassMergeSortShuffleWriterSuite ) val memoryManager = new TestMemoryManager(conf) val taskMemoryManager = new TaskMemoryManager(memoryManager, 0) - resetDependency(conf) + resetDependency(conf, rowBasedCheckSumEnabled = false) when(taskContext.taskMetrics()).thenReturn(taskMetrics) when(blockResolver.getDataFile(0, 0)).thenReturn(outputFile) when(blockManager.diskBlockManager).thenReturn(diskBlockManager) @@ -145,14 +145,17 @@ class BypassMergeSortShuffleWriterSuite } } - private def resetDependency(sc: SparkConf): Unit = { + private def resetDependency(sc: SparkConf, rowBasedCheckSumEnabled: Boolean): Unit = { reset(dependency) val numPartitions = 7 when(dependency.partitioner).thenReturn(new HashPartitioner(numPartitions)) when(dependency.serializer).thenReturn(new JavaSerializer(sc)) - val checksumSize = - if (sc.get(config.SHUFFLE_ORDER_INDEPENDENT_CHECKSUM_ENABLED)) numPartitions else 0 - val checksumAlgorithm = conf.get(config.SHUFFLE_CHECKSUM_ALGORITHM) + val checksumSize = if (rowBasedCheckSumEnabled) { + numPartitions + } else { + 0 + } + val checksumAlgorithm = sc.get(config.SHUFFLE_CHECKSUM_ALGORITHM) val rowBasedChecksums = createPartitionRowBasedChecksums(checksumSize, checksumAlgorithm) when(dependency.rowBasedChecksums).thenReturn(rowBasedChecksums) } @@ -308,8 +311,6 @@ class BypassMergeSortShuffleWriterSuite } test("Row-based checksums are independent of input row order") { - val transferConf = - conf.clone.set(config.SHUFFLE_ORDER_INDEPENDENT_CHECKSUM_ENABLED.key, true.toString) val records: List[(Int, Int)] = List( (1, 1), (1, 2), (1, 3), (1, 4), (1, 5), (2, 2), (2, 3), (2, 4), (2, 5), (2, 6), @@ -322,12 +323,12 @@ class BypassMergeSortShuffleWriterSuite var checksumValues : Array[Long] = Array[Long]() var aggregatedChecksumValue = 0L for (i <- 1 to 100) { - resetDependency(transferConf); + resetDependency(conf, rowBasedCheckSumEnabled = true) val writer = new BypassMergeSortShuffleWriter[Int, Int]( blockManager, shuffleHandle, 0L, // MapId - transferConf, + conf, taskContext.taskMetrics().shuffleWriteMetrics, shuffleExecutorComponents) From 602729cb00f01fdd16153f3d4a526477ad74df9a Mon Sep 17 00:00:00 2001 From: Tengfei Huang Date: Thu, 4 Sep 2025 15:06:19 +0000 Subject: [PATCH 19/29] address comments --- .../shuffle/sort/SortShuffleWriterSuite.scala | 23 ++++++++----------- 1 file changed, 10 insertions(+), 13 deletions(-) diff --git a/core/src/test/scala/org/apache/spark/shuffle/sort/SortShuffleWriterSuite.scala b/core/src/test/scala/org/apache/spark/shuffle/sort/SortShuffleWriterSuite.scala index 635c775bb907b..9736c2e904edd 100644 --- a/core/src/test/scala/org/apache/spark/shuffle/sort/SortShuffleWriterSuite.scala +++ b/core/src/test/scala/org/apache/spark/shuffle/sort/SortShuffleWriterSuite.scala @@ -66,7 +66,7 @@ class SortShuffleWriterSuite shuffleHandle = { new BaseShuffleHandle(shuffleId, dependency) } - resetDependency() + resetDependency(rowBasedCheckSumEnabled = false) shuffleExecutorComponents = new LocalDiskShuffleExecutorComponents( conf, blockManager, shuffleBlockResolver) } @@ -79,14 +79,17 @@ class SortShuffleWriterSuite } } - private def resetDependency(): Unit = { - reset(dependency); + private def resetDependency(rowBasedCheckSumEnabled: Boolean): Unit = { + reset(dependency) when(dependency.partitioner).thenReturn(partitioner) when(dependency.serializer).thenReturn(serializer) when(dependency.aggregator).thenReturn(None) when(dependency.keyOrdering).thenReturn(None) - val checksumSize = - if (conf.get(config.SHUFFLE_ORDER_INDEPENDENT_CHECKSUM_ENABLED)) numMaps else 0 + val checksumSize = if (rowBasedCheckSumEnabled) { + numMaps + } else { + 0 + } val checksumAlgorithm = conf.get(config.SHUFFLE_CHECKSUM_ALGORITHM) val rowBasedChecksums = createPartitionRowBasedChecksums(checksumSize, checksumAlgorithm) when(dependency.rowBasedChecksums).thenReturn(rowBasedChecksums) @@ -128,13 +131,8 @@ class SortShuffleWriterSuite } test("Row-based checksums are independent of input row order") { - conf.set("spark.shuffle.orderIndependentChecksum.enabled", true.toString) - // FIXME: this can affect other tests (if any) after this set of tests - // since `sc` is global. - sc.stop() - val localSC = new SparkContext("local[4]", "test", conf) val shuffleBlockResolver = new IndexShuffleBlockResolver(conf) - val context = MemoryTestingUtils.fakeTaskContext(localSC.env) + val context = MemoryTestingUtils.fakeTaskContext(sc.env) val records: List[(Int, Int)] = List( (1, 1), (1, 2), (1, 3), (1, 4), (1, 5), (2, 2), (2, 3), (2, 4), (2, 5), (2, 6), @@ -145,7 +143,7 @@ class SortShuffleWriterSuite var checksumValues : Array[Long] = Array[Long]() var aggregatedChecksumValue = 0L for (i <- 1 to 100) { - resetDependency() + resetDependency(rowBasedCheckSumEnabled = true) val writer = new SortShuffleWriter[Int, Int, Int]( shuffleHandle, mapId = 2, @@ -168,7 +166,6 @@ class SortShuffleWriterSuite } writer.stop(success = true) } - localSC.stop() } Seq((true, false, false), From 137f254c7b6cacdc90c97c886e8b90632c2210d5 Mon Sep 17 00:00:00 2001 From: Tengfei Huang Date: Fri, 5 Sep 2025 14:43:18 +0800 Subject: [PATCH 20/29] Update core/src/main/scala/org/apache/spark/MapOutputTracker.scala Co-authored-by: Wenchen Fan --- core/src/main/scala/org/apache/spark/MapOutputTracker.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala index 5780b740fe5ab..3f823b60156ad 100644 --- a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala +++ b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala @@ -103,7 +103,7 @@ private class ShuffleStatus( * Keep the indices of the Map tasks whose checksums are different across retries. * Exposed for testing. */ - private[spark] val checksumMismatchIndices : Set[Int] = Set() + private[spark] val checksumMismatchIndices: Set[Int] = Set() /** * MergeStatus for each shuffle partition when push-based shuffle is enabled. The index of the From dde16d4f5de6e63947adea18f2a98d3a37a21065 Mon Sep 17 00:00:00 2001 From: Tengfei Huang Date: Fri, 5 Sep 2025 14:44:40 +0800 Subject: [PATCH 21/29] Update core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala Co-authored-by: Wenchen Fan --- .../org/apache/spark/scheduler/DAGSchedulerSuite.scala | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala index 3b7e88f439270..c6ae95f35d4e6 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala @@ -5285,10 +5285,10 @@ object DAGSchedulerSuite { val mergerLocs = ArrayBuffer[BlockManagerId]() def makeMapStatus( - host: String, - reduces: Int, - sizes: Byte = 2, - mapTaskId: Long = -1, + host: String, + reduces: Int, + sizes: Byte = 2, + mapTaskId: Long = -1, checksumVal: Long = 0): MapStatus = MapStatus(makeBlockManagerId(host), Array.fill[Long](reduces)(sizes), mapTaskId, checksumVal) From f7d9dfa450e2f57d014de721e44fc040a111cb2e Mon Sep 17 00:00:00 2001 From: Tengfei Huang Date: Fri, 5 Sep 2025 14:45:59 +0800 Subject: [PATCH 22/29] Update core/src/test/scala/org/apache/spark/shuffle/sort/SortShuffleWriterSuite.scala Co-authored-by: Wenchen Fan --- .../org/apache/spark/shuffle/sort/SortShuffleWriterSuite.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/core/src/test/scala/org/apache/spark/shuffle/sort/SortShuffleWriterSuite.scala b/core/src/test/scala/org/apache/spark/shuffle/sort/SortShuffleWriterSuite.scala index 9736c2e904edd..3dca919203191 100644 --- a/core/src/test/scala/org/apache/spark/shuffle/sort/SortShuffleWriterSuite.scala +++ b/core/src/test/scala/org/apache/spark/shuffle/sort/SortShuffleWriterSuite.scala @@ -79,7 +79,7 @@ class SortShuffleWriterSuite } } - private def resetDependency(rowBasedCheckSumEnabled: Boolean): Unit = { + private def resetDependency(rowBasedChecksumEnabled: Boolean): Unit = { reset(dependency) when(dependency.partitioner).thenReturn(partitioner) when(dependency.serializer).thenReturn(serializer) From 5aabe70c17677cc060f7f300063af2d0ed523181 Mon Sep 17 00:00:00 2001 From: Tengfei Huang Date: Fri, 5 Sep 2025 14:46:31 +0800 Subject: [PATCH 23/29] Update core/src/test/scala/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriterSuite.scala Co-authored-by: Wenchen Fan --- .../spark/shuffle/sort/BypassMergeSortShuffleWriterSuite.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/core/src/test/scala/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriterSuite.scala b/core/src/test/scala/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriterSuite.scala index 61f4a36c8fda2..55006c8777e41 100644 --- a/core/src/test/scala/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriterSuite.scala +++ b/core/src/test/scala/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriterSuite.scala @@ -145,7 +145,7 @@ class BypassMergeSortShuffleWriterSuite } } - private def resetDependency(sc: SparkConf, rowBasedCheckSumEnabled: Boolean): Unit = { + private def resetDependency(sc: SparkConf, rowBasedChecksumEnabled: Boolean): Unit = { reset(dependency) val numPartitions = 7 when(dependency.partitioner).thenReturn(new HashPartitioner(numPartitions)) From 2fd0a94e1a89c2bc56faec5fa729f6ebadb5098a Mon Sep 17 00:00:00 2001 From: Tengfei Huang Date: Fri, 5 Sep 2025 14:47:13 +0800 Subject: [PATCH 24/29] Update core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java Co-authored-by: Wenchen Fan --- .../org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java b/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java index 295663702889b..b4b8accc6efb4 100644 --- a/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java +++ b/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java @@ -180,7 +180,7 @@ public void setUp() throws Exception { when(taskContext.taskMemoryManager()).thenReturn(taskMemoryManager); } - private void resetDependency(boolean rowBasedCheckSumEnabled) { + private void resetDependency(boolean rowBasedChecksumEnabled) { when(shuffleDep.serializer()).thenReturn(serializer); when(shuffleDep.partitioner()).thenReturn(hashPartitioner); final int checksumSize = rowBasedCheckSumEnabled ? NUM_PARTITIONS : 0; From bbe26bfd3059ec73ee1f8897d7cf300febc709a2 Mon Sep 17 00:00:00 2001 From: Tengfei Huang Date: Fri, 5 Sep 2025 06:58:06 +0000 Subject: [PATCH 25/29] address comments --- .../apache/spark/internal/config/package.scala | 15 --------------- .../shuffle/sort/UnsafeShuffleWriterSuite.java | 5 ++--- .../shuffle/ShuffleChecksumTestHelper.scala | 6 ++---- .../BypassMergeSortShuffleWriterSuite.scala | 9 ++++----- .../shuffle/sort/SortShuffleWriterSuite.scala | 17 +++++------------ .../apache/spark/sql/internal/SQLConf.scala | 18 ++++++++++++++++++ .../exchange/ShuffleExchangeExec.scala | 2 +- .../spark/sql/MapStatusEndToEndSuite.scala | 3 +-- 8 files changed, 33 insertions(+), 42 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/internal/config/package.scala b/core/src/main/scala/org/apache/spark/internal/config/package.scala index 6891973955912..0bee708bca3c7 100644 --- a/core/src/main/scala/org/apache/spark/internal/config/package.scala +++ b/core/src/main/scala/org/apache/spark/internal/config/package.scala @@ -1650,21 +1650,6 @@ package object config { s"The buffer size must be greater than 0 and less than or equal to ${Int.MaxValue}.") .createWithDefault(4096) - private[spark] val SHUFFLE_ORDER_INDEPENDENT_CHECKSUM_ENABLED = - ConfigBuilder("spark.shuffle.orderIndependentChecksum.enabled") - .doc("Whether to calculate order independent checksum for the shuffle data or not. If " + - "enabled, Spark will calculate a checksum that is independent of the input row order for " + - "each mapper and returns the checksums from executors to driver. This is different from " + - "the checksum computed when spark.shuffle.checksum.enabled is enabled which is sensitive " + - "to shuffle data ordering to detect file corruption. While this checksum will be the " + - "same even if the shuffle row order changes and it is used to detect whether different " + - "task attempts of the same partition produce different output data or not (same set of " + - "keyValue pairs). In case the output data has changed across retries, Spark will need to " + - "retry all tasks of the consumer stages to avoid correctness issues.") - .version("4.1.0") - .booleanConf - .createWithDefault(false) - private[spark] val SHUFFLE_CHECKSUM_ENABLED = ConfigBuilder("spark.shuffle.checksum.enabled") .doc("Whether to calculate the checksum of shuffle data. If enabled, Spark will calculate " + diff --git a/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java b/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java index b4b8accc6efb4..b13d8982ad0df 100644 --- a/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java +++ b/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java @@ -183,10 +183,9 @@ public void setUp() throws Exception { private void resetDependency(boolean rowBasedChecksumEnabled) { when(shuffleDep.serializer()).thenReturn(serializer); when(shuffleDep.partitioner()).thenReturn(hashPartitioner); - final int checksumSize = rowBasedCheckSumEnabled ? NUM_PARTITIONS : 0; - final String checksumAlgorithm = conf.get(package$.MODULE$.SHUFFLE_CHECKSUM_ALGORITHM()); + final int checksumSize = rowBasedChecksumEnabled ? NUM_PARTITIONS : 0; final RowBasedChecksum[] rowBasedChecksums = - createPartitionRowBasedChecksums(checksumSize, checksumAlgorithm); + createPartitionRowBasedChecksums(checksumSize); when(shuffleDep.rowBasedChecksums()).thenReturn(rowBasedChecksums); } diff --git a/core/src/test/scala/org/apache/spark/shuffle/ShuffleChecksumTestHelper.scala b/core/src/test/scala/org/apache/spark/shuffle/ShuffleChecksumTestHelper.scala index c8fe2d5c70d97..439d75ee9364b 100644 --- a/core/src/test/scala/org/apache/spark/shuffle/ShuffleChecksumTestHelper.scala +++ b/core/src/test/scala/org/apache/spark/shuffle/ShuffleChecksumTestHelper.scala @@ -48,9 +48,7 @@ trait ShuffleChecksumTestHelper { } } - def createPartitionRowBasedChecksums( - numPartitions: Int, - checksumAlgorithm: String): Array[RowBasedChecksum] = { - Array.tabulate(numPartitions)(_ => new OutputStreamRowBasedChecksum(checksumAlgorithm)) + def createPartitionRowBasedChecksums(numPartitions: Int): Array[RowBasedChecksum] = { + Array.tabulate(numPartitions)(_ => new OutputStreamRowBasedChecksum("ADLER32")) } } diff --git a/core/src/test/scala/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriterSuite.scala b/core/src/test/scala/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriterSuite.scala index 55006c8777e41..c908c06b399dd 100644 --- a/core/src/test/scala/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriterSuite.scala +++ b/core/src/test/scala/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriterSuite.scala @@ -75,7 +75,7 @@ class BypassMergeSortShuffleWriterSuite ) val memoryManager = new TestMemoryManager(conf) val taskMemoryManager = new TaskMemoryManager(memoryManager, 0) - resetDependency(conf, rowBasedCheckSumEnabled = false) + resetDependency(conf, rowBasedChecksumEnabled = false) when(taskContext.taskMetrics()).thenReturn(taskMetrics) when(blockResolver.getDataFile(0, 0)).thenReturn(outputFile) when(blockManager.diskBlockManager).thenReturn(diskBlockManager) @@ -150,13 +150,12 @@ class BypassMergeSortShuffleWriterSuite val numPartitions = 7 when(dependency.partitioner).thenReturn(new HashPartitioner(numPartitions)) when(dependency.serializer).thenReturn(new JavaSerializer(sc)) - val checksumSize = if (rowBasedCheckSumEnabled) { + val checksumSize = if (rowBasedChecksumEnabled) { numPartitions } else { 0 } - val checksumAlgorithm = sc.get(config.SHUFFLE_CHECKSUM_ALGORITHM) - val rowBasedChecksums = createPartitionRowBasedChecksums(checksumSize, checksumAlgorithm) + val rowBasedChecksums = createPartitionRowBasedChecksums(checksumSize) when(dependency.rowBasedChecksums).thenReturn(rowBasedChecksums) } @@ -323,7 +322,7 @@ class BypassMergeSortShuffleWriterSuite var checksumValues : Array[Long] = Array[Long]() var aggregatedChecksumValue = 0L for (i <- 1 to 100) { - resetDependency(conf, rowBasedCheckSumEnabled = true) + resetDependency(conf, rowBasedChecksumEnabled = true) val writer = new BypassMergeSortShuffleWriter[Int, Int]( blockManager, shuffleHandle, diff --git a/core/src/test/scala/org/apache/spark/shuffle/sort/SortShuffleWriterSuite.scala b/core/src/test/scala/org/apache/spark/shuffle/sort/SortShuffleWriterSuite.scala index 3dca919203191..9d4b0625f762d 100644 --- a/core/src/test/scala/org/apache/spark/shuffle/sort/SortShuffleWriterSuite.scala +++ b/core/src/test/scala/org/apache/spark/shuffle/sort/SortShuffleWriterSuite.scala @@ -31,7 +31,6 @@ import org.apache.spark.memory.MemoryTestingUtils import org.apache.spark.serializer.JavaSerializer import org.apache.spark.shuffle.{BaseShuffleHandle, IndexShuffleBlockResolver, ShuffleChecksumTestHelper} import org.apache.spark.shuffle.api.ShuffleExecutorComponents -import org.apache.spark.shuffle.checksum.RowBasedChecksum import org.apache.spark.shuffle.sort.io.LocalDiskShuffleExecutorComponents import org.apache.spark.storage.BlockManager import org.apache.spark.util.Utils @@ -66,7 +65,7 @@ class SortShuffleWriterSuite shuffleHandle = { new BaseShuffleHandle(shuffleId, dependency) } - resetDependency(rowBasedCheckSumEnabled = false) + resetDependency(rowBasedChecksumEnabled = false) shuffleExecutorComponents = new LocalDiskShuffleExecutorComponents( conf, blockManager, shuffleBlockResolver) } @@ -85,13 +84,12 @@ class SortShuffleWriterSuite when(dependency.serializer).thenReturn(serializer) when(dependency.aggregator).thenReturn(None) when(dependency.keyOrdering).thenReturn(None) - val checksumSize = if (rowBasedCheckSumEnabled) { + val checksumSize = if (rowBasedChecksumEnabled) { numMaps } else { 0 } - val checksumAlgorithm = conf.get(config.SHUFFLE_CHECKSUM_ALGORITHM) - val rowBasedChecksums = createPartitionRowBasedChecksums(checksumSize, checksumAlgorithm) + val rowBasedChecksums = createPartitionRowBasedChecksums(checksumSize) when(dependency.rowBasedChecksums).thenReturn(rowBasedChecksums) } @@ -143,7 +141,7 @@ class SortShuffleWriterSuite var checksumValues : Array[Long] = Array[Long]() var aggregatedChecksumValue = 0L for (i <- 1 to 100) { - resetDependency(rowBasedCheckSumEnabled = true) + resetDependency(rowBasedChecksumEnabled = true) val writer = new SortShuffleWriter[Int, Int, Int]( shuffleHandle, mapId = 2, @@ -195,12 +193,7 @@ class SortShuffleWriterSuite when(dependency.serializer).thenReturn(serializer) when(dependency.aggregator).thenReturn(aggregator) when(dependency.keyOrdering).thenReturn(order) - val checksumSize = - if (conf.get(config.SHUFFLE_ORDER_INDEPENDENT_CHECKSUM_ENABLED)) numMaps else 0 - val checksumAlgorithm = conf.get(config.SHUFFLE_CHECKSUM_ALGORITHM) - val rowBasedChecksums: Array[RowBasedChecksum] = - createPartitionRowBasedChecksums(checksumSize, checksumAlgorithm) - when(dependency.rowBasedChecksums).thenReturn(rowBasedChecksums) + when(dependency.rowBasedChecksums).thenReturn(Array.empty) new BaseShuffleHandle[Int, Int, Int](shuffleId, dependency) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 48bbe09f16431..3f0156c75f727 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -875,6 +875,21 @@ object SQLConf { .checkValue(_ > 0, "The value of spark.sql.shuffle.partitions must be positive") .createWithDefault(200) + val SHUFFLE_ORDER_INDEPENDENT_CHECKSUM_ENABLED = + buildConf("spark.sql.shuffle.orderIndependentChecksum.enabled") + .doc("Whether to calculate order independent checksum for the shuffle data or not. If " + + "enabled, Spark will calculate a checksum that is independent of the input row order for " + + "each mapper and returns the checksums from executors to driver. This is different from " + + "the checksum computed when spark.shuffle.checksum.enabled is enabled which is sensitive " + + "to shuffle data ordering to detect file corruption. While this checksum will be the " + + "same even if the shuffle row order changes and it is used to detect whether different " + + "task attempts of the same partition produce different output data or not (same set of " + + "keyValue pairs). In case the output data has changed across retries, Spark will need to " + + "retry all tasks of the consumer stages to avoid correctness issues.") + .version("4.1.0") + .booleanConf + .createWithDefault(false) + val SHUFFLE_TARGET_POSTSHUFFLE_INPUT_SIZE = buildConf("spark.sql.adaptive.shuffle.targetPostShuffleInputSize") .internal() @@ -6613,6 +6628,9 @@ class SQLConf extends Serializable with Logging with SqlApiConf { } } + def shuffleOrderIndependentChecksumEnabled: Boolean = + getConf(SHUFFLE_ORDER_INDEPENDENT_CHECKSUM_ENABLED) + def allowCollationsInMapKeys: Boolean = getConf(ALLOW_COLLATIONS_IN_MAP_KEYS) def objectLevelCollationsEnabled: Boolean = getConf(OBJECT_LEVEL_COLLATIONS_ENABLED) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchangeExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchangeExec.scala index 13f1f117fddc8..fb157b4141d05 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchangeExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchangeExec.scala @@ -472,7 +472,7 @@ object ShuffleExchangeExec { // are in the form of (partitionId, row) and every partitionId is in the expected range // [0, part.numPartitions - 1]. The partitioner of this is a PartitionIdPassthrough. val checksumSize = - if (SparkEnv.get.conf.get(config.SHUFFLE_ORDER_INDEPENDENT_CHECKSUM_ENABLED)) { + if (SQLConf.get.shuffleOrderIndependentChecksumEnabled) { part.numPartitions } else { 0 diff --git a/sql/core/src/test/scala/org/apache/spark/sql/MapStatusEndToEndSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/MapStatusEndToEndSuite.scala index bb62d7ea74e63..b89c0470e6f79 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/MapStatusEndToEndSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/MapStatusEndToEndSuite.scala @@ -18,7 +18,6 @@ package org.apache.spark.sql import org.apache.spark.{MapOutputTrackerMaster, SparkFunSuite} -import org.apache.spark.internal.config.SHUFFLE_ORDER_INDEPENDENT_CHECKSUM_ENABLED import org.apache.spark.sql.classic.SparkSession import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SQLTestUtils @@ -26,7 +25,7 @@ import org.apache.spark.sql.test.SQLTestUtils class MapStatusEndToEndSuite extends SparkFunSuite with SQLTestUtils { override def spark: SparkSession = SparkSession.builder() .master("local") - .config(SHUFFLE_ORDER_INDEPENDENT_CHECKSUM_ENABLED.key, value = true) + .config(SQLConf.SHUFFLE_ORDER_INDEPENDENT_CHECKSUM_ENABLED.key, value = true) .config(SQLConf.LEAF_NODE_DEFAULT_PARALLELISM.key, value = 5) .config(SQLConf.CLASSIC_SHUFFLE_DEPENDENCY_FILE_CLEANUP_ENABLED.key, value = false) .getOrCreate() From 1a8e9f7e672d60bb60144b8f4b581c85d25f57de Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Fri, 5 Sep 2025 16:35:50 +0800 Subject: [PATCH 26/29] Update core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala --- .../scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala index c6ae95f35d4e6..1ada81cbdd0ee 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala @@ -5289,7 +5289,7 @@ object DAGSchedulerSuite { reduces: Int, sizes: Byte = 2, mapTaskId: Long = -1, - checksumVal: Long = 0): MapStatus = + checksumVal: Long = 0): MapStatus = MapStatus(makeBlockManagerId(host), Array.fill[Long](reduces)(sizes), mapTaskId, checksumVal) def makeBlockManagerId(host: String, execId: Option[String] = None): BlockManagerId = { From 97af717b7ae8bedca323f5ff08d24e9d4746f77c Mon Sep 17 00:00:00 2001 From: Tengfei Huang Date: Fri, 5 Sep 2025 10:49:57 +0000 Subject: [PATCH 27/29] fix ut --- .../scala/org/apache/spark/sql/MapStatusEndToEndSuite.scala | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/MapStatusEndToEndSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/MapStatusEndToEndSuite.scala index b89c0470e6f79..0fe6603122103 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/MapStatusEndToEndSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/MapStatusEndToEndSuite.scala @@ -39,8 +39,9 @@ class MapStatusEndToEndSuite extends SparkFunSuite with SQLTestUtils { } test("Propagate checksum from executor to driver") { - assert(spark.sparkContext.conf.get("spark.shuffle.orderIndependentChecksum.enabled") == "true") - assert(spark.conf.get("spark.shuffle.orderIndependentChecksum.enabled") == "true") + assert(spark.sparkContext.conf + .get("spark.sql.shuffle.orderIndependentChecksum.enabled") == "true") + assert(spark.conf.get("spark.sql.shuffle.orderIndependentChecksum.enabled") == "true") assert(spark.sparkContext.conf.get("spark.sql.leafNodeDefaultParallelism") == "5") assert(spark.conf.get("spark.sql.leafNodeDefaultParallelism") == "5") assert(spark.sparkContext.conf.get("spark.sql.classic.shuffleDependency.fileCleanup.enabled") From 5e01c528427a14d4f9516977abed334f3d3e49ca Mon Sep 17 00:00:00 2001 From: Tengfei Huang Date: Mon, 8 Sep 2025 07:58:11 +0000 Subject: [PATCH 28/29] address comments --- .../scala/org/apache/spark/Dependency.scala | 22 +------------------ .../shuffle/sort/SortShuffleWriter.scala | 6 ++++- .../util/collection/ExternalSorter.scala | 4 ++-- 3 files changed, 8 insertions(+), 24 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/Dependency.scala b/core/src/main/scala/org/apache/spark/Dependency.scala index 5b83556b5144f..1372385ee22f8 100644 --- a/core/src/main/scala/org/apache/spark/Dependency.scala +++ b/core/src/main/scala/org/apache/spark/Dependency.scala @@ -61,7 +61,7 @@ abstract class NarrowDependency[T](_rdd: RDD[T]) extends Dependency[T] { } object ShuffleDependency { - private val EMPTY_ROW_BASED_CHECKSUMS: Array[RowBasedChecksum] = Array.empty + private[spark] val EMPTY_ROW_BASED_CHECKSUMS: Array[RowBasedChecksum] = Array.empty } /** @@ -92,26 +92,6 @@ class ShuffleDependency[K: ClassTag, V: ClassTag, C: ClassTag]( val rowBasedChecksums: Array[RowBasedChecksum] = ShuffleDependency.EMPTY_ROW_BASED_CHECKSUMS) extends Dependency[Product2[K, V]] with Logging { - def this( - rdd: RDD[_ <: Product2[K, V]], - partitioner: Partitioner, - serializer: Serializer, - keyOrdering: Option[Ordering[K]], - aggregator: Option[Aggregator[K, V, C]], - mapSideCombine: Boolean, - shuffleWriterProcessor: ShuffleWriteProcessor) = { - this( - rdd, - partitioner, - serializer, - keyOrdering, - aggregator, - mapSideCombine, - shuffleWriterProcessor, - ShuffleDependency.EMPTY_ROW_BASED_CHECKSUMS - ) - } - if (mapSideCombine) { require(aggregator.isDefined, "Map-side combine without Aggregator specified!") } diff --git a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala index 388ef1e82fa7e..a7ac20016a0ec 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala @@ -50,7 +50,11 @@ private[spark] class SortShuffleWriter[K, V, C]( private var partitionLengths: Array[Long] = _ def getRowBasedChecksums: Array[RowBasedChecksum] = { - if (sorter != null) sorter.getRowBasedChecksums else new Array[RowBasedChecksum](0) + if (sorter != null) { + sorter.getRowBasedChecksums + } else { + ShuffleDependency.EMPTY_ROW_BASED_CHECKSUMS + } } def getAggregatedChecksumValue: Long = { diff --git a/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala b/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala index 67bd25d8d9da8..4da89a94201a9 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala @@ -207,7 +207,7 @@ private[spark] class ExternalSorter[K, V, C]( val partitionId = actualPartitioner.getPartition(kv._1) map.changeValue((partitionId, kv._1), update) maybeSpillCollection(usingMap = true) - if (!rowBasedChecksums.isEmpty) { + if (rowBasedChecksums.nonEmpty) { rowBasedChecksums(partitionId).update(kv._1, kv._2) } } @@ -219,7 +219,7 @@ private[spark] class ExternalSorter[K, V, C]( val partitionId = actualPartitioner.getPartition(kv._1) buffer.insert(partitionId, kv._1, kv._2.asInstanceOf[C]) maybeSpillCollection(usingMap = false) - if (!rowBasedChecksums.isEmpty) { + if (rowBasedChecksums.nonEmpty) { rowBasedChecksums(partitionId).update(kv._1, kv._2) } } From ce293119449d6ff4029171c3e070df733de5ff2a Mon Sep 17 00:00:00 2001 From: Tengfei Huang Date: Mon, 8 Sep 2025 15:47:28 +0000 Subject: [PATCH 29/29] fix mima test failure --- .../scala/org/apache/spark/Dependency.scala | 20 +++++++++++++++++++ 1 file changed, 20 insertions(+) diff --git a/core/src/main/scala/org/apache/spark/Dependency.scala b/core/src/main/scala/org/apache/spark/Dependency.scala index 1372385ee22f8..93a2bbe25157b 100644 --- a/core/src/main/scala/org/apache/spark/Dependency.scala +++ b/core/src/main/scala/org/apache/spark/Dependency.scala @@ -92,6 +92,26 @@ class ShuffleDependency[K: ClassTag, V: ClassTag, C: ClassTag]( val rowBasedChecksums: Array[RowBasedChecksum] = ShuffleDependency.EMPTY_ROW_BASED_CHECKSUMS) extends Dependency[Product2[K, V]] with Logging { + def this( + rdd: RDD[_ <: Product2[K, V]], + partitioner: Partitioner, + serializer: Serializer, + keyOrdering: Option[Ordering[K]], + aggregator: Option[Aggregator[K, V, C]], + mapSideCombine: Boolean, + shuffleWriterProcessor: ShuffleWriteProcessor) = { + this( + rdd, + partitioner, + serializer, + keyOrdering, + aggregator, + mapSideCombine, + shuffleWriterProcessor, + ShuffleDependency.EMPTY_ROW_BASED_CHECKSUMS + ) + } + if (mapSideCombine) { require(aggregator.isDefined, "Map-side combine without Aggregator specified!") }