diff --git a/streams/integration-tests/src/test/java/org/apache/kafka/streams/integration/StateUpdaterFailureIntegrationTest.java b/streams/integration-tests/src/test/java/org/apache/kafka/streams/integration/StateUpdaterFailureIntegrationTest.java new file mode 100644 index 0000000000000..06d487f371293 --- /dev/null +++ b/streams/integration-tests/src/test/java/org/apache/kafka/streams/integration/StateUpdaterFailureIntegrationTest.java @@ -0,0 +1,157 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.integration; + +import org.apache.kafka.clients.consumer.ConsumerConfig; +import org.apache.kafka.common.serialization.Serdes; +import org.apache.kafka.common.utils.MockTime; +import org.apache.kafka.streams.KafkaStreams; +import org.apache.kafka.streams.StreamsConfig; +import org.apache.kafka.streams.TopologyWrapper; +import org.apache.kafka.streams.errors.ProcessorStateException; +import org.apache.kafka.streams.integration.utils.EmbeddedKafkaCluster; +import org.apache.kafka.streams.processor.StateStore; +import org.apache.kafka.streams.processor.StateStoreContext; +import org.apache.kafka.streams.state.KeyValueStore; +import org.apache.kafka.streams.state.StoreBuilder; +import org.apache.kafka.streams.state.internals.AbstractStoreBuilder; +import org.apache.kafka.test.MockApiProcessorSupplier; +import org.apache.kafka.test.MockKeyValueStore; +import org.apache.kafka.test.TestUtils; + +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.TestInfo; + +import java.io.IOException; +import java.time.Duration; +import java.util.Properties; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicInteger; + +import static org.apache.kafka.streams.utils.TestUtils.safeUniqueTestName; +import static org.junit.jupiter.api.Assertions.assertTrue; + +public class StateUpdaterFailureIntegrationTest { + + private static final int NUM_BROKERS = 1; + protected static final String INPUT_TOPIC_NAME = "input-topic"; + private static final int NUM_PARTITIONS = 6; + + private final EmbeddedKafkaCluster cluster = new EmbeddedKafkaCluster(NUM_BROKERS); + + private Properties streamsConfiguration; + private final MockTime mockTime = cluster.time; + private KafkaStreams streams; + + @BeforeEach + public void before(final TestInfo testInfo) throws InterruptedException, IOException { + cluster.start(); + cluster.createTopic(INPUT_TOPIC_NAME, NUM_PARTITIONS, 1); + streamsConfiguration = new Properties(); + final String safeTestName = safeUniqueTestName(testInfo); + streamsConfiguration.put(StreamsConfig.APPLICATION_ID_CONFIG, "app-" + safeTestName); + streamsConfiguration.put(StreamsConfig.BOOTSTRAP_SERVERS_CONFIG, cluster.bootstrapServers()); + streamsConfiguration.put(ConsumerConfig.AUTO_OFFSET_RESET_CONFIG, "earliest"); + streamsConfiguration.put(StreamsConfig.STATE_DIR_CONFIG, TestUtils.tempDirectory().getPath()); + streamsConfiguration.put(StreamsConfig.STATESTORE_CACHE_MAX_BYTES_CONFIG, 0); + streamsConfiguration.put(StreamsConfig.COMMIT_INTERVAL_MS_CONFIG, 100L); + streamsConfiguration.put(StreamsConfig.DEFAULT_KEY_SERDE_CLASS_CONFIG, Serdes.IntegerSerde.class); + streamsConfiguration.put(StreamsConfig.DEFAULT_VALUE_SERDE_CLASS_CONFIG, Serdes.StringSerde.class); + streamsConfiguration.put(StreamsConfig.NUM_STREAM_THREADS_CONFIG, 2); + + } + + @AfterEach + public void after() { + cluster.stop(); + if (streams != null) { + streams.close(Duration.ofSeconds(30)); + } + } + + /** + * The conditions that we need to meet: + *

+ * If all conditions are met, {@code TaskManager} needs to be able to handle the failed task from the {@code DefaultStateUpdater} correctly and not hang. + */ + @Test + public void correctlyHandleFlushErrorsDuringRebalance() throws Exception { + final AtomicInteger numberOfStoreInits = new AtomicInteger(); + final CountDownLatch pendingShutdownLatch = new CountDownLatch(1); + + final StoreBuilder> storeBuilder = new AbstractStoreBuilder<>("testStateStore", Serdes.Integer(), Serdes.ByteArray(), new MockTime()) { + + @Override + public KeyValueStore build() { + return new MockKeyValueStore(name, false) { + + @Override + public void init(final StateStoreContext stateStoreContext, final StateStore root) { + super.init(stateStoreContext, root); + numberOfStoreInits.incrementAndGet(); + } + + @Override + public void flush() { + // we want to throw the ProcessorStateException here only when the rebalance finished(we reassigned the 3 tasks from the removed thread to the existing thread) + // we use waitForCondition to wait until the current state is PENDING_SHUTDOWN to make sure the Stream Thread will not handle the exception and we can get to in TaskManager#shutdownStateUpdater + if (numberOfStoreInits.get() == 9) { + try { + pendingShutdownLatch.await(); + } catch (final InterruptedException e) { + throw new RuntimeException(e); + } + throw new ProcessorStateException("flush"); + } + } + }; + } + }; + + final TopologyWrapper topology = new TopologyWrapper(); + topology.addSource("ingest", INPUT_TOPIC_NAME); + topology.addProcessor("my-processor", new MockApiProcessorSupplier<>(), "ingest"); + topology.addStateStore(storeBuilder, "my-processor"); + + streams = new KafkaStreams(topology, streamsConfiguration); + streams.setStateListener((newState, oldState) -> { + if (newState == KafkaStreams.State.PENDING_SHUTDOWN) { + pendingShutdownLatch.countDown(); + } + }); + streams.start(); + + TestUtils.waitForCondition(() -> streams.state() == KafkaStreams.State.RUNNING, "Streams never reached RUNNING state"); + + streams.removeStreamThread(); + + TestUtils.waitForCondition(() -> streams.state() == KafkaStreams.State.REBALANCING, TimeUnit.MINUTES.toMillis(2), "Streams never reached REBALANCING state"); + + // Before shutting down, we want the tasks to be reassigned + TestUtils.waitForCondition(() -> numberOfStoreInits.get() == 9, "Streams never reinitialized the store enough times"); + + assertTrue(streams.close(Duration.ofSeconds(60))); + } +} diff --git a/streams/src/main/java/org/apache/kafka/streams/processor/internals/DefaultStateUpdater.java b/streams/src/main/java/org/apache/kafka/streams/processor/internals/DefaultStateUpdater.java index a3a44f6f02d31..30fdd499b0f8b 100644 --- a/streams/src/main/java/org/apache/kafka/streams/processor/internals/DefaultStateUpdater.java +++ b/streams/src/main/java/org/apache/kafka/streams/processor/internals/DefaultStateUpdater.java @@ -349,23 +349,26 @@ private void handleTaskCorruptedException(final TaskCorruptedException taskCorru // TODO: we can let the exception encode the actual corrupted changelog partitions and only // mark those instead of marking all changelogs private void removeCheckpointForCorruptedTask(final Task task) { - task.markChangelogAsCorrupted(task.changelogPartitions()); + try { + task.markChangelogAsCorrupted(task.changelogPartitions()); - // we need to enforce a checkpoint that removes the corrupted partitions - measureCheckpointLatency(() -> task.maybeCheckpoint(true)); + // we need to enforce a checkpoint that removes the corrupted partitions + measureCheckpointLatency(() -> task.maybeCheckpoint(true)); + } catch (final StreamsException swallow) { + log.warn("Checkpoint failed for corrupted task {}", task.id(), swallow); + } } private void handleStreamsException(final StreamsException streamsException) { log.info("Encountered streams exception: ", streamsException); if (streamsException.taskId().isPresent()) { - handleStreamsExceptionWithTask(streamsException); + handleStreamsExceptionWithTask(streamsException, streamsException.taskId().get()); } else { handleStreamsExceptionWithoutTask(streamsException); } } - private void handleStreamsExceptionWithTask(final StreamsException streamsException) { - final TaskId failedTaskId = streamsException.taskId().get(); + private void handleStreamsExceptionWithTask(final StreamsException streamsException, final TaskId failedTaskId) { if (updatingTasks.containsKey(failedTaskId)) { addToExceptionsAndFailedTasksThenRemoveFromUpdatingTasks( new ExceptionAndTask(streamsException, updatingTasks.get(failedTaskId)) @@ -518,7 +521,7 @@ private void removeTask(final TaskId taskId, final CompletableFuture task.maybeCheckpoint(true)); - pausedTasks.put(taskId, task); - updatingTasks.remove(taskId); - if (task.isActive()) { - transitToUpdateStandbysIfOnlyStandbysLeft(); + try { + measureCheckpointLatency(() -> task.maybeCheckpoint(true)); + pausedTasks.put(taskId, task); + updatingTasks.remove(taskId); + if (task.isActive()) { + transitToUpdateStandbysIfOnlyStandbysLeft(); + } + log.info((task.isActive() ? "Active" : "Standby") + + " task " + task.id() + " was paused from the updating tasks and added to the paused tasks."); + + } catch (final StreamsException streamsException) { + handleStreamsExceptionWithTask(streamsException, taskId); } - log.info((task.isActive() ? "Active" : "Standby") - + " task " + task.id() + " was paused from the updating tasks and added to the paused tasks."); } private void resumeTask(final Task task) { @@ -671,11 +679,15 @@ private void maybeCompleteRestoration(final StreamTask task, final Set restoredChangelogs) { final Collection changelogPartitions = task.changelogPartitions(); if (restoredChangelogs.containsAll(changelogPartitions)) { - measureCheckpointLatency(() -> task.maybeCheckpoint(true)); - changelogReader.unregister(changelogPartitions); - addToRestoredTasks(task); - log.info("Stateful active task " + task.id() + " completed restoration"); - transitToUpdateStandbysIfOnlyStandbysLeft(); + try { + measureCheckpointLatency(() -> task.maybeCheckpoint(true)); + changelogReader.unregister(changelogPartitions); + addToRestoredTasks(task); + log.info("Stateful active task " + task.id() + " completed restoration"); + transitToUpdateStandbysIfOnlyStandbysLeft(); + } catch (final StreamsException streamsException) { + handleStreamsExceptionWithTask(streamsException, task.id()); + } } } @@ -707,8 +719,12 @@ private void maybeCheckpointTasks(final long now) { measureCheckpointLatency(() -> { for (final Task task : updatingTasks.values()) { - // do not enforce checkpointing during restoration if its position has not advanced much - task.maybeCheckpoint(false); + try { + // do not enforce checkpointing during restoration if its position has not advanced much + task.maybeCheckpoint(false); + } catch (final StreamsException streamsException) { + handleStreamsExceptionWithTask(streamsException, task.id()); + } } }); diff --git a/streams/src/main/java/org/apache/kafka/streams/processor/internals/TaskManager.java b/streams/src/main/java/org/apache/kafka/streams/processor/internals/TaskManager.java index 9207ec74a7cb0..859d62906ab8e 100644 --- a/streams/src/main/java/org/apache/kafka/streams/processor/internals/TaskManager.java +++ b/streams/src/main/java/org/apache/kafka/streams/processor/internals/TaskManager.java @@ -65,6 +65,7 @@ import java.util.TreeSet; import java.util.concurrent.CompletableFuture; import java.util.concurrent.ExecutionException; +import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicReference; import java.util.stream.Collectors; import java.util.stream.Stream; @@ -704,7 +705,7 @@ private StateUpdater.RemovedTaskResult waitForFuture(final TaskId taskId, final CompletableFuture future) { final StateUpdater.RemovedTaskResult removedTaskResult; try { - removedTaskResult = future.get(); + removedTaskResult = future.get(5, TimeUnit.MINUTES); if (removedTaskResult == null) { throw new IllegalStateException("Task " + taskId + " was not found in the state updater. " + BUG_ERROR_MESSAGE); @@ -719,6 +720,10 @@ private StateUpdater.RemovedTaskResult waitForFuture(final TaskId taskId, Thread.currentThread().interrupt(); log.error(INTERRUPTED_ERROR_MESSAGE, shouldNotHappen); throw new IllegalStateException(INTERRUPTED_ERROR_MESSAGE, shouldNotHappen); + } catch (final java.util.concurrent.TimeoutException timeoutException) { + log.warn("The state updater wasn't able to remove task {} in time. The state updater thread may be dead. " + + BUG_ERROR_MESSAGE, taskId, timeoutException); + return null; } } @@ -1499,6 +1504,12 @@ void shutdown(final boolean clean) { private void shutdownStateUpdater() { if (stateUpdater != null) { + // If there are failed tasks handling them first + for (final StateUpdater.ExceptionAndTask exceptionAndTask : stateUpdater.drainExceptionsAndFailedTasks()) { + final Task failedTask = exceptionAndTask.task(); + closeTaskDirty(failedTask, false); + } + final Map> futures = new LinkedHashMap<>(); for (final Task task : stateUpdater.tasks()) { final CompletableFuture future = stateUpdater.remove(task.id()); @@ -1507,7 +1518,8 @@ private void shutdownStateUpdater() { final Set tasksToCloseClean = new HashSet<>(); final Set tasksToCloseDirty = new HashSet<>(); addToTasksToClose(futures, tasksToCloseClean, tasksToCloseDirty); - stateUpdater.shutdown(Duration.ofMillis(Long.MAX_VALUE)); + // at this point we removed all tasks, so the shutdown should not take a lot of time + stateUpdater.shutdown(Duration.ofMinutes(1L)); for (final Task task : tasksToCloseClean) { tasks.addTask(task); @@ -1515,16 +1527,22 @@ private void shutdownStateUpdater() { for (final Task task : tasksToCloseDirty) { closeTaskDirty(task, false); } + // Handling all failures that occurred during the remove process for (final StateUpdater.ExceptionAndTask exceptionAndTask : stateUpdater.drainExceptionsAndFailedTasks()) { final Task failedTask = exceptionAndTask.task(); closeTaskDirty(failedTask, false); } + + // If there is anything left unhandled due to timeouts, handling now + for (final Task task : stateUpdater.tasks()) { + closeTaskDirty(task, false); + } } } private void shutdownSchedulingTaskManager() { if (schedulingTaskManager != null) { - schedulingTaskManager.shutdown(Duration.ofMillis(Long.MAX_VALUE)); + schedulingTaskManager.shutdown(Duration.ofMinutes(5L)); } } diff --git a/streams/src/test/java/org/apache/kafka/streams/processor/internals/DefaultStateUpdaterTest.java b/streams/src/test/java/org/apache/kafka/streams/processor/internals/DefaultStateUpdaterTest.java index 8f767af93e01b..4f9d1b3c0e6b4 100644 --- a/streams/src/test/java/org/apache/kafka/streams/processor/internals/DefaultStateUpdaterTest.java +++ b/streams/src/test/java/org/apache/kafka/streams/processor/internals/DefaultStateUpdaterTest.java @@ -23,6 +23,7 @@ import org.apache.kafka.common.utils.MockTime; import org.apache.kafka.common.utils.Time; import org.apache.kafka.streams.StreamsConfig; +import org.apache.kafka.streams.errors.ProcessorStateException; import org.apache.kafka.streams.errors.StreamsException; import org.apache.kafka.streams.errors.TaskCorruptedException; import org.apache.kafka.streams.processor.TaskId; @@ -73,6 +74,7 @@ import static org.mockito.ArgumentMatchers.anyBoolean; import static org.mockito.ArgumentMatchers.anyMap; import static org.mockito.Mockito.atLeast; +import static org.mockito.Mockito.doAnswer; import static org.mockito.Mockito.doReturn; import static org.mockito.Mockito.doThrow; import static org.mockito.Mockito.inOrder; @@ -1717,6 +1719,114 @@ public void shouldRemoveMetricsWithoutInterference() { } } + @Test + public void shouldNotFailTheThreadIfMaybeCheckpointFails() throws Exception { + final StreamTask activeTask1 = statefulTask(TASK_0_0, Set.of(TOPIC_PARTITION_A_0)).inState(State.RESTORING).build(); + final StreamTask activeTask2 = statefulTask(TASK_0_1, Set.of(TOPIC_PARTITION_A_0)).inState(State.RESTORING).build(); + final StreamTask failedStatefulTask = statefulTask(TASK_0_2, Set.of(TOPIC_PARTITION_A_0)).inState(State.RESTORING).build(); + final ProcessorStateException processorStateException = new ProcessorStateException("flush"); + doThrow(processorStateException).when(failedStatefulTask).maybeCheckpoint(anyBoolean()); + + stateUpdater.add(failedStatefulTask); + stateUpdater.add(activeTask1); + stateUpdater.start(); + verifyExceptionsAndFailedTasks(new ExceptionAndTask(processorStateException, failedStatefulTask)); + verifyUpdatingTasks(activeTask1); + + stateUpdater.add(activeTask2); + verifyUpdatingTasks(activeTask1, activeTask2); + } + + @Test + public void shouldNotFailTheThreadIfMaybeCheckpointFailsForCorruptedTask() throws Exception { + final StreamTask activeTask1 = statefulTask(TASK_0_0, Set.of(TOPIC_PARTITION_A_0)).inState(State.RESTORING).build(); + final StreamTask activeTask2 = statefulTask(TASK_0_1, Set.of(TOPIC_PARTITION_A_0)).inState(State.RESTORING).build(); + final StreamTask failedStatefulTask = statefulTask(TASK_0_2, Set.of(TOPIC_PARTITION_A_0)).inState(State.RESTORING).build(); + final ProcessorStateException processorStateException = new ProcessorStateException("flush"); + doThrow(processorStateException).when(failedStatefulTask).maybeCheckpoint(anyBoolean()); + + final TaskCorruptedException taskCorruptedException = new TaskCorruptedException(Set.of(TASK_0_2)); + when(changelogReader.restore(Map.of( + TASK_0_0, activeTask1, + TASK_0_2, failedStatefulTask)) + ).thenThrow(taskCorruptedException); + + stateUpdater.add(failedStatefulTask); + stateUpdater.add(activeTask1); + stateUpdater.start(); + verifyExceptionsAndFailedTasks(new ExceptionAndTask(taskCorruptedException, failedStatefulTask)); + verifyUpdatingTasks(activeTask1); + + stateUpdater.add(activeTask2); + verifyUpdatingTasks(activeTask1, activeTask2); + } + + @Test + public void shouldNotFailTheThreadIfMaybeCheckpointFailsDuringTaskRemoval() throws Exception { + final StreamTask activeTask1 = statefulTask(TASK_0_0, Set.of(TOPIC_PARTITION_A_0)).inState(State.RESTORING).build(); + final StreamTask activeTask2 = statefulTask(TASK_0_1, Set.of(TOPIC_PARTITION_A_0)).inState(State.RESTORING).build(); + final StreamTask failedStatefulTask = statefulTask(TASK_0_2, Set.of(TOPIC_PARTITION_A_0)).inState(State.RESTORING).build(); + final ProcessorStateException processorStateException = new ProcessorStateException("flush"); + final AtomicBoolean throwException = new AtomicBoolean(false); + doAnswer(invocation -> { + if (throwException.get()) { + throw processorStateException; + } + return null; + }).when(failedStatefulTask).maybeCheckpoint(anyBoolean()); + when(changelogReader.allChangelogsCompleted()).thenReturn(true); + + stateUpdater.add(failedStatefulTask); + stateUpdater.add(activeTask1); + stateUpdater.start(); + verifyUpdatingTasks(failedStatefulTask, activeTask1); + + throwException.set(true); + final ExecutionException exception = assertThrows(ExecutionException.class, () -> stateUpdater.remove(TASK_0_2).get()); + assertEquals(processorStateException, exception.getCause()); + + stateUpdater.add(activeTask2); + verifyUpdatingTasks(activeTask1, activeTask2); + } + + @Test + public void shouldNotFailTheThreadIfMaybeCheckpointFailsDuringTaskPause() throws Exception { + final StreamTask activeTask1 = statefulTask(TASK_0_0, Set.of(TOPIC_PARTITION_A_0)).inState(State.RESTORING).build(); + final StreamTask activeTask2 = statefulTask(TASK_0_1, Set.of(TOPIC_PARTITION_A_0)).inState(State.RESTORING).build(); + final StreamTask failedStatefulTask = statefulTask(TASK_0_2, Set.of(TOPIC_PARTITION_A_0)).inState(State.RESTORING).build(); + final ProcessorStateException processorStateException = new ProcessorStateException("flush"); + doThrow(processorStateException).when(failedStatefulTask).maybeCheckpoint(anyBoolean()); + when(topologyMetadata.isPaused(null)).thenReturn(false).thenReturn(false).thenReturn(true); + + stateUpdater.add(failedStatefulTask); + stateUpdater.add(activeTask1); + stateUpdater.start(); + verifyExceptionsAndFailedTasks(new ExceptionAndTask(processorStateException, failedStatefulTask)); + verifyPausedTasks(activeTask1); + + stateUpdater.add(activeTask2); + verifyPausedTasks(activeTask1, activeTask2); + } + + @Test + public void shouldNotFailTheThreadIfMaybeCheckpointFailsDuringTaskRestore() throws Exception { + final StreamTask activeTask1 = statefulTask(TASK_0_0, Set.of(TOPIC_PARTITION_A_0)).inState(State.RESTORING).build(); + final StreamTask activeTask2 = statefulTask(TASK_0_1, Set.of(TOPIC_PARTITION_A_0)).inState(State.RESTORING).build(); + final StreamTask failedStatefulTask = statefulTask(TASK_0_2, Set.of(TOPIC_PARTITION_B_0)).inState(State.RESTORING).build(); + final ProcessorStateException processorStateException = new ProcessorStateException("flush"); + doThrow(processorStateException).when(failedStatefulTask).maybeCheckpoint(anyBoolean()); + when(changelogReader.completedChangelogs()).thenReturn(Set.of(TOPIC_PARTITION_B_0)); + + stateUpdater.add(failedStatefulTask); + stateUpdater.add(activeTask1); + stateUpdater.start(); + verifyExceptionsAndFailedTasks(new ExceptionAndTask(processorStateException, failedStatefulTask)); + verifyUpdatingTasks(activeTask1); + + stateUpdater.add(activeTask2); + verifyUpdatingTasks(activeTask1, activeTask2); + } + private static List getMetricNames(final String threadId) { final Map tagMap = Map.of("thread-id", threadId); return List.of( @@ -1779,7 +1889,8 @@ private void verifyRestoredActiveTasks(final StreamTask... tasks) throws Excepti && restoredTasks.size() == expectedRestoredTasks.size(); }, VERIFICATION_TIMEOUT, - "Did not get all restored active task within the given timeout!" + () -> "Did not get all restored active task within the given timeout! Expected: " + + expectedRestoredTasks + ", actual: " + restoredTasks ); } } @@ -1794,7 +1905,8 @@ private void verifyDrainingRestoredActiveTasks(final StreamTask... tasks) throws && restoredTasks.size() == expectedRestoredTasks.size(); }, VERIFICATION_TIMEOUT, - "Did not get all restored active task within the given timeout!" + () -> "Did not get all restored active task within the given timeout! Expected: " + + expectedRestoredTasks + ", actual: " + restoredTasks ); assertTrue(stateUpdater.drainRestoredActiveTasks(Duration.ZERO).isEmpty()); } @@ -1816,7 +1928,8 @@ private void verifyUpdatingTasks(final Task... tasks) throws Exception { && updatingTasks.size() == expectedUpdatingTasks.size(); }, VERIFICATION_TIMEOUT, - "Did not get all updating task within the given timeout!" + () -> "Did not get all updating task within the given timeout! Expected: " + + expectedUpdatingTasks + ", actual: " + updatingTasks ); } } @@ -1831,7 +1944,8 @@ private void verifyUpdatingStandbyTasks(final StandbyTask... tasks) throws Excep && standbyTasks.size() == expectedStandbyTasks.size(); }, VERIFICATION_TIMEOUT, - "Did not see all standby task within the given timeout!" + () -> "Did not see all standby task within the given timeout! Expected: " + + expectedStandbyTasks + ", actual: " + standbyTasks ); } @@ -1860,7 +1974,8 @@ private void verifyPausedTasks(final Task... tasks) throws Exception { && pausedTasks.size() == expectedPausedTasks.size(); }, VERIFICATION_TIMEOUT, - "Did not get all paused task within the given timeout!" + () -> "Did not get all paused task within the given timeout! Expected: " + + expectedPausedTasks + ", actual: " + pausedTasks ); } } @@ -1875,7 +1990,8 @@ private void verifyExceptionsAndFailedTasks(final ExceptionAndTask... exceptions && failedTasks.size() == expectedExceptionAndTasks.size(); }, VERIFICATION_TIMEOUT, - "Did not get all exceptions and failed tasks within the given timeout!" + () -> "Did not get all exceptions and failed tasks within the given timeout! Expected: " + + expectedExceptionAndTasks + ", actual: " + failedTasks ); } @@ -1893,7 +2009,8 @@ private void verifyFailedTasks(final Class clazz, fi && failedTasks.size() == expectedFailedTasks.size(); }, VERIFICATION_TIMEOUT, - "Did not get all exceptions and failed tasks within the given timeout!" + () -> "Did not get all exceptions and failed tasks within the given timeout! Expected: " + + expectedFailedTasks + ", actual: " + failedTasks ); } @@ -1911,7 +2028,8 @@ private void verifyDrainingExceptionsAndFailedTasks(final ExceptionAndTask... ex && failedTasks.size() == expectedExceptionAndTasks.size(); }, VERIFICATION_TIMEOUT, - "Did not get all exceptions and failed tasks within the given timeout!" + () -> "Did not get all exceptions and failed tasks within the given timeout! Expected: " + + expectedExceptionAndTasks + ", actual: " + failedTasks ); assertFalse(stateUpdater.hasExceptionsAndFailedTasks()); assertTrue(stateUpdater.drainExceptionsAndFailedTasks().isEmpty()); diff --git a/streams/src/test/java/org/apache/kafka/streams/processor/internals/TaskManagerTest.java b/streams/src/test/java/org/apache/kafka/streams/processor/internals/TaskManagerTest.java index 00b4d5d8a07fc..25cf8fdc401bf 100644 --- a/streams/src/test/java/org/apache/kafka/streams/processor/internals/TaskManagerTest.java +++ b/streams/src/test/java/org/apache/kafka/streams/processor/internals/TaskManagerTest.java @@ -126,6 +126,7 @@ import static org.mockito.Mockito.lenient; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.never; +import static org.mockito.Mockito.spy; import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.verifyNoInteractions; @@ -3386,6 +3387,28 @@ public Set changelogPartitions() { verify(activeTaskCreator).close(); } + @SuppressWarnings("unchecked") + @Test + public void shouldCloseTasksIfStateUpdaterTimesOutOnRemove() throws Exception { + final TaskManager taskManager = setUpTaskManager(ProcessingMode.AT_LEAST_ONCE, null, true); + final Map> assignment = mkMap( + mkEntry(taskId00, taskId00Partitions) + ); + final Task task00 = spy(new StateMachineTask(taskId00, taskId00Partitions, true, stateManager)); + + when(activeTaskCreator.createTasks(any(), eq(assignment))).thenReturn(singletonList(task00)); + taskManager.handleAssignment(assignment, emptyMap()); + + when(stateUpdater.tasks()).thenReturn(singleton(task00)); + final CompletableFuture future = mock(CompletableFuture.class); + when(stateUpdater.remove(eq(taskId00))).thenReturn(future); + when(future.get(anyLong(), any())).thenThrow(new java.util.concurrent.TimeoutException()); + + taskManager.shutdown(true); + + verify(task00).closeDirty(); + } + @Test public void shouldOnlyCommitRevokedStandbyTaskAndPropagatePrepareCommitException() { setUpTaskManager(ProcessingMode.EXACTLY_ONCE_V2, false); @@ -3553,13 +3576,14 @@ public void shouldShutDownStateUpdaterAndCloseFailedTasksDirty() { .thenReturn(Arrays.asList( new ExceptionAndTask(new RuntimeException(), failedStatefulTask), new ExceptionAndTask(new RuntimeException(), failedStandbyTask)) - ); + ) + .thenReturn(Collections.emptyList()); final TaskManager taskManager = setUpTaskManager(ProcessingMode.AT_LEAST_ONCE, tasks, true); taskManager.shutdown(true); verify(activeTaskCreator).close(); - verify(stateUpdater).shutdown(Duration.ofMillis(Long.MAX_VALUE)); + verify(stateUpdater).shutdown(Duration.ofMinutes(1L)); verify(failedStatefulTask).prepareCommit(false); verify(failedStatefulTask).suspend(); verify(failedStatefulTask).closeDirty(); @@ -3572,7 +3596,7 @@ public void shouldShutdownSchedulingTaskManager() { taskManager.shutdown(true); - verify(schedulingTaskManager).shutdown(Duration.ofMillis(Long.MAX_VALUE)); + verify(schedulingTaskManager).shutdown(Duration.ofMinutes(5L)); } @Test @@ -3597,8 +3621,8 @@ public void shouldShutDownStateUpdaterAndCloseDirtyTasksFailedDuringRemoval() { removedFailedStatefulTask, removedFailedStandbyTask, removedFailedStatefulTaskDuringRemoval, - removedFailedStandbyTaskDuringRemoval - )); + removedFailedStandbyTaskDuringRemoval) + ).thenReturn(Collections.emptySet()); final CompletableFuture futureForRemovedStatefulTask = new CompletableFuture<>(); final CompletableFuture futureForRemovedStandbyTask = new CompletableFuture<>(); final CompletableFuture futureForRemovedFailedStatefulTask = new CompletableFuture<>(); @@ -3613,10 +3637,11 @@ public void shouldShutDownStateUpdaterAndCloseDirtyTasksFailedDuringRemoval() { .thenReturn(futureForRemovedFailedStatefulTaskDuringRemoval); when(stateUpdater.remove(removedFailedStandbyTaskDuringRemoval.id())) .thenReturn(futureForRemovedFailedStandbyTaskDuringRemoval); - when(stateUpdater.drainExceptionsAndFailedTasks()).thenReturn(Arrays.asList( - new ExceptionAndTask(new StreamsException("KABOOM!"), removedFailedStatefulTaskDuringRemoval), - new ExceptionAndTask(new StreamsException("KABOOM!"), removedFailedStandbyTaskDuringRemoval) - )); + when(stateUpdater.drainExceptionsAndFailedTasks()) + .thenReturn(Arrays.asList( + new ExceptionAndTask(new StreamsException("KABOOM!"), removedFailedStatefulTaskDuringRemoval), + new ExceptionAndTask(new StreamsException("KABOOM!"), removedFailedStandbyTaskDuringRemoval)) + ).thenReturn(Collections.emptyList()); final TaskManager taskManager = setUpTaskManager(ProcessingMode.AT_LEAST_ONCE, tasks, true); futureForRemovedStatefulTask.complete(new StateUpdater.RemovedTaskResult(removedStatefulTask)); futureForRemovedStandbyTask.complete(new StateUpdater.RemovedTaskResult(removedStandbyTask)); @@ -3631,7 +3656,7 @@ public void shouldShutDownStateUpdaterAndCloseDirtyTasksFailedDuringRemoval() { taskManager.shutdown(true); - verify(stateUpdater).shutdown(Duration.ofMillis(Long.MAX_VALUE)); + verify(stateUpdater).shutdown(Duration.ofMinutes(1L)); verify(tasks).addTask(removedStatefulTask); verify(tasks).addTask(removedStandbyTask); verify(removedFailedStatefulTask).prepareCommit(false);