Skip to content

Commit 2f80341

Browse files
committed
[SPARK-55619][SQL] Fix custom metrics in case of coalesced partitions
1 parent 7643e09 commit 2f80341

File tree

2 files changed

+39
-18
lines changed

2 files changed

+39
-18
lines changed

sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceRDD.scala

Lines changed: 24 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ class DataSourceRDDPartition(val index: Int, val inputPartition: Option[InputPar
3434
extends Partition with Serializable
3535

3636
/**
37-
* Holds the state for a reader in a thread, used by the completion listener to access the most
37+
* Holds the state for a reader in a thread/task, used by the completion listener to access the most
3838
* recently created reader and iterator for final metrics updates and cleanup.
3939
*
4040
* When `compute()` is called multiple times on the same thread (e.g., different input partitions
@@ -69,7 +69,7 @@ class DataSourceRDD(
6969
customMetrics: Map[String, SQLMetric])
7070
extends RDD[InternalRow](sc, Nil) {
7171

72-
// ThreadLocal to store the last reader state for this thread.
72+
// ThreadLocal to store the last reader state for this thread/task.
7373
// A null value indicates that no completion listener has been added yet.
7474
@transient lazy private val readerStateThreadLocal = new ThreadLocal[ReaderState]()
7575

@@ -85,6 +85,24 @@ class DataSourceRDD(
8585
}
8686

8787
override def compute(split: Partition, context: TaskContext): Iterator[InternalRow] = {
88+
// In case of early stopping before consuming the entire iterator, we need to do one more metric
89+
// update at the end of the task.
90+
// Add completion listener only once per thread (null means no listener added yet)
91+
val readerState = readerStateThreadLocal.get()
92+
if (readerState == null) {
93+
context.addTaskCompletionListener[Unit] { _ =>
94+
// Use the reader and iterator from ThreadLocal (the last ones created in this thread/task)
95+
val readerState = readerStateThreadLocal.get()
96+
if (readerState != null) {
97+
CustomMetrics.updateMetrics(
98+
readerState.reader.currentMetricsValues.toImmutableArraySeq, customMetrics)
99+
readerState.iterator.forceUpdateMetrics()
100+
readerState.reader.close()
101+
}
102+
readerStateThreadLocal.remove()
103+
}
104+
}
105+
88106
castPartition(split).inputPartition.iterator.flatMap { inputPartition =>
89107
val (iter, reader) = if (columnarReads) {
90108
val batchReader = partitionReaderFactory.createColumnarReader(inputPartition)
@@ -100,22 +118,10 @@ class DataSourceRDD(
100118

101119
// Add completion listener only once per thread (null means no listener added yet)
102120
val readerState = readerStateThreadLocal.get()
103-
if (readerState == null) {
104-
context.addTaskCompletionListener[Unit] { _ =>
105-
// Use the reader and iterator from ThreadLocal (the last ones created in this thread)
106-
val readerState = readerStateThreadLocal.get()
107-
if (readerState != null) {
108-
// In case of early stopping before consuming the entire iterator,
109-
// we need to do one more metric update at the end of the task.
110-
CustomMetrics.updateMetrics(
111-
readerState.reader.currentMetricsValues.toImmutableArraySeq, customMetrics)
112-
readerState.iterator.forceUpdateMetrics()
113-
readerState.reader.close()
114-
}
115-
readerStateThreadLocal.remove()
116-
}
117-
} else {
118-
reader.initMetricsValues(readerState.reader.currentMetricsValues())
121+
if (readerState != null) {
122+
val metrics = readerState.reader.currentMetricsValues
123+
CustomMetrics.updateMetrics(metrics.toImmutableArraySeq, customMetrics)
124+
reader.initMetricsValues(metrics)
119125
readerState.reader.close()
120126
}
121127

sql/core/src/test/scala/org/apache/spark/sql/connector/KeyGroupedPartitioningSuite.scala

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2861,6 +2861,21 @@ class KeyGroupedPartitioningSuite extends DistributionAndOrderingSuiteBase {
28612861
assert(metrics("number of rows read") == "3")
28622862
}
28632863

2864+
test("SPARK-55619: Custom metrics of coalesced partitions") {
2865+
val items_partitions = Array(identity("id"))
2866+
createTable(items, itemsColumns, items_partitions)
2867+
2868+
sql(s"INSERT INTO testcat.ns.$items VALUES " +
2869+
"(1, 'aa', 40.0, cast('2020-01-01' as timestamp)), " +
2870+
"(2, 'bb', 10.0, cast('2021-01-01' as timestamp))")
2871+
2872+
val metrics = runAndFetchMetrics {
2873+
val df = sql(s"SELECT * FROM testcat.ns.$items").coalesce(1)
2874+
df.collect()
2875+
}
2876+
assert(metrics("number of rows read") == "2")
2877+
}
2878+
28642879
test("SPARK-55411: Fix ArrayIndexOutOfBoundsException when join keys " +
28652880
"are less than cluster keys") {
28662881
withSQLConf(

0 commit comments

Comments
 (0)