Skip to content

Commit 04d30a0

Browse files
mccheahifilonenko
authored andcommitted
[SPARK-25299] Use the shuffle writer plugin for the SortShuffleWriter. (apache-spark-on-k8s#532)
* [SPARK-25299] Use the shuffle writer plugin for the SortShuffleWriter. * Remove unused * Handle empty partitions properly. * Adjust formatting * Don't close streams twice. Because compressed output streams don't like it. * Clarify comment
1 parent 83bcfcf commit 04d30a0

File tree

8 files changed

+270
-28
lines changed

8 files changed

+270
-28
lines changed

core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -157,7 +157,8 @@ private[spark] class SortShuffleManager(conf: SparkConf) extends ShuffleManager
157157
metrics,
158158
shuffleExecutorComponents.writes())
159159
case other: BaseShuffleHandle[K @unchecked, V @unchecked, _] =>
160-
new SortShuffleWriter(shuffleBlockResolver, other, mapId, context)
160+
new SortShuffleWriter(
161+
shuffleBlockResolver, other, mapId, context, shuffleExecutorComponents.writes())
161162
}
162163
}
163164

core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala

Lines changed: 8 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -18,18 +18,18 @@
1818
package org.apache.spark.shuffle.sort
1919

2020
import org.apache.spark._
21+
import org.apache.spark.api.shuffle.ShuffleWriteSupport
2122
import org.apache.spark.internal.{config, Logging}
2223
import org.apache.spark.scheduler.MapStatus
2324
import org.apache.spark.shuffle.{BaseShuffleHandle, IndexShuffleBlockResolver, ShuffleWriter}
24-
import org.apache.spark.storage.ShuffleBlockId
25-
import org.apache.spark.util.Utils
2625
import org.apache.spark.util.collection.ExternalSorter
2726

2827
private[spark] class SortShuffleWriter[K, V, C](
2928
shuffleBlockResolver: IndexShuffleBlockResolver,
3029
handle: BaseShuffleHandle[K, V, C],
3130
mapId: Int,
32-
context: TaskContext)
31+
context: TaskContext,
32+
writeSupport: ShuffleWriteSupport)
3333
extends ShuffleWriter[K, V] with Logging {
3434

3535
private val dep = handle.dependency
@@ -64,18 +64,11 @@ private[spark] class SortShuffleWriter[K, V, C](
6464
// Don't bother including the time to open the merged output file in the shuffle write time,
6565
// because it just opens a single file, so is typically too fast to measure accurately
6666
// (see SPARK-3570).
67-
val output = shuffleBlockResolver.getDataFile(dep.shuffleId, mapId)
68-
val tmp = Utils.tempFileWith(output)
69-
try {
70-
val blockId = ShuffleBlockId(dep.shuffleId, mapId, IndexShuffleBlockResolver.NOOP_REDUCE_ID)
71-
val partitionLengths = sorter.writePartitionedFile(blockId, tmp)
72-
shuffleBlockResolver.writeIndexFileAndCommit(dep.shuffleId, mapId, partitionLengths, tmp)
73-
mapStatus = MapStatus(blockManager.shuffleServerId, partitionLengths)
74-
} finally {
75-
if (tmp.exists() && !tmp.delete()) {
76-
logError(s"Error while deleting temp file ${tmp.getAbsolutePath}")
77-
}
78-
}
67+
val mapOutputWriter = writeSupport.createMapOutputWriter(
68+
dep.shuffleId, mapId, dep.partitioner.numPartitions)
69+
val partitionLengths = sorter.writePartitionedMapOutput(dep.shuffleId, mapId, mapOutputWriter)
70+
mapOutputWriter.commitAllPartitions()
71+
mapStatus = MapStatus(blockManager.shuffleServerId, partitionLengths)
7972
}
8073

8174
/** Close this writer, passing along whether the map completed */

core/src/main/scala/org/apache/spark/storage/DiskBlockObjectWriter.scala

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ import org.apache.spark.internal.Logging
2424
import org.apache.spark.serializer.{SerializationStream, SerializerInstance, SerializerManager}
2525
import org.apache.spark.shuffle.ShuffleWriteMetricsReporter
2626
import org.apache.spark.util.Utils
27+
import org.apache.spark.util.collection.PairsWriter
2728

2829
/**
2930
* A class for writing JVM objects directly to a file on disk. This class allows data to be appended
@@ -46,7 +47,8 @@ private[spark] class DiskBlockObjectWriter(
4647
writeMetrics: ShuffleWriteMetricsReporter,
4748
val blockId: BlockId = null)
4849
extends OutputStream
49-
with Logging {
50+
with Logging
51+
with PairsWriter {
5052

5153
/**
5254
* Guards against close calls, e.g. from a wrapping stream.

core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala

Lines changed: 123 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -26,10 +26,11 @@ import scala.collection.mutable.ArrayBuffer
2626
import com.google.common.io.ByteStreams
2727

2828
import org.apache.spark._
29+
import org.apache.spark.api.shuffle.{ShuffleMapOutputWriter, ShufflePartitionWriter}
2930
import org.apache.spark.executor.ShuffleWriteMetrics
3031
import org.apache.spark.internal.{config, Logging}
3132
import org.apache.spark.serializer._
32-
import org.apache.spark.storage.{BlockId, DiskBlockObjectWriter}
33+
import org.apache.spark.storage.{BlockId, DiskBlockObjectWriter, ShuffleBlockId}
3334

3435
/**
3536
* Sorts and potentially merges a number of key-value pairs of type (K, V) to produce key-combiner
@@ -674,11 +675,9 @@ private[spark] class ExternalSorter[K, V, C](
674675
}
675676

676677
/**
677-
* Write all the data added into this ExternalSorter into a file in the disk store. This is
678-
* called by the SortShuffleWriter.
679-
*
680-
* @param blockId block ID to write to. The index file will be blockId.name + ".index".
681-
* @return array of lengths, in bytes, of each partition of the file (used by map output tracker)
678+
* TODO remove this, as this is only used by UnsafeRowSerializerSuite in the SQL project.
679+
* We should figure out an alternative way to test that so that we can remove this otherwise
680+
* unused code path.
682681
*/
683682
def writePartitionedFile(
684683
blockId: BlockId,
@@ -722,6 +721,123 @@ private[spark] class ExternalSorter[K, V, C](
722721
lengths
723722
}
724723

724+
private def writeEmptyPartition(mapOutputWriter: ShuffleMapOutputWriter): Unit = {
725+
var partitionWriter: ShufflePartitionWriter = null
726+
try {
727+
partitionWriter = mapOutputWriter.getNextPartitionWriter
728+
} finally {
729+
if (partitionWriter != null) {
730+
partitionWriter.close()
731+
}
732+
}
733+
}
734+
735+
/**
736+
* Write all the data added into this ExternalSorter into a map output writer that pushes bytes
737+
* to some arbitrary backing store. This is called by the SortShuffleWriter.
738+
*
739+
* @return array of lengths, in bytes, of each partition of the file (used by map output tracker)
740+
*/
741+
def writePartitionedMapOutput(
742+
shuffleId: Int, mapId: Int, mapOutputWriter: ShuffleMapOutputWriter): Array[Long] = {
743+
// Track location of each range in the map output
744+
val lengths = new Array[Long](numPartitions)
745+
var nextPartitionId = 0
746+
if (spills.isEmpty) {
747+
// Case where we only have in-memory data
748+
val collection = if (aggregator.isDefined) map else buffer
749+
val it = collection.destructiveSortedWritablePartitionedIterator(comparator)
750+
while (it.hasNext()) {
751+
val partitionId = it.nextPartition()
752+
// The contract for the plugin is that we will ask for a writer for every partition
753+
// even if it's empty. However, the external sorter will return non-contiguous
754+
// partition ids. So this loop "backfills" the empty partitions that form the gaps.
755+
756+
// The algorithm as a whole is correct because the partition ids are returned by the
757+
// iterator in ascending order.
758+
for (emptyPartition <- nextPartitionId until partitionId) {
759+
writeEmptyPartition(mapOutputWriter)
760+
}
761+
var partitionWriter: ShufflePartitionWriter = null
762+
var partitionPairsWriter: ShufflePartitionPairsWriter = null
763+
try {
764+
partitionWriter = mapOutputWriter.getNextPartitionWriter
765+
val blockId = ShuffleBlockId(shuffleId, mapId, partitionId)
766+
partitionPairsWriter = new ShufflePartitionPairsWriter(
767+
partitionWriter,
768+
serializerManager,
769+
serInstance,
770+
blockId,
771+
context.taskMetrics().shuffleWriteMetrics)
772+
while (it.hasNext && it.nextPartition() == partitionId) {
773+
it.writeNext(partitionPairsWriter)
774+
}
775+
} finally {
776+
if (partitionPairsWriter != null) {
777+
partitionPairsWriter.close()
778+
}
779+
if (partitionWriter != null) {
780+
partitionWriter.close()
781+
}
782+
}
783+
if (partitionWriter != null) {
784+
lengths(partitionId) = partitionWriter.getNumBytesWritten
785+
}
786+
nextPartitionId = partitionId + 1
787+
}
788+
} else {
789+
// We must perform merge-sort; get an iterator by partition and write everything directly.
790+
for ((id, elements) <- this.partitionedIterator) {
791+
// The contract for the plugin is that we will ask for a writer for every partition
792+
// even if it's empty. However, the external sorter will return non-contiguous
793+
// partition ids. So this loop "backfills" the empty partitions that form the gaps.
794+
795+
// The algorithm as a whole is correct because the partition ids are returned by the
796+
// iterator in ascending order.
797+
for (emptyPartition <- nextPartitionId until id) {
798+
writeEmptyPartition(mapOutputWriter)
799+
}
800+
val blockId = ShuffleBlockId(shuffleId, mapId, id)
801+
var partitionWriter: ShufflePartitionWriter = null
802+
var partitionPairsWriter: ShufflePartitionPairsWriter = null
803+
try {
804+
partitionWriter = mapOutputWriter.getNextPartitionWriter
805+
partitionPairsWriter = new ShufflePartitionPairsWriter(
806+
partitionWriter,
807+
serializerManager,
808+
serInstance,
809+
blockId,
810+
context.taskMetrics().shuffleWriteMetrics)
811+
if (elements.hasNext) {
812+
for (elem <- elements) {
813+
partitionPairsWriter.write(elem._1, elem._2)
814+
}
815+
}
816+
} finally {
817+
if (partitionPairsWriter!= null) {
818+
partitionPairsWriter.close()
819+
}
820+
}
821+
if (partitionWriter != null) {
822+
lengths(id) = partitionWriter.getNumBytesWritten
823+
}
824+
nextPartitionId = id + 1
825+
}
826+
}
827+
828+
// The iterator may have stopped short of opening a writer for every partition. So fill in the
829+
// remaining empty partitions.
830+
for (emptyPartition <- nextPartitionId until numPartitions) {
831+
writeEmptyPartition(mapOutputWriter)
832+
}
833+
834+
context.taskMetrics().incMemoryBytesSpilled(memoryBytesSpilled)
835+
context.taskMetrics().incDiskBytesSpilled(diskBytesSpilled)
836+
context.taskMetrics().incPeakExecutionMemory(peakMemoryUsedBytes)
837+
838+
lengths
839+
}
840+
725841
def stop(): Unit = {
726842
spills.foreach(s => s.file.delete())
727843
spills.clear()
@@ -785,7 +901,7 @@ private[spark] class ExternalSorter[K, V, C](
785901
val inMemoryIterator = new WritablePartitionedIterator {
786902
private[this] var cur = if (upstream.hasNext) upstream.next() else null
787903

788-
def writeNext(writer: DiskBlockObjectWriter): Unit = {
904+
def writeNext(writer: PairsWriter): Unit = {
789905
writer.write(cur._1._2, cur._2)
790906
cur = if (upstream.hasNext) upstream.next() else null
791907
}
Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.util.collection
19+
20+
private[spark] trait PairsWriter {
21+
22+
def write(key: Any, value: Any): Unit
23+
}
Lines changed: 105 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,105 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.util.collection
19+
20+
import java.io.{Closeable, FilterOutputStream, OutputStream}
21+
22+
import org.apache.spark.api.shuffle.ShufflePartitionWriter
23+
import org.apache.spark.serializer.{SerializationStream, SerializerInstance, SerializerManager}
24+
import org.apache.spark.shuffle.ShuffleWriteMetricsReporter
25+
import org.apache.spark.storage.BlockId
26+
27+
/**
28+
* A key-value writer inspired by {@link DiskBlockObjectWriter} that pushes the bytes to an
29+
* arbitrary partition writer instead of writing to local disk through the block manager.
30+
*/
31+
private[spark] class ShufflePartitionPairsWriter(
32+
partitionWriter: ShufflePartitionWriter,
33+
serializerManager: SerializerManager,
34+
serializerInstance: SerializerInstance,
35+
blockId: BlockId,
36+
writeMetrics: ShuffleWriteMetricsReporter)
37+
extends PairsWriter with Closeable {
38+
39+
private var isOpen = false
40+
private var partitionStream: OutputStream = _
41+
private var wrappedStream: OutputStream = _
42+
private var objOut: SerializationStream = _
43+
private var numRecordsWritten = 0
44+
private var curNumBytesWritten = 0L
45+
46+
override def write(key: Any, value: Any): Unit = {
47+
if (!isOpen) {
48+
open()
49+
isOpen = true
50+
}
51+
objOut.writeKey(key)
52+
objOut.writeValue(value)
53+
writeMetrics.incRecordsWritten(1)
54+
}
55+
56+
private def open(): Unit = {
57+
// The contract is that the partition writer is expected to close its own streams, but
58+
// the compressor will only flush the stream when it is specifically closed. So we want to
59+
// close objOut to flush the compressed bytes to the partition writer stream, but we don't want
60+
// to close the partition output stream in the process.
61+
partitionStream = new CloseShieldOutputStream(partitionWriter.toStream)
62+
wrappedStream = serializerManager.wrapStream(blockId, partitionStream)
63+
objOut = serializerInstance.serializeStream(wrappedStream)
64+
}
65+
66+
override def close(): Unit = {
67+
if (isOpen) {
68+
// Closing objOut should propagate close to all inner layers
69+
// We can't close wrappedStream explicitly because closing objOut and closing wrappedStream
70+
// causes problems when closing compressed output streams twice.
71+
objOut.close()
72+
objOut = null
73+
wrappedStream = null
74+
partitionStream = null
75+
partitionWriter.close()
76+
isOpen = false
77+
updateBytesWritten()
78+
}
79+
}
80+
81+
/**
82+
* Notify the writer that a record worth of bytes has been written with OutputStream#write.
83+
*/
84+
private def recordWritten(): Unit = {
85+
numRecordsWritten += 1
86+
writeMetrics.incRecordsWritten(1)
87+
88+
if (numRecordsWritten % 16384 == 0) {
89+
updateBytesWritten()
90+
}
91+
}
92+
93+
private def updateBytesWritten(): Unit = {
94+
val numBytesWritten = partitionWriter.getNumBytesWritten
95+
val bytesWrittenDiff = numBytesWritten - curNumBytesWritten
96+
writeMetrics.incBytesWritten(bytesWrittenDiff)
97+
curNumBytesWritten = numBytesWritten
98+
}
99+
100+
private class CloseShieldOutputStream(delegate: OutputStream)
101+
extends FilterOutputStream(delegate) {
102+
103+
override def close(): Unit = flush()
104+
}
105+
}

core/src/main/scala/org/apache/spark/util/collection/WritablePartitionedPairCollection.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ private[spark] trait WritablePartitionedPairCollection[K, V] {
5252
new WritablePartitionedIterator {
5353
private[this] var cur = if (it.hasNext) it.next() else null
5454

55-
def writeNext(writer: DiskBlockObjectWriter): Unit = {
55+
def writeNext(writer: PairsWriter): Unit = {
5656
writer.write(cur._1._2, cur._2)
5757
cur = if (it.hasNext) it.next() else null
5858
}
@@ -96,7 +96,7 @@ private[spark] object WritablePartitionedPairCollection {
9696
* has an associated partition.
9797
*/
9898
private[spark] trait WritablePartitionedIterator {
99-
def writeNext(writer: DiskBlockObjectWriter): Unit
99+
def writeNext(writer: PairsWriter): Unit
100100

101101
def hasNext(): Boolean
102102

0 commit comments

Comments
 (0)