Skip to content

Commit de84899

Browse files
rxingatorsmile
authored andcommitted
[SPARK-26140] Enable custom metrics implementation in shuffle reader
## What changes were proposed in this pull request? This patch defines an internal Spark interface for reporting shuffle metrics and uses that in shuffle reader. Before this patch, shuffle metrics is tied to a specific implementation (using a thread local temporary data structure and accumulators). After this patch, callers that define their own shuffle RDDs can create a custom metrics implementation. With this patch, we would be able to create a better metrics for the SQL layer, e.g. reporting shuffle metrics in the SQL UI, for each exchange operator. Note that I'm separating read side and write side implementations, as they are very different, to simplify code review. Write side change is at apache#23106 ## How was this patch tested? No behavior change expected, as it is a straightforward refactoring. Updated all existing test cases. Closes apache#23105 from rxin/SPARK-26140. Authored-by: Reynold Xin <[email protected]> Signed-off-by: gatorsmile <[email protected]>
1 parent ecb785f commit de84899

File tree

15 files changed

+155
-36
lines changed

15 files changed

+155
-36
lines changed

core/src/main/scala/org/apache/spark/executor/ShuffleReadMetrics.scala

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
package org.apache.spark.executor
1919

2020
import org.apache.spark.annotation.DeveloperApi
21+
import org.apache.spark.shuffle.ShuffleMetricsReporter
2122
import org.apache.spark.util.LongAccumulator
2223

2324

@@ -123,12 +124,13 @@ class ShuffleReadMetrics private[spark] () extends Serializable {
123124
}
124125
}
125126

127+
126128
/**
127129
* A temporary shuffle read metrics holder that is used to collect shuffle read metrics for each
128130
* shuffle dependency, and all temporary metrics will be merged into the [[ShuffleReadMetrics]] at
129131
* last.
130132
*/
131-
private[spark] class TempShuffleReadMetrics {
133+
private[spark] class TempShuffleReadMetrics extends ShuffleMetricsReporter {
132134
private[this] var _remoteBlocksFetched = 0L
133135
private[this] var _localBlocksFetched = 0L
134136
private[this] var _remoteBytesRead = 0L
@@ -137,13 +139,13 @@ private[spark] class TempShuffleReadMetrics {
137139
private[this] var _fetchWaitTime = 0L
138140
private[this] var _recordsRead = 0L
139141

140-
def incRemoteBlocksFetched(v: Long): Unit = _remoteBlocksFetched += v
141-
def incLocalBlocksFetched(v: Long): Unit = _localBlocksFetched += v
142-
def incRemoteBytesRead(v: Long): Unit = _remoteBytesRead += v
143-
def incRemoteBytesReadToDisk(v: Long): Unit = _remoteBytesReadToDisk += v
144-
def incLocalBytesRead(v: Long): Unit = _localBytesRead += v
145-
def incFetchWaitTime(v: Long): Unit = _fetchWaitTime += v
146-
def incRecordsRead(v: Long): Unit = _recordsRead += v
142+
override def incRemoteBlocksFetched(v: Long): Unit = _remoteBlocksFetched += v
143+
override def incLocalBlocksFetched(v: Long): Unit = _localBlocksFetched += v
144+
override def incRemoteBytesRead(v: Long): Unit = _remoteBytesRead += v
145+
override def incRemoteBytesReadToDisk(v: Long): Unit = _remoteBytesReadToDisk += v
146+
override def incLocalBytesRead(v: Long): Unit = _localBytesRead += v
147+
override def incFetchWaitTime(v: Long): Unit = _fetchWaitTime += v
148+
override def incRecordsRead(v: Long): Unit = _recordsRead += v
147149

148150
def remoteBlocksFetched: Long = _remoteBlocksFetched
149151
def localBlocksFetched: Long = _localBlocksFetched

core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -143,8 +143,10 @@ class CoGroupedRDD[K: ClassTag](
143143

144144
case shuffleDependency: ShuffleDependency[_, _, _] =>
145145
// Read map outputs of shuffle
146+
val metrics = context.taskMetrics().createTempShuffleReadMetrics()
146147
val it = SparkEnv.get.shuffleManager
147-
.getReader(shuffleDependency.shuffleHandle, split.index, split.index + 1, context)
148+
.getReader(
149+
shuffleDependency.shuffleHandle, split.index, split.index + 1, context, metrics)
148150
.read()
149151
rddIterators += ((it, depNum))
150152
}

core/src/main/scala/org/apache/spark/rdd/ShuffledRDD.scala

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -101,7 +101,9 @@ class ShuffledRDD[K: ClassTag, V: ClassTag, C: ClassTag](
101101

102102
override def compute(split: Partition, context: TaskContext): Iterator[(K, C)] = {
103103
val dep = dependencies.head.asInstanceOf[ShuffleDependency[K, V, C]]
104-
SparkEnv.get.shuffleManager.getReader(dep.shuffleHandle, split.index, split.index + 1, context)
104+
val metrics = context.taskMetrics().createTempShuffleReadMetrics()
105+
SparkEnv.get.shuffleManager.getReader(
106+
dep.shuffleHandle, split.index, split.index + 1, context, metrics)
105107
.read()
106108
.asInstanceOf[Iterator[(K, C)]]
107109
}

core/src/main/scala/org/apache/spark/rdd/SubtractedRDD.scala

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -107,9 +107,14 @@ private[spark] class SubtractedRDD[K: ClassTag, V: ClassTag, W: ClassTag](
107107
.asInstanceOf[Iterator[Product2[K, V]]].foreach(op)
108108

109109
case shuffleDependency: ShuffleDependency[_, _, _] =>
110+
val metrics = context.taskMetrics().createTempShuffleReadMetrics()
110111
val iter = SparkEnv.get.shuffleManager
111112
.getReader(
112-
shuffleDependency.shuffleHandle, partition.index, partition.index + 1, context)
113+
shuffleDependency.shuffleHandle,
114+
partition.index,
115+
partition.index + 1,
116+
context,
117+
metrics)
113118
.read()
114119
iter.foreach(op)
115120
}

core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ private[spark] class BlockStoreShuffleReader[K, C](
3333
startPartition: Int,
3434
endPartition: Int,
3535
context: TaskContext,
36+
readMetrics: ShuffleMetricsReporter,
3637
serializerManager: SerializerManager = SparkEnv.get.serializerManager,
3738
blockManager: BlockManager = SparkEnv.get.blockManager,
3839
mapOutputTracker: MapOutputTracker = SparkEnv.get.mapOutputTracker)
@@ -53,7 +54,8 @@ private[spark] class BlockStoreShuffleReader[K, C](
5354
SparkEnv.get.conf.getInt("spark.reducer.maxReqsInFlight", Int.MaxValue),
5455
SparkEnv.get.conf.get(config.REDUCER_MAX_BLOCKS_IN_FLIGHT_PER_ADDRESS),
5556
SparkEnv.get.conf.get(config.MAX_REMOTE_BLOCK_SIZE_FETCH_TO_MEM),
56-
SparkEnv.get.conf.getBoolean("spark.shuffle.detectCorrupt", true))
57+
SparkEnv.get.conf.getBoolean("spark.shuffle.detectCorrupt", true),
58+
readMetrics)
5759

5860
val serializerInstance = dep.serializer.newInstance()
5961

@@ -66,7 +68,6 @@ private[spark] class BlockStoreShuffleReader[K, C](
6668
}
6769

6870
// Update the context task metrics for each record read.
69-
val readMetrics = context.taskMetrics.createTempShuffleReadMetrics()
7071
val metricIter = CompletionIterator[(Any, Any), Iterator[(Any, Any)]](
7172
recordIter.map { record =>
7273
readMetrics.incRecordsRead(1)

core/src/main/scala/org/apache/spark/shuffle/ShuffleManager.scala

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,8 @@ private[spark] trait ShuffleManager {
4848
handle: ShuffleHandle,
4949
startPartition: Int,
5050
endPartition: Int,
51-
context: TaskContext): ShuffleReader[K, C]
51+
context: TaskContext,
52+
metrics: ShuffleMetricsReporter): ShuffleReader[K, C]
5253

5354
/**
5455
* Remove a shuffle's metadata from the ShuffleManager.
Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.shuffle
19+
20+
/**
21+
* An interface for reporting shuffle information, for each shuffle. This interface assumes
22+
* all the methods are called on a single-threaded, i.e. concrete implementations would not need
23+
* to synchronize anything.
24+
*/
25+
private[spark] trait ShuffleMetricsReporter {
26+
def incRemoteBlocksFetched(v: Long): Unit
27+
def incLocalBlocksFetched(v: Long): Unit
28+
def incRemoteBytesRead(v: Long): Unit
29+
def incRemoteBytesReadToDisk(v: Long): Unit
30+
def incLocalBytesRead(v: Long): Unit
31+
def incFetchWaitTime(v: Long): Unit
32+
def incRecordsRead(v: Long): Unit
33+
}
Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.shuffle
19+
20+
/**
21+
* An interface for reporting shuffle read metrics, for each shuffle. This interface assumes
22+
* all the methods are called on a single-threaded, i.e. concrete implementations would not need
23+
* to synchronize.
24+
*
25+
* All methods have additional Spark visibility modifier to allow public, concrete implementations
26+
* that still have these methods marked as private[spark].
27+
*/
28+
private[spark] trait ShuffleReadMetricsReporter {
29+
private[spark] def incRemoteBlocksFetched(v: Long): Unit
30+
private[spark] def incLocalBlocksFetched(v: Long): Unit
31+
private[spark] def incRemoteBytesRead(v: Long): Unit
32+
private[spark] def incRemoteBytesReadToDisk(v: Long): Unit
33+
private[spark] def incLocalBytesRead(v: Long): Unit
34+
private[spark] def incFetchWaitTime(v: Long): Unit
35+
private[spark] def incRecordsRead(v: Long): Unit
36+
}
37+
38+
39+
/**
40+
* An interface for reporting shuffle write metrics. This interface assumes all the methods are
41+
* called on a single-threaded, i.e. concrete implementations would not need to synchronize.
42+
*
43+
* All methods have additional Spark visibility modifier to allow public, concrete implementations
44+
* that still have these methods marked as private[spark].
45+
*/
46+
private[spark] trait ShuffleWriteMetricsReporter {
47+
private[spark] def incBytesWritten(v: Long): Unit
48+
private[spark] def incRecordsWritten(v: Long): Unit
49+
private[spark] def incWriteTime(v: Long): Unit
50+
private[spark] def decBytesWritten(v: Long): Unit
51+
private[spark] def decRecordsWritten(v: Long): Unit
52+
}

core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -114,9 +114,11 @@ private[spark] class SortShuffleManager(conf: SparkConf) extends ShuffleManager
114114
handle: ShuffleHandle,
115115
startPartition: Int,
116116
endPartition: Int,
117-
context: TaskContext): ShuffleReader[K, C] = {
117+
context: TaskContext,
118+
metrics: ShuffleMetricsReporter): ShuffleReader[K, C] = {
118119
new BlockStoreShuffleReader(
119-
handle.asInstanceOf[BaseShuffleHandle[K, _, C]], startPartition, endPartition, context)
120+
handle.asInstanceOf[BaseShuffleHandle[K, _, C]],
121+
startPartition, endPartition, context, metrics)
120122
}
121123

122124
/** Get a writer for a given partition. Called on executors by map tasks. */

core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ import org.apache.spark.internal.Logging
3030
import org.apache.spark.network.buffer.{FileSegmentManagedBuffer, ManagedBuffer}
3131
import org.apache.spark.network.shuffle._
3232
import org.apache.spark.network.util.TransportConf
33-
import org.apache.spark.shuffle.FetchFailedException
33+
import org.apache.spark.shuffle.{FetchFailedException, ShuffleMetricsReporter}
3434
import org.apache.spark.util.Utils
3535
import org.apache.spark.util.io.ChunkedByteBufferOutputStream
3636

@@ -51,14 +51,15 @@ import org.apache.spark.util.io.ChunkedByteBufferOutputStream
5151
* For each block we also require the size (in bytes as a long field) in
5252
* order to throttle the memory usage. Note that zero-sized blocks are
5353
* already excluded, which happened in
54-
* [[MapOutputTracker.convertMapStatuses]].
54+
* [[org.apache.spark.MapOutputTracker.convertMapStatuses]].
5555
* @param streamWrapper A function to wrap the returned input stream.
5656
* @param maxBytesInFlight max size (in bytes) of remote blocks to fetch at any given point.
5757
* @param maxReqsInFlight max number of remote requests to fetch blocks at any given point.
5858
* @param maxBlocksInFlightPerAddress max number of shuffle blocks being fetched at any given point
5959
* for a given remote host:port.
6060
* @param maxReqSizeShuffleToMem max size (in bytes) of a request that can be shuffled to memory.
6161
* @param detectCorrupt whether to detect any corruption in fetched blocks.
62+
* @param shuffleMetrics used to report shuffle metrics.
6263
*/
6364
private[spark]
6465
final class ShuffleBlockFetcherIterator(
@@ -71,7 +72,8 @@ final class ShuffleBlockFetcherIterator(
7172
maxReqsInFlight: Int,
7273
maxBlocksInFlightPerAddress: Int,
7374
maxReqSizeShuffleToMem: Long,
74-
detectCorrupt: Boolean)
75+
detectCorrupt: Boolean,
76+
shuffleMetrics: ShuffleMetricsReporter)
7577
extends Iterator[(BlockId, InputStream)] with DownloadFileManager with Logging {
7678

7779
import ShuffleBlockFetcherIterator._
@@ -137,8 +139,6 @@ final class ShuffleBlockFetcherIterator(
137139
*/
138140
private[this] val corruptedBlocks = mutable.HashSet[BlockId]()
139141

140-
private[this] val shuffleMetrics = context.taskMetrics().createTempShuffleReadMetrics()
141-
142142
/**
143143
* Whether the iterator is still active. If isZombie is true, the callback interface will no
144144
* longer place fetched blocks into [[results]].

0 commit comments

Comments
 (0)