Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@ import org.apache.spark.deploy.SparkHadoopUtil
import org.apache.spark.internal.Logging
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.connector.metric.CustomTaskMetric
import org.apache.spark.sql.connector.read.{InputPartition, PartitionReader, PartitionReaderFactory}
import org.apache.spark.sql.errors.QueryExecutionErrors
import org.apache.spark.sql.execution.metric.{CustomMetrics, SQLMetric}
Expand All @@ -34,6 +33,19 @@ import org.apache.spark.util.ArrayImplicits._
class DataSourceRDDPartition(val index: Int, val inputPartitions: Seq[InputPartition])
extends Partition with Serializable

/**
* Holds the state for a reader in a thread/task, used by the completion listener to access the most
* recently created reader and iterator for final metrics updates and cleanup.
*
* When `compute()` is called multiple times on the same thread (e.g., different input partitions
* of a scan are coalesced), this state is updated to track the most recent reader. The task
* completion listener then uses this most recent reader for final cleanup and metrics reporting.
*
* @param reader The partition reader
* @param iterator The metrics iterator wrapping the reader
*/
private case class ReaderState(reader: PartitionReader[_], iterator: MetricsIterator[_])

// TODO: we should have 2 RDDs: an RDD[InternalRow] for row-based scan, an `RDD[ColumnarBatch]` for
// columnar scan.
class DataSourceRDD(
Expand All @@ -44,6 +56,10 @@ class DataSourceRDD(
customMetrics: Map[String, SQLMetric])
extends RDD[InternalRow](sc, Nil) {

// ThreadLocal to store the last reader state for this thread/task.
// A null value indicates that no completion listener has been added yet.
@transient lazy private val readerStateThreadLocal = new ThreadLocal[ReaderState]()

override protected def getPartitions: Array[Partition] = {
inputPartitions.zipWithIndex.map {
case (inputPartitions, index) => new DataSourceRDDPartition(index, inputPartitions)
Expand All @@ -57,19 +73,29 @@ class DataSourceRDD(

override def compute(split: Partition, context: TaskContext): Iterator[InternalRow] = {

// In case of early stopping before consuming the entire iterator, we need to do one more metric
// update at the end of the task.
// Add completion listener only once per thread (null means no listener added yet)
val readerState = readerStateThreadLocal.get()
if (readerState == null) {
context.addTaskCompletionListener[Unit] { _ =>
// Use the reader and iterator from ThreadLocal (the last ones created in this thread/task)
val readerState = readerStateThreadLocal.get()
if (readerState != null) {
CustomMetrics.updateMetrics(
readerState.reader.currentMetricsValues.toImmutableArraySeq, customMetrics)
readerState.iterator.forceUpdateMetrics()
readerState.reader.close()
}
readerStateThreadLocal.remove()
}
}

val iterator = new Iterator[Object] {
private val inputPartitions = castPartition(split).inputPartitions
private var currentIter: Option[Iterator[Object]] = None
private var currentIndex: Int = 0

private val partitionMetricCallback = new PartitionMetricCallback(customMetrics)

// In case of early stopping before consuming the entire iterator,
// we need to do one more metric update at the end of the task.
context.addTaskCompletionListener[Unit] { _ =>
partitionMetricCallback.execute()
}

override def hasNext: Boolean = currentIter.exists(_.hasNext) || advanceToNextIter()

override def next(): Object = {
Expand Down Expand Up @@ -97,9 +123,19 @@ class DataSourceRDD(
(iter, rowReader)
}

// Once we advance to the next partition, update the metric callback for early finish
val previousMetrics = partitionMetricCallback.advancePartition(iter, reader)
previousMetrics.foreach(reader.initMetricsValues)
val readerState = readerStateThreadLocal.get()
if (readerState != null) {
val metrics = readerState.reader.currentMetricsValues
CustomMetrics.updateMetrics(metrics.toImmutableArraySeq, customMetrics)
reader.initMetricsValues(metrics)
readerState.reader.close()
}

// Store the current reader and iterator in ThreadLocal so the completion listener can
// access the most recently created instances. On subsequent compute() calls in the same
// thread/task, the metrics from the previous reader are preserved and can be used to
// initialize the new reader.
readerStateThreadLocal.set(ReaderState(reader, iter))

currentIter = Some(iter)
hasNext
Expand All @@ -115,35 +151,6 @@ class DataSourceRDD(
}
}

private class PartitionMetricCallback
(customMetrics: Map[String, SQLMetric]) {
private var iter: MetricsIterator[_] = null
private var reader: PartitionReader[_] = null

def advancePartition(
iter: MetricsIterator[_],
reader: PartitionReader[_]): Option[Array[CustomTaskMetric]] = {
val metrics = execute()

this.iter = iter
this.reader = reader

metrics
}

def execute(): Option[Array[CustomTaskMetric]] = {
if (iter != null && reader != null) {
val metrics = reader.currentMetricsValues
CustomMetrics.updateMetrics(metrics.toImmutableArraySeq, customMetrics)
iter.forceUpdateMetrics()
reader.close()
Some(metrics)
} else {
None
}
}
}

private class PartitionIterator[T](
reader: PartitionReader[T],
customMetrics: Map[String, SQLMetric]) extends Iterator[T] {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2840,6 +2840,21 @@ class KeyGroupedPartitioningSuite extends DistributionAndOrderingSuiteBase {
assert(metrics("number of rows read") == "3")
}

test("SPARK-55619: Custom metrics of coalesced partitions") {
val items_partitions = Array(identity("id"))
createTable(items, itemsColumns, items_partitions)

sql(s"INSERT INTO testcat.ns.$items VALUES " +
"(1, 'aa', 40.0, cast('2020-01-01' as timestamp)), " +
"(2, 'bb', 10.0, cast('2021-01-01' as timestamp))")

val metrics = runAndFetchMetrics {
val df = sql(s"SELECT * FROM testcat.ns.$items").coalesce(1)
df.collect()
}
assert(metrics("number of rows read") == "2")
}

test("SPARK-55411: Fix ArrayIndexOutOfBoundsException when join keys " +
"are less than cluster keys") {
withSQLConf(
Expand Down