diff --git a/core/src/main/java/org/apache/spark/shuffle/checksum/RowBasedChecksum.scala b/core/src/main/java/org/apache/spark/shuffle/checksum/RowBasedChecksum.scala new file mode 100644 index 0000000000000..886296dc8a828 --- /dev/null +++ b/core/src/main/java/org/apache/spark/shuffle/checksum/RowBasedChecksum.scala @@ -0,0 +1,80 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.shuffle.checksum + +import scala.util.control.NonFatal + +import org.apache.spark.internal.Logging + +/** + * A class for computing checksum for input (key, value) pairs. The checksum is independent of + * the order of the input (key, value) pairs. It is done by computing a checksum for each row + * first, then computing the XOR and SUM for all the row checksums and mixing these two values + * as the final checksum. + */ +abstract class RowBasedChecksum() extends Serializable with Logging { + private val ROTATE_POSITIONS = 27 + private var hasError: Boolean = false + private var checksumXor: Long = 0 + private var checksumSum: Long = 0 + + /** + * Returns the checksum value. It returns the default checksum value (0) if there + * are any errors encountered during the checksum computation. + */ + def getValue: Long = { + if (!hasError) { + // Here we rotate the `checksumSum` to transforms these two values into a single, strong + // composite checksum by ensuring their bit patterns are thoroughly mixed. + checksumXor ^ rotateLeft(checksumSum) + } else { + 0 + } + } + + /** Updates the row-based checksum with the given (key, value) pair. Not thread safe. */ + def update(key: Any, value: Any): Unit = { + if (!hasError) { + try { + val rowChecksumValue = calculateRowChecksum(key, value) + checksumXor = checksumXor ^ rowChecksumValue + checksumSum += rowChecksumValue + } catch { + case NonFatal(e) => + logError("Checksum computation encountered error: ", e) + hasError = true + } + } + } + + /** Computes and returns the checksum value for the given (key, value) pair */ + protected def calculateRowChecksum(key: Any, value: Any): Long + + // Rotate the value by shifting the bits by `ROTATE_POSITIONS` positions to the left. + private def rotateLeft(value: Long): Long = { + (value << ROTATE_POSITIONS) | (value >>> (64 - ROTATE_POSITIONS)) + } +} + +object RowBasedChecksum { + def getAggregatedChecksumValue(rowBasedChecksums: Array[RowBasedChecksum]): Long = { + Option(rowBasedChecksums) + .map(_.foldLeft(0L)((acc, c) => acc * 31L + c.getValue)) + .getOrElse(0L) + } +} diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java b/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java index 8072a432ab110..5acc66e120630 100644 --- a/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java +++ b/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java @@ -32,6 +32,7 @@ import scala.Tuple2; import scala.collection.Iterator; +import com.google.common.annotations.VisibleForTesting; import com.google.common.io.Closeables; import org.apache.spark.internal.SparkLogger; @@ -53,6 +54,7 @@ import org.apache.spark.scheduler.MapStatus$; import org.apache.spark.serializer.Serializer; import org.apache.spark.serializer.SerializerInstance; +import org.apache.spark.shuffle.checksum.RowBasedChecksum; import org.apache.spark.shuffle.ShuffleWriteMetricsReporter; import org.apache.spark.shuffle.ShuffleWriter; import org.apache.spark.storage.*; @@ -104,6 +106,13 @@ final class BypassMergeSortShuffleWriter private long[] partitionLengths; /** Checksum calculator for each partition. Empty when shuffle checksum disabled. */ private final Checksum[] partitionChecksums; + /** + * Checksum calculator for each partition. Different from the above Checksum, + * RowBasedChecksum is independent of the input row order, which is used to + * detect whether different task attempts of the same partition produce different + * output data or not. + */ + private final RowBasedChecksum[] rowBasedChecksums; /** * Are we in the process of stopping? Because map tasks can call stop() with success = true @@ -132,6 +141,7 @@ final class BypassMergeSortShuffleWriter this.serializer = dep.serializer(); this.shuffleExecutorComponents = shuffleExecutorComponents; this.partitionChecksums = createPartitionChecksums(numPartitions, conf); + this.rowBasedChecksums = dep.rowBasedChecksums(); } @Override @@ -144,7 +154,7 @@ public void write(Iterator> records) throws IOException { partitionLengths = mapOutputWriter.commitAllPartitions( ShuffleChecksumHelper.EMPTY_CHECKSUM_VALUE).getPartitionLengths(); mapStatus = MapStatus$.MODULE$.apply( - blockManager.shuffleServerId(), partitionLengths, mapId); + blockManager.shuffleServerId(), partitionLengths, mapId, getAggregatedChecksumValue()); return; } final SerializerInstance serInstance = serializer.newInstance(); @@ -171,7 +181,11 @@ public void write(Iterator> records) throws IOException { while (records.hasNext()) { final Product2 record = records.next(); final K key = record._1(); - partitionWriters[partitioner.getPartition(key)].write(key, record._2()); + final int partitionId = partitioner.getPartition(key); + partitionWriters[partitionId].write(key, record._2()); + if (rowBasedChecksums.length > 0) { + rowBasedChecksums[partitionId].update(key, record._2()); + } } for (int i = 0; i < numPartitions; i++) { @@ -182,7 +196,7 @@ public void write(Iterator> records) throws IOException { partitionLengths = writePartitionedData(mapOutputWriter); mapStatus = MapStatus$.MODULE$.apply( - blockManager.shuffleServerId(), partitionLengths, mapId); + blockManager.shuffleServerId(), partitionLengths, mapId, getAggregatedChecksumValue()); } catch (Exception e) { try { mapOutputWriter.abort(e); @@ -199,6 +213,17 @@ public long[] getPartitionLengths() { return partitionLengths; } + // For test only. + @VisibleForTesting + RowBasedChecksum[] getRowBasedChecksums() { + return rowBasedChecksums; + } + + @VisibleForTesting + long getAggregatedChecksumValue() { + return RowBasedChecksum.getAggregatedChecksumValue(rowBasedChecksums); + } + /** * Concatenate all of the per-partition files into a single combined file. * diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java b/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java index 36a1487627367..e3ecfed323481 100644 --- a/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java +++ b/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java @@ -59,9 +59,11 @@ import org.apache.spark.shuffle.api.ShufflePartitionWriter; import org.apache.spark.shuffle.api.SingleSpillShuffleMapOutputWriter; import org.apache.spark.shuffle.api.WritableByteChannelWrapper; +import org.apache.spark.shuffle.checksum.RowBasedChecksum; import org.apache.spark.storage.BlockManager; import org.apache.spark.storage.TimeTrackingOutputStream; import org.apache.spark.unsafe.Platform; +import org.apache.spark.util.ExposedBufferByteArrayOutputStream; import org.apache.spark.util.Utils; @Private @@ -93,15 +95,16 @@ public class UnsafeShuffleWriter extends ShuffleWriter { @Nullable private long[] partitionLengths; private long peakMemoryUsedBytes = 0; - /** Subclass of ByteArrayOutputStream that exposes `buf` directly. */ - private static final class MyByteArrayOutputStream extends ByteArrayOutputStream { - MyByteArrayOutputStream(int size) { super(size); } - public byte[] getBuf() { return buf; } - } - - private MyByteArrayOutputStream serBuffer; + private ExposedBufferByteArrayOutputStream serBuffer; private SerializationStream serOutputStream; + /** + * RowBasedChecksum calculator for each partition. RowBasedChecksum is independent + * of the input row order, which is used to detect whether different task attempts + * of the same partition produce different output data or not. + */ + private final RowBasedChecksum[] rowBasedChecksums; + /** * Are we in the process of stopping? Because map tasks can call stop() with success = true * and then call stop() with success = false if they get an exception, we want to make sure @@ -141,6 +144,7 @@ public UnsafeShuffleWriter( (int) (long) sparkConf.get(package$.MODULE$.SHUFFLE_SORT_INIT_BUFFER_SIZE()); this.mergeBufferSizeInBytes = (int) (long) sparkConf.get(package$.MODULE$.SHUFFLE_FILE_MERGE_BUFFER_SIZE()) * 1024; + this.rowBasedChecksums = dep.rowBasedChecksums(); open(); } @@ -162,6 +166,17 @@ public long getPeakMemoryUsedBytes() { return peakMemoryUsedBytes; } + // For test only. + @VisibleForTesting + RowBasedChecksum[] getRowBasedChecksums() { + return rowBasedChecksums; + } + + @VisibleForTesting + long getAggregatedChecksumValue() { + return RowBasedChecksum.getAggregatedChecksumValue(rowBasedChecksums); + } + /** * This convenience method should only be called in test code. */ @@ -210,7 +225,7 @@ private void open() throws SparkException { partitioner.numPartitions(), sparkConf, writeMetrics); - serBuffer = new MyByteArrayOutputStream(DEFAULT_INITIAL_SER_BUFFER_SIZE); + serBuffer = new ExposedBufferByteArrayOutputStream(DEFAULT_INITIAL_SER_BUFFER_SIZE); serOutputStream = serializer.serializeStream(serBuffer); } @@ -233,7 +248,7 @@ void closeAndWriteOutput() throws IOException { } } mapStatus = MapStatus$.MODULE$.apply( - blockManager.shuffleServerId(), partitionLengths, mapId); + blockManager.shuffleServerId(), partitionLengths, mapId, getAggregatedChecksumValue()); } @VisibleForTesting @@ -251,6 +266,9 @@ void insertRecordIntoSorter(Product2 record) throws IOException { sorter.insertRecord( serBuffer.getBuf(), Platform.BYTE_ARRAY_OFFSET, serializedRecordSize, partitionId); + if (rowBasedChecksums.length > 0) { + rowBasedChecksums[partitionId].update(key, record._2()); + } } @VisibleForTesting diff --git a/core/src/main/java/org/apache/spark/util/ExposedBufferByteArrayOutputStream.java b/core/src/main/java/org/apache/spark/util/ExposedBufferByteArrayOutputStream.java new file mode 100644 index 0000000000000..bd59bd176fb93 --- /dev/null +++ b/core/src/main/java/org/apache/spark/util/ExposedBufferByteArrayOutputStream.java @@ -0,0 +1,26 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.util; + +import java.io.ByteArrayOutputStream; + +/** Subclass of ByteArrayOutputStream that exposes `buf` directly. */ +public final class ExposedBufferByteArrayOutputStream extends ByteArrayOutputStream { + public ExposedBufferByteArrayOutputStream(int size) { super(size); } + public byte[] getBuf() { return buf; } +} diff --git a/core/src/main/scala/org/apache/spark/Dependency.scala b/core/src/main/scala/org/apache/spark/Dependency.scala index 745faf866cebf..93a2bbe25157b 100644 --- a/core/src/main/scala/org/apache/spark/Dependency.scala +++ b/core/src/main/scala/org/apache/spark/Dependency.scala @@ -29,6 +29,7 @@ import org.apache.spark.internal.LogKeys._ import org.apache.spark.rdd.RDD import org.apache.spark.serializer.Serializer import org.apache.spark.shuffle.{ShuffleHandle, ShuffleWriteProcessor} +import org.apache.spark.shuffle.checksum.RowBasedChecksum import org.apache.spark.storage.BlockManagerId import org.apache.spark.util.Utils @@ -59,6 +60,9 @@ abstract class NarrowDependency[T](_rdd: RDD[T]) extends Dependency[T] { override def rdd: RDD[T] = _rdd } +object ShuffleDependency { + private[spark] val EMPTY_ROW_BASED_CHECKSUMS: Array[RowBasedChecksum] = Array.empty +} /** * :: DeveloperApi :: @@ -74,6 +78,7 @@ abstract class NarrowDependency[T](_rdd: RDD[T]) extends Dependency[T] { * @param aggregator map/reduce-side aggregator for RDD's shuffle * @param mapSideCombine whether to perform partial aggregation (also known as map-side combine) * @param shuffleWriterProcessor the processor to control the write behavior in ShuffleMapTask + * @param rowBasedChecksums the row-based checksums for each shuffle partition */ @DeveloperApi class ShuffleDependency[K: ClassTag, V: ClassTag, C: ClassTag]( @@ -83,9 +88,30 @@ class ShuffleDependency[K: ClassTag, V: ClassTag, C: ClassTag]( val keyOrdering: Option[Ordering[K]] = None, val aggregator: Option[Aggregator[K, V, C]] = None, val mapSideCombine: Boolean = false, - val shuffleWriterProcessor: ShuffleWriteProcessor = new ShuffleWriteProcessor) + val shuffleWriterProcessor: ShuffleWriteProcessor = new ShuffleWriteProcessor, + val rowBasedChecksums: Array[RowBasedChecksum] = ShuffleDependency.EMPTY_ROW_BASED_CHECKSUMS) extends Dependency[Product2[K, V]] with Logging { + def this( + rdd: RDD[_ <: Product2[K, V]], + partitioner: Partitioner, + serializer: Serializer, + keyOrdering: Option[Ordering[K]], + aggregator: Option[Aggregator[K, V, C]], + mapSideCombine: Boolean, + shuffleWriterProcessor: ShuffleWriteProcessor) = { + this( + rdd, + partitioner, + serializer, + keyOrdering, + aggregator, + mapSideCombine, + shuffleWriterProcessor, + ShuffleDependency.EMPTY_ROW_BASED_CHECKSUMS + ) + } + if (mapSideCombine) { require(aggregator.isDefined, "Map-side combine without Aggregator specified!") } diff --git a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala index 9b2d3d748ed4d..3f823b60156ad 100644 --- a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala +++ b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala @@ -23,7 +23,7 @@ import java.util.concurrent.{ConcurrentHashMap, LinkedBlockingQueue, ThreadPoolE import java.util.concurrent.locks.ReentrantReadWriteLock import scala.collection -import scala.collection.mutable.{HashMap, ListBuffer, Map} +import scala.collection.mutable.{HashMap, ListBuffer, Map, Set} import scala.concurrent.{ExecutionContext, Future} import scala.concurrent.duration.Duration import scala.jdk.CollectionConverters._ @@ -99,6 +99,12 @@ private class ShuffleStatus( */ val mapStatusesDeleted = new Array[MapStatus](numPartitions) + /** + * Keep the indices of the Map tasks whose checksums are different across retries. + * Exposed for testing. + */ + private[spark] val checksumMismatchIndices: Set[Int] = Set() + /** * MergeStatus for each shuffle partition when push-based shuffle is enabled. The index of the * array is the shuffle partition id (reduce id). Each value in the array is the MergeStatus for @@ -169,6 +175,15 @@ private class ShuffleStatus( } else { mapIdToMapIndex.remove(currentMapStatus.mapId) } + logDebug(s"Checksum of map output for task ${status.mapId} is ${status.checksumValue}") + + val preStatus = + if (mapStatuses(mapIndex) != null) mapStatuses(mapIndex) else mapStatusesDeleted(mapIndex) + if (preStatus != null && preStatus.checksumValue != status.checksumValue) { + logInfo(s"Checksum of map output changes from ${preStatus.checksumValue} to " + + s"${status.checksumValue} for task ${status.mapId}.") + checksumMismatchIndices.add(mapIndex) + } mapStatuses(mapIndex) = status mapIdToMapIndex(status.mapId) = mapIndex } diff --git a/core/src/main/scala/org/apache/spark/scheduler/MapStatus.scala b/core/src/main/scala/org/apache/spark/scheduler/MapStatus.scala index 113521453ad7b..e348b6a5f1493 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/MapStatus.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/MapStatus.scala @@ -58,6 +58,12 @@ private[spark] sealed trait MapStatus extends ShuffleOutputStatus { * partitionId of the task or taskContext.taskAttemptId is used. */ def mapId: Long + + /** + * The checksum value of this shuffle map task, which can be used to evaluate whether the + * output data has changed across different map task retries. + */ + def checksumValue: Long = 0 } @@ -74,11 +80,12 @@ private[spark] object MapStatus { def apply( loc: BlockManagerId, uncompressedSizes: Array[Long], - mapTaskId: Long): MapStatus = { + mapTaskId: Long, + checksumVal: Long = 0): MapStatus = { if (uncompressedSizes.length > minPartitionsToUseHighlyCompressMapStatus) { - HighlyCompressedMapStatus(loc, uncompressedSizes, mapTaskId) + HighlyCompressedMapStatus(loc, uncompressedSizes, mapTaskId, checksumVal) } else { - new CompressedMapStatus(loc, uncompressedSizes, mapTaskId) + new CompressedMapStatus(loc, uncompressedSizes, mapTaskId, checksumVal) } } @@ -119,18 +126,24 @@ private[spark] object MapStatus { * @param loc location where the task is being executed. * @param compressedSizes size of the blocks, indexed by reduce partition id. * @param _mapTaskId unique task id for the task + * @param _checksumVal the checksum value for the task */ private[spark] class CompressedMapStatus( private[this] var loc: BlockManagerId, private[this] var compressedSizes: Array[Byte], - private[this] var _mapTaskId: Long) + private[this] var _mapTaskId: Long, + private[this] var _checksumVal: Long = 0) extends MapStatus with Externalizable { // For deserialization only - protected def this() = this(null, null.asInstanceOf[Array[Byte]], -1) + protected def this() = this(null, null.asInstanceOf[Array[Byte]], -1, 0) - def this(loc: BlockManagerId, uncompressedSizes: Array[Long], mapTaskId: Long) = { - this(loc, uncompressedSizes.map(MapStatus.compressSize), mapTaskId) + def this( + loc: BlockManagerId, + uncompressedSizes: Array[Long], + mapTaskId: Long, + checksumVal: Long) = { + this(loc, uncompressedSizes.map(MapStatus.compressSize), mapTaskId, checksumVal) } override def location: BlockManagerId = loc @@ -145,11 +158,14 @@ private[spark] class CompressedMapStatus( override def mapId: Long = _mapTaskId + override def checksumValue: Long = _checksumVal + override def writeExternal(out: ObjectOutput): Unit = Utils.tryOrIOException { loc.writeExternal(out) out.writeInt(compressedSizes.length) out.write(compressedSizes) out.writeLong(_mapTaskId) + out.writeLong(_checksumVal) } override def readExternal(in: ObjectInput): Unit = Utils.tryOrIOException { @@ -158,6 +174,7 @@ private[spark] class CompressedMapStatus( compressedSizes = new Array[Byte](len) in.readFully(compressedSizes) _mapTaskId = in.readLong() + _checksumVal = in.readLong() } } @@ -172,6 +189,7 @@ private[spark] class CompressedMapStatus( * @param avgSize average size of the non-empty and non-huge blocks * @param hugeBlockSizes sizes of huge blocks by their reduceId. * @param _mapTaskId unique task id for the task + * @param _checksumVal checksum value for the task */ private[spark] class HighlyCompressedMapStatus private ( private[this] var loc: BlockManagerId, @@ -179,7 +197,8 @@ private[spark] class HighlyCompressedMapStatus private ( private[this] var emptyBlocks: RoaringBitmap, private[this] var avgSize: Long, private[this] var hugeBlockSizes: scala.collection.Map[Int, Byte], - private[this] var _mapTaskId: Long) + private[this] var _mapTaskId: Long, + private[this] var _checksumVal: Long = 0) extends MapStatus with Externalizable { // loc could be null when the default constructor is called during deserialization @@ -187,7 +206,7 @@ private[spark] class HighlyCompressedMapStatus private ( || numNonEmptyBlocks == 0 || _mapTaskId > 0, "Average size can only be zero for map stages that produced no output") - protected def this() = this(null, -1, null, -1, null, -1) // For deserialization only + protected def this() = this(null, -1, null, -1, null, -1, 0) // For deserialization only override def location: BlockManagerId = loc @@ -209,6 +228,8 @@ private[spark] class HighlyCompressedMapStatus private ( override def mapId: Long = _mapTaskId + override def checksumValue: Long = _checksumVal + override def writeExternal(out: ObjectOutput): Unit = Utils.tryOrIOException { loc.writeExternal(out) emptyBlocks.serialize(out) @@ -219,6 +240,7 @@ private[spark] class HighlyCompressedMapStatus private ( out.writeByte(kv._2) } out.writeLong(_mapTaskId) + out.writeLong(_checksumVal) } override def readExternal(in: ObjectInput): Unit = Utils.tryOrIOException { @@ -236,6 +258,7 @@ private[spark] class HighlyCompressedMapStatus private ( } hugeBlockSizes = hugeBlockSizesImpl _mapTaskId = in.readLong() + _checksumVal = in.readLong() } } @@ -243,7 +266,8 @@ private[spark] object HighlyCompressedMapStatus { def apply( loc: BlockManagerId, uncompressedSizes: Array[Long], - mapTaskId: Long): HighlyCompressedMapStatus = { + mapTaskId: Long, + checksumVal: Long = 0): HighlyCompressedMapStatus = { // We must keep track of which blocks are empty so that we don't report a zero-sized // block as being non-empty (or vice-versa) when using the average block size. var i = 0 @@ -310,6 +334,6 @@ private[spark] object HighlyCompressedMapStatus { emptyBlocks.trim() emptyBlocks.runOptimize() new HighlyCompressedMapStatus(loc, numNonEmptyBlocks, emptyBlocks, avgSize, - hugeBlockSizes, mapTaskId) + hugeBlockSizes, mapTaskId, checksumVal) } } diff --git a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala index 3be7d24f7e4ec..a7ac20016a0ec 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala @@ -23,6 +23,7 @@ import org.apache.spark.scheduler.MapStatus import org.apache.spark.shuffle.{BaseShuffleHandle, ShuffleWriter} import org.apache.spark.shuffle.ShuffleWriteMetricsReporter import org.apache.spark.shuffle.api.ShuffleExecutorComponents +import org.apache.spark.shuffle.checksum.RowBasedChecksum import org.apache.spark.util.collection.ExternalSorter private[spark] class SortShuffleWriter[K, V, C]( @@ -48,17 +49,31 @@ private[spark] class SortShuffleWriter[K, V, C]( private var partitionLengths: Array[Long] = _ + def getRowBasedChecksums: Array[RowBasedChecksum] = { + if (sorter != null) { + sorter.getRowBasedChecksums + } else { + ShuffleDependency.EMPTY_ROW_BASED_CHECKSUMS + } + } + + def getAggregatedChecksumValue: Long = { + if (sorter != null) sorter.getAggregatedChecksumValue else 0 + } + /** Write a bunch of records to this task's output */ override def write(records: Iterator[Product2[K, V]]): Unit = { sorter = if (dep.mapSideCombine) { new ExternalSorter[K, V, C]( - context, dep.aggregator, Some(dep.partitioner), dep.keyOrdering, dep.serializer) + context, dep.aggregator, Some(dep.partitioner), dep.keyOrdering, + dep.serializer, dep.rowBasedChecksums) } else { // In this case we pass neither an aggregator nor an ordering to the sorter, because we don't // care whether the keys get sorted in each partition; that will be done on the reduce side // if the operation being run is sortByKey. new ExternalSorter[K, V, V]( - context, aggregator = None, Some(dep.partitioner), ordering = None, dep.serializer) + context, aggregator = None, Some(dep.partitioner), ordering = None, + dep.serializer, dep.rowBasedChecksums) } sorter.insertAll(records) @@ -69,7 +84,8 @@ private[spark] class SortShuffleWriter[K, V, C]( dep.shuffleId, mapId, dep.partitioner.numPartitions) sorter.writePartitionedMapOutput(dep.shuffleId, mapId, mapOutputWriter, writeMetrics) partitionLengths = mapOutputWriter.commitAllPartitions(sorter.getChecksums).getPartitionLengths - mapStatus = MapStatus(blockManager.shuffleServerId, partitionLengths, mapId) + mapStatus = + MapStatus(blockManager.shuffleServerId, partitionLengths, mapId, getAggregatedChecksumValue) } /** Close this writer, passing along whether the map completed */ diff --git a/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala b/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala index 8dd207b25bb94..4da89a94201a9 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala @@ -33,7 +33,7 @@ import org.apache.spark.internal.LogKeys.{NUM_BYTES, TASK_ATTEMPT_ID} import org.apache.spark.serializer._ import org.apache.spark.shuffle.{ShufflePartitionPairsWriter, ShuffleWriteMetricsReporter} import org.apache.spark.shuffle.api.{ShuffleMapOutputWriter, ShufflePartitionWriter} -import org.apache.spark.shuffle.checksum.ShuffleChecksumSupport +import org.apache.spark.shuffle.checksum.{RowBasedChecksum, ShuffleChecksumSupport} import org.apache.spark.storage.{BlockId, DiskBlockObjectWriter, ShuffleBlockId} import org.apache.spark.util.{CompletionIterator, Utils => TryUtils} @@ -97,7 +97,8 @@ private[spark] class ExternalSorter[K, V, C]( aggregator: Option[Aggregator[K, V, C]] = None, partitioner: Option[Partitioner] = None, ordering: Option[Ordering[K]] = None, - serializer: Serializer = SparkEnv.get.serializer) + serializer: Serializer = SparkEnv.get.serializer, + rowBasedChecksums: Array[RowBasedChecksum] = Array.empty) extends Spillable[WritablePartitionedPairCollection[K, C]](context.taskMemoryManager()) with Logging with ShuffleChecksumSupport { @@ -142,10 +143,16 @@ private[spark] class ExternalSorter[K, V, C]( private val forceSpillFiles = new ArrayBuffer[SpilledFile] @volatile private var readingIterator: SpillableIterator = null + /** Checksum calculator for each partition. Empty when shuffle checksum disabled. */ private val partitionChecksums = createPartitionChecksums(numPartitions, conf) def getChecksums: Array[Long] = getChecksumValues(partitionChecksums) + def getRowBasedChecksums: Array[RowBasedChecksum] = rowBasedChecksums + + def getAggregatedChecksumValue: Long = + RowBasedChecksum.getAggregatedChecksumValue(rowBasedChecksums) + // A comparator for keys K that orders them within a partition to allow aggregation or sorting. // Can be a partial ordering by hash code if a total ordering is not provided through by the // user. (A partial ordering means that equal keys have comparator.compare(k, k) = 0, but some @@ -197,16 +204,24 @@ private[spark] class ExternalSorter[K, V, C]( while (records.hasNext) { addElementsRead() kv = records.next() - map.changeValue((actualPartitioner.getPartition(kv._1), kv._1), update) + val partitionId = actualPartitioner.getPartition(kv._1) + map.changeValue((partitionId, kv._1), update) maybeSpillCollection(usingMap = true) + if (rowBasedChecksums.nonEmpty) { + rowBasedChecksums(partitionId).update(kv._1, kv._2) + } } } else { // Stick values into our buffer while (records.hasNext) { addElementsRead() val kv = records.next() - buffer.insert(actualPartitioner.getPartition(kv._1), kv._1, kv._2.asInstanceOf[C]) + val partitionId = actualPartitioner.getPartition(kv._1) + buffer.insert(partitionId, kv._1, kv._2.asInstanceOf[C]) maybeSpillCollection(usingMap = false) + if (rowBasedChecksums.nonEmpty) { + rowBasedChecksums(partitionId).update(kv._1, kv._2) + } } } } diff --git a/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java b/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java index c55254e04f401..b13d8982ad0df 100644 --- a/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java +++ b/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java @@ -49,6 +49,7 @@ import org.apache.spark.scheduler.MapStatus; import org.apache.spark.security.CryptoStreamUtils; import org.apache.spark.serializer.*; +import org.apache.spark.shuffle.checksum.RowBasedChecksum; import org.apache.spark.shuffle.IndexShuffleBlockResolver; import org.apache.spark.shuffle.sort.io.LocalDiskShuffleExecutorComponents; import org.apache.spark.storage.*; @@ -174,11 +175,18 @@ public void setUp() throws Exception { File file = (File) invocationOnMock.getArguments()[0]; return Utils.tempFileWith(file); }); - + resetDependency(false); when(taskContext.taskMetrics()).thenReturn(taskMetrics); + when(taskContext.taskMemoryManager()).thenReturn(taskMemoryManager); + } + + private void resetDependency(boolean rowBasedChecksumEnabled) { when(shuffleDep.serializer()).thenReturn(serializer); when(shuffleDep.partitioner()).thenReturn(hashPartitioner); - when(taskContext.taskMemoryManager()).thenReturn(taskMemoryManager); + final int checksumSize = rowBasedChecksumEnabled ? NUM_PARTITIONS : 0; + final RowBasedChecksum[] rowBasedChecksums = + createPartitionRowBasedChecksums(checksumSize); + when(shuffleDep.rowBasedChecksums()).thenReturn(rowBasedChecksums); } private UnsafeShuffleWriter createWriter(boolean transferToEnabled) @@ -613,6 +621,43 @@ public void spillFilesAreDeletedWhenStoppingAfterError() throws IOException, Spa assertSpillFilesWereCleanedUp(); } + @Test + public void testRowBasedChecksum() throws IOException, SparkException { + final ArrayList> dataToWrite = new ArrayList<>(); + for (int i = 0; i < NUM_PARTITIONS; i++) { + for (int j = 0; j < 5; j++) { + dataToWrite.add(new Tuple2<>(i, i + j)); + } + } + + long[] checksumValues = new long[0]; + long aggregatedChecksumValue = 0; + try { + for (int i = 0; i < 100; i++) { + resetDependency(true); + final UnsafeShuffleWriter writer = createWriter(false); + Collections.shuffle(dataToWrite); + writer.write(dataToWrite.iterator()); + writer.stop(true); + + if (i == 0) { + checksumValues = getRowBasedChecksumValues(writer.getRowBasedChecksums()); + assertEquals(checksumValues.length, NUM_PARTITIONS); + Arrays.stream(checksumValues).allMatch(v -> v > 0); + + aggregatedChecksumValue = writer.getAggregatedChecksumValue(); + assert(aggregatedChecksumValue != 0); + } else { + assertArrayEquals(checksumValues, + getRowBasedChecksumValues(writer.getRowBasedChecksums())); + assertEquals(aggregatedChecksumValue, writer.getAggregatedChecksumValue()); + } + } + } finally { + resetDependency(false); + } + } + @Test public void testPeakMemoryUsed() throws Exception { final long recordLengthBytes = 8; diff --git a/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala b/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala index 68e366e9ad107..d2344b4e72911 100644 --- a/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala +++ b/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala @@ -272,7 +272,7 @@ class MapOutputTrackerSuite extends SparkFunSuite with LocalSparkContext { masterTracker.registerShuffle(20, 100, MergeStatus.SHUFFLE_PUSH_DUMMY_NUM_REDUCES) (0 until 100).foreach { i => masterTracker.registerMapOutput(20, i, new CompressedMapStatus( - BlockManagerId("999", "mps", 1000), createArray(4000000, 0L), 5)) + BlockManagerId("999", "mps", 1000), createArray(4000000, 0L), 5, 100)) } val senderAddress = RpcAddress("localhost", 12345) val rpcCallContext = mock(classOf[RpcCallContext]) @@ -579,7 +579,7 @@ class MapOutputTrackerSuite extends SparkFunSuite with LocalSparkContext { masterTracker.registerShuffle(20, 100, MergeStatus.SHUFFLE_PUSH_DUMMY_NUM_REDUCES) (0 until 100).foreach { i => masterTracker.registerMapOutput(20, i, new CompressedMapStatus( - BlockManagerId("999", "mps", 1000), createArray(4000000, 0L), 5)) + BlockManagerId("999", "mps", 1000), createArray(4000000, 0L), 5, 100)) } val mapWorkerRpcEnv = createRpcEnv("spark-worker", "localhost", 0, new SecurityManager(conf)) @@ -626,7 +626,7 @@ class MapOutputTrackerSuite extends SparkFunSuite with LocalSparkContext { masterTracker.registerShuffle(20, 100, MergeStatus.SHUFFLE_PUSH_DUMMY_NUM_REDUCES) (0 until 100).foreach { i => masterTracker.registerMapOutput(20, i, new CompressedMapStatus( - BlockManagerId("999", "mps", 1000), createArray(4000000, 0L), 5)) + BlockManagerId("999", "mps", 1000), createArray(4000000, 0L), 5, 100)) } masterTracker.registerMergeResult(20, 0, MergeStatus(BlockManagerId("999", "mps", 1000), 0, bitmap1, 1000L)) diff --git a/core/src/test/scala/org/apache/spark/MapStatusesSerDeserBenchmark.scala b/core/src/test/scala/org/apache/spark/MapStatusesSerDeserBenchmark.scala index 75f952d063d33..bd8766bd260e9 100644 --- a/core/src/test/scala/org/apache/spark/MapStatusesSerDeserBenchmark.scala +++ b/core/src/test/scala/org/apache/spark/MapStatusesSerDeserBenchmark.scala @@ -59,7 +59,7 @@ object MapStatusesSerDeserBenchmark extends BenchmarkBase { Array.fill(blockSize) { // Creating block size ranging from 0byte to 1GB (r.nextDouble() * 1024 * 1024 * 1024).toLong - }, i)) + }, i, i * 100)) } val shuffleStatus = tracker.shuffleStatuses.get(shuffleId).head diff --git a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala index bf38c629f700b..1ada81cbdd0ee 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala @@ -1166,7 +1166,8 @@ class DAGSchedulerSuite extends SparkFunSuite with TempLocalSparkContext with Ti stageId: Int, attemptIdx: Int, numShufflePartitions: Int, - hostNames: Seq[String] = Seq.empty[String]): Unit = { + hostNames: Seq[String] = Seq.empty[String], + checksumVal: Long = 0): Unit = { def compareStageAttempt(taskSet: TaskSet): Boolean = { taskSet.stageId == stageId && taskSet.stageAttemptId == attemptIdx } @@ -1181,7 +1182,7 @@ class DAGSchedulerSuite extends SparkFunSuite with TempLocalSparkContext with Ti } else { s"host${('A' + idx).toChar}" } - (Success, makeMapStatus(hostName, numShufflePartitions)) + (Success, makeMapStatus(hostName, numShufflePartitions, checksumVal = checksumVal)) }.toSeq) } @@ -4852,6 +4853,44 @@ class DAGSchedulerSuite extends SparkFunSuite with TempLocalSparkContext with Ti assert(mapStatuses.count(s => s != null && s.location.executorId == "hostB-exec") === 1) } + /** + * In this test, we simulate a job where some tasks in a stage fail, and it triggers the retry + * of the task in its previous stage. The two attempts of the same task in the previous stage + * produce different shuffle checksums. + */ + test("Tasks that produce different checksum across retries") { + setupStageAbortTest(sc) + + val parts = 8 + val shuffleMapRdd = new MyRDD(sc, parts, Nil) + val shuffleDep = new ShuffleDependency(shuffleMapRdd, new HashPartitioner(parts)) + val reduceRdd = new MyRDD(sc, parts, List(shuffleDep), tracker = mapOutputTracker) + submit(reduceRdd, (0 until parts).toArray) + + // Complete stage 0 and then fail stage 1, and tasks in stage 0 produce a checksum of 100. + completeShuffleMapStageSuccessfully(0, 0, numShufflePartitions = parts, checksumVal = 100) + completeNextStageWithFetchFailure(1, 0, shuffleDep) + + // Resubmit and confirm that now all is well. + scheduler.resubmitFailedStages() + assert(scheduler.runningStages.nonEmpty) + assert(!ended) + + // Complete stage 0 and then stage 1, and the retried task in stage 0 produces a different + // checksum of 200. + completeShuffleMapStageSuccessfully(0, 1, numShufflePartitions = parts, checksumVal = 200) + completeNextResultStageWithSuccess(1, 1) + + // Confirm job finished successfully. + sc.listenerBus.waitUntilEmpty() + assert(ended) + assert(results == (0 until parts).map { idx => idx -> 42 }.toMap) + assert( + mapOutputTracker.shuffleStatuses(shuffleDep.shuffleId).checksumMismatchIndices.size == 1) + assertDataStructuresEmpty() + mapOutputTracker.unregisterShuffle(shuffleDep.shuffleId) + } + Seq(true, false).foreach { registerMergeResults => test("SPARK-40096: Send finalize events even if shuffle merger blocks indefinitely " + s"with registerMergeResults is ${registerMergeResults}") { @@ -5245,8 +5284,13 @@ class DAGSchedulerAbortStageOffSuite extends DAGSchedulerSuite { object DAGSchedulerSuite { val mergerLocs = ArrayBuffer[BlockManagerId]() - def makeMapStatus(host: String, reduces: Int, sizes: Byte = 2, mapTaskId: Long = -1): MapStatus = - MapStatus(makeBlockManagerId(host), Array.fill[Long](reduces)(sizes), mapTaskId) + def makeMapStatus( + host: String, + reduces: Int, + sizes: Byte = 2, + mapTaskId: Long = -1, + checksumVal: Long = 0): MapStatus = + MapStatus(makeBlockManagerId(host), Array.fill[Long](reduces)(sizes), mapTaskId, checksumVal) def makeBlockManagerId(host: String, execId: Option[String] = None): BlockManagerId = { BlockManagerId(execId.getOrElse(host + "-exec"), host, 12345) diff --git a/core/src/test/scala/org/apache/spark/shuffle/ShuffleChecksumTestHelper.scala b/core/src/test/scala/org/apache/spark/shuffle/ShuffleChecksumTestHelper.scala index 8be103b7be860..439d75ee9364b 100644 --- a/core/src/test/scala/org/apache/spark/shuffle/ShuffleChecksumTestHelper.scala +++ b/core/src/test/scala/org/apache/spark/shuffle/ShuffleChecksumTestHelper.scala @@ -19,6 +19,8 @@ package org.apache.spark.shuffle import java.io.File +import org.apache.spark.shuffle.checksum.{OutputStreamRowBasedChecksum, RowBasedChecksum} + trait ShuffleChecksumTestHelper { /** @@ -37,4 +39,16 @@ trait ShuffleChecksumTestHelper { assert(ShuffleChecksumUtils.compareChecksums(numPartition, algorithm, checksum, data, index), "checksum must be consistent at both write and read sides") } + + def getRowBasedChecksumValues(rowBasedChecksums: Array[RowBasedChecksum]): Array[Long] = { + if (rowBasedChecksums.isEmpty) { + Array.empty + } else { + rowBasedChecksums.map(_.getValue) + } + } + + def createPartitionRowBasedChecksums(numPartitions: Int): Array[RowBasedChecksum] = { + Array.tabulate(numPartitions)(_ => new OutputStreamRowBasedChecksum("ADLER32")) + } } diff --git a/core/src/test/scala/org/apache/spark/shuffle/checksum/OutputStreamRowBasedChecksum.scala b/core/src/test/scala/org/apache/spark/shuffle/checksum/OutputStreamRowBasedChecksum.scala new file mode 100644 index 0000000000000..3abec5f4bd656 --- /dev/null +++ b/core/src/test/scala/org/apache/spark/shuffle/checksum/OutputStreamRowBasedChecksum.scala @@ -0,0 +1,64 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.shuffle.checksum + +import java.io.ObjectOutputStream +import java.util.zip.Checksum + +import org.apache.spark.network.shuffle.checksum.ShuffleChecksumHelper +import org.apache.spark.util.ExposedBufferByteArrayOutputStream + +/** + * A Concrete implementation of RowBasedChecksum. The checksum for each row is + * computed by first converting the (key, value) pair to byte array using OutputStreams, + * and then computing the checksum for the byte array. + * Note that this checksum computation is very expensive, and it is used only in tests + * in the core component. A much cheaper implementation of RowBasedChecksum is in + * UnsafeRowChecksum. + * + * @param checksumAlgorithm the algorithm used for computing checksum. + */ +class OutputStreamRowBasedChecksum(checksumAlgorithm: String) + extends RowBasedChecksum() { + + private val DEFAULT_INITIAL_SER_BUFFER_SIZE = 32 * 1024 + + @transient private lazy val serBuffer = + new ExposedBufferByteArrayOutputStream(DEFAULT_INITIAL_SER_BUFFER_SIZE) + @transient private lazy val objOut = new ObjectOutputStream(serBuffer) + + @transient + protected lazy val checksum: Checksum = + ShuffleChecksumHelper.getChecksumByAlgorithm(checksumAlgorithm) + + override protected def calculateRowChecksum(key: Any, value: Any): Long = { + assert(checksum != null, "Checksum is null") + + // Converts the (key, value) pair into byte array. + objOut.reset() + serBuffer.reset() + objOut.writeObject((key, value)) + objOut.flush() + serBuffer.flush() + + // Computes and returns the checksum for the byte array. + checksum.reset() + checksum.update(serBuffer.getBuf, 0, serBuffer.size()) + checksum.getValue + } +} diff --git a/core/src/test/scala/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriterSuite.scala b/core/src/test/scala/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriterSuite.scala index ce2aefa74229a..c908c06b399dd 100644 --- a/core/src/test/scala/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriterSuite.scala +++ b/core/src/test/scala/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriterSuite.scala @@ -22,6 +22,7 @@ import java.util.UUID import scala.collection.mutable import scala.collection.mutable.ArrayBuffer +import scala.util.Random import org.mockito.{Mock, MockitoAnnotations} import org.mockito.Answers.RETURNS_SMART_NULLS @@ -74,8 +75,7 @@ class BypassMergeSortShuffleWriterSuite ) val memoryManager = new TestMemoryManager(conf) val taskMemoryManager = new TaskMemoryManager(memoryManager, 0) - when(dependency.partitioner).thenReturn(new HashPartitioner(7)) - when(dependency.serializer).thenReturn(new JavaSerializer(conf)) + resetDependency(conf, rowBasedChecksumEnabled = false) when(taskContext.taskMetrics()).thenReturn(taskMetrics) when(blockResolver.getDataFile(0, 0)).thenReturn(outputFile) when(blockManager.diskBlockManager).thenReturn(diskBlockManager) @@ -145,6 +145,20 @@ class BypassMergeSortShuffleWriterSuite } } + private def resetDependency(sc: SparkConf, rowBasedChecksumEnabled: Boolean): Unit = { + reset(dependency) + val numPartitions = 7 + when(dependency.partitioner).thenReturn(new HashPartitioner(numPartitions)) + when(dependency.serializer).thenReturn(new JavaSerializer(sc)) + val checksumSize = if (rowBasedChecksumEnabled) { + numPartitions + } else { + 0 + } + val rowBasedChecksums = createPartitionRowBasedChecksums(checksumSize) + when(dependency.rowBasedChecksums).thenReturn(rowBasedChecksums) + } + test("write empty iterator") { val writer = new BypassMergeSortShuffleWriter[Int, Int]( blockManager, @@ -294,4 +308,44 @@ class BypassMergeSortShuffleWriterSuite assert(checksumFile.length() === 8 * numPartition) compareChecksums(numPartition, checksumAlgorithm, checksumFile, dataFile, indexFile) } + + test("Row-based checksums are independent of input row order") { + val records: List[(Int, Int)] = List( + (1, 1), (1, 2), (1, 3), (1, 4), (1, 5), + (2, 2), (2, 3), (2, 4), (2, 5), (2, 6), + (3, 3), (3, 4), (3, 5), (3, 6), (3, 7), + (4, 4), (4, 5), (4, 6), (4, 7), (4, 8), + (5, 5), (5, 6), (5, 7), (5, 8), (5, 9), + (6, 6), (6, 7), (6, 8), (6, 9), (6, 10), + (7, 7), (7, 8), (7, 9), (7, 10), (7, 11)) + + var checksumValues : Array[Long] = Array[Long]() + var aggregatedChecksumValue = 0L + for (i <- 1 to 100) { + resetDependency(conf, rowBasedChecksumEnabled = true) + val writer = new BypassMergeSortShuffleWriter[Int, Int]( + blockManager, + shuffleHandle, + 0L, // MapId + conf, + taskContext.taskMetrics().shuffleWriteMetrics, + shuffleExecutorComponents) + + writer.write(Random.shuffle(records).iterator) + writer.stop(/* success = */ true) + + if(i == 1) { + checksumValues = getRowBasedChecksumValues(writer.getRowBasedChecksums) + assert(checksumValues.length > 0) + assert(checksumValues.forall(_ > 0)) + + aggregatedChecksumValue = writer.getAggregatedChecksumValue() + assert(aggregatedChecksumValue != 0) + } else { + assert(checksumValues.sameElements( + getRowBasedChecksumValues(writer.getRowBasedChecksums))) + assert(aggregatedChecksumValue == writer.getAggregatedChecksumValue()) + } + } + } } diff --git a/core/src/test/scala/org/apache/spark/shuffle/sort/SortShuffleWriterSuite.scala b/core/src/test/scala/org/apache/spark/shuffle/sort/SortShuffleWriterSuite.scala index 99402abb16cac..9d4b0625f762d 100644 --- a/core/src/test/scala/org/apache/spark/shuffle/sort/SortShuffleWriterSuite.scala +++ b/core/src/test/scala/org/apache/spark/shuffle/sort/SortShuffleWriterSuite.scala @@ -17,6 +17,8 @@ package org.apache.spark.shuffle.sort +import scala.util.Random + import org.mockito.{Mock, MockitoAnnotations} import org.mockito.Answers.RETURNS_SMART_NULLS import org.mockito.Mockito._ @@ -50,6 +52,7 @@ class SortShuffleWriterSuite private val shuffleBlockResolver = new IndexShuffleBlockResolver(conf) private val serializer = new JavaSerializer(conf) private var shuffleExecutorComponents: ShuffleExecutorComponents = _ + @Mock(answer = RETURNS_SMART_NULLS) private var dependency: ShuffleDependency[Int, Int, Int] = _ private val partitioner = new Partitioner() { def numPartitions = numMaps @@ -60,13 +63,9 @@ class SortShuffleWriterSuite super.beforeEach() MockitoAnnotations.openMocks(this).close() shuffleHandle = { - val dependency = mock(classOf[ShuffleDependency[Int, Int, Int]]) - when(dependency.partitioner).thenReturn(partitioner) - when(dependency.serializer).thenReturn(serializer) - when(dependency.aggregator).thenReturn(None) - when(dependency.keyOrdering).thenReturn(None) new BaseShuffleHandle(shuffleId, dependency) } + resetDependency(rowBasedChecksumEnabled = false) shuffleExecutorComponents = new LocalDiskShuffleExecutorComponents( conf, blockManager, shuffleBlockResolver) } @@ -79,6 +78,21 @@ class SortShuffleWriterSuite } } + private def resetDependency(rowBasedChecksumEnabled: Boolean): Unit = { + reset(dependency) + when(dependency.partitioner).thenReturn(partitioner) + when(dependency.serializer).thenReturn(serializer) + when(dependency.aggregator).thenReturn(None) + when(dependency.keyOrdering).thenReturn(None) + val checksumSize = if (rowBasedChecksumEnabled) { + numMaps + } else { + 0 + } + val rowBasedChecksums = createPartitionRowBasedChecksums(checksumSize) + when(dependency.rowBasedChecksums).thenReturn(rowBasedChecksums) + } + test("write empty iterator") { val context = MemoryTestingUtils.fakeTaskContext(sc.env) val writer = new SortShuffleWriter[Int, Int, Int]( @@ -114,6 +128,44 @@ class SortShuffleWriterSuite assert(records.size === writeMetrics.recordsWritten) } + test("Row-based checksums are independent of input row order") { + val shuffleBlockResolver = new IndexShuffleBlockResolver(conf) + val context = MemoryTestingUtils.fakeTaskContext(sc.env) + val records: List[(Int, Int)] = List( + (1, 1), (1, 2), (1, 3), (1, 4), (1, 5), + (2, 2), (2, 3), (2, 4), (2, 5), (2, 6), + (3, 3), (3, 4), (3, 5), (3, 6), (3, 7), + (4, 4), (4, 5), (4, 6), (4, 7), (4, 8), + (5, 5), (5, 6), (5, 7), (5, 8), (5, 9)) + + var checksumValues : Array[Long] = Array[Long]() + var aggregatedChecksumValue = 0L + for (i <- 1 to 100) { + resetDependency(rowBasedChecksumEnabled = true) + val writer = new SortShuffleWriter[Int, Int, Int]( + shuffleHandle, + mapId = 2, + context, + context.taskMetrics().shuffleWriteMetrics, + new LocalDiskShuffleExecutorComponents( + conf, shuffleBlockResolver._blockManager, shuffleBlockResolver)) + writer.write(Random.shuffle(records).iterator) + if(i == 1) { + checksumValues = getRowBasedChecksumValues(writer.getRowBasedChecksums) + assert(checksumValues.length > 0) + assert(checksumValues.forall(_ > 0)) + + aggregatedChecksumValue = writer.getAggregatedChecksumValue + assert(aggregatedChecksumValue != 0) + } else { + assert(checksumValues.sameElements( + getRowBasedChecksumValues(writer.getRowBasedChecksums))) + assert(aggregatedChecksumValue == writer.getAggregatedChecksumValue) + } + writer.stop(success = true) + } + } + Seq((true, false, false), (true, true, false), (true, false, true), @@ -141,6 +193,7 @@ class SortShuffleWriterSuite when(dependency.serializer).thenReturn(serializer) when(dependency.aggregator).thenReturn(aggregator) when(dependency.keyOrdering).thenReturn(order) + when(dependency.rowBasedChecksums).thenReturn(Array.empty) new BaseShuffleHandle[Int, Int, Int](shuffleId, dependency) } diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRowChecksum.scala b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRowChecksum.scala new file mode 100644 index 0000000000000..2be675070eb10 --- /dev/null +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRowChecksum.scala @@ -0,0 +1,53 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions + +import org.apache.spark.shuffle.checksum.RowBasedChecksum + +/** + * A concrete implementation of RowBasedChecksum for computing checksum for UnsafeRow. + * The checksum for each row is computed by first casting or converting the baseObject + * in the UnsafeRow to a byte array, and then computing the checksum for the byte array. + * + * Note that the input key is ignored in the checksum computation. As the Spark shuffle + * currently uses a PartitionIdPassthrough partitioner, the keys are already the partition + * IDs for sending the data, and they are the same for all rows in the same partition. + */ +class UnsafeRowChecksum extends RowBasedChecksum() { + + override protected def calculateRowChecksum(key: Any, value: Any): Long = { + assert( + value.isInstanceOf[UnsafeRow], + "Expecting UnsafeRow but got " + value.getClass.getName) + + // Casts or converts the baseObject in UnsafeRow to a byte array. + val unsafeRow = value.asInstanceOf[UnsafeRow] + XXH64.hashUnsafeBytes( + unsafeRow.getBaseObject, + unsafeRow.getBaseOffset, + unsafeRow.getSizeInBytes, + 0 + ) + } +} + +object UnsafeRowChecksum { + def createUnsafeRowChecksums(numPartitions: Int): Array[RowBasedChecksum] = { + Array.tabulate(numPartitions)(_ => new UnsafeRowChecksum()) + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 48bbe09f16431..3f0156c75f727 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -875,6 +875,21 @@ object SQLConf { .checkValue(_ > 0, "The value of spark.sql.shuffle.partitions must be positive") .createWithDefault(200) + val SHUFFLE_ORDER_INDEPENDENT_CHECKSUM_ENABLED = + buildConf("spark.sql.shuffle.orderIndependentChecksum.enabled") + .doc("Whether to calculate order independent checksum for the shuffle data or not. If " + + "enabled, Spark will calculate a checksum that is independent of the input row order for " + + "each mapper and returns the checksums from executors to driver. This is different from " + + "the checksum computed when spark.shuffle.checksum.enabled is enabled which is sensitive " + + "to shuffle data ordering to detect file corruption. While this checksum will be the " + + "same even if the shuffle row order changes and it is used to detect whether different " + + "task attempts of the same partition produce different output data or not (same set of " + + "keyValue pairs). In case the output data has changed across retries, Spark will need to " + + "retry all tasks of the consumer stages to avoid correctness issues.") + .version("4.1.0") + .booleanConf + .createWithDefault(false) + val SHUFFLE_TARGET_POSTSHUFFLE_INPUT_SIZE = buildConf("spark.sql.adaptive.shuffle.targetPostShuffleInputSize") .internal() @@ -6613,6 +6628,9 @@ class SQLConf extends Serializable with Logging with SqlApiConf { } } + def shuffleOrderIndependentChecksumEnabled: Boolean = + getConf(SHUFFLE_ORDER_INDEPENDENT_CHECKSUM_ENABLED) + def allowCollationsInMapKeys: Boolean = getConf(ALLOW_COLLATIONS_IN_MAP_KEYS) def objectLevelCollationsEnabled: Boolean = getConf(OBJECT_LEVEL_COLLATIONS_ENABLED) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchangeExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchangeExec.scala index 31a3f53eb7191..fb157b4141d05 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchangeExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchangeExec.scala @@ -30,7 +30,7 @@ import org.apache.spark.serializer.Serializer import org.apache.spark.shuffle.{ShuffleWriteMetricsReporter, ShuffleWriteProcessor} import org.apache.spark.shuffle.sort.SortShuffleManager import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.{Attribute, BoundReference, UnsafeProjection, UnsafeRow} +import org.apache.spark.sql.catalyst.expressions.{Attribute, BoundReference, UnsafeProjection, UnsafeRow, UnsafeRowChecksum} import org.apache.spark.sql.catalyst.expressions.BindReferences.bindReferences import org.apache.spark.sql.catalyst.expressions.codegen.LazilyGeneratedOrdering import org.apache.spark.sql.catalyst.plans.logical.Statistics @@ -471,12 +471,19 @@ object ShuffleExchangeExec { // Now, we manually create a ShuffleDependency. Because pairs in rddWithPartitionIds // are in the form of (partitionId, row) and every partitionId is in the expected range // [0, part.numPartitions - 1]. The partitioner of this is a PartitionIdPassthrough. + val checksumSize = + if (SQLConf.get.shuffleOrderIndependentChecksumEnabled) { + part.numPartitions + } else { + 0 + } val dependency = new ShuffleDependency[Int, InternalRow, InternalRow]( rddWithPartitionIds, new PartitionIdPassthrough(part.numPartitions), serializer, - shuffleWriterProcessor = createShuffleWriteProcessor(writeMetrics)) + shuffleWriterProcessor = createShuffleWriteProcessor(writeMetrics), + rowBasedChecksums = UnsafeRowChecksum.createUnsafeRowChecksums(checksumSize)) dependency } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/MapStatusEndToEndSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/MapStatusEndToEndSuite.scala new file mode 100644 index 0000000000000..0fe6603122103 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/MapStatusEndToEndSuite.scala @@ -0,0 +1,64 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import org.apache.spark.{MapOutputTrackerMaster, SparkFunSuite} +import org.apache.spark.sql.classic.SparkSession +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.test.SQLTestUtils + +class MapStatusEndToEndSuite extends SparkFunSuite with SQLTestUtils { + override def spark: SparkSession = SparkSession.builder() + .master("local") + .config(SQLConf.SHUFFLE_ORDER_INDEPENDENT_CHECKSUM_ENABLED.key, value = true) + .config(SQLConf.LEAF_NODE_DEFAULT_PARALLELISM.key, value = 5) + .config(SQLConf.CLASSIC_SHUFFLE_DEPENDENCY_FILE_CLEANUP_ENABLED.key, value = false) + .getOrCreate() + + override def afterAll(): Unit = { + // This suite should not interfere with the other test suites. + SparkSession.getActiveSession.foreach(_.stop()) + SparkSession.clearActiveSession() + SparkSession.getDefaultSession.foreach(_.stop()) + SparkSession.clearDefaultSession() + } + + test("Propagate checksum from executor to driver") { + assert(spark.sparkContext.conf + .get("spark.sql.shuffle.orderIndependentChecksum.enabled") == "true") + assert(spark.conf.get("spark.sql.shuffle.orderIndependentChecksum.enabled") == "true") + assert(spark.sparkContext.conf.get("spark.sql.leafNodeDefaultParallelism") == "5") + assert(spark.conf.get("spark.sql.leafNodeDefaultParallelism") == "5") + assert(spark.sparkContext.conf.get("spark.sql.classic.shuffleDependency.fileCleanup.enabled") + == "false") + assert(spark.conf.get("spark.sql.classic.shuffleDependency.fileCleanup.enabled") == "false") + + withTable("t") { + spark.range(1000).repartition(10).write.mode("overwrite"). + saveAsTable("t") + } + + val shuffleStatuses = spark.sparkContext.env.mapOutputTracker. + asInstanceOf[MapOutputTrackerMaster].shuffleStatuses + assert(shuffleStatuses.size == 1) + + val mapStatuses = shuffleStatuses(0).mapStatuses + assert(mapStatuses.length == 5) + assert(mapStatuses.forall(_.checksumValue != 0)) + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/UnsafeRowChecksumSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/UnsafeRowChecksumSuite.scala new file mode 100644 index 0000000000000..07941ad626332 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/UnsafeRowChecksumSuite.scala @@ -0,0 +1,149 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import java.nio.ByteBuffer + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder +import org.apache.spark.sql.catalyst.expressions.{UnsafeRow, UnsafeRowChecksum} +import org.apache.spark.sql.types._ + +class UnsafeRowChecksumSuite extends SparkFunSuite { + private val schema = new StructType().add("value", IntegerType) + private val toUnsafeRow = ExpressionEncoder(schema).createSerializer() + + private val schemaComplex = new StructType() + .add("stringCol", StringType) + .add("doubleCol", DoubleType) + .add("longCol", LongType) + .add("int32Col", IntegerType) + .add("int16Col", ShortType) + .add("int8Col", ByteType) + .add("boolCol", BooleanType) + private val toUnsafeRowComplex = ExpressionEncoder(schemaComplex).createSerializer() + + private def setUnsafeRowValue( + stringCol: String, + doubleCol: Double, + longCol: Long, + int32Col: Int, + int16Col: Short, + int8Col: Byte, + boolCol: Boolean, + unsafeRowOffheap: UnsafeRow): Unit = { + unsafeRowOffheap.writeFieldTo(0, ByteBuffer.wrap(stringCol.getBytes)) + unsafeRowOffheap.setDouble(1, doubleCol) + unsafeRowOffheap.setLong(2, longCol) + unsafeRowOffheap.setInt(3, int32Col) + unsafeRowOffheap.setShort(4, int16Col) + unsafeRowOffheap.setByte(5, int8Col) + unsafeRowOffheap.setBoolean(6, boolCol) + } + + test("Non-UnsafeRow value should fail") { + val rowBasedChecksum = new UnsafeRowChecksum() + rowBasedChecksum.update(1, Long.box(20)) + // We fail to compute the checksum, and getValue returns 0. + assert(rowBasedChecksum.getValue == 0) + } + + test("Two identical rows should not have a checksum of zero") { + val rowBasedChecksum = new UnsafeRowChecksum() + assert(rowBasedChecksum.getValue == 0) + + // Updates the checksum with one row. + rowBasedChecksum.update(1, toUnsafeRow(Row(20))) + assert(rowBasedChecksum.getValue == -9094624449814316735L) + + // Updates the checksum with the same row again, since we mix the final xor and sum + // of the row-based checksum, the result would not be 0. + rowBasedChecksum.update(1, toUnsafeRow(Row(20))) + assert(rowBasedChecksum.getValue == -1240577858172431653L) + } + + test("The checksum is independent of row order - two rows") { + val rowBasedChecksum1 = new UnsafeRowChecksum() + val rowBasedChecksum2 = new UnsafeRowChecksum() + assert(rowBasedChecksum1.getValue == 0) + assert(rowBasedChecksum2.getValue == 0) + + rowBasedChecksum1.update(1, toUnsafeRow(Row(20))) + rowBasedChecksum2.update(1, toUnsafeRow(Row(40))) + assert(rowBasedChecksum1.getValue != rowBasedChecksum2.getValue) + + rowBasedChecksum1.update(2, toUnsafeRow(Row(40))) + rowBasedChecksum2.update(2, toUnsafeRow(Row(20))) + assert(rowBasedChecksum1.getValue == rowBasedChecksum2.getValue) + + assert(rowBasedChecksum1.getValue != 0) + assert(rowBasedChecksum2.getValue != 0) + } + + test("The checksum is independent of row order - multiple rows") { + val rowBasedChecksum1 = new UnsafeRowChecksum() + val rowBasedChecksum2 = new UnsafeRowChecksum() + assert(rowBasedChecksum1.getValue == 0) + assert(rowBasedChecksum2.getValue == 0) + + rowBasedChecksum1.update(1, toUnsafeRow(Row(20))) + rowBasedChecksum2.update(1, toUnsafeRow(Row(100))) + assert(rowBasedChecksum1.getValue != rowBasedChecksum2.getValue) + + rowBasedChecksum1.update(2, toUnsafeRow(Row(40))) + rowBasedChecksum2.update(2, toUnsafeRow(Row(80))) + assert(rowBasedChecksum1.getValue != rowBasedChecksum2.getValue) + + rowBasedChecksum1.update(3, toUnsafeRow(Row(60))) + rowBasedChecksum2.update(3, toUnsafeRow(Row(60))) + assert(rowBasedChecksum1.getValue != rowBasedChecksum2.getValue) + + rowBasedChecksum1.update(4, toUnsafeRow(Row(80))) + rowBasedChecksum2.update(4, toUnsafeRow(Row(40))) + assert(rowBasedChecksum1.getValue != rowBasedChecksum2.getValue) + + rowBasedChecksum1.update(5, toUnsafeRow(Row(100))) + rowBasedChecksum2.update(5, toUnsafeRow(Row(20))) + assert(rowBasedChecksum1.getValue == rowBasedChecksum2.getValue) + + assert(rowBasedChecksum1.getValue != 0) + assert(rowBasedChecksum2.getValue != 0) + } + + test("The checksum is independent of row order - complex rows") { + val rowBasedChecksum1 = new UnsafeRowChecksum() + val rowBasedChecksum2 = new UnsafeRowChecksum() + assert(rowBasedChecksum1.getValue == 0) + assert(rowBasedChecksum2.getValue == 0) + + rowBasedChecksum1.update(1, toUnsafeRowComplex(Row( + "Some string", 0.99, 10000L, 1000, 100.toShort, 10.toByte, true))) + rowBasedChecksum2.update(1, toUnsafeRowComplex(Row( + "Some other string", 10.88, 20000L, 2000, 200.toShort, 20.toByte, false))) + assert(rowBasedChecksum1.getValue != rowBasedChecksum2.getValue) + + rowBasedChecksum1.update(2, toUnsafeRowComplex(Row( + "Some other string", 10.88, 20000L, 2000, 200.toShort, 20.toByte, false))) + rowBasedChecksum2.update(2, toUnsafeRowComplex(Row( + "Some string", 0.99, 10000L, 1000, 100.toShort, 10.toByte, true))) + assert(rowBasedChecksum1.getValue == rowBasedChecksum2.getValue) + + assert(rowBasedChecksum1.getValue != 0) + assert(rowBasedChecksum2.getValue != 0) + } +}