Skip to content

Commit 4dcecf1

Browse files
storage: use single batch find for a fetch
The fetch path had two batch finds, one for the DelayedFetch.tryCompleteDiskless and seconds at DelayedFetch.onComplete. Collect the batches at tryCompleteDiskless and reuse at onComplete. This reduces the database queries to batch coordinator.
1 parent 5875f03 commit 4dcecf1

File tree

10 files changed

+230
-202
lines changed

10 files changed

+230
-202
lines changed

core/src/main/scala/kafka/server/DelayedFetch.scala

Lines changed: 56 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
package kafka.server
1919

2020
import com.yammer.metrics.core.Meter
21-
import io.aiven.inkless.control_plane.FindBatchRequest
21+
import io.aiven.inkless.control_plane.FindBatchResponse
2222
import kafka.utils.Logging
2323

2424
import java.util.concurrent.TimeUnit
@@ -60,6 +60,7 @@ class DelayedFetch(
6060
minBytes: Option[Int] = None,
6161
responseCallback: Seq[(TopicIdPartition, FetchPartitionData)] => Unit,
6262
) extends DelayedOperation(maxWaitMs.getOrElse(params.maxWaitMs)) with Logging {
63+
var maybeBatchCoordinates: Option[Map[TopicIdPartition, FindBatchResponse]] = None
6364

6465
override def toString: String = {
6566
s"DelayedFetch(params=$params" +
@@ -153,7 +154,15 @@ class DelayedFetch(
153154
}
154155
}
155156

156-
tryCompleteDiskless(disklessFetchPartitionStatus) match {
157+
// adjust the max bytes for diskless fetches based on the percentage of diskless partitions
158+
// Complete the classic fetches first
159+
val classicRequestsSize = classicFetchPartitionStatus.size.toFloat
160+
val disklessRequestsSize = disklessFetchPartitionStatus.size.toFloat
161+
val totalRequestsSize = classicRequestsSize + disklessRequestsSize
162+
val disklessPercentage = disklessRequestsSize / totalRequestsSize
163+
val disklessParams = replicaManager.fetchParamsWithNewMaxBytes(params, disklessPercentage)
164+
165+
tryCompleteDiskless(disklessFetchPartitionStatus, disklessParams.maxBytes) match {
157166
case Some(disklessAccumulatedSize) => accumulatedSize += disklessAccumulatedSize
158167
case None => forceComplete()
159168
}
@@ -174,53 +183,55 @@ class DelayedFetch(
174183
* Case D: The fetch offset is equal to the end offset, meaning that we have reached the end of the log
175184
* Upon completion, should return whatever data is available for each valid partition
176185
*/
177-
private def tryCompleteDiskless(fetchPartitionStatus: Seq[(TopicIdPartition, FetchPartitionStatus)]): Option[Long] = {
186+
private def tryCompleteDiskless(
187+
fetchPartitionStatus: Seq[(TopicIdPartition, FetchPartitionStatus)],
188+
disklessMaxBytes: Int
189+
): Option[Long] = {
178190
var accumulatedSize = 0L
179191
val fetchPartitionStatusMap = fetchPartitionStatus.toMap
180-
val requests = fetchPartitionStatus.map { case (topicIdPartition, fetchStatus) =>
181-
new FindBatchRequest(topicIdPartition, fetchStatus.startOffsetMetadata.messageOffset, fetchStatus.fetchInfo.maxBytes)
182-
}
183-
if (requests.isEmpty) return Some(0)
184192

185-
val response = try {
186-
replicaManager.findDisklessBatches(requests, Int.MaxValue)
193+
maybeBatchCoordinates = try {
194+
Some(replicaManager.findDisklessBatches(fetchPartitionStatus, disklessMaxBytes))
187195
} catch {
188196
case e: Throwable =>
189197
error("Error while trying to find diskless batches on delayed fetch.", e)
190198
return None // Case C
191199
}
192200

193-
response.get.asScala.foreach { r =>
194-
r.errors() match {
195-
case Errors.NONE =>
196-
if (r.batches().size() > 0) {
197-
// Gather topic id partition from first batch. Same for all batches in the response.
198-
val topicIdPartition = r.batches().get(0).metadata().topicIdPartition()
199-
val endOffset = r.highWatermark()
200-
201-
val fetchPartitionStatus = fetchPartitionStatusMap.get(topicIdPartition)
202-
if (fetchPartitionStatus.isEmpty) {
203-
warn(s"Fetch partition status for $topicIdPartition not found in delayed fetch $this.")
204-
return None // Case C
205-
}
206-
207-
val fetchOffset = fetchPartitionStatus.get.startOffsetMetadata
208-
// If the fetch offset is greater than the end offset, it means that the log has been truncated
209-
// If it is equal to the end offset, it means that we have reached the end of the log
210-
// If the fetch offset is less than the end offset, we can accumulate the size of the batches
211-
if (fetchOffset.messageOffset > endOffset) {
212-
// Truncation happened
213-
debug(s"Satisfying fetch $this since it is fetching later segments of partition $topicIdPartition.")
214-
return None // Case A
215-
} else if (fetchOffset.messageOffset < endOffset) {
216-
val bytesAvailable = r.estimatedByteSize(fetchOffset.messageOffset)
217-
accumulatedSize += bytesAvailable // Case B: accumulate the size of the batches
218-
} // Case D: same as fetchOffset == endOffset, no new data available
201+
maybeBatchCoordinates match {
202+
case Some(exists) =>
203+
exists.values.foreach { r =>
204+
r.errors() match {
205+
case Errors.NONE =>
206+
if (r.batches().size() > 0) {
207+
// Gather topic id partition from first batch. Same for all batches in the response.
208+
val topicIdPartition = r.batches().get(0).metadata().topicIdPartition()
209+
val endOffset = r.highWatermark()
210+
211+
val fetchPartitionStatus = fetchPartitionStatusMap.get(topicIdPartition)
212+
if (fetchPartitionStatus.isEmpty) {
213+
warn(s"Fetch partition status for $topicIdPartition not found in delayed fetch $this.")
214+
return None // Case C
215+
}
216+
217+
val fetchOffset = fetchPartitionStatus.get.startOffsetMetadata
218+
// If the fetch offset is greater than the end offset, it means that the log has been truncated
219+
// If it is equal to the end offset, it means that we have reached the end of the log
220+
// If the fetch offset is less than the end offset, we can accumulate the size of the batches
221+
if (fetchOffset.messageOffset > endOffset) {
222+
// Truncation happened
223+
debug(s"Satisfying fetch $this since it is fetching later segments of partition $topicIdPartition.")
224+
return None // Case A
225+
} else if (fetchOffset.messageOffset < endOffset) {
226+
val bytesAvailable = r.estimatedByteSize(fetchOffset.messageOffset)
227+
accumulatedSize += bytesAvailable // Case B: accumulate the size of the batches
228+
} // Case D: same as fetchOffset == endOffset, no new data available
229+
}
230+
case _ => return None // Case C
219231
}
220-
case _ => return None // Case C
221-
}
232+
}
233+
case None => // Case D
222234
}
223-
224235
Some(accumulatedSize)
225236
}
226237

@@ -272,13 +283,16 @@ class DelayedFetch(
272283

273284
if (disklessRequestsSize > 0) {
274285
// Classic fetches are complete, now handle diskless fetches
275-
// adjust the max bytes for diskless fetches based on the percentage of diskless partitions
276-
val disklessPercentage = disklessRequestsSize / totalRequestsSize
277-
val disklessParams = replicaManager.fetchParamsWithNewMaxBytes(params, disklessPercentage)
278286
val disklessFetchInfos = disklessFetchPartitionStatus.map { case (tp, status) =>
279287
tp -> status.fetchInfo
280288
}
281-
val disklessFetchResponseFuture = replicaManager.fetchDisklessMessages(disklessParams, disklessFetchInfos)
289+
val batchCoordinates = maybeBatchCoordinates match {
290+
case Some(batchCoordinates) => batchCoordinates
291+
case None =>
292+
responseCallback(Seq.empty)
293+
return
294+
}
295+
val disklessFetchResponseFuture = replicaManager.fetchDisklessMessages(batchCoordinates, disklessFetchInfos)
282296

283297
// Combine the classic fetch results with the diskless fetch results
284298
disklessFetchResponseFuture.whenComplete { case (disklessFetchPartitionData, _) =>

core/src/main/scala/kafka/server/ReplicaManager.scala

Lines changed: 45 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1719,16 +1719,40 @@ class ReplicaManager(val config: KafkaConfig,
17191719
}
17201720
}
17211721

1722-
def findDisklessBatches(requests: Seq[FindBatchRequest], maxBytes: Int): Option[util.List[FindBatchResponse]] = {
1723-
inklessSharedState.map { sharedState =>
1724-
sharedState.controlPlane().findBatches(requests.asJava, maxBytes, sharedState.config().maxBatchesPerPartitionToFind())
1722+
def findDisklessBatches(fetchPartitionStatus: Seq[(TopicIdPartition, FetchPartitionStatus)], maxBytes: Int): Map[TopicIdPartition, FindBatchResponse] = {
1723+
val requests = fetchPartitionStatus.map { case (topicIdPartition, fetchStatus) =>
1724+
new FindBatchRequest(topicIdPartition, fetchStatus.startOffsetMetadata.messageOffset, fetchStatus.fetchInfo.maxBytes)
1725+
}
1726+
if (requests.isEmpty) return Map.empty
1727+
1728+
val findBatchResponses = try {
1729+
inklessSharedState.map { sharedState =>
1730+
sharedState.controlPlane().findBatches(requests.asJava, maxBytes, sharedState.config().maxBatchesPerPartitionToFind())
1731+
}
1732+
} match {
1733+
case Some(responses) => responses
1734+
case None =>
1735+
return Map.empty
1736+
} catch {
1737+
case e: Throwable =>
1738+
// kala
1739+
trace("Error while trying to find diskless batches.", e)
1740+
return Map.empty
17251741
}
1742+
1743+
val topicPartitionToFindBatchResponse = collection.mutable.Map[TopicIdPartition, FindBatchResponse]()
1744+
for (i <- requests.indices) {
1745+
val request = requests(i)
1746+
val response = findBatchResponses.get(i)
1747+
topicPartitionToFindBatchResponse.update(request.topicIdPartition, response)
1748+
}
1749+
topicPartitionToFindBatchResponse;
17261750
}
17271751

1728-
def fetchDisklessMessages(params: FetchParams,
1752+
def fetchDisklessMessages(batchCoordinates: Map[TopicIdPartition, FindBatchResponse],
17291753
fetchInfos: Seq[(TopicIdPartition, PartitionData)]): CompletableFuture[Seq[(TopicIdPartition, FetchPartitionData)]] = {
17301754
inklessFetchHandler match {
1731-
case Some(handler) => handler.handle(params, fetchInfos.toMap.asJava).thenApply(_.asScala.toSeq)
1755+
case Some(handler) => handler.handle(batchCoordinates.asJava, fetchInfos.toMap.asJava).thenApply(_.asScala.toSeq)
17321756
case None =>
17331757
if (fetchInfos.nonEmpty)
17341758
error(s"Received diskless fetch request for topics ${fetchInfos.map(_._1.topic()).distinct.mkString(", ")} but diskless fetch handler is not available. " +
@@ -1830,6 +1854,8 @@ class ReplicaManager(val config: KafkaConfig,
18301854
delayedFetchPurgatory.tryCompleteElseWatch(delayedFetch, (classicDelayedFetchKeys ++ disklessDelayedFetchKeys).asJava)
18311855
}
18321856

1857+
// If there is nothing to fetch for classic topics,
1858+
// create delayed response and fetch possible diskless data there.
18331859
if (classicFetchInfos.isEmpty) {
18341860
delayedResponse(Seq.empty)
18351861
return
@@ -1894,9 +1920,18 @@ class ReplicaManager(val config: KafkaConfig,
18941920
// In case of remote fetches, synchronously wait for diskless records and then perform the remote fetch.
18951921
// This is currently a workaround to avoid modifying the DelayedRemoteFetch in order to correctly process
18961922
// diskless fetches.
1923+
// Get diskless batch coordinates and hand over to fetching
1924+
val batchCoordinates = try {
1925+
findDisklessBatches(fetchPartitionStatus, Int.MaxValue)
1926+
} catch {
1927+
case e: Throwable =>
1928+
error("Error while trying to find diskless batches on remote fetch.", e)
1929+
responseCallback(Seq.empty)
1930+
return
1931+
}
1932+
18971933
val disklessFetchResults = try {
1898-
val disklessParams = fetchParamsWithNewMaxBytes(params, disklessFetchInfos.size.toFloat / fetchInfos.size.toFloat)
1899-
val disklessResponsesFuture = fetchDisklessMessages(disklessParams, disklessFetchInfos)
1934+
val disklessResponsesFuture = fetchDisklessMessages(batchCoordinates, disklessFetchInfos)
19001935

19011936
val response = disklessResponsesFuture.get(maxWaitMs, TimeUnit.MILLISECONDS)
19021937
response.map { case (tp, data) =>
@@ -1933,8 +1968,11 @@ class ReplicaManager(val config: KafkaConfig,
19331968
}
19341969
} else {
19351970
if (disklessFetchInfos.isEmpty && (bytesReadable >= params.minBytes || params.maxWaitMs <= 0)) {
1971+
// No remote fetch needed and not any diskless topics to be fetched.
1972+
// Response immediately.
19361973
responseCallback(fetchPartitionData)
19371974
} else {
1975+
// No remote fetch, requires fetching data from the diskless topics.
19381976
delayedResponse(fetchPartitionStatus)
19391977
}
19401978
}

core/src/test/scala/integration/kafka/server/DelayedFetchTest.scala

Lines changed: 19 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
*/
1717
package kafka.server
1818

19-
import io.aiven.inkless.control_plane.{BatchInfo, BatchMetadata, FindBatchRequest, FindBatchResponse}
19+
import io.aiven.inkless.control_plane.{BatchInfo, BatchMetadata, FindBatchResponse}
2020

2121
import java.util.{Collections, Optional, OptionalLong}
2222
import scala.collection.Seq
@@ -213,6 +213,9 @@ class DelayedFetchTest {
213213
responseCallback = callback
214214
)
215215

216+
val batchCoordinates = Map.empty[TopicIdPartition, FindBatchResponse]
217+
when(replicaManager.findDisklessBatches(any(), anyInt())).thenReturn(batchCoordinates)
218+
216219
val partition: Partition = mock(classOf[Partition])
217220
when(replicaManager.getPartitionOrException(topicIdPartition.topicPartition)).thenReturn(partition)
218221
// Note that the high-watermark does not contain the complete metadata
@@ -345,12 +348,13 @@ class DelayedFetchTest {
345348
)))
346349
when(mockResponse.highWatermark()).thenReturn(endOffset) // endOffset < fetchOffset (truncation)
347350

348-
val future = Some(Collections.singletonList(mockResponse))
349-
when(replicaManager.findDisklessBatches(any[Seq[FindBatchRequest]], anyInt())).thenReturn(future)
351+
val batchCoordinates = Map((topicIdPartition, mockResponse))
352+
353+
when(replicaManager.findDisklessBatches(any(), anyInt())).thenReturn(batchCoordinates)
350354

351355
// Mock fetchDisklessMessages for onComplete
352356
when(replicaManager.fetchParamsWithNewMaxBytes(any[FetchParams], any[Float])).thenAnswer(_.getArgument(0))
353-
when(replicaManager.fetchDisklessMessages(any[FetchParams], any[Seq[(TopicIdPartition, FetchRequest.PartitionData)]]))
357+
when(replicaManager.fetchDisklessMessages(any[Map[TopicIdPartition, FindBatchResponse]], any[Seq[(TopicIdPartition, FetchRequest.PartitionData)]]))
354358
.thenReturn(CompletableFuture.completedFuture(Seq((topicIdPartition, mock(classOf[FetchPartitionData])))))
355359

356360
when(replicaManager.readFromLog(
@@ -434,8 +438,8 @@ class DelayedFetchTest {
434438
when(mockResponse.highWatermark()).thenReturn(fetchOffset) // fetchOffset == endOffset (no new data)
435439
when(mockResponse.estimatedByteSize(fetchOffset)).thenReturn(estimatedBatchSize)
436440

437-
val future = Some(Collections.singletonList(mockResponse))
438-
when(replicaManager.findDisklessBatches(any[Seq[FindBatchRequest]], anyInt())).thenReturn(future)
441+
val future = Map((topicIdPartition, mockResponse))
442+
when(replicaManager.findDisklessBatches(any[Seq[(TopicIdPartition, FetchPartitionStatus)]], anyInt())).thenReturn(future)
439443

440444
when(replicaManager.readFromLog(
441445
fetchParams,
@@ -451,7 +455,7 @@ class DelayedFetchTest {
451455
assertFalse(fetchResultOpt.isDefined)
452456

453457
// Verify that estimatedByteSize is never called since fetchOffset == endOffset
454-
verify(replicaManager, never()).fetchDisklessMessages(any[FetchParams], any[Seq[(TopicIdPartition, FetchRequest.PartitionData)]])
458+
verify(replicaManager, never()).fetchDisklessMessages(any[Map[TopicIdPartition, FindBatchResponse]], any[Seq[(TopicIdPartition, FetchRequest.PartitionData)]])
455459
verify(mockResponse, never()).estimatedByteSize(anyLong())
456460
}
457461

@@ -519,8 +523,8 @@ class DelayedFetchTest {
519523
when(mockResponse.highWatermark()).thenReturn(endOffset) // endOffset > fetchOffset (data available)
520524
when(mockResponse.estimatedByteSize(fetchOffset)).thenReturn(estimatedBatchSize)
521525

522-
val future = Some(Collections.singletonList(mockResponse))
523-
when(replicaManager.findDisklessBatches(any[Seq[FindBatchRequest]], anyInt())).thenReturn(future)
526+
val batchCoordinates = Map((topicIdPartition, mockResponse))
527+
when(replicaManager.findDisklessBatches(any(), anyInt())).thenReturn(batchCoordinates)
524528

525529
when(replicaManager.readFromLog(
526530
fetchParams,
@@ -601,12 +605,12 @@ class DelayedFetchTest {
601605
when(mockResponse.highWatermark()).thenReturn(endOffset) // endOffset > fetchOffset (data available)
602606
when(mockResponse.estimatedByteSize(fetchOffset)).thenReturn(estimatedBatchSize)
603607

604-
val future = Some(Collections.singletonList(mockResponse))
605-
when(replicaManager.findDisklessBatches(any[Seq[FindBatchRequest]], anyInt())).thenReturn(future)
608+
val batchCoordinates = Map((topicIdPartition, mockResponse))
609+
when(replicaManager.findDisklessBatches(any(), anyInt())).thenReturn(batchCoordinates)
606610

607611
// Mock fetchDisklessMessages for onComplete
608612
when(replicaManager.fetchParamsWithNewMaxBytes(any[FetchParams], anyFloat())).thenAnswer(_.getArgument(0))
609-
when(replicaManager.fetchDisklessMessages(any[FetchParams], any[Seq[(TopicIdPartition, FetchRequest.PartitionData)]]))
613+
when(replicaManager.fetchDisklessMessages(any[Map[TopicIdPartition, FindBatchResponse]], any[Seq[(TopicIdPartition, FetchRequest.PartitionData)]]))
610614
.thenReturn(CompletableFuture.completedFuture(Seq((topicIdPartition, mock(classOf[FetchPartitionData])))))
611615

612616
when(replicaManager.readFromLog(
@@ -685,12 +689,12 @@ class DelayedFetchTest {
685689
)))
686690
when(mockResponse.highWatermark()).thenReturn(600L)
687691

688-
val future = Some(Collections.singletonList(mockResponse))
689-
when(replicaManager.findDisklessBatches(any[Seq[FindBatchRequest]], anyInt())).thenReturn(future)
692+
val batchCoordinates = Map((topicIdPartition, mockResponse))
693+
when(replicaManager.findDisklessBatches(any(), anyInt())).thenReturn(batchCoordinates)
690694

691695
// Mock fetchDisklessMessages for onComplete
692696
when(replicaManager.fetchParamsWithNewMaxBytes(any[FetchParams], anyFloat())).thenAnswer(_.getArgument(0))
693-
when(replicaManager.fetchDisklessMessages(any[FetchParams], any[Seq[(TopicIdPartition, FetchRequest.PartitionData)]]))
697+
when(replicaManager.fetchDisklessMessages(any[Map[TopicIdPartition, FindBatchResponse]], any[Seq[(TopicIdPartition, FetchRequest.PartitionData)]]))
694698
.thenReturn(CompletableFuture.completedFuture(Seq((topicIdPartition, mock(classOf[FetchPartitionData])))))
695699

696700
when(replicaManager.readFromLog(

core/src/test/scala/unit/kafka/server/ReplicaManagerTest.scala

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7098,9 +7098,10 @@ class ReplicaManagerTest {
70987098
// and the response does not satisfy minBytes, it should be delayed in the purgatory
70997099
// until the delayed fetch expires.
71007100
replicaManager.fetchMessages(fetchParams, fetchInfos, QuotaFactory.UNBOUNDED_QUOTA, responseCallback)
7101-
assertEquals(0, replicaManager.delayedFetchPurgatory.numDelayed())
7101+
assertEquals(1, replicaManager.delayedFetchPurgatory.numDelayed())
71027102

71037103
latch.await(10, TimeUnit.SECONDS) // Wait for the delayed fetch to expire
7104+
assertEquals(0, replicaManager.delayedFetchPurgatory.numDelayed())
71047105
assertNotNull(responseData)
71057106
assertEquals(2, responseData.size)
71067107
assertEquals(disklessResponse(disklessTopicPartition), responseData(disklessTopicPartition))

0 commit comments

Comments
 (0)