Skip to content

Commit 3f0c131

Browse files
committed
[SPARK-25299] Move shuffle writers back to being given specific partition ids (apache-spark-on-k8s#540)
We originally made the shuffle map output writer API behave like an iterator in fetching the "next" partition writer. However, the shuffle writer implementations tend to skip opening empty partitions. If we used an iterator-like API though we would be tied down to opening a partition writer for every single partition, even if some of them are empty. Here, we go back to using specific partition identifiers to give us more freedom to avoid needing to create writers for empty partitions.
1 parent e17c7ea commit 3f0c131

File tree

6 files changed

+17
-57
lines changed

6 files changed

+17
-57
lines changed

core/src/main/java/org/apache/spark/api/shuffle/ShuffleMapOutputWriter.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@
3030
*/
3131
@Experimental
3232
public interface ShuffleMapOutputWriter {
33-
ShufflePartitionWriter getNextPartitionWriter() throws IOException;
33+
ShufflePartitionWriter getPartitionWriter(int partitionId) throws IOException;
3434

3535
Optional<MapShuffleLocations> commitAllPartitions() throws IOException;
3636

core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -216,7 +216,7 @@ private long[] writePartitionedData(ShuffleMapOutputWriter mapOutputWriter) thro
216216
boolean copyThrewException = true;
217217
ShufflePartitionWriter writer = null;
218218
try {
219-
writer = mapOutputWriter.getNextPartitionWriter();
219+
writer = mapOutputWriter.getPartitionWriter(i);
220220
if (!file.exists()) {
221221
copyThrewException = false;
222222
} else {

core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java

Lines changed: 2 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -291,18 +291,6 @@ private long[] mergeSpills(SpillInfo[] spills,
291291
long[] partitionLengths = new long[numPartitions];
292292
try {
293293
if (spills.length == 0) {
294-
// The contract we are working under states that we will open a partition writer for
295-
// each partition, regardless of number of spills
296-
for (int i = 0; i < numPartitions; i++) {
297-
ShufflePartitionWriter writer = null;
298-
try {
299-
writer = mapWriter.getNextPartitionWriter();
300-
} finally {
301-
if (writer != null) {
302-
writer.close();
303-
}
304-
}
305-
}
306294
return partitionLengths;
307295
} else {
308296
// There are multiple spills to merge, so none of these spill files' lengths were counted
@@ -378,7 +366,7 @@ private long[] mergeSpillsWithFileStream(
378366
boolean copyThrewExecption = true;
379367
ShufflePartitionWriter writer = null;
380368
try {
381-
writer = mapWriter.getNextPartitionWriter();
369+
writer = mapWriter.getPartitionWriter(partition);
382370
OutputStream partitionOutput = null;
383371
try {
384372
// Shield the underlying output stream from close() calls, so that we can close the
@@ -457,7 +445,7 @@ private long[] mergeSpillsWithTransferTo(
457445
boolean copyThrewExecption = true;
458446
ShufflePartitionWriter writer = null;
459447
try {
460-
writer = mapWriter.getNextPartitionWriter();
448+
writer = mapWriter.getPartitionWriter(partition);
461449
WritableByteChannel channel = writer.toChannel();
462450
for (int i = 0; i < spills.length; i++) {
463451
long partitionLengthInSpill = 0L;

core/src/main/java/org/apache/spark/shuffle/sort/io/DefaultShuffleMapOutputWriter.java

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ public class DefaultShuffleMapOutputWriter implements ShuffleMapOutputWriter {
5151
private final IndexShuffleBlockResolver blockResolver;
5252
private final long[] partitionLengths;
5353
private final int bufferSize;
54-
private int currPartitionId = 0;
54+
private int lastPartitionId = -1;
5555
private long currChannelPosition;
5656
private final BlockManagerId shuffleServerId;
5757

@@ -84,7 +84,11 @@ public DefaultShuffleMapOutputWriter(
8484
}
8585

8686
@Override
87-
public ShufflePartitionWriter getNextPartitionWriter() throws IOException {
87+
public ShufflePartitionWriter getPartitionWriter(int partitionId) throws IOException {
88+
if (partitionId <= lastPartitionId) {
89+
throw new IllegalArgumentException("Partitions should be requested in increasing order.");
90+
}
91+
lastPartitionId = partitionId;
8892
if (outputTempFile == null) {
8993
outputTempFile = Utils.tempFileWith(outputFile);
9094
}
@@ -93,7 +97,7 @@ public ShufflePartitionWriter getNextPartitionWriter() throws IOException {
9397
} else {
9498
currChannelPosition = 0L;
9599
}
96-
return new DefaultShufflePartitionWriter(currPartitionId++);
100+
return new DefaultShufflePartitionWriter(partitionId);
97101
}
98102

99103
@Override

core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala

Lines changed: 2 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -717,17 +717,6 @@ private[spark] class ExternalSorter[K, V, C](
717717
lengths
718718
}
719719

720-
private def writeEmptyPartition(mapOutputWriter: ShuffleMapOutputWriter): Unit = {
721-
var partitionWriter: ShufflePartitionWriter = null
722-
try {
723-
partitionWriter = mapOutputWriter.getNextPartitionWriter
724-
} finally {
725-
if (partitionWriter != null) {
726-
partitionWriter.close()
727-
}
728-
}
729-
}
730-
731720
/**
732721
* Write all the data added into this ExternalSorter into a map output writer that pushes bytes
733722
* to some arbitrary backing store. This is called by the SortShuffleWriter.
@@ -738,26 +727,16 @@ private[spark] class ExternalSorter[K, V, C](
738727
shuffleId: Int, mapId: Int, mapOutputWriter: ShuffleMapOutputWriter): Array[Long] = {
739728
// Track location of each range in the map output
740729
val lengths = new Array[Long](numPartitions)
741-
var nextPartitionId = 0
742730
if (spills.isEmpty) {
743731
// Case where we only have in-memory data
744732
val collection = if (aggregator.isDefined) map else buffer
745733
val it = collection.destructiveSortedWritablePartitionedIterator(comparator)
746734
while (it.hasNext()) {
747735
val partitionId = it.nextPartition()
748-
// The contract for the plugin is that we will ask for a writer for every partition
749-
// even if it's empty. However, the external sorter will return non-contiguous
750-
// partition ids. So this loop "backfills" the empty partitions that form the gaps.
751-
752-
// The algorithm as a whole is correct because the partition ids are returned by the
753-
// iterator in ascending order.
754-
for (emptyPartition <- nextPartitionId until partitionId) {
755-
writeEmptyPartition(mapOutputWriter)
756-
}
757736
var partitionWriter: ShufflePartitionWriter = null
758737
var partitionPairsWriter: ShufflePartitionPairsWriter = null
759738
try {
760-
partitionWriter = mapOutputWriter.getNextPartitionWriter
739+
partitionWriter = mapOutputWriter.getPartitionWriter(partitionId)
761740
val blockId = ShuffleBlockId(shuffleId, mapId, partitionId)
762741
partitionPairsWriter = new ShufflePartitionPairsWriter(
763742
partitionWriter,
@@ -779,7 +758,6 @@ private[spark] class ExternalSorter[K, V, C](
779758
if (partitionWriter != null) {
780759
lengths(partitionId) = partitionWriter.getNumBytesWritten
781760
}
782-
nextPartitionId = partitionId + 1
783761
}
784762
} else {
785763
// We must perform merge-sort; get an iterator by partition and write everything directly.
@@ -790,14 +768,11 @@ private[spark] class ExternalSorter[K, V, C](
790768

791769
// The algorithm as a whole is correct because the partition ids are returned by the
792770
// iterator in ascending order.
793-
for (emptyPartition <- nextPartitionId until id) {
794-
writeEmptyPartition(mapOutputWriter)
795-
}
796771
val blockId = ShuffleBlockId(shuffleId, mapId, id)
797772
var partitionWriter: ShufflePartitionWriter = null
798773
var partitionPairsWriter: ShufflePartitionPairsWriter = null
799774
try {
800-
partitionWriter = mapOutputWriter.getNextPartitionWriter
775+
partitionWriter = mapOutputWriter.getPartitionWriter(id)
801776
partitionPairsWriter = new ShufflePartitionPairsWriter(
802777
partitionWriter,
803778
serializerManager,
@@ -817,16 +792,9 @@ private[spark] class ExternalSorter[K, V, C](
817792
if (partitionWriter != null) {
818793
lengths(id) = partitionWriter.getNumBytesWritten
819794
}
820-
nextPartitionId = id + 1
821795
}
822796
}
823797

824-
// The iterator may have stopped short of opening a writer for every partition. So fill in the
825-
// remaining empty partitions.
826-
for (emptyPartition <- nextPartitionId until numPartitions) {
827-
writeEmptyPartition(mapOutputWriter)
828-
}
829-
830798
context.taskMetrics().incMemoryBytesSpilled(memoryBytesSpilled)
831799
context.taskMetrics().incDiskBytesSpilled(diskBytesSpilled)
832800
context.taskMetrics().incPeakExecutionMemory(peakMemoryUsedBytes)

core/src/test/scala/org/apache/spark/shuffle/sort/io/DefaultShuffleMapOutputWriterSuite.scala

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -140,7 +140,7 @@ class DefaultShuffleMapOutputWriterSuite extends SparkFunSuite with BeforeAndAft
140140

141141
test("writing to an outputstream") {
142142
(0 until NUM_PARTITIONS).foreach{ p =>
143-
val writer = mapOutputWriter.getNextPartitionWriter
143+
val writer = mapOutputWriter.getPartitionWriter(p)
144144
val stream = writer.toStream()
145145
data(p).foreach { i => stream.write(i)}
146146
stream.close()
@@ -159,7 +159,7 @@ class DefaultShuffleMapOutputWriterSuite extends SparkFunSuite with BeforeAndAft
159159

160160
test("writing to a channel") {
161161
(0 until NUM_PARTITIONS).foreach{ p =>
162-
val writer = mapOutputWriter.getNextPartitionWriter
162+
val writer = mapOutputWriter.getPartitionWriter(p)
163163
val channel = writer.toChannel()
164164
val byteBuffer = ByteBuffer.allocate(D_LEN * 4)
165165
val intBuffer = byteBuffer.asIntBuffer()
@@ -179,7 +179,7 @@ class DefaultShuffleMapOutputWriterSuite extends SparkFunSuite with BeforeAndAft
179179

180180
test("copyStreams with an outputstream") {
181181
(0 until NUM_PARTITIONS).foreach{ p =>
182-
val writer = mapOutputWriter.getNextPartitionWriter
182+
val writer = mapOutputWriter.getPartitionWriter(p)
183183
val stream = writer.toStream()
184184
val byteBuffer = ByteBuffer.allocate(D_LEN * 4)
185185
val intBuffer = byteBuffer.asIntBuffer()
@@ -200,7 +200,7 @@ class DefaultShuffleMapOutputWriterSuite extends SparkFunSuite with BeforeAndAft
200200

201201
test("copyStreamsWithNIO with a channel") {
202202
(0 until NUM_PARTITIONS).foreach{ p =>
203-
val writer = mapOutputWriter.getNextPartitionWriter
203+
val writer = mapOutputWriter.getPartitionWriter(p)
204204
val channel = writer.toChannel()
205205
val byteBuffer = ByteBuffer.allocate(D_LEN * 4)
206206
val intBuffer = byteBuffer.asIntBuffer()

0 commit comments

Comments
 (0)