@@ -24,6 +24,7 @@ import org.apache.spark.deploy.SparkHadoopUtil
2424import org .apache .spark .internal .Logging
2525import org .apache .spark .rdd .RDD
2626import org .apache .spark .sql .catalyst .InternalRow
27+ import org .apache .spark .sql .connector .metric .CustomTaskMetric
2728import org .apache .spark .sql .connector .read .{InputPartition , PartitionReader , PartitionReaderFactory }
2829import org .apache .spark .sql .errors .QueryExecutionErrors
2930import org .apache .spark .sql .execution .metric .{CustomMetrics , SQLMetric }
@@ -33,6 +34,19 @@ import org.apache.spark.util.ArrayImplicits._
3334class DataSourceRDDPartition (val index : Int , val inputPartition : Option [InputPartition ])
3435 extends Partition with Serializable
3536
37+ /**
38+ * Holds the state in a thread, used by the completion listener to access the most recently created
39+ * reader and iterator for final metrics updates and cleanup.
40+ *
41+ * @param reader The partition reader
42+ * @param iterator The metrics iterator wrapping the reader
43+ * @param metrics Optional array of custom task metrics from the previous reader
44+ */
45+ private case class State (
46+ reader : PartitionReader [_],
47+ iterator : MetricsIterator [_],
48+ metrics : Option [Array [CustomTaskMetric ]])
49+
3650// TODO: we should have 2 RDDs: an RDD[InternalRow] for row-based scan, an `RDD[ColumnarBatch]` for
3751// columnar scan.
3852class DataSourceRDD (
@@ -43,6 +57,10 @@ class DataSourceRDD(
4357 customMetrics : Map [String , SQLMetric ])
4458 extends RDD [InternalRow ](sc, Nil ) {
4559
60+ // ThreadLocal to store the last state for this thread.
61+ // A null value indicates that no completion listener has been added yet.
62+ @ transient lazy private val lastThreadLocal = new ThreadLocal [State ]()
63+
4664 override protected def getPartitions : Array [Partition ] = {
4765 inputPartitions.zipWithIndex.map {
4866 case (inputPartition, index) => new DataSourceRDDPartition (index, inputPartition)
@@ -59,21 +77,39 @@ class DataSourceRDD(
5977 val (iter, reader) = if (columnarReads) {
6078 val batchReader = partitionReaderFactory.createColumnarReader(inputPartition)
6179 val iter = new MetricsBatchIterator (
62- new PartitionIterator [ColumnarBatch ](batchReader, customMetrics))
80+ new PartitionIterator [ColumnarBatch ](batchReader, customMetrics), lastThreadLocal )
6381 (iter, batchReader)
6482 } else {
6583 val rowReader = partitionReaderFactory.createReader(inputPartition)
6684 val iter = new MetricsRowIterator (
67- new PartitionIterator [InternalRow ](rowReader, customMetrics))
85+ new PartitionIterator [InternalRow ](rowReader, customMetrics), lastThreadLocal )
6886 (iter, rowReader)
6987 }
70- context.addTaskCompletionListener[Unit ] { _ =>
71- // In case of early stopping before consuming the entire iterator,
72- // we need to do one more metric update at the end of the task.
73- CustomMetrics .updateMetrics(reader.currentMetricsValues.toImmutableArraySeq, customMetrics)
74- iter.forceUpdateMetrics()
75- reader.close()
88+
89+ // Add completion listener only once per thread (null means no listener added yet)
90+ val last = lastThreadLocal.get()
91+ if (last == null ) {
92+ context.addTaskCompletionListener[Unit ] { _ =>
93+ // Use the reader and iterator from ThreadLocal (the last ones created in this thread)
94+ val last = lastThreadLocal.get()
95+ if (last != null ) {
96+ // In case of early stopping before consuming the entire iterator,
97+ // we need to do one more metric update at the end of the task.
98+ CustomMetrics .updateMetrics(
99+ last.reader.currentMetricsValues.toImmutableArraySeq, customMetrics)
100+ last.iterator.forceUpdateMetrics()
101+ last.reader.close()
102+ }
103+ lastThreadLocal.remove()
104+ }
105+ } else {
106+ last.metrics.foreach(reader.initMetricsValues)
76107 }
108+
109+ // Store the current reader and iterator in ThreadLocal so the completion listener
110+ // can access the most recently created instances
111+ lastThreadLocal.set(State (reader, iter, None ))
112+
77113 // TODO: SPARK-25083 remove the type erasure hack in data source scan
78114 new InterruptibleIterator (context, iter.asInstanceOf [Iterator [InternalRow ]])
79115 }
@@ -113,7 +149,7 @@ private class PartitionIterator[T](
113149 }
114150}
115151
116- private class MetricsHandler extends Logging with Serializable {
152+ private [spark] class MetricsHandler extends Logging with Serializable {
117153 private val inputMetrics = TaskContext .get().taskMetrics().inputMetrics
118154 private val startingBytesRead = inputMetrics.bytesRead
119155 private val getBytesRead = SparkHadoopUtil .get.getFSBytesReadOnThreadCallback()
@@ -128,13 +164,18 @@ private class MetricsHandler extends Logging with Serializable {
128164 }
129165}
130166
131- private abstract class MetricsIterator [I ](iter : Iterator [I ]) extends Iterator [I ] {
167+ private [spark] abstract class MetricsIterator [I ](
168+ iter : Iterator [I ],
169+ lastThreadLocal : ThreadLocal [State ]
170+ ) extends Iterator [I ] {
132171 protected val metricsHandler = new MetricsHandler
133172
134173 override def hasNext : Boolean = {
135174 if (iter.hasNext) {
136175 true
137176 } else {
177+ val last = lastThreadLocal.get()
178+ lastThreadLocal.set(last.copy(metrics = Some (last.reader.currentMetricsValues())))
138179 forceUpdateMetrics()
139180 false
140181 }
@@ -143,17 +184,21 @@ private abstract class MetricsIterator[I](iter: Iterator[I]) extends Iterator[I]
143184 def forceUpdateMetrics (): Unit = metricsHandler.updateMetrics(0 , force = true )
144185}
145186
146- private class MetricsRowIterator (
147- iter : Iterator [InternalRow ]) extends MetricsIterator [InternalRow ](iter) {
187+ private [spark] class MetricsRowIterator (
188+ iter : Iterator [InternalRow ],
189+ lastThreadLocal : ThreadLocal [State ]
190+ ) extends MetricsIterator [InternalRow ](iter, lastThreadLocal) {
148191 override def next (): InternalRow = {
149192 val item = iter.next()
150193 metricsHandler.updateMetrics(1 )
151194 item
152195 }
153196}
154197
155- private class MetricsBatchIterator (
156- iter : Iterator [ColumnarBatch ]) extends MetricsIterator [ColumnarBatch ](iter) {
198+ private [spark] class MetricsBatchIterator (
199+ iter : Iterator [ColumnarBatch ],
200+ lastThreadLocal : ThreadLocal [State ]
201+ ) extends MetricsIterator [ColumnarBatch ](iter, lastThreadLocal) {
157202 override def next (): ColumnarBatch = {
158203 val batch : ColumnarBatch = iter.next()
159204 metricsHandler.updateMetrics(batch.numRows)
0 commit comments