Skip to content

Commit 6a064ba

Browse files
committed
[SPARK-26141] Enable custom metrics implementation in shuffle write
## What changes were proposed in this pull request? This is the write side counterpart to apache#23105 ## How was this patch tested? No behavior change expected, as it is a straightforward refactoring. Updated all existing test cases. Closes apache#23106 from rxin/SPARK-26141. Authored-by: Reynold Xin <[email protected]> Signed-off-by: Reynold Xin <[email protected]>
1 parent 85383d2 commit 6a064ba

File tree

15 files changed

+79
-54
lines changed

15 files changed

+79
-54
lines changed

core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -37,12 +37,11 @@
3737
import org.apache.spark.Partitioner;
3838
import org.apache.spark.ShuffleDependency;
3939
import org.apache.spark.SparkConf;
40-
import org.apache.spark.TaskContext;
41-
import org.apache.spark.executor.ShuffleWriteMetrics;
4240
import org.apache.spark.scheduler.MapStatus;
4341
import org.apache.spark.scheduler.MapStatus$;
4442
import org.apache.spark.serializer.Serializer;
4543
import org.apache.spark.serializer.SerializerInstance;
44+
import org.apache.spark.shuffle.ShuffleWriteMetricsReporter;
4645
import org.apache.spark.shuffle.IndexShuffleBlockResolver;
4746
import org.apache.spark.shuffle.ShuffleWriter;
4847
import org.apache.spark.storage.*;
@@ -79,7 +78,7 @@ final class BypassMergeSortShuffleWriter<K, V> extends ShuffleWriter<K, V> {
7978
private final int numPartitions;
8079
private final BlockManager blockManager;
8180
private final Partitioner partitioner;
82-
private final ShuffleWriteMetrics writeMetrics;
81+
private final ShuffleWriteMetricsReporter writeMetrics;
8382
private final int shuffleId;
8483
private final int mapId;
8584
private final Serializer serializer;
@@ -103,8 +102,8 @@ final class BypassMergeSortShuffleWriter<K, V> extends ShuffleWriter<K, V> {
103102
IndexShuffleBlockResolver shuffleBlockResolver,
104103
BypassMergeSortShuffleHandle<K, V> handle,
105104
int mapId,
106-
TaskContext taskContext,
107-
SparkConf conf) {
105+
SparkConf conf,
106+
ShuffleWriteMetricsReporter writeMetrics) {
108107
// Use getSizeAsKb (not bytes) to maintain backwards compatibility if no units are provided
109108
this.fileBufferSize = (int) conf.getSizeAsKb("spark.shuffle.file.buffer", "32k") * 1024;
110109
this.transferToEnabled = conf.getBoolean("spark.file.transferTo", true);
@@ -114,7 +113,7 @@ final class BypassMergeSortShuffleWriter<K, V> extends ShuffleWriter<K, V> {
114113
this.shuffleId = dep.shuffleId();
115114
this.partitioner = dep.partitioner();
116115
this.numPartitions = partitioner.numPartitions();
117-
this.writeMetrics = taskContext.taskMetrics().shuffleWriteMetrics();
116+
this.writeMetrics = writeMetrics;
118117
this.serializer = dep.serializer();
119118
this.shuffleBlockResolver = shuffleBlockResolver;
120119
}

core/src/main/java/org/apache/spark/shuffle/sort/ShuffleExternalSorter.java

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@
3838
import org.apache.spark.memory.TooLargePageException;
3939
import org.apache.spark.serializer.DummySerializerInstance;
4040
import org.apache.spark.serializer.SerializerInstance;
41+
import org.apache.spark.shuffle.ShuffleWriteMetricsReporter;
4142
import org.apache.spark.storage.BlockManager;
4243
import org.apache.spark.storage.DiskBlockObjectWriter;
4344
import org.apache.spark.storage.FileSegment;
@@ -75,7 +76,7 @@ final class ShuffleExternalSorter extends MemoryConsumer {
7576
private final TaskMemoryManager taskMemoryManager;
7677
private final BlockManager blockManager;
7778
private final TaskContext taskContext;
78-
private final ShuffleWriteMetrics writeMetrics;
79+
private final ShuffleWriteMetricsReporter writeMetrics;
7980

8081
/**
8182
* Force this sorter to spill when there are this many elements in memory.
@@ -113,7 +114,7 @@ final class ShuffleExternalSorter extends MemoryConsumer {
113114
int initialSize,
114115
int numPartitions,
115116
SparkConf conf,
116-
ShuffleWriteMetrics writeMetrics) {
117+
ShuffleWriteMetricsReporter writeMetrics) {
117118
super(memoryManager,
118119
(int) Math.min(PackedRecordPointer.MAXIMUM_PAGE_SIZE_BYTES, memoryManager.pageSizeBytes()),
119120
memoryManager.getTungstenMemoryMode());
@@ -144,7 +145,7 @@ final class ShuffleExternalSorter extends MemoryConsumer {
144145
*/
145146
private void writeSortedFile(boolean isLastFile) {
146147

147-
final ShuffleWriteMetrics writeMetricsToUse;
148+
final ShuffleWriteMetricsReporter writeMetricsToUse;
148149

149150
if (isLastFile) {
150151
// We're writing the final non-spill file, so we _do_ want to count this as shuffle bytes.
@@ -241,9 +242,14 @@ private void writeSortedFile(boolean isLastFile) {
241242
//
242243
// Note that we intentionally ignore the value of `writeMetricsToUse.shuffleWriteTime()`.
243244
// Consistent with ExternalSorter, we do not count this IO towards shuffle write time.
244-
// This means that this IO time is not accounted for anywhere; SPARK-3577 will fix this.
245-
writeMetrics.incRecordsWritten(writeMetricsToUse.recordsWritten());
246-
taskContext.taskMetrics().incDiskBytesSpilled(writeMetricsToUse.bytesWritten());
245+
// SPARK-3577 tracks the spill time separately.
246+
247+
// This is guaranteed to be a ShuffleWriteMetrics based on the if check in the beginning
248+
// of this method.
249+
writeMetrics.incRecordsWritten(
250+
((ShuffleWriteMetrics)writeMetricsToUse).recordsWritten());
251+
taskContext.taskMetrics().incDiskBytesSpilled(
252+
((ShuffleWriteMetrics)writeMetricsToUse).bytesWritten());
247253
}
248254
}
249255

core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,6 @@
3737

3838
import org.apache.spark.*;
3939
import org.apache.spark.annotation.Private;
40-
import org.apache.spark.executor.ShuffleWriteMetrics;
4140
import org.apache.spark.io.CompressionCodec;
4241
import org.apache.spark.io.CompressionCodec$;
4342
import org.apache.spark.io.NioBufferedFileInputStream;
@@ -47,6 +46,7 @@
4746
import org.apache.spark.network.util.LimitedInputStream;
4847
import org.apache.spark.scheduler.MapStatus;
4948
import org.apache.spark.scheduler.MapStatus$;
49+
import org.apache.spark.shuffle.ShuffleWriteMetricsReporter;
5050
import org.apache.spark.serializer.SerializationStream;
5151
import org.apache.spark.serializer.SerializerInstance;
5252
import org.apache.spark.shuffle.IndexShuffleBlockResolver;
@@ -73,7 +73,7 @@ public class UnsafeShuffleWriter<K, V> extends ShuffleWriter<K, V> {
7373
private final TaskMemoryManager memoryManager;
7474
private final SerializerInstance serializer;
7575
private final Partitioner partitioner;
76-
private final ShuffleWriteMetrics writeMetrics;
76+
private final ShuffleWriteMetricsReporter writeMetrics;
7777
private final int shuffleId;
7878
private final int mapId;
7979
private final TaskContext taskContext;
@@ -122,7 +122,8 @@ public UnsafeShuffleWriter(
122122
SerializedShuffleHandle<K, V> handle,
123123
int mapId,
124124
TaskContext taskContext,
125-
SparkConf sparkConf) throws IOException {
125+
SparkConf sparkConf,
126+
ShuffleWriteMetricsReporter writeMetrics) throws IOException {
126127
final int numPartitions = handle.dependency().partitioner().numPartitions();
127128
if (numPartitions > SortShuffleManager.MAX_SHUFFLE_OUTPUT_PARTITIONS_FOR_SERIALIZED_MODE()) {
128129
throw new IllegalArgumentException(
@@ -138,7 +139,7 @@ public UnsafeShuffleWriter(
138139
this.shuffleId = dep.shuffleId();
139140
this.serializer = dep.serializer().newInstance();
140141
this.partitioner = dep.partitioner();
141-
this.writeMetrics = taskContext.taskMetrics().shuffleWriteMetrics();
142+
this.writeMetrics = writeMetrics;
142143
this.taskContext = taskContext;
143144
this.sparkConf = sparkConf;
144145
this.transferToEnabled = sparkConf.getBoolean("spark.file.transferTo", true);

core/src/main/java/org/apache/spark/storage/TimeTrackingOutputStream.java

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
import java.io.OutputStream;
2222

2323
import org.apache.spark.annotation.Private;
24-
import org.apache.spark.executor.ShuffleWriteMetrics;
24+
import org.apache.spark.shuffle.ShuffleWriteMetricsReporter;
2525

2626
/**
2727
* Intercepts write calls and tracks total time spent writing in order to update shuffle write
@@ -30,10 +30,11 @@
3030
@Private
3131
public final class TimeTrackingOutputStream extends OutputStream {
3232

33-
private final ShuffleWriteMetrics writeMetrics;
33+
private final ShuffleWriteMetricsReporter writeMetrics;
3434
private final OutputStream outputStream;
3535

36-
public TimeTrackingOutputStream(ShuffleWriteMetrics writeMetrics, OutputStream outputStream) {
36+
public TimeTrackingOutputStream(
37+
ShuffleWriteMetricsReporter writeMetrics, OutputStream outputStream) {
3738
this.writeMetrics = writeMetrics;
3839
this.outputStream = outputStream;
3940
}

core/src/main/scala/org/apache/spark/executor/ShuffleWriteMetrics.scala

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
package org.apache.spark.executor
1919

2020
import org.apache.spark.annotation.DeveloperApi
21+
import org.apache.spark.shuffle.ShuffleWriteMetricsReporter
2122
import org.apache.spark.util.LongAccumulator
2223

2324

@@ -27,7 +28,7 @@ import org.apache.spark.util.LongAccumulator
2728
* Operations are not thread-safe.
2829
*/
2930
@DeveloperApi
30-
class ShuffleWriteMetrics private[spark] () extends Serializable {
31+
class ShuffleWriteMetrics private[spark] () extends ShuffleWriteMetricsReporter with Serializable {
3132
private[executor] val _bytesWritten = new LongAccumulator
3233
private[executor] val _recordsWritten = new LongAccumulator
3334
private[executor] val _writeTime = new LongAccumulator
@@ -47,13 +48,13 @@ class ShuffleWriteMetrics private[spark] () extends Serializable {
4748
*/
4849
def writeTime: Long = _writeTime.sum
4950

50-
private[spark] def incBytesWritten(v: Long): Unit = _bytesWritten.add(v)
51-
private[spark] def incRecordsWritten(v: Long): Unit = _recordsWritten.add(v)
52-
private[spark] def incWriteTime(v: Long): Unit = _writeTime.add(v)
53-
private[spark] def decBytesWritten(v: Long): Unit = {
51+
private[spark] override def incBytesWritten(v: Long): Unit = _bytesWritten.add(v)
52+
private[spark] override def incRecordsWritten(v: Long): Unit = _recordsWritten.add(v)
53+
private[spark] override def incWriteTime(v: Long): Unit = _writeTime.add(v)
54+
private[spark] override def decBytesWritten(v: Long): Unit = {
5455
_bytesWritten.setValue(bytesWritten - v)
5556
}
56-
private[spark] def decRecordsWritten(v: Long): Unit = {
57+
private[spark] override def decRecordsWritten(v: Long): Unit = {
5758
_recordsWritten.setValue(recordsWritten - v)
5859
}
5960
}

core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,8 @@ private[spark] class ShuffleMapTask(
9595
var writer: ShuffleWriter[Any, Any] = null
9696
try {
9797
val manager = SparkEnv.get.shuffleManager
98-
writer = manager.getWriter[Any, Any](dep.shuffleHandle, partitionId, context)
98+
writer = manager.getWriter[Any, Any](
99+
dep.shuffleHandle, partitionId, context, context.taskMetrics().shuffleWriteMetrics)
99100
writer.write(rdd.iterator(partition, context).asInstanceOf[Iterator[_ <: Product2[Any, Any]]])
100101
writer.stop(success = true).get
101102
} catch {

core/src/main/scala/org/apache/spark/shuffle/ShuffleManager.scala

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,11 @@ private[spark] trait ShuffleManager {
3838
dependency: ShuffleDependency[K, V, C]): ShuffleHandle
3939

4040
/** Get a writer for a given partition. Called on executors by map tasks. */
41-
def getWriter[K, V](handle: ShuffleHandle, mapId: Int, context: TaskContext): ShuffleWriter[K, V]
41+
def getWriter[K, V](
42+
handle: ShuffleHandle,
43+
mapId: Int,
44+
context: TaskContext,
45+
metrics: ShuffleWriteMetricsReporter): ShuffleWriter[K, V]
4246

4347
/**
4448
* Get a reader for a range of reduce partitions (startPartition to endPartition-1, inclusive).

core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -125,7 +125,8 @@ private[spark] class SortShuffleManager(conf: SparkConf) extends ShuffleManager
125125
override def getWriter[K, V](
126126
handle: ShuffleHandle,
127127
mapId: Int,
128-
context: TaskContext): ShuffleWriter[K, V] = {
128+
context: TaskContext,
129+
metrics: ShuffleWriteMetricsReporter): ShuffleWriter[K, V] = {
129130
numMapsForShuffle.putIfAbsent(
130131
handle.shuffleId, handle.asInstanceOf[BaseShuffleHandle[_, _, _]].numMaps)
131132
val env = SparkEnv.get
@@ -138,15 +139,16 @@ private[spark] class SortShuffleManager(conf: SparkConf) extends ShuffleManager
138139
unsafeShuffleHandle,
139140
mapId,
140141
context,
141-
env.conf)
142+
env.conf,
143+
metrics)
142144
case bypassMergeSortHandle: BypassMergeSortShuffleHandle[K @unchecked, V @unchecked] =>
143145
new BypassMergeSortShuffleWriter(
144146
env.blockManager,
145147
shuffleBlockResolver.asInstanceOf[IndexShuffleBlockResolver],
146148
bypassMergeSortHandle,
147149
mapId,
148-
context,
149-
env.conf)
150+
env.conf,
151+
metrics)
150152
case other: BaseShuffleHandle[K @unchecked, V @unchecked, _] =>
151153
new SortShuffleWriter(shuffleBlockResolver, other, mapId, context)
152154
}

core/src/main/scala/org/apache/spark/storage/BlockManager.scala

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -33,10 +33,9 @@ import scala.util.Random
3333
import scala.util.control.NonFatal
3434

3535
import com.codahale.metrics.{MetricRegistry, MetricSet}
36-
import com.google.common.io.CountingOutputStream
3736

3837
import org.apache.spark._
39-
import org.apache.spark.executor.{DataReadMethod, ShuffleWriteMetrics}
38+
import org.apache.spark.executor.DataReadMethod
4039
import org.apache.spark.internal.{config, Logging}
4140
import org.apache.spark.memory.{MemoryManager, MemoryMode}
4241
import org.apache.spark.metrics.source.Source
@@ -50,7 +49,7 @@ import org.apache.spark.network.util.TransportConf
5049
import org.apache.spark.rpc.RpcEnv
5150
import org.apache.spark.scheduler.ExecutorCacheTaskLocation
5251
import org.apache.spark.serializer.{SerializerInstance, SerializerManager}
53-
import org.apache.spark.shuffle.ShuffleManager
52+
import org.apache.spark.shuffle.{ShuffleManager, ShuffleWriteMetricsReporter}
5453
import org.apache.spark.storage.memory._
5554
import org.apache.spark.unsafe.Platform
5655
import org.apache.spark.util._
@@ -932,7 +931,7 @@ private[spark] class BlockManager(
932931
file: File,
933932
serializerInstance: SerializerInstance,
934933
bufferSize: Int,
935-
writeMetrics: ShuffleWriteMetrics): DiskBlockObjectWriter = {
934+
writeMetrics: ShuffleWriteMetricsReporter): DiskBlockObjectWriter = {
936935
val syncWrites = conf.getBoolean("spark.shuffle.sync", false)
937936
new DiskBlockObjectWriter(file, serializerManager, serializerInstance, bufferSize,
938937
syncWrites, writeMetrics, blockId)

core/src/main/scala/org/apache/spark/storage/DiskBlockObjectWriter.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,9 +20,9 @@ package org.apache.spark.storage
2020
import java.io.{BufferedOutputStream, File, FileOutputStream, OutputStream}
2121
import java.nio.channels.FileChannel
2222

23-
import org.apache.spark.executor.ShuffleWriteMetrics
2423
import org.apache.spark.internal.Logging
2524
import org.apache.spark.serializer.{SerializationStream, SerializerInstance, SerializerManager}
25+
import org.apache.spark.shuffle.ShuffleWriteMetricsReporter
2626
import org.apache.spark.util.Utils
2727

2828
/**
@@ -43,7 +43,7 @@ private[spark] class DiskBlockObjectWriter(
4343
syncWrites: Boolean,
4444
// These write metrics concurrently shared with other active DiskBlockObjectWriters who
4545
// are themselves performing writes. All updates must be relative.
46-
writeMetrics: ShuffleWriteMetrics,
46+
writeMetrics: ShuffleWriteMetricsReporter,
4747
val blockId: BlockId = null)
4848
extends OutputStream
4949
with Logging {

0 commit comments

Comments
 (0)