diff --git a/streams/integration-tests/src/test/java/org/apache/kafka/streams/integration/StateUpdaterFailureIntegrationTest.java b/streams/integration-tests/src/test/java/org/apache/kafka/streams/integration/StateUpdaterFailureIntegrationTest.java
new file mode 100644
index 0000000000000..06d487f371293
--- /dev/null
+++ b/streams/integration-tests/src/test/java/org/apache/kafka/streams/integration/StateUpdaterFailureIntegrationTest.java
@@ -0,0 +1,157 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.kafka.streams.integration;
+
+import org.apache.kafka.clients.consumer.ConsumerConfig;
+import org.apache.kafka.common.serialization.Serdes;
+import org.apache.kafka.common.utils.MockTime;
+import org.apache.kafka.streams.KafkaStreams;
+import org.apache.kafka.streams.StreamsConfig;
+import org.apache.kafka.streams.TopologyWrapper;
+import org.apache.kafka.streams.errors.ProcessorStateException;
+import org.apache.kafka.streams.integration.utils.EmbeddedKafkaCluster;
+import org.apache.kafka.streams.processor.StateStore;
+import org.apache.kafka.streams.processor.StateStoreContext;
+import org.apache.kafka.streams.state.KeyValueStore;
+import org.apache.kafka.streams.state.StoreBuilder;
+import org.apache.kafka.streams.state.internals.AbstractStoreBuilder;
+import org.apache.kafka.test.MockApiProcessorSupplier;
+import org.apache.kafka.test.MockKeyValueStore;
+import org.apache.kafka.test.TestUtils;
+
+import org.junit.jupiter.api.AfterEach;
+import org.junit.jupiter.api.BeforeEach;
+import org.junit.jupiter.api.Test;
+import org.junit.jupiter.api.TestInfo;
+
+import java.io.IOException;
+import java.time.Duration;
+import java.util.Properties;
+import java.util.concurrent.CountDownLatch;
+import java.util.concurrent.TimeUnit;
+import java.util.concurrent.atomic.AtomicInteger;
+
+import static org.apache.kafka.streams.utils.TestUtils.safeUniqueTestName;
+import static org.junit.jupiter.api.Assertions.assertTrue;
+
+public class StateUpdaterFailureIntegrationTest {
+
+ private static final int NUM_BROKERS = 1;
+ protected static final String INPUT_TOPIC_NAME = "input-topic";
+ private static final int NUM_PARTITIONS = 6;
+
+ private final EmbeddedKafkaCluster cluster = new EmbeddedKafkaCluster(NUM_BROKERS);
+
+ private Properties streamsConfiguration;
+ private final MockTime mockTime = cluster.time;
+ private KafkaStreams streams;
+
+ @BeforeEach
+ public void before(final TestInfo testInfo) throws InterruptedException, IOException {
+ cluster.start();
+ cluster.createTopic(INPUT_TOPIC_NAME, NUM_PARTITIONS, 1);
+ streamsConfiguration = new Properties();
+ final String safeTestName = safeUniqueTestName(testInfo);
+ streamsConfiguration.put(StreamsConfig.APPLICATION_ID_CONFIG, "app-" + safeTestName);
+ streamsConfiguration.put(StreamsConfig.BOOTSTRAP_SERVERS_CONFIG, cluster.bootstrapServers());
+ streamsConfiguration.put(ConsumerConfig.AUTO_OFFSET_RESET_CONFIG, "earliest");
+ streamsConfiguration.put(StreamsConfig.STATE_DIR_CONFIG, TestUtils.tempDirectory().getPath());
+ streamsConfiguration.put(StreamsConfig.STATESTORE_CACHE_MAX_BYTES_CONFIG, 0);
+ streamsConfiguration.put(StreamsConfig.COMMIT_INTERVAL_MS_CONFIG, 100L);
+ streamsConfiguration.put(StreamsConfig.DEFAULT_KEY_SERDE_CLASS_CONFIG, Serdes.IntegerSerde.class);
+ streamsConfiguration.put(StreamsConfig.DEFAULT_VALUE_SERDE_CLASS_CONFIG, Serdes.StringSerde.class);
+ streamsConfiguration.put(StreamsConfig.NUM_STREAM_THREADS_CONFIG, 2);
+
+ }
+
+ @AfterEach
+ public void after() {
+ cluster.stop();
+ if (streams != null) {
+ streams.close(Duration.ofSeconds(30));
+ }
+ }
+
+ /**
+ * The conditions that we need to meet:
+ *
+ * We have an unhandled task in {@link org.apache.kafka.streams.processor.internals.DefaultStateUpdater}
+ * StreamThread is not running, so {@link org.apache.kafka.streams.processor.internals.TaskManager#handleExceptionsFromStateUpdater} is not called anymore
+ * The task throws exception in {@link org.apache.kafka.streams.processor.internals.Task#maybeCheckpoint(boolean)} while being processed by {@code DefaultStateUpdater}
+ * {@link org.apache.kafka.streams.processor.internals.TaskManager#shutdownStateUpdater} tries to clean up all tasks that are left in the {@code DefaultStateUpdater}
+ *
+ * If all conditions are met, {@code TaskManager} needs to be able to handle the failed task from the {@code DefaultStateUpdater} correctly and not hang.
+ */
+ @Test
+ public void correctlyHandleFlushErrorsDuringRebalance() throws Exception {
+ final AtomicInteger numberOfStoreInits = new AtomicInteger();
+ final CountDownLatch pendingShutdownLatch = new CountDownLatch(1);
+
+ final StoreBuilder> storeBuilder = new AbstractStoreBuilder<>("testStateStore", Serdes.Integer(), Serdes.ByteArray(), new MockTime()) {
+
+ @Override
+ public KeyValueStore build() {
+ return new MockKeyValueStore(name, false) {
+
+ @Override
+ public void init(final StateStoreContext stateStoreContext, final StateStore root) {
+ super.init(stateStoreContext, root);
+ numberOfStoreInits.incrementAndGet();
+ }
+
+ @Override
+ public void flush() {
+ // we want to throw the ProcessorStateException here only when the rebalance finished(we reassigned the 3 tasks from the removed thread to the existing thread)
+ // we use waitForCondition to wait until the current state is PENDING_SHUTDOWN to make sure the Stream Thread will not handle the exception and we can get to in TaskManager#shutdownStateUpdater
+ if (numberOfStoreInits.get() == 9) {
+ try {
+ pendingShutdownLatch.await();
+ } catch (final InterruptedException e) {
+ throw new RuntimeException(e);
+ }
+ throw new ProcessorStateException("flush");
+ }
+ }
+ };
+ }
+ };
+
+ final TopologyWrapper topology = new TopologyWrapper();
+ topology.addSource("ingest", INPUT_TOPIC_NAME);
+ topology.addProcessor("my-processor", new MockApiProcessorSupplier<>(), "ingest");
+ topology.addStateStore(storeBuilder, "my-processor");
+
+ streams = new KafkaStreams(topology, streamsConfiguration);
+ streams.setStateListener((newState, oldState) -> {
+ if (newState == KafkaStreams.State.PENDING_SHUTDOWN) {
+ pendingShutdownLatch.countDown();
+ }
+ });
+ streams.start();
+
+ TestUtils.waitForCondition(() -> streams.state() == KafkaStreams.State.RUNNING, "Streams never reached RUNNING state");
+
+ streams.removeStreamThread();
+
+ TestUtils.waitForCondition(() -> streams.state() == KafkaStreams.State.REBALANCING, TimeUnit.MINUTES.toMillis(2), "Streams never reached REBALANCING state");
+
+ // Before shutting down, we want the tasks to be reassigned
+ TestUtils.waitForCondition(() -> numberOfStoreInits.get() == 9, "Streams never reinitialized the store enough times");
+
+ assertTrue(streams.close(Duration.ofSeconds(60)));
+ }
+}
diff --git a/streams/src/main/java/org/apache/kafka/streams/processor/internals/DefaultStateUpdater.java b/streams/src/main/java/org/apache/kafka/streams/processor/internals/DefaultStateUpdater.java
index a3a44f6f02d31..30fdd499b0f8b 100644
--- a/streams/src/main/java/org/apache/kafka/streams/processor/internals/DefaultStateUpdater.java
+++ b/streams/src/main/java/org/apache/kafka/streams/processor/internals/DefaultStateUpdater.java
@@ -349,23 +349,26 @@ private void handleTaskCorruptedException(final TaskCorruptedException taskCorru
// TODO: we can let the exception encode the actual corrupted changelog partitions and only
// mark those instead of marking all changelogs
private void removeCheckpointForCorruptedTask(final Task task) {
- task.markChangelogAsCorrupted(task.changelogPartitions());
+ try {
+ task.markChangelogAsCorrupted(task.changelogPartitions());
- // we need to enforce a checkpoint that removes the corrupted partitions
- measureCheckpointLatency(() -> task.maybeCheckpoint(true));
+ // we need to enforce a checkpoint that removes the corrupted partitions
+ measureCheckpointLatency(() -> task.maybeCheckpoint(true));
+ } catch (final StreamsException swallow) {
+ log.warn("Checkpoint failed for corrupted task {}", task.id(), swallow);
+ }
}
private void handleStreamsException(final StreamsException streamsException) {
log.info("Encountered streams exception: ", streamsException);
if (streamsException.taskId().isPresent()) {
- handleStreamsExceptionWithTask(streamsException);
+ handleStreamsExceptionWithTask(streamsException, streamsException.taskId().get());
} else {
handleStreamsExceptionWithoutTask(streamsException);
}
}
- private void handleStreamsExceptionWithTask(final StreamsException streamsException) {
- final TaskId failedTaskId = streamsException.taskId().get();
+ private void handleStreamsExceptionWithTask(final StreamsException streamsException, final TaskId failedTaskId) {
if (updatingTasks.containsKey(failedTaskId)) {
addToExceptionsAndFailedTasksThenRemoveFromUpdatingTasks(
new ExceptionAndTask(streamsException, updatingTasks.get(failedTaskId))
@@ -518,7 +521,7 @@ private void removeTask(final TaskId taskId, final CompletableFuture task.maybeCheckpoint(true));
- pausedTasks.put(taskId, task);
- updatingTasks.remove(taskId);
- if (task.isActive()) {
- transitToUpdateStandbysIfOnlyStandbysLeft();
+ try {
+ measureCheckpointLatency(() -> task.maybeCheckpoint(true));
+ pausedTasks.put(taskId, task);
+ updatingTasks.remove(taskId);
+ if (task.isActive()) {
+ transitToUpdateStandbysIfOnlyStandbysLeft();
+ }
+ log.info((task.isActive() ? "Active" : "Standby")
+ + " task " + task.id() + " was paused from the updating tasks and added to the paused tasks.");
+
+ } catch (final StreamsException streamsException) {
+ handleStreamsExceptionWithTask(streamsException, taskId);
}
- log.info((task.isActive() ? "Active" : "Standby")
- + " task " + task.id() + " was paused from the updating tasks and added to the paused tasks.");
}
private void resumeTask(final Task task) {
@@ -671,11 +679,15 @@ private void maybeCompleteRestoration(final StreamTask task,
final Set restoredChangelogs) {
final Collection changelogPartitions = task.changelogPartitions();
if (restoredChangelogs.containsAll(changelogPartitions)) {
- measureCheckpointLatency(() -> task.maybeCheckpoint(true));
- changelogReader.unregister(changelogPartitions);
- addToRestoredTasks(task);
- log.info("Stateful active task " + task.id() + " completed restoration");
- transitToUpdateStandbysIfOnlyStandbysLeft();
+ try {
+ measureCheckpointLatency(() -> task.maybeCheckpoint(true));
+ changelogReader.unregister(changelogPartitions);
+ addToRestoredTasks(task);
+ log.info("Stateful active task " + task.id() + " completed restoration");
+ transitToUpdateStandbysIfOnlyStandbysLeft();
+ } catch (final StreamsException streamsException) {
+ handleStreamsExceptionWithTask(streamsException, task.id());
+ }
}
}
@@ -707,8 +719,12 @@ private void maybeCheckpointTasks(final long now) {
measureCheckpointLatency(() -> {
for (final Task task : updatingTasks.values()) {
- // do not enforce checkpointing during restoration if its position has not advanced much
- task.maybeCheckpoint(false);
+ try {
+ // do not enforce checkpointing during restoration if its position has not advanced much
+ task.maybeCheckpoint(false);
+ } catch (final StreamsException streamsException) {
+ handleStreamsExceptionWithTask(streamsException, task.id());
+ }
}
});
diff --git a/streams/src/main/java/org/apache/kafka/streams/processor/internals/TaskManager.java b/streams/src/main/java/org/apache/kafka/streams/processor/internals/TaskManager.java
index 9207ec74a7cb0..859d62906ab8e 100644
--- a/streams/src/main/java/org/apache/kafka/streams/processor/internals/TaskManager.java
+++ b/streams/src/main/java/org/apache/kafka/streams/processor/internals/TaskManager.java
@@ -65,6 +65,7 @@
import java.util.TreeSet;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ExecutionException;
+import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicReference;
import java.util.stream.Collectors;
import java.util.stream.Stream;
@@ -704,7 +705,7 @@ private StateUpdater.RemovedTaskResult waitForFuture(final TaskId taskId,
final CompletableFuture future) {
final StateUpdater.RemovedTaskResult removedTaskResult;
try {
- removedTaskResult = future.get();
+ removedTaskResult = future.get(5, TimeUnit.MINUTES);
if (removedTaskResult == null) {
throw new IllegalStateException("Task " + taskId + " was not found in the state updater. "
+ BUG_ERROR_MESSAGE);
@@ -719,6 +720,10 @@ private StateUpdater.RemovedTaskResult waitForFuture(final TaskId taskId,
Thread.currentThread().interrupt();
log.error(INTERRUPTED_ERROR_MESSAGE, shouldNotHappen);
throw new IllegalStateException(INTERRUPTED_ERROR_MESSAGE, shouldNotHappen);
+ } catch (final java.util.concurrent.TimeoutException timeoutException) {
+ log.warn("The state updater wasn't able to remove task {} in time. The state updater thread may be dead. "
+ + BUG_ERROR_MESSAGE, taskId, timeoutException);
+ return null;
}
}
@@ -1499,6 +1504,12 @@ void shutdown(final boolean clean) {
private void shutdownStateUpdater() {
if (stateUpdater != null) {
+ // If there are failed tasks handling them first
+ for (final StateUpdater.ExceptionAndTask exceptionAndTask : stateUpdater.drainExceptionsAndFailedTasks()) {
+ final Task failedTask = exceptionAndTask.task();
+ closeTaskDirty(failedTask, false);
+ }
+
final Map> futures = new LinkedHashMap<>();
for (final Task task : stateUpdater.tasks()) {
final CompletableFuture future = stateUpdater.remove(task.id());
@@ -1507,7 +1518,8 @@ private void shutdownStateUpdater() {
final Set tasksToCloseClean = new HashSet<>();
final Set tasksToCloseDirty = new HashSet<>();
addToTasksToClose(futures, tasksToCloseClean, tasksToCloseDirty);
- stateUpdater.shutdown(Duration.ofMillis(Long.MAX_VALUE));
+ // at this point we removed all tasks, so the shutdown should not take a lot of time
+ stateUpdater.shutdown(Duration.ofMinutes(1L));
for (final Task task : tasksToCloseClean) {
tasks.addTask(task);
@@ -1515,16 +1527,22 @@ private void shutdownStateUpdater() {
for (final Task task : tasksToCloseDirty) {
closeTaskDirty(task, false);
}
+ // Handling all failures that occurred during the remove process
for (final StateUpdater.ExceptionAndTask exceptionAndTask : stateUpdater.drainExceptionsAndFailedTasks()) {
final Task failedTask = exceptionAndTask.task();
closeTaskDirty(failedTask, false);
}
+
+ // If there is anything left unhandled due to timeouts, handling now
+ for (final Task task : stateUpdater.tasks()) {
+ closeTaskDirty(task, false);
+ }
}
}
private void shutdownSchedulingTaskManager() {
if (schedulingTaskManager != null) {
- schedulingTaskManager.shutdown(Duration.ofMillis(Long.MAX_VALUE));
+ schedulingTaskManager.shutdown(Duration.ofMinutes(5L));
}
}
diff --git a/streams/src/test/java/org/apache/kafka/streams/processor/internals/DefaultStateUpdaterTest.java b/streams/src/test/java/org/apache/kafka/streams/processor/internals/DefaultStateUpdaterTest.java
index 8f767af93e01b..4f9d1b3c0e6b4 100644
--- a/streams/src/test/java/org/apache/kafka/streams/processor/internals/DefaultStateUpdaterTest.java
+++ b/streams/src/test/java/org/apache/kafka/streams/processor/internals/DefaultStateUpdaterTest.java
@@ -23,6 +23,7 @@
import org.apache.kafka.common.utils.MockTime;
import org.apache.kafka.common.utils.Time;
import org.apache.kafka.streams.StreamsConfig;
+import org.apache.kafka.streams.errors.ProcessorStateException;
import org.apache.kafka.streams.errors.StreamsException;
import org.apache.kafka.streams.errors.TaskCorruptedException;
import org.apache.kafka.streams.processor.TaskId;
@@ -73,6 +74,7 @@
import static org.mockito.ArgumentMatchers.anyBoolean;
import static org.mockito.ArgumentMatchers.anyMap;
import static org.mockito.Mockito.atLeast;
+import static org.mockito.Mockito.doAnswer;
import static org.mockito.Mockito.doReturn;
import static org.mockito.Mockito.doThrow;
import static org.mockito.Mockito.inOrder;
@@ -1717,6 +1719,114 @@ public void shouldRemoveMetricsWithoutInterference() {
}
}
+ @Test
+ public void shouldNotFailTheThreadIfMaybeCheckpointFails() throws Exception {
+ final StreamTask activeTask1 = statefulTask(TASK_0_0, Set.of(TOPIC_PARTITION_A_0)).inState(State.RESTORING).build();
+ final StreamTask activeTask2 = statefulTask(TASK_0_1, Set.of(TOPIC_PARTITION_A_0)).inState(State.RESTORING).build();
+ final StreamTask failedStatefulTask = statefulTask(TASK_0_2, Set.of(TOPIC_PARTITION_A_0)).inState(State.RESTORING).build();
+ final ProcessorStateException processorStateException = new ProcessorStateException("flush");
+ doThrow(processorStateException).when(failedStatefulTask).maybeCheckpoint(anyBoolean());
+
+ stateUpdater.add(failedStatefulTask);
+ stateUpdater.add(activeTask1);
+ stateUpdater.start();
+ verifyExceptionsAndFailedTasks(new ExceptionAndTask(processorStateException, failedStatefulTask));
+ verifyUpdatingTasks(activeTask1);
+
+ stateUpdater.add(activeTask2);
+ verifyUpdatingTasks(activeTask1, activeTask2);
+ }
+
+ @Test
+ public void shouldNotFailTheThreadIfMaybeCheckpointFailsForCorruptedTask() throws Exception {
+ final StreamTask activeTask1 = statefulTask(TASK_0_0, Set.of(TOPIC_PARTITION_A_0)).inState(State.RESTORING).build();
+ final StreamTask activeTask2 = statefulTask(TASK_0_1, Set.of(TOPIC_PARTITION_A_0)).inState(State.RESTORING).build();
+ final StreamTask failedStatefulTask = statefulTask(TASK_0_2, Set.of(TOPIC_PARTITION_A_0)).inState(State.RESTORING).build();
+ final ProcessorStateException processorStateException = new ProcessorStateException("flush");
+ doThrow(processorStateException).when(failedStatefulTask).maybeCheckpoint(anyBoolean());
+
+ final TaskCorruptedException taskCorruptedException = new TaskCorruptedException(Set.of(TASK_0_2));
+ when(changelogReader.restore(Map.of(
+ TASK_0_0, activeTask1,
+ TASK_0_2, failedStatefulTask))
+ ).thenThrow(taskCorruptedException);
+
+ stateUpdater.add(failedStatefulTask);
+ stateUpdater.add(activeTask1);
+ stateUpdater.start();
+ verifyExceptionsAndFailedTasks(new ExceptionAndTask(taskCorruptedException, failedStatefulTask));
+ verifyUpdatingTasks(activeTask1);
+
+ stateUpdater.add(activeTask2);
+ verifyUpdatingTasks(activeTask1, activeTask2);
+ }
+
+ @Test
+ public void shouldNotFailTheThreadIfMaybeCheckpointFailsDuringTaskRemoval() throws Exception {
+ final StreamTask activeTask1 = statefulTask(TASK_0_0, Set.of(TOPIC_PARTITION_A_0)).inState(State.RESTORING).build();
+ final StreamTask activeTask2 = statefulTask(TASK_0_1, Set.of(TOPIC_PARTITION_A_0)).inState(State.RESTORING).build();
+ final StreamTask failedStatefulTask = statefulTask(TASK_0_2, Set.of(TOPIC_PARTITION_A_0)).inState(State.RESTORING).build();
+ final ProcessorStateException processorStateException = new ProcessorStateException("flush");
+ final AtomicBoolean throwException = new AtomicBoolean(false);
+ doAnswer(invocation -> {
+ if (throwException.get()) {
+ throw processorStateException;
+ }
+ return null;
+ }).when(failedStatefulTask).maybeCheckpoint(anyBoolean());
+ when(changelogReader.allChangelogsCompleted()).thenReturn(true);
+
+ stateUpdater.add(failedStatefulTask);
+ stateUpdater.add(activeTask1);
+ stateUpdater.start();
+ verifyUpdatingTasks(failedStatefulTask, activeTask1);
+
+ throwException.set(true);
+ final ExecutionException exception = assertThrows(ExecutionException.class, () -> stateUpdater.remove(TASK_0_2).get());
+ assertEquals(processorStateException, exception.getCause());
+
+ stateUpdater.add(activeTask2);
+ verifyUpdatingTasks(activeTask1, activeTask2);
+ }
+
+ @Test
+ public void shouldNotFailTheThreadIfMaybeCheckpointFailsDuringTaskPause() throws Exception {
+ final StreamTask activeTask1 = statefulTask(TASK_0_0, Set.of(TOPIC_PARTITION_A_0)).inState(State.RESTORING).build();
+ final StreamTask activeTask2 = statefulTask(TASK_0_1, Set.of(TOPIC_PARTITION_A_0)).inState(State.RESTORING).build();
+ final StreamTask failedStatefulTask = statefulTask(TASK_0_2, Set.of(TOPIC_PARTITION_A_0)).inState(State.RESTORING).build();
+ final ProcessorStateException processorStateException = new ProcessorStateException("flush");
+ doThrow(processorStateException).when(failedStatefulTask).maybeCheckpoint(anyBoolean());
+ when(topologyMetadata.isPaused(null)).thenReturn(false).thenReturn(false).thenReturn(true);
+
+ stateUpdater.add(failedStatefulTask);
+ stateUpdater.add(activeTask1);
+ stateUpdater.start();
+ verifyExceptionsAndFailedTasks(new ExceptionAndTask(processorStateException, failedStatefulTask));
+ verifyPausedTasks(activeTask1);
+
+ stateUpdater.add(activeTask2);
+ verifyPausedTasks(activeTask1, activeTask2);
+ }
+
+ @Test
+ public void shouldNotFailTheThreadIfMaybeCheckpointFailsDuringTaskRestore() throws Exception {
+ final StreamTask activeTask1 = statefulTask(TASK_0_0, Set.of(TOPIC_PARTITION_A_0)).inState(State.RESTORING).build();
+ final StreamTask activeTask2 = statefulTask(TASK_0_1, Set.of(TOPIC_PARTITION_A_0)).inState(State.RESTORING).build();
+ final StreamTask failedStatefulTask = statefulTask(TASK_0_2, Set.of(TOPIC_PARTITION_B_0)).inState(State.RESTORING).build();
+ final ProcessorStateException processorStateException = new ProcessorStateException("flush");
+ doThrow(processorStateException).when(failedStatefulTask).maybeCheckpoint(anyBoolean());
+ when(changelogReader.completedChangelogs()).thenReturn(Set.of(TOPIC_PARTITION_B_0));
+
+ stateUpdater.add(failedStatefulTask);
+ stateUpdater.add(activeTask1);
+ stateUpdater.start();
+ verifyExceptionsAndFailedTasks(new ExceptionAndTask(processorStateException, failedStatefulTask));
+ verifyUpdatingTasks(activeTask1);
+
+ stateUpdater.add(activeTask2);
+ verifyUpdatingTasks(activeTask1, activeTask2);
+ }
+
private static List getMetricNames(final String threadId) {
final Map tagMap = Map.of("thread-id", threadId);
return List.of(
@@ -1779,7 +1889,8 @@ private void verifyRestoredActiveTasks(final StreamTask... tasks) throws Excepti
&& restoredTasks.size() == expectedRestoredTasks.size();
},
VERIFICATION_TIMEOUT,
- "Did not get all restored active task within the given timeout!"
+ () -> "Did not get all restored active task within the given timeout! Expected: "
+ + expectedRestoredTasks + ", actual: " + restoredTasks
);
}
}
@@ -1794,7 +1905,8 @@ private void verifyDrainingRestoredActiveTasks(final StreamTask... tasks) throws
&& restoredTasks.size() == expectedRestoredTasks.size();
},
VERIFICATION_TIMEOUT,
- "Did not get all restored active task within the given timeout!"
+ () -> "Did not get all restored active task within the given timeout! Expected: "
+ + expectedRestoredTasks + ", actual: " + restoredTasks
);
assertTrue(stateUpdater.drainRestoredActiveTasks(Duration.ZERO).isEmpty());
}
@@ -1816,7 +1928,8 @@ private void verifyUpdatingTasks(final Task... tasks) throws Exception {
&& updatingTasks.size() == expectedUpdatingTasks.size();
},
VERIFICATION_TIMEOUT,
- "Did not get all updating task within the given timeout!"
+ () -> "Did not get all updating task within the given timeout! Expected: "
+ + expectedUpdatingTasks + ", actual: " + updatingTasks
);
}
}
@@ -1831,7 +1944,8 @@ private void verifyUpdatingStandbyTasks(final StandbyTask... tasks) throws Excep
&& standbyTasks.size() == expectedStandbyTasks.size();
},
VERIFICATION_TIMEOUT,
- "Did not see all standby task within the given timeout!"
+ () -> "Did not see all standby task within the given timeout! Expected: "
+ + expectedStandbyTasks + ", actual: " + standbyTasks
);
}
@@ -1860,7 +1974,8 @@ private void verifyPausedTasks(final Task... tasks) throws Exception {
&& pausedTasks.size() == expectedPausedTasks.size();
},
VERIFICATION_TIMEOUT,
- "Did not get all paused task within the given timeout!"
+ () -> "Did not get all paused task within the given timeout! Expected: "
+ + expectedPausedTasks + ", actual: " + pausedTasks
);
}
}
@@ -1875,7 +1990,8 @@ private void verifyExceptionsAndFailedTasks(final ExceptionAndTask... exceptions
&& failedTasks.size() == expectedExceptionAndTasks.size();
},
VERIFICATION_TIMEOUT,
- "Did not get all exceptions and failed tasks within the given timeout!"
+ () -> "Did not get all exceptions and failed tasks within the given timeout! Expected: "
+ + expectedExceptionAndTasks + ", actual: " + failedTasks
);
}
@@ -1893,7 +2009,8 @@ private void verifyFailedTasks(final Class extends RuntimeException> clazz, fi
&& failedTasks.size() == expectedFailedTasks.size();
},
VERIFICATION_TIMEOUT,
- "Did not get all exceptions and failed tasks within the given timeout!"
+ () -> "Did not get all exceptions and failed tasks within the given timeout! Expected: "
+ + expectedFailedTasks + ", actual: " + failedTasks
);
}
@@ -1911,7 +2028,8 @@ private void verifyDrainingExceptionsAndFailedTasks(final ExceptionAndTask... ex
&& failedTasks.size() == expectedExceptionAndTasks.size();
},
VERIFICATION_TIMEOUT,
- "Did not get all exceptions and failed tasks within the given timeout!"
+ () -> "Did not get all exceptions and failed tasks within the given timeout! Expected: "
+ + expectedExceptionAndTasks + ", actual: " + failedTasks
);
assertFalse(stateUpdater.hasExceptionsAndFailedTasks());
assertTrue(stateUpdater.drainExceptionsAndFailedTasks().isEmpty());
diff --git a/streams/src/test/java/org/apache/kafka/streams/processor/internals/TaskManagerTest.java b/streams/src/test/java/org/apache/kafka/streams/processor/internals/TaskManagerTest.java
index 00b4d5d8a07fc..25cf8fdc401bf 100644
--- a/streams/src/test/java/org/apache/kafka/streams/processor/internals/TaskManagerTest.java
+++ b/streams/src/test/java/org/apache/kafka/streams/processor/internals/TaskManagerTest.java
@@ -126,6 +126,7 @@
import static org.mockito.Mockito.lenient;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.never;
+import static org.mockito.Mockito.spy;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.verifyNoInteractions;
@@ -3386,6 +3387,28 @@ public Set changelogPartitions() {
verify(activeTaskCreator).close();
}
+ @SuppressWarnings("unchecked")
+ @Test
+ public void shouldCloseTasksIfStateUpdaterTimesOutOnRemove() throws Exception {
+ final TaskManager taskManager = setUpTaskManager(ProcessingMode.AT_LEAST_ONCE, null, true);
+ final Map> assignment = mkMap(
+ mkEntry(taskId00, taskId00Partitions)
+ );
+ final Task task00 = spy(new StateMachineTask(taskId00, taskId00Partitions, true, stateManager));
+
+ when(activeTaskCreator.createTasks(any(), eq(assignment))).thenReturn(singletonList(task00));
+ taskManager.handleAssignment(assignment, emptyMap());
+
+ when(stateUpdater.tasks()).thenReturn(singleton(task00));
+ final CompletableFuture future = mock(CompletableFuture.class);
+ when(stateUpdater.remove(eq(taskId00))).thenReturn(future);
+ when(future.get(anyLong(), any())).thenThrow(new java.util.concurrent.TimeoutException());
+
+ taskManager.shutdown(true);
+
+ verify(task00).closeDirty();
+ }
+
@Test
public void shouldOnlyCommitRevokedStandbyTaskAndPropagatePrepareCommitException() {
setUpTaskManager(ProcessingMode.EXACTLY_ONCE_V2, false);
@@ -3553,13 +3576,14 @@ public void shouldShutDownStateUpdaterAndCloseFailedTasksDirty() {
.thenReturn(Arrays.asList(
new ExceptionAndTask(new RuntimeException(), failedStatefulTask),
new ExceptionAndTask(new RuntimeException(), failedStandbyTask))
- );
+ )
+ .thenReturn(Collections.emptyList());
final TaskManager taskManager = setUpTaskManager(ProcessingMode.AT_LEAST_ONCE, tasks, true);
taskManager.shutdown(true);
verify(activeTaskCreator).close();
- verify(stateUpdater).shutdown(Duration.ofMillis(Long.MAX_VALUE));
+ verify(stateUpdater).shutdown(Duration.ofMinutes(1L));
verify(failedStatefulTask).prepareCommit(false);
verify(failedStatefulTask).suspend();
verify(failedStatefulTask).closeDirty();
@@ -3572,7 +3596,7 @@ public void shouldShutdownSchedulingTaskManager() {
taskManager.shutdown(true);
- verify(schedulingTaskManager).shutdown(Duration.ofMillis(Long.MAX_VALUE));
+ verify(schedulingTaskManager).shutdown(Duration.ofMinutes(5L));
}
@Test
@@ -3597,8 +3621,8 @@ public void shouldShutDownStateUpdaterAndCloseDirtyTasksFailedDuringRemoval() {
removedFailedStatefulTask,
removedFailedStandbyTask,
removedFailedStatefulTaskDuringRemoval,
- removedFailedStandbyTaskDuringRemoval
- ));
+ removedFailedStandbyTaskDuringRemoval)
+ ).thenReturn(Collections.emptySet());
final CompletableFuture futureForRemovedStatefulTask = new CompletableFuture<>();
final CompletableFuture futureForRemovedStandbyTask = new CompletableFuture<>();
final CompletableFuture futureForRemovedFailedStatefulTask = new CompletableFuture<>();
@@ -3613,10 +3637,11 @@ public void shouldShutDownStateUpdaterAndCloseDirtyTasksFailedDuringRemoval() {
.thenReturn(futureForRemovedFailedStatefulTaskDuringRemoval);
when(stateUpdater.remove(removedFailedStandbyTaskDuringRemoval.id()))
.thenReturn(futureForRemovedFailedStandbyTaskDuringRemoval);
- when(stateUpdater.drainExceptionsAndFailedTasks()).thenReturn(Arrays.asList(
- new ExceptionAndTask(new StreamsException("KABOOM!"), removedFailedStatefulTaskDuringRemoval),
- new ExceptionAndTask(new StreamsException("KABOOM!"), removedFailedStandbyTaskDuringRemoval)
- ));
+ when(stateUpdater.drainExceptionsAndFailedTasks())
+ .thenReturn(Arrays.asList(
+ new ExceptionAndTask(new StreamsException("KABOOM!"), removedFailedStatefulTaskDuringRemoval),
+ new ExceptionAndTask(new StreamsException("KABOOM!"), removedFailedStandbyTaskDuringRemoval))
+ ).thenReturn(Collections.emptyList());
final TaskManager taskManager = setUpTaskManager(ProcessingMode.AT_LEAST_ONCE, tasks, true);
futureForRemovedStatefulTask.complete(new StateUpdater.RemovedTaskResult(removedStatefulTask));
futureForRemovedStandbyTask.complete(new StateUpdater.RemovedTaskResult(removedStandbyTask));
@@ -3631,7 +3656,7 @@ public void shouldShutDownStateUpdaterAndCloseDirtyTasksFailedDuringRemoval() {
taskManager.shutdown(true);
- verify(stateUpdater).shutdown(Duration.ofMillis(Long.MAX_VALUE));
+ verify(stateUpdater).shutdown(Duration.ofMinutes(1L));
verify(tasks).addTask(removedStatefulTask);
verify(tasks).addTask(removedStandbyTask);
verify(removedFailedStatefulTask).prepareCommit(false);