Skip to content

Commit 4b0ba42

Browse files
authored
KAFKA-19690 Add epoch check before verification guard check to prevent unexpected fatal error (#20607)
Cherry-pick changes (#20534) to 4.0 Conflicts: -> storage/src/main/java/org/apache/kafka/storage/internals/log/UnifiedLog.java - kept the file the same, and the rest of the code is in UnifiedLog.scala in 4.0 so added it there -> core/src/test/scala/unit/kafka/log/UnifiedLogTest.scala - just added the required test and kept everything else the same Reviewers: Justine Olshan [[email protected]](mailto:[email protected]), Chia-Ping Tsai [[email protected]](mailto:[email protected])
1 parent f181048 commit 4b0ba42

File tree

2 files changed

+102
-2
lines changed

2 files changed

+102
-2
lines changed

core/src/main/scala/kafka/log/UnifiedLog.scala

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1057,8 +1057,17 @@ class UnifiedLog(@volatile var logStartOffset: Long,
10571057
// transaction is completed or aborted. We can guarantee the transaction coordinator knows about the transaction given step 1 and that the transaction is still
10581058
// ongoing. If the transaction is expected to be ongoing, we will not set a VerificationGuard. If the transaction is aborted, hasOngoingTransaction is false and
10591059
// requestVerificationGuard is the sentinel, so we will throw an error. A subsequent produce request (retry) should create verification state and return to phase 1.
1060-
if (batch.isTransactional && !hasOngoingTransaction(batch.producerId, batch.producerEpoch()) && batchMissingRequiredVerification(batch, requestVerificationGuard))
1061-
throw new InvalidTxnStateException("Record was not part of an ongoing transaction")
1060+
if (batch.isTransactional && !hasOngoingTransaction(batch.producerId, batch.producerEpoch)) {
1061+
// Check epoch first: if producer epoch is stale, throw recoverable InvalidProducerEpochException.
1062+
val entry = producerStateManager.activeProducers.get(batch.producerId)
1063+
if (entry != null && batch.producerEpoch < entry.producerEpoch) {
1064+
val message = "Epoch of producer " + batch.producerId + " is " + batch.producerEpoch + ", which is smaller than the last seen epoch " + entry.producerEpoch
1065+
throw new InvalidProducerEpochException(message)
1066+
}
1067+
// Only check verification if epoch is current
1068+
if (batchMissingRequiredVerification(batch, requestVerificationGuard))
1069+
throw new InvalidTxnStateException("Record was not part of an ongoing transaction")
1070+
}
10621071
}
10631072

10641073
// We cache offset metadata for the start of each transaction. This allows us to

core/src/test/scala/unit/kafka/log/UnifiedLogTest.scala

Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,7 @@ import org.mockito.Mockito.{doAnswer, doThrow, spy}
5757
import net.jqwik.api.AfterFailureMode
5858
import net.jqwik.api.ForAll
5959
import net.jqwik.api.Property
60+
import org.apache.kafka.server.common.RequestLocal
6061

6162
import java.io._
6263
import java.nio.ByteBuffer
@@ -4660,6 +4661,96 @@ class UnifiedLogTest {
46604661
assertEquals(new OffsetResultHolder(Optional.empty(), Optional.empty()), result)
46614662
}
46624663

4664+
@Test
4665+
def testStaleProducerEpochReturnsRecoverableErrorForTV1Clients(): Unit = {
4666+
// Producer epoch gets incremented (coordinator fail over, completed transaction, etc.)
4667+
// and client has stale cached epoch. Fix prevents fatal InvalidTxnStateException.
4668+
4669+
val producerStateManagerConfig = new ProducerStateManagerConfig(86400000, true)
4670+
val logConfig = LogTestUtils.createLogConfig(segmentBytes = 2048 * 5)
4671+
val log = createLog(logDir, logConfig, producerStateManagerConfig = producerStateManagerConfig)
4672+
4673+
val producerId = 123L
4674+
val oldEpoch = 5.toShort
4675+
val newEpoch = 6.toShort
4676+
4677+
// Step 1: Simulate a scenario where producer epoch was incremented to fence the producer
4678+
val previousRecords = MemoryRecords.withTransactionalRecords(
4679+
Compression.NONE, producerId, newEpoch, 0,
4680+
new SimpleRecord("previous-key".getBytes, "previous-value".getBytes)
4681+
)
4682+
val previousGuard = log.maybeStartTransactionVerification(producerId, 0, newEpoch, false) // TV1 = supportsEpochBump = false
4683+
log.appendAsLeader(previousRecords, 0, AppendOrigin.CLIENT, RequestLocal.noCaching, previousGuard)
4684+
4685+
// Complete the transaction normally (commits do update producer state with current epoch)
4686+
val commitMarker = MemoryRecords.withEndTransactionMarker(
4687+
producerId, newEpoch, new EndTransactionMarker(ControlRecordType.COMMIT, 0)
4688+
)
4689+
log.appendAsLeader(commitMarker, 0, AppendOrigin.COORDINATOR, RequestLocal.noCaching, VerificationGuard.SENTINEL)
4690+
4691+
// Step 2: TV1 client tries to write with stale cached epoch (before learning about epoch increment)
4692+
val staleEpochRecords = MemoryRecords.withTransactionalRecords(
4693+
Compression.NONE, producerId, oldEpoch, 0,
4694+
new SimpleRecord("stale-epoch-key".getBytes, "stale-epoch-value".getBytes)
4695+
)
4696+
4697+
// Step 3: Verify our fix - should get InvalidProducerEpochException (recoverable), not InvalidTxnStateException (fatal)
4698+
val exception = assertThrows(classOf[InvalidProducerEpochException], () => {
4699+
val staleGuard = log.maybeStartTransactionVerification(producerId, 0, oldEpoch, false)
4700+
log.appendAsLeader(staleEpochRecords, 0, AppendOrigin.CLIENT, RequestLocal.noCaching, staleGuard)
4701+
})
4702+
4703+
// Verify the error message indicates epoch mismatch
4704+
assertTrue(exception.getMessage.contains("smaller than the last seen epoch"))
4705+
assertTrue(exception.getMessage.contains(s"$oldEpoch"))
4706+
assertTrue(exception.getMessage.contains(s"$newEpoch"))
4707+
}
4708+
4709+
@Test
4710+
def testStaleProducerEpochReturnsRecoverableErrorForTV2Clients(): Unit = {
4711+
// Check producer epoch FIRST - if stale, return recoverable error before verification checks.
4712+
4713+
val producerStateManagerConfig = new ProducerStateManagerConfig(86400000, true)
4714+
val logConfig = LogTestUtils.createLogConfig(segmentBytes = 2048 * 5)
4715+
val log = createLog(logDir, logConfig, producerStateManagerConfig = producerStateManagerConfig)
4716+
4717+
val producerId = 456L
4718+
val originalEpoch = 3.toShort
4719+
val bumpedEpoch = 4.toShort
4720+
4721+
// Step 1: Start transaction with epoch 3 (before timeout)
4722+
val initialRecords = MemoryRecords.withTransactionalRecords(
4723+
Compression.NONE, producerId, originalEpoch, 0,
4724+
new SimpleRecord("ks-initial-key".getBytes, "ks-initial-value".getBytes)
4725+
)
4726+
val initialGuard = log.maybeStartTransactionVerification(producerId, 0, originalEpoch, true) // TV2 = supportsEpochBump = true
4727+
log.appendAsLeader(initialRecords, 0, AppendOrigin.CLIENT, RequestLocal.noCaching, initialGuard)
4728+
4729+
// Step 2: Coordinator times out and aborts transaction
4730+
// TV2 (KIP-890): Coordinator bumps epoch from 3 → 4 and sends abort marker with epoch 4
4731+
val abortMarker = MemoryRecords.withEndTransactionMarker(
4732+
producerId, bumpedEpoch, new EndTransactionMarker(ControlRecordType.ABORT, 0)
4733+
)
4734+
log.appendAsLeader(abortMarker, 0, AppendOrigin.COORDINATOR, RequestLocal.noCaching, VerificationGuard.SENTINEL)
4735+
4736+
// Step 3: TV2 transactional producer tries to append with stale epoch (timeout recovery scenario)
4737+
val staleEpochRecords = MemoryRecords.withTransactionalRecords(
4738+
Compression.NONE, producerId, originalEpoch, 0,
4739+
new SimpleRecord("ks-resume-key".getBytes, "ks-resume-value".getBytes)
4740+
)
4741+
4742+
// Step 4: Verify our fix works for TV2 - should get InvalidProducerEpochException (recoverable), not InvalidTxnStateException (fatal)
4743+
val exception = assertThrows(classOf[InvalidProducerEpochException], () => {
4744+
val staleGuard = log.maybeStartTransactionVerification(producerId, 0, originalEpoch, true) // TV2 = supportsEpochBump = true
4745+
log.appendAsLeader(staleEpochRecords, 0, AppendOrigin.CLIENT, RequestLocal.noCaching, staleGuard)
4746+
})
4747+
4748+
// Verify the error message indicates epoch mismatch (3 < 4)
4749+
assertTrue(exception.getMessage.contains("smaller than the last seen epoch"))
4750+
assertTrue(exception.getMessage.contains(s"$originalEpoch"))
4751+
assertTrue(exception.getMessage.contains(s"$bumpedEpoch"))
4752+
}
4753+
46634754
private def appendTransactionalToBuffer(buffer: ByteBuffer,
46644755
producerId: Long,
46654756
producerEpoch: Short,

0 commit comments

Comments
 (0)