Skip to content

Commit c1e3e93

Browse files
committed
fix DSv2 custom metrics
1 parent 4ce05c9 commit c1e3e93

File tree

1 file changed

+59
-14
lines changed

1 file changed

+59
-14
lines changed

sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceRDD.scala

Lines changed: 59 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ import org.apache.spark.deploy.SparkHadoopUtil
2424
import org.apache.spark.internal.Logging
2525
import org.apache.spark.rdd.RDD
2626
import org.apache.spark.sql.catalyst.InternalRow
27+
import org.apache.spark.sql.connector.metric.CustomTaskMetric
2728
import org.apache.spark.sql.connector.read.{InputPartition, PartitionReader, PartitionReaderFactory}
2829
import org.apache.spark.sql.errors.QueryExecutionErrors
2930
import org.apache.spark.sql.execution.metric.{CustomMetrics, SQLMetric}
@@ -33,6 +34,19 @@ import org.apache.spark.util.ArrayImplicits._
3334
class DataSourceRDDPartition(val index: Int, val inputPartition: Option[InputPartition])
3435
extends Partition with Serializable
3536

37+
/**
38+
* Holds the state in a thread, used by the completion listener to access the most recently created
39+
* reader and iterator for final metrics updates and cleanup.
40+
*
41+
* @param reader The partition reader
42+
* @param iterator The metrics iterator wrapping the reader
43+
* @param metrics Optional array of custom task metrics from the previous reader
44+
*/
45+
private case class State(
46+
reader: PartitionReader[_],
47+
iterator: MetricsIterator[_],
48+
metrics: Option[Array[CustomTaskMetric]])
49+
3650
// TODO: we should have 2 RDDs: an RDD[InternalRow] for row-based scan, an `RDD[ColumnarBatch]` for
3751
// columnar scan.
3852
class DataSourceRDD(
@@ -43,6 +57,10 @@ class DataSourceRDD(
4357
customMetrics: Map[String, SQLMetric])
4458
extends RDD[InternalRow](sc, Nil) {
4559

60+
// ThreadLocal to store the last state for this thread.
61+
// A null value indicates that no completion listener has been added yet.
62+
@transient lazy private val lastThreadLocal = new ThreadLocal[State]()
63+
4664
override protected def getPartitions: Array[Partition] = {
4765
inputPartitions.zipWithIndex.map {
4866
case (inputPartition, index) => new DataSourceRDDPartition(index, inputPartition)
@@ -59,21 +77,39 @@ class DataSourceRDD(
5977
val (iter, reader) = if (columnarReads) {
6078
val batchReader = partitionReaderFactory.createColumnarReader(inputPartition)
6179
val iter = new MetricsBatchIterator(
62-
new PartitionIterator[ColumnarBatch](batchReader, customMetrics))
80+
new PartitionIterator[ColumnarBatch](batchReader, customMetrics), lastThreadLocal)
6381
(iter, batchReader)
6482
} else {
6583
val rowReader = partitionReaderFactory.createReader(inputPartition)
6684
val iter = new MetricsRowIterator(
67-
new PartitionIterator[InternalRow](rowReader, customMetrics))
85+
new PartitionIterator[InternalRow](rowReader, customMetrics), lastThreadLocal)
6886
(iter, rowReader)
6987
}
70-
context.addTaskCompletionListener[Unit] { _ =>
71-
// In case of early stopping before consuming the entire iterator,
72-
// we need to do one more metric update at the end of the task.
73-
CustomMetrics.updateMetrics(reader.currentMetricsValues.toImmutableArraySeq, customMetrics)
74-
iter.forceUpdateMetrics()
75-
reader.close()
88+
89+
// Add completion listener only once per thread (null means no listener added yet)
90+
val last = lastThreadLocal.get()
91+
if (last == null) {
92+
context.addTaskCompletionListener[Unit] { _ =>
93+
// Use the reader and iterator from ThreadLocal (the last ones created in this thread)
94+
val last = lastThreadLocal.get()
95+
if (last != null) {
96+
// In case of early stopping before consuming the entire iterator,
97+
// we need to do one more metric update at the end of the task.
98+
CustomMetrics.updateMetrics(
99+
last.reader.currentMetricsValues.toImmutableArraySeq, customMetrics)
100+
last.iterator.forceUpdateMetrics()
101+
last.reader.close()
102+
}
103+
lastThreadLocal.remove()
104+
}
105+
} else {
106+
last.metrics.foreach(reader.initMetricsValues)
76107
}
108+
109+
// Store the current reader and iterator in ThreadLocal so the completion listener
110+
// can access the most recently created instances
111+
lastThreadLocal.set(State(reader, iter, None))
112+
77113
// TODO: SPARK-25083 remove the type erasure hack in data source scan
78114
new InterruptibleIterator(context, iter.asInstanceOf[Iterator[InternalRow]])
79115
}
@@ -113,7 +149,7 @@ private class PartitionIterator[T](
113149
}
114150
}
115151

116-
private class MetricsHandler extends Logging with Serializable {
152+
private[spark] class MetricsHandler extends Logging with Serializable {
117153
private val inputMetrics = TaskContext.get().taskMetrics().inputMetrics
118154
private val startingBytesRead = inputMetrics.bytesRead
119155
private val getBytesRead = SparkHadoopUtil.get.getFSBytesReadOnThreadCallback()
@@ -128,13 +164,18 @@ private class MetricsHandler extends Logging with Serializable {
128164
}
129165
}
130166

131-
private abstract class MetricsIterator[I](iter: Iterator[I]) extends Iterator[I] {
167+
private[spark] abstract class MetricsIterator[I](
168+
iter: Iterator[I],
169+
lastThreadLocal: ThreadLocal[State]
170+
) extends Iterator[I] {
132171
protected val metricsHandler = new MetricsHandler
133172

134173
override def hasNext: Boolean = {
135174
if (iter.hasNext) {
136175
true
137176
} else {
177+
val last = lastThreadLocal.get()
178+
lastThreadLocal.set(last.copy(metrics = Some(last.reader.currentMetricsValues())))
138179
forceUpdateMetrics()
139180
false
140181
}
@@ -143,17 +184,21 @@ private abstract class MetricsIterator[I](iter: Iterator[I]) extends Iterator[I]
143184
def forceUpdateMetrics(): Unit = metricsHandler.updateMetrics(0, force = true)
144185
}
145186

146-
private class MetricsRowIterator(
147-
iter: Iterator[InternalRow]) extends MetricsIterator[InternalRow](iter) {
187+
private[spark] class MetricsRowIterator(
188+
iter: Iterator[InternalRow],
189+
lastThreadLocal: ThreadLocal[State]
190+
) extends MetricsIterator[InternalRow](iter, lastThreadLocal) {
148191
override def next(): InternalRow = {
149192
val item = iter.next()
150193
metricsHandler.updateMetrics(1)
151194
item
152195
}
153196
}
154197

155-
private class MetricsBatchIterator(
156-
iter: Iterator[ColumnarBatch]) extends MetricsIterator[ColumnarBatch](iter) {
198+
private[spark] class MetricsBatchIterator(
199+
iter: Iterator[ColumnarBatch],
200+
lastThreadLocal: ThreadLocal[State]
201+
) extends MetricsIterator[ColumnarBatch](iter, lastThreadLocal) {
157202
override def next(): ColumnarBatch = {
158203
val batch: ColumnarBatch = iter.next()
159204
metricsHandler.updateMetrics(batch.numRows)

0 commit comments

Comments
 (0)