Skip to content

Commit 31dd8a4

Browse files
mccheahifilonenko
authored andcommitted
[SPARK-25299] Move shuffle writers back to being given specific partition ids (apache-spark-on-k8s#540)
We originally made the shuffle map output writer API behave like an iterator in fetching the "next" partition writer. However, the shuffle writer implementations tend to skip opening empty partitions. If we used an iterator-like API though we would be tied down to opening a partition writer for every single partition, even if some of them are empty. Here, we go back to using specific partition identifiers to give us more freedom to avoid needing to create writers for empty partitions.
1 parent b9a2bc9 commit 31dd8a4

File tree

6 files changed

+17
-57
lines changed

6 files changed

+17
-57
lines changed

core/src/main/java/org/apache/spark/api/shuffle/ShuffleMapOutputWriter.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@
3030
*/
3131
@Experimental
3232
public interface ShuffleMapOutputWriter {
33-
ShufflePartitionWriter getNextPartitionWriter() throws IOException;
33+
ShufflePartitionWriter getPartitionWriter(int partitionId) throws IOException;
3434

3535
Optional<MapShuffleLocations> commitAllPartitions() throws IOException;
3636

core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -216,7 +216,7 @@ private long[] writePartitionedData(ShuffleMapOutputWriter mapOutputWriter) thro
216216
boolean copyThrewException = true;
217217
ShufflePartitionWriter writer = null;
218218
try {
219-
writer = mapOutputWriter.getNextPartitionWriter();
219+
writer = mapOutputWriter.getPartitionWriter(i);
220220
if (!file.exists()) {
221221
copyThrewException = false;
222222
} else {

core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java

Lines changed: 2 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -291,18 +291,6 @@ private long[] mergeSpills(SpillInfo[] spills,
291291
long[] partitionLengths = new long[numPartitions];
292292
try {
293293
if (spills.length == 0) {
294-
// The contract we are working under states that we will open a partition writer for
295-
// each partition, regardless of number of spills
296-
for (int i = 0; i < numPartitions; i++) {
297-
ShufflePartitionWriter writer = null;
298-
try {
299-
writer = mapWriter.getNextPartitionWriter();
300-
} finally {
301-
if (writer != null) {
302-
writer.close();
303-
}
304-
}
305-
}
306294
return partitionLengths;
307295
} else {
308296
// There are multiple spills to merge, so none of these spill files' lengths were counted
@@ -378,7 +366,7 @@ private long[] mergeSpillsWithFileStream(
378366
boolean copyThrewExecption = true;
379367
ShufflePartitionWriter writer = null;
380368
try {
381-
writer = mapWriter.getNextPartitionWriter();
369+
writer = mapWriter.getPartitionWriter(partition);
382370
OutputStream partitionOutput = null;
383371
try {
384372
// Shield the underlying output stream from close() calls, so that we can close the
@@ -457,7 +445,7 @@ private long[] mergeSpillsWithTransferTo(
457445
boolean copyThrewExecption = true;
458446
ShufflePartitionWriter writer = null;
459447
try {
460-
writer = mapWriter.getNextPartitionWriter();
448+
writer = mapWriter.getPartitionWriter(partition);
461449
WritableByteChannel channel = writer.toChannel();
462450
for (int i = 0; i < spills.length; i++) {
463451
long partitionLengthInSpill = 0L;

core/src/main/java/org/apache/spark/shuffle/sort/io/DefaultShuffleMapOutputWriter.java

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ public class DefaultShuffleMapOutputWriter implements ShuffleMapOutputWriter {
5151
private final IndexShuffleBlockResolver blockResolver;
5252
private final long[] partitionLengths;
5353
private final int bufferSize;
54-
private int currPartitionId = 0;
54+
private int lastPartitionId = -1;
5555
private long currChannelPosition;
5656
private final BlockManagerId shuffleServerId;
5757

@@ -84,7 +84,11 @@ public DefaultShuffleMapOutputWriter(
8484
}
8585

8686
@Override
87-
public ShufflePartitionWriter getNextPartitionWriter() throws IOException {
87+
public ShufflePartitionWriter getPartitionWriter(int partitionId) throws IOException {
88+
if (partitionId <= lastPartitionId) {
89+
throw new IllegalArgumentException("Partitions should be requested in increasing order.");
90+
}
91+
lastPartitionId = partitionId;
8892
if (outputTempFile == null) {
8993
outputTempFile = Utils.tempFileWith(outputFile);
9094
}
@@ -93,7 +97,7 @@ public ShufflePartitionWriter getNextPartitionWriter() throws IOException {
9397
} else {
9498
currChannelPosition = 0L;
9599
}
96-
return new DefaultShufflePartitionWriter(currPartitionId++);
100+
return new DefaultShufflePartitionWriter(partitionId);
97101
}
98102

99103
@Override

core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala

Lines changed: 2 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -721,17 +721,6 @@ private[spark] class ExternalSorter[K, V, C](
721721
lengths
722722
}
723723

724-
private def writeEmptyPartition(mapOutputWriter: ShuffleMapOutputWriter): Unit = {
725-
var partitionWriter: ShufflePartitionWriter = null
726-
try {
727-
partitionWriter = mapOutputWriter.getNextPartitionWriter
728-
} finally {
729-
if (partitionWriter != null) {
730-
partitionWriter.close()
731-
}
732-
}
733-
}
734-
735724
/**
736725
* Write all the data added into this ExternalSorter into a map output writer that pushes bytes
737726
* to some arbitrary backing store. This is called by the SortShuffleWriter.
@@ -742,26 +731,16 @@ private[spark] class ExternalSorter[K, V, C](
742731
shuffleId: Int, mapId: Int, mapOutputWriter: ShuffleMapOutputWriter): Array[Long] = {
743732
// Track location of each range in the map output
744733
val lengths = new Array[Long](numPartitions)
745-
var nextPartitionId = 0
746734
if (spills.isEmpty) {
747735
// Case where we only have in-memory data
748736
val collection = if (aggregator.isDefined) map else buffer
749737
val it = collection.destructiveSortedWritablePartitionedIterator(comparator)
750738
while (it.hasNext()) {
751739
val partitionId = it.nextPartition()
752-
// The contract for the plugin is that we will ask for a writer for every partition
753-
// even if it's empty. However, the external sorter will return non-contiguous
754-
// partition ids. So this loop "backfills" the empty partitions that form the gaps.
755-
756-
// The algorithm as a whole is correct because the partition ids are returned by the
757-
// iterator in ascending order.
758-
for (emptyPartition <- nextPartitionId until partitionId) {
759-
writeEmptyPartition(mapOutputWriter)
760-
}
761740
var partitionWriter: ShufflePartitionWriter = null
762741
var partitionPairsWriter: ShufflePartitionPairsWriter = null
763742
try {
764-
partitionWriter = mapOutputWriter.getNextPartitionWriter
743+
partitionWriter = mapOutputWriter.getPartitionWriter(partitionId)
765744
val blockId = ShuffleBlockId(shuffleId, mapId, partitionId)
766745
partitionPairsWriter = new ShufflePartitionPairsWriter(
767746
partitionWriter,
@@ -783,7 +762,6 @@ private[spark] class ExternalSorter[K, V, C](
783762
if (partitionWriter != null) {
784763
lengths(partitionId) = partitionWriter.getNumBytesWritten
785764
}
786-
nextPartitionId = partitionId + 1
787765
}
788766
} else {
789767
// We must perform merge-sort; get an iterator by partition and write everything directly.
@@ -794,14 +772,11 @@ private[spark] class ExternalSorter[K, V, C](
794772

795773
// The algorithm as a whole is correct because the partition ids are returned by the
796774
// iterator in ascending order.
797-
for (emptyPartition <- nextPartitionId until id) {
798-
writeEmptyPartition(mapOutputWriter)
799-
}
800775
val blockId = ShuffleBlockId(shuffleId, mapId, id)
801776
var partitionWriter: ShufflePartitionWriter = null
802777
var partitionPairsWriter: ShufflePartitionPairsWriter = null
803778
try {
804-
partitionWriter = mapOutputWriter.getNextPartitionWriter
779+
partitionWriter = mapOutputWriter.getPartitionWriter(id)
805780
partitionPairsWriter = new ShufflePartitionPairsWriter(
806781
partitionWriter,
807782
serializerManager,
@@ -821,16 +796,9 @@ private[spark] class ExternalSorter[K, V, C](
821796
if (partitionWriter != null) {
822797
lengths(id) = partitionWriter.getNumBytesWritten
823798
}
824-
nextPartitionId = id + 1
825799
}
826800
}
827801

828-
// The iterator may have stopped short of opening a writer for every partition. So fill in the
829-
// remaining empty partitions.
830-
for (emptyPartition <- nextPartitionId until numPartitions) {
831-
writeEmptyPartition(mapOutputWriter)
832-
}
833-
834802
context.taskMetrics().incMemoryBytesSpilled(memoryBytesSpilled)
835803
context.taskMetrics().incDiskBytesSpilled(diskBytesSpilled)
836804
context.taskMetrics().incPeakExecutionMemory(peakMemoryUsedBytes)

core/src/test/scala/org/apache/spark/shuffle/sort/io/DefaultShuffleMapOutputWriterSuite.scala

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -140,7 +140,7 @@ class DefaultShuffleMapOutputWriterSuite extends SparkFunSuite with BeforeAndAft
140140

141141
test("writing to an outputstream") {
142142
(0 until NUM_PARTITIONS).foreach{ p =>
143-
val writer = mapOutputWriter.getNextPartitionWriter
143+
val writer = mapOutputWriter.getPartitionWriter(p)
144144
val stream = writer.toStream()
145145
data(p).foreach { i => stream.write(i)}
146146
stream.close()
@@ -159,7 +159,7 @@ class DefaultShuffleMapOutputWriterSuite extends SparkFunSuite with BeforeAndAft
159159

160160
test("writing to a channel") {
161161
(0 until NUM_PARTITIONS).foreach{ p =>
162-
val writer = mapOutputWriter.getNextPartitionWriter
162+
val writer = mapOutputWriter.getPartitionWriter(p)
163163
val channel = writer.toChannel()
164164
val byteBuffer = ByteBuffer.allocate(D_LEN * 4)
165165
val intBuffer = byteBuffer.asIntBuffer()
@@ -179,7 +179,7 @@ class DefaultShuffleMapOutputWriterSuite extends SparkFunSuite with BeforeAndAft
179179

180180
test("copyStreams with an outputstream") {
181181
(0 until NUM_PARTITIONS).foreach{ p =>
182-
val writer = mapOutputWriter.getNextPartitionWriter
182+
val writer = mapOutputWriter.getPartitionWriter(p)
183183
val stream = writer.toStream()
184184
val byteBuffer = ByteBuffer.allocate(D_LEN * 4)
185185
val intBuffer = byteBuffer.asIntBuffer()
@@ -200,7 +200,7 @@ class DefaultShuffleMapOutputWriterSuite extends SparkFunSuite with BeforeAndAft
200200

201201
test("copyStreamsWithNIO with a channel") {
202202
(0 until NUM_PARTITIONS).foreach{ p =>
203-
val writer = mapOutputWriter.getNextPartitionWriter
203+
val writer = mapOutputWriter.getPartitionWriter(p)
204204
val channel = writer.toChannel()
205205
val byteBuffer = ByteBuffer.allocate(D_LEN * 4)
206206
val intBuffer = byteBuffer.asIntBuffer()

0 commit comments

Comments
 (0)