@@ -34,7 +34,7 @@ class DataSourceRDDPartition(val index: Int, val inputPartition: Option[InputPar
3434 extends Partition with Serializable
3535
3636/**
37- * Holds the state for a reader in a thread, used by the completion listener to access the most
37+ * Holds the state for a reader in a thread/task , used by the completion listener to access the most
3838 * recently created reader and iterator for final metrics updates and cleanup.
3939 *
4040 * When `compute()` is called multiple times on the same thread (e.g., different input partitions
@@ -69,7 +69,7 @@ class DataSourceRDD(
6969 customMetrics : Map [String , SQLMetric ])
7070 extends RDD [InternalRow ](sc, Nil ) {
7171
72- // ThreadLocal to store the last reader state for this thread.
72+ // ThreadLocal to store the last reader state for this thread/task .
7373 // A null value indicates that no completion listener has been added yet.
7474 @ transient lazy private val readerStateThreadLocal = new ThreadLocal [ReaderState ]()
7575
@@ -85,6 +85,24 @@ class DataSourceRDD(
8585 }
8686
8787 override def compute (split : Partition , context : TaskContext ): Iterator [InternalRow ] = {
88+ // In case of early stopping before consuming the entire iterator, we need to do one more metric
89+ // update at the end of the task.
90+ // Add completion listener only once per thread (null means no listener added yet)
91+ val readerState = readerStateThreadLocal.get()
92+ if (readerState == null ) {
93+ context.addTaskCompletionListener[Unit ] { _ =>
94+ // Use the reader and iterator from ThreadLocal (the last ones created in this thread/task)
95+ val readerState = readerStateThreadLocal.get()
96+ if (readerState != null ) {
97+ CustomMetrics .updateMetrics(
98+ readerState.reader.currentMetricsValues.toImmutableArraySeq, customMetrics)
99+ readerState.iterator.forceUpdateMetrics()
100+ readerState.reader.close()
101+ }
102+ readerStateThreadLocal.remove()
103+ }
104+ }
105+
88106 castPartition(split).inputPartition.iterator.flatMap { inputPartition =>
89107 val (iter, reader) = if (columnarReads) {
90108 val batchReader = partitionReaderFactory.createColumnarReader(inputPartition)
@@ -100,22 +118,10 @@ class DataSourceRDD(
100118
101119 // Add completion listener only once per thread (null means no listener added yet)
102120 val readerState = readerStateThreadLocal.get()
103- if (readerState == null ) {
104- context.addTaskCompletionListener[Unit ] { _ =>
105- // Use the reader and iterator from ThreadLocal (the last ones created in this thread)
106- val readerState = readerStateThreadLocal.get()
107- if (readerState != null ) {
108- // In case of early stopping before consuming the entire iterator,
109- // we need to do one more metric update at the end of the task.
110- CustomMetrics .updateMetrics(
111- readerState.reader.currentMetricsValues.toImmutableArraySeq, customMetrics)
112- readerState.iterator.forceUpdateMetrics()
113- readerState.reader.close()
114- }
115- readerStateThreadLocal.remove()
116- }
117- } else {
118- reader.initMetricsValues(readerState.reader.currentMetricsValues())
121+ if (readerState != null ) {
122+ val metrics = readerState.reader.currentMetricsValues
123+ CustomMetrics .updateMetrics(metrics.toImmutableArraySeq, customMetrics)
124+ reader.initMetricsValues(metrics)
119125 readerState.reader.close()
120126 }
121127
0 commit comments