Skip to content

Commit 2f95817

Browse files
committed
[SPARK-54200][SS] Call close() against underlying InputPartition when LowLatencyReaderWrap.close() is called
### What changes were proposed in this pull request? This PR proposes to fix the bug of missing close() on underlying InputPartition when LowLatencyReaderWrap.close() is called. ### Why are the changes needed? Not closing the underlying InputPartition could leak resource; e.g. Kafka consumer is not returned to the pool, which ends up with destroying the purpose of connection pool and creating Kafka consumer instances every batch. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? A new UT for Kafka rather than general one, since Kafka data source has an internal metric to provide the necessary information for validation. ### Was this patch authored or co-authored using generative AI tooling? No Closes apache#52903 from HeartSaVioR/SPARK-54200. Authored-by: Jungtaek Lim <kabhwan.opensource@gmail.com> Signed-off-by: Jungtaek Lim <kabhwan.opensource@gmail.com>
1 parent 852daf5 commit 2f95817

File tree

4 files changed

+112
-3
lines changed

4 files changed

+112
-3
lines changed

connector/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/consumer/InternalKafkaConsumerPool.scala

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -129,6 +129,14 @@ private[consumer] class InternalKafkaConsumerPool(
129129

130130
def size(key: CacheKey): Int = numIdle(key) + numActive(key)
131131

132+
private[kafka010] def numActiveInGroupIdPrefix(groupIdPrefix: String): Int = {
133+
import scala.jdk.CollectionConverters._
134+
135+
pool.getNumActivePerKey().asScala.filter { case (key, _) =>
136+
key.startsWith(groupIdPrefix + "-")
137+
}.values.map(_.toInt).sum
138+
}
139+
132140
// TODO: revisit the relation between CacheKey and kafkaParams - for now it looks a bit weird
133141
// as we force all consumers having same (groupId, topicPartition) to have same kafkaParams
134142
// which might be viable in performance perspective (kafkaParams might be too huge to use

connector/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/consumer/KafkaDataConsumer.scala

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -848,4 +848,8 @@ private[kafka010] object KafkaDataConsumer extends Logging {
848848

849849
new KafkaDataConsumer(topicPartition, kafkaParams, consumerPool, fetchedDataPool)
850850
}
851+
852+
private[kafka010] def getActiveSizeInConsumerPool(groupIdPrefix: String): Int = {
853+
consumerPool.numActiveInGroupIdPrefix(groupIdPrefix)
854+
}
851855
}

connector/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaRealTimeModeSuite.scala

Lines changed: 97 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@
1717

1818
package org.apache.spark.sql.kafka010
1919

20+
import java.util.UUID
21+
2022
import org.scalatest.matchers.should.Matchers
2123
import org.scalatest.time.SpanSugar._
2224

@@ -26,6 +28,7 @@ import org.apache.spark.sql.execution.streaming._
2628
import org.apache.spark.sql.execution.streaming.sources.ContinuousMemorySink
2729
import org.apache.spark.sql.execution.streaming.state.RocksDBStateStoreProvider
2830
import org.apache.spark.sql.internal.SQLConf
31+
import org.apache.spark.sql.kafka010.consumer.KafkaDataConsumer
2932
import org.apache.spark.sql.streaming.{StreamingQuery, Trigger}
3033
import org.apache.spark.sql.streaming.OutputMode.Update
3134
import org.apache.spark.sql.streaming.util.GlobalSingletonManualClock
@@ -39,9 +42,7 @@ class KafkaRealTimeModeSuite
3942
override protected val defaultTrigger = RealTimeTrigger.apply("3 seconds")
4043

4144
override protected def sparkConf: SparkConf = {
42-
// Should turn to use StreamingShuffleManager when it is ready.
4345
super.sparkConf
44-
.set("spark.databricks.streaming.realTimeMode.enabled", "true")
4546
.set(
4647
SQLConf.STATE_STORE_PROVIDER_CLASS,
4748
classOf[RocksDBStateStoreProvider].getName)
@@ -679,3 +680,97 @@ class KafkaRealTimeModeSuite
679680
)
680681
}
681682
}
683+
684+
class KafkaConsumerPoolRealTimeModeSuite
685+
extends KafkaSourceTest
686+
with Matchers {
687+
override protected val defaultTrigger = RealTimeTrigger.apply("3 seconds")
688+
689+
override protected def sparkConf: SparkConf = {
690+
super.sparkConf
691+
.set(
692+
SQLConf.STATE_STORE_PROVIDER_CLASS,
693+
classOf[RocksDBStateStoreProvider].getName)
694+
}
695+
696+
import testImplicits._
697+
698+
override def beforeAll(): Unit = {
699+
super.beforeAll()
700+
spark.conf.set(
701+
SQLConf.STREAMING_REAL_TIME_MODE_MIN_BATCH_DURATION,
702+
defaultTrigger.batchDurationMs
703+
)
704+
}
705+
706+
test("SPARK-54200: Kafka consumers in consumer pool should be properly reused") {
707+
val topic = newTopic()
708+
testUtils.createTopic(topic, partitions = 2)
709+
710+
testUtils.sendMessages(topic, Array("1", "2"), Some(0))
711+
testUtils.sendMessages(topic, Array("3"), Some(1))
712+
713+
val groupIdPrefix = UUID.randomUUID().toString
714+
715+
val reader = spark
716+
.readStream
717+
.format("kafka")
718+
.option("kafka.bootstrap.servers", testUtils.brokerAddress)
719+
.option("subscribe", topic)
720+
.option("startingOffsets", "earliest")
721+
.option("groupIdPrefix", groupIdPrefix)
722+
.load()
723+
.selectExpr("CAST(value AS STRING)")
724+
.as[String]
725+
.map(_.toInt)
726+
.map(_ + 1)
727+
728+
// At any point of time, Kafka consumer pool should only contain at most 2 active instances.
729+
testStream(reader, Update, sink = new ContinuousMemorySink())(
730+
StartStream(),
731+
CheckAnswerWithTimeout(60000, 2, 3, 4),
732+
WaitUntilCurrentBatchProcessed,
733+
// After completion of batch 0
734+
new ExternalAction() {
735+
override def runAction(): Unit = {
736+
assertActiveSizeOnConsumerPool(groupIdPrefix, 2)
737+
738+
testUtils.sendMessages(topic, Array("4", "5"), Some(0))
739+
testUtils.sendMessages(topic, Array("6"), Some(1))
740+
}
741+
},
742+
CheckAnswerWithTimeout(5000, 2, 3, 4, 5, 6, 7),
743+
WaitUntilCurrentBatchProcessed,
744+
// After completion of batch 1
745+
new ExternalAction() {
746+
override def runAction(): Unit = {
747+
assertActiveSizeOnConsumerPool(groupIdPrefix, 2)
748+
749+
testUtils.sendMessages(topic, Array("7"), Some(1))
750+
}
751+
},
752+
CheckAnswerWithTimeout(5000, 2, 3, 4, 5, 6, 7, 8),
753+
WaitUntilCurrentBatchProcessed,
754+
// After completion of batch 2
755+
new ExternalAction() {
756+
override def runAction(): Unit = {
757+
assertActiveSizeOnConsumerPool(groupIdPrefix, 2)
758+
}
759+
},
760+
StopStream
761+
)
762+
}
763+
764+
/**
765+
* NOTE: This method leverages that we run test code, driver and executor in a same process in
766+
* a normal unit test setup (say, local[<number, or *>] in spark master). With that setup, we
767+
* can access singleton object directly.
768+
*/
769+
private def assertActiveSizeOnConsumerPool(
770+
groupIdPrefix: String,
771+
maxAllowedActiveSize: Int): Unit = {
772+
val activeSize = KafkaDataConsumer.getActiveSizeInConsumerPool(groupIdPrefix)
773+
assert(activeSize <= maxAllowedActiveSize, s"Consumer pool size is expected to be less " +
774+
s"than $maxAllowedActiveSize, but $activeSize.")
775+
}
776+
}

sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/RealTimeStreamScanExec.scala

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,9 @@ case class LowLatencyReaderWrap(
8383
reader.get()
8484
}
8585

86-
override def close(): Unit = {}
86+
override def close(): Unit = {
87+
reader.close()
88+
}
8789
}
8890

8991
/**

0 commit comments

Comments
 (0)