Skip to content

Commit 8f5fb60

Browse files
committed
[SPARK-25299] Use the shuffle writer plugin for the SortShuffleWriter. (apache-spark-on-k8s#532)
* [SPARK-25299] Use the shuffle writer plugin for the SortShuffleWriter. * Remove unused * Handle empty partitions properly. * Adjust formatting * Don't close streams twice. Because compressed output streams don't like it. * Clarify comment
1 parent d13037f commit 8f5fb60

File tree

7 files changed

+266
-26
lines changed

7 files changed

+266
-26
lines changed

core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -158,7 +158,8 @@ private[spark] class SortShuffleManager(conf: SparkConf) extends ShuffleManager
158158
metrics,
159159
shuffleExecutorComponents.writes())
160160
case other: BaseShuffleHandle[K @unchecked, V @unchecked, _] =>
161-
new SortShuffleWriter(shuffleBlockResolver, other, mapId, context)
161+
new SortShuffleWriter(
162+
shuffleBlockResolver, other, mapId, context, shuffleExecutorComponents.writes())
162163
}
163164
}
164165

core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala

Lines changed: 8 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -18,18 +18,18 @@
1818
package org.apache.spark.shuffle.sort
1919

2020
import org.apache.spark._
21+
import org.apache.spark.api.shuffle.ShuffleWriteSupport
2122
import org.apache.spark.internal.{config, Logging}
2223
import org.apache.spark.scheduler.MapStatus
2324
import org.apache.spark.shuffle.{BaseShuffleHandle, IndexShuffleBlockResolver, ShuffleWriter}
24-
import org.apache.spark.storage.ShuffleBlockId
25-
import org.apache.spark.util.Utils
2625
import org.apache.spark.util.collection.ExternalSorter
2726

2827
private[spark] class SortShuffleWriter[K, V, C](
2928
shuffleBlockResolver: IndexShuffleBlockResolver,
3029
handle: BaseShuffleHandle[K, V, C],
3130
mapId: Int,
32-
context: TaskContext)
31+
context: TaskContext,
32+
writeSupport: ShuffleWriteSupport)
3333
extends ShuffleWriter[K, V] with Logging {
3434

3535
private val dep = handle.dependency
@@ -64,18 +64,11 @@ private[spark] class SortShuffleWriter[K, V, C](
6464
// Don't bother including the time to open the merged output file in the shuffle write time,
6565
// because it just opens a single file, so is typically too fast to measure accurately
6666
// (see SPARK-3570).
67-
val output = shuffleBlockResolver.getDataFile(dep.shuffleId, mapId)
68-
val tmp = Utils.tempFileWith(output)
69-
try {
70-
val blockId = ShuffleBlockId(dep.shuffleId, mapId, IndexShuffleBlockResolver.NOOP_REDUCE_ID)
71-
val partitionLengths = sorter.writePartitionedFile(blockId, tmp)
72-
shuffleBlockResolver.writeIndexFileAndCommit(dep.shuffleId, mapId, partitionLengths, tmp)
73-
mapStatus = MapStatus(blockManager.shuffleServerId, partitionLengths)
74-
} finally {
75-
if (tmp.exists() && !tmp.delete()) {
76-
logError(s"Error while deleting temp file ${tmp.getAbsolutePath}")
77-
}
78-
}
67+
val mapOutputWriter = writeSupport.createMapOutputWriter(
68+
dep.shuffleId, mapId, dep.partitioner.numPartitions)
69+
val partitionLengths = sorter.writePartitionedMapOutput(dep.shuffleId, mapId, mapOutputWriter)
70+
mapOutputWriter.commitAllPartitions()
71+
mapStatus = MapStatus(blockManager.shuffleServerId, partitionLengths)
7972
}
8073

8174
/** Close this writer, passing along whether the map completed */

core/src/main/scala/org/apache/spark/storage/DiskBlockObjectWriter.scala

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ import org.apache.spark.internal.Logging
2424
import org.apache.spark.serializer.{SerializationStream, SerializerInstance, SerializerManager}
2525
import org.apache.spark.shuffle.ShuffleWriteMetricsReporter
2626
import org.apache.spark.util.Utils
27+
import org.apache.spark.util.collection.PairsWriter
2728

2829
/**
2930
* A class for writing JVM objects directly to a file on disk. This class allows data to be appended
@@ -46,7 +47,8 @@ private[spark] class DiskBlockObjectWriter(
4647
writeMetrics: ShuffleWriteMetricsReporter,
4748
val blockId: BlockId = null)
4849
extends OutputStream
49-
with Logging {
50+
with Logging
51+
with PairsWriter {
5052

5153
/**
5254
* Guards against close calls, e.g. from a wrapping stream.

core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala

Lines changed: 123 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -26,10 +26,11 @@ import scala.collection.mutable.ArrayBuffer
2626
import com.google.common.io.ByteStreams
2727

2828
import org.apache.spark._
29+
import org.apache.spark.api.shuffle.{ShuffleMapOutputWriter, ShufflePartitionWriter}
2930
import org.apache.spark.executor.ShuffleWriteMetrics
3031
import org.apache.spark.internal.{config, Logging}
3132
import org.apache.spark.serializer._
32-
import org.apache.spark.storage.{BlockId, DiskBlockObjectWriter}
33+
import org.apache.spark.storage.{BlockId, DiskBlockObjectWriter, ShuffleBlockId}
3334

3435
/**
3536
* Sorts and potentially merges a number of key-value pairs of type (K, V) to produce key-combiner
@@ -670,11 +671,9 @@ private[spark] class ExternalSorter[K, V, C](
670671
}
671672

672673
/**
673-
* Write all the data added into this ExternalSorter into a file in the disk store. This is
674-
* called by the SortShuffleWriter.
675-
*
676-
* @param blockId block ID to write to. The index file will be blockId.name + ".index".
677-
* @return array of lengths, in bytes, of each partition of the file (used by map output tracker)
674+
* TODO remove this, as this is only used by UnsafeRowSerializerSuite in the SQL project.
675+
* We should figure out an alternative way to test that so that we can remove this otherwise
676+
* unused code path.
678677
*/
679678
def writePartitionedFile(
680679
blockId: BlockId,
@@ -718,6 +717,123 @@ private[spark] class ExternalSorter[K, V, C](
718717
lengths
719718
}
720719

720+
private def writeEmptyPartition(mapOutputWriter: ShuffleMapOutputWriter): Unit = {
721+
var partitionWriter: ShufflePartitionWriter = null
722+
try {
723+
partitionWriter = mapOutputWriter.getNextPartitionWriter
724+
} finally {
725+
if (partitionWriter != null) {
726+
partitionWriter.close()
727+
}
728+
}
729+
}
730+
731+
/**
732+
* Write all the data added into this ExternalSorter into a map output writer that pushes bytes
733+
* to some arbitrary backing store. This is called by the SortShuffleWriter.
734+
*
735+
* @return array of lengths, in bytes, of each partition of the file (used by map output tracker)
736+
*/
737+
def writePartitionedMapOutput(
738+
shuffleId: Int, mapId: Int, mapOutputWriter: ShuffleMapOutputWriter): Array[Long] = {
739+
// Track location of each range in the map output
740+
val lengths = new Array[Long](numPartitions)
741+
var nextPartitionId = 0
742+
if (spills.isEmpty) {
743+
// Case where we only have in-memory data
744+
val collection = if (aggregator.isDefined) map else buffer
745+
val it = collection.destructiveSortedWritablePartitionedIterator(comparator)
746+
while (it.hasNext()) {
747+
val partitionId = it.nextPartition()
748+
// The contract for the plugin is that we will ask for a writer for every partition
749+
// even if it's empty. However, the external sorter will return non-contiguous
750+
// partition ids. So this loop "backfills" the empty partitions that form the gaps.
751+
752+
// The algorithm as a whole is correct because the partition ids are returned by the
753+
// iterator in ascending order.
754+
for (emptyPartition <- nextPartitionId until partitionId) {
755+
writeEmptyPartition(mapOutputWriter)
756+
}
757+
var partitionWriter: ShufflePartitionWriter = null
758+
var partitionPairsWriter: ShufflePartitionPairsWriter = null
759+
try {
760+
partitionWriter = mapOutputWriter.getNextPartitionWriter
761+
val blockId = ShuffleBlockId(shuffleId, mapId, partitionId)
762+
partitionPairsWriter = new ShufflePartitionPairsWriter(
763+
partitionWriter,
764+
serializerManager,
765+
serInstance,
766+
blockId,
767+
context.taskMetrics().shuffleWriteMetrics)
768+
while (it.hasNext && it.nextPartition() == partitionId) {
769+
it.writeNext(partitionPairsWriter)
770+
}
771+
} finally {
772+
if (partitionPairsWriter != null) {
773+
partitionPairsWriter.close()
774+
}
775+
if (partitionWriter != null) {
776+
partitionWriter.close()
777+
}
778+
}
779+
if (partitionWriter != null) {
780+
lengths(partitionId) = partitionWriter.getNumBytesWritten
781+
}
782+
nextPartitionId = partitionId + 1
783+
}
784+
} else {
785+
// We must perform merge-sort; get an iterator by partition and write everything directly.
786+
for ((id, elements) <- this.partitionedIterator) {
787+
// The contract for the plugin is that we will ask for a writer for every partition
788+
// even if it's empty. However, the external sorter will return non-contiguous
789+
// partition ids. So this loop "backfills" the empty partitions that form the gaps.
790+
791+
// The algorithm as a whole is correct because the partition ids are returned by the
792+
// iterator in ascending order.
793+
for (emptyPartition <- nextPartitionId until id) {
794+
writeEmptyPartition(mapOutputWriter)
795+
}
796+
val blockId = ShuffleBlockId(shuffleId, mapId, id)
797+
var partitionWriter: ShufflePartitionWriter = null
798+
var partitionPairsWriter: ShufflePartitionPairsWriter = null
799+
try {
800+
partitionWriter = mapOutputWriter.getNextPartitionWriter
801+
partitionPairsWriter = new ShufflePartitionPairsWriter(
802+
partitionWriter,
803+
serializerManager,
804+
serInstance,
805+
blockId,
806+
context.taskMetrics().shuffleWriteMetrics)
807+
if (elements.hasNext) {
808+
for (elem <- elements) {
809+
partitionPairsWriter.write(elem._1, elem._2)
810+
}
811+
}
812+
} finally {
813+
if (partitionPairsWriter!= null) {
814+
partitionPairsWriter.close()
815+
}
816+
}
817+
if (partitionWriter != null) {
818+
lengths(id) = partitionWriter.getNumBytesWritten
819+
}
820+
nextPartitionId = id + 1
821+
}
822+
}
823+
824+
// The iterator may have stopped short of opening a writer for every partition. So fill in the
825+
// remaining empty partitions.
826+
for (emptyPartition <- nextPartitionId until numPartitions) {
827+
writeEmptyPartition(mapOutputWriter)
828+
}
829+
830+
context.taskMetrics().incMemoryBytesSpilled(memoryBytesSpilled)
831+
context.taskMetrics().incDiskBytesSpilled(diskBytesSpilled)
832+
context.taskMetrics().incPeakExecutionMemory(peakMemoryUsedBytes)
833+
834+
lengths
835+
}
836+
721837
def stop(): Unit = {
722838
spills.foreach(s => s.file.delete())
723839
spills.clear()
@@ -781,7 +897,7 @@ private[spark] class ExternalSorter[K, V, C](
781897
val inMemoryIterator = new WritablePartitionedIterator {
782898
private[this] var cur = if (upstream.hasNext) upstream.next() else null
783899

784-
def writeNext(writer: DiskBlockObjectWriter): Unit = {
900+
def writeNext(writer: PairsWriter): Unit = {
785901
writer.write(cur._1._2, cur._2)
786902
cur = if (upstream.hasNext) upstream.next() else null
787903
}
Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.util.collection
19+
20+
private[spark] trait PairsWriter {
21+
22+
def write(key: Any, value: Any): Unit
23+
}
Lines changed: 105 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,105 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.util.collection
19+
20+
import java.io.{Closeable, FilterOutputStream, OutputStream}
21+
22+
import org.apache.spark.api.shuffle.ShufflePartitionWriter
23+
import org.apache.spark.serializer.{SerializationStream, SerializerInstance, SerializerManager}
24+
import org.apache.spark.shuffle.ShuffleWriteMetricsReporter
25+
import org.apache.spark.storage.BlockId
26+
27+
/**
28+
* A key-value writer inspired by {@link DiskBlockObjectWriter} that pushes the bytes to an
29+
* arbitrary partition writer instead of writing to local disk through the block manager.
30+
*/
31+
private[spark] class ShufflePartitionPairsWriter(
32+
partitionWriter: ShufflePartitionWriter,
33+
serializerManager: SerializerManager,
34+
serializerInstance: SerializerInstance,
35+
blockId: BlockId,
36+
writeMetrics: ShuffleWriteMetricsReporter)
37+
extends PairsWriter with Closeable {
38+
39+
private var isOpen = false
40+
private var partitionStream: OutputStream = _
41+
private var wrappedStream: OutputStream = _
42+
private var objOut: SerializationStream = _
43+
private var numRecordsWritten = 0
44+
private var curNumBytesWritten = 0L
45+
46+
override def write(key: Any, value: Any): Unit = {
47+
if (!isOpen) {
48+
open()
49+
isOpen = true
50+
}
51+
objOut.writeKey(key)
52+
objOut.writeValue(value)
53+
writeMetrics.incRecordsWritten(1)
54+
}
55+
56+
private def open(): Unit = {
57+
// The contract is that the partition writer is expected to close its own streams, but
58+
// the compressor will only flush the stream when it is specifically closed. So we want to
59+
// close objOut to flush the compressed bytes to the partition writer stream, but we don't want
60+
// to close the partition output stream in the process.
61+
partitionStream = new CloseShieldOutputStream(partitionWriter.toStream)
62+
wrappedStream = serializerManager.wrapStream(blockId, partitionStream)
63+
objOut = serializerInstance.serializeStream(wrappedStream)
64+
}
65+
66+
override def close(): Unit = {
67+
if (isOpen) {
68+
// Closing objOut should propagate close to all inner layers
69+
// We can't close wrappedStream explicitly because closing objOut and closing wrappedStream
70+
// causes problems when closing compressed output streams twice.
71+
objOut.close()
72+
objOut = null
73+
wrappedStream = null
74+
partitionStream = null
75+
partitionWriter.close()
76+
isOpen = false
77+
updateBytesWritten()
78+
}
79+
}
80+
81+
/**
82+
* Notify the writer that a record worth of bytes has been written with OutputStream#write.
83+
*/
84+
private def recordWritten(): Unit = {
85+
numRecordsWritten += 1
86+
writeMetrics.incRecordsWritten(1)
87+
88+
if (numRecordsWritten % 16384 == 0) {
89+
updateBytesWritten()
90+
}
91+
}
92+
93+
private def updateBytesWritten(): Unit = {
94+
val numBytesWritten = partitionWriter.getNumBytesWritten
95+
val bytesWrittenDiff = numBytesWritten - curNumBytesWritten
96+
writeMetrics.incBytesWritten(bytesWrittenDiff)
97+
curNumBytesWritten = numBytesWritten
98+
}
99+
100+
private class CloseShieldOutputStream(delegate: OutputStream)
101+
extends FilterOutputStream(delegate) {
102+
103+
override def close(): Unit = flush()
104+
}
105+
}

core/src/main/scala/org/apache/spark/util/collection/WritablePartitionedPairCollection.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ private[spark] trait WritablePartitionedPairCollection[K, V] {
5252
new WritablePartitionedIterator {
5353
private[this] var cur = if (it.hasNext) it.next() else null
5454

55-
def writeNext(writer: DiskBlockObjectWriter): Unit = {
55+
def writeNext(writer: PairsWriter): Unit = {
5656
writer.write(cur._1._2, cur._2)
5757
cur = if (it.hasNext) it.next() else null
5858
}
@@ -89,7 +89,7 @@ private[spark] object WritablePartitionedPairCollection {
8989
* has an associated partition.
9090
*/
9191
private[spark] trait WritablePartitionedIterator {
92-
def writeNext(writer: DiskBlockObjectWriter): Unit
92+
def writeNext(writer: PairsWriter): Unit
9393

9494
def hasNext(): Boolean
9595

0 commit comments

Comments
 (0)