diff --git a/streams/src/main/java/org/apache/kafka/streams/processor/assignment/assignors/StickyTaskAssignor.java b/streams/src/main/java/org/apache/kafka/streams/processor/assignment/assignors/StickyTaskAssignor.java index 6ad61851ddae3..60c715fea4f3b 100644 --- a/streams/src/main/java/org/apache/kafka/streams/processor/assignment/assignors/StickyTaskAssignor.java +++ b/streams/src/main/java/org/apache/kafka/streams/processor/assignment/assignors/StickyTaskAssignor.java @@ -146,7 +146,6 @@ private static void assignActive(final ApplicationState applicationState, final int totalCapacity = computeTotalProcessingThreads(clients); final Set allTaskIds = applicationState.allTasks().keySet(); final int taskCount = allTaskIds.size(); - final int activeTasksPerThread = taskCount / totalCapacity; final Set unassigned = new HashSet<>(allTaskIds); // first try and re-assign existing active tasks to clients that previously had @@ -154,7 +153,7 @@ private static void assignActive(final ApplicationState applicationState, for (final TaskId taskId : assignmentState.previousActiveAssignment.keySet()) { final ProcessId previousClientForTask = assignmentState.previousActiveAssignment.get(taskId); if (allTaskIds.contains(taskId)) { - if (mustPreserveActiveTaskAssignment || assignmentState.hasRoomForActiveTask(previousClientForTask, activeTasksPerThread)) { + if (mustPreserveActiveTaskAssignment || assignmentState.hasRoomForActiveTask(previousClientForTask, taskCount, totalCapacity)) { assignmentState.finalizeAssignment(taskId, previousClientForTask, AssignedTask.Type.ACTIVE); unassigned.remove(taskId); } @@ -167,7 +166,7 @@ private static void assignActive(final ApplicationState applicationState, final TaskId taskId = iterator.next(); final Set previousClientsForStandbyTask = assignmentState.previousStandbyAssignment.getOrDefault(taskId, new HashSet<>()); for (final ProcessId client: previousClientsForStandbyTask) { - if (assignmentState.hasRoomForActiveTask(client, activeTasksPerThread)) { + if (assignmentState.hasRoomForActiveTask(client, taskCount, totalCapacity)) { assignmentState.finalizeAssignment(taskId, client, AssignedTask.Type.ACTIVE); iterator.remove(); break; @@ -295,14 +294,15 @@ private void processOptimizedAssignments(final Map KafkaStreamsAssignment.of(processId, new HashSet<>())) .tasks().values() .stream().filter(assignedTask -> assignedTask.type() == AssignedTask.Type.ACTIVE) .collect(Collectors.toSet()) .size(); - return newActiveTaskCount < capacity * activeTasksPerThread; + final int instanceLimit = (taskCount * capacity + totalCapacity - 1) / totalCapacity; + return newActiveTaskCount < instanceLimit; } private ProcessId findBestClientForTask(final TaskId taskId, final Set clientsWithin) { diff --git a/streams/src/test/java/org/apache/kafka/streams/processor/internals/assignment/CustomStickyTaskAssignorTest.java b/streams/src/test/java/org/apache/kafka/streams/processor/internals/assignment/CustomStickyTaskAssignorTest.java index fa92a55dcae4c..a417926b50288 100644 --- a/streams/src/test/java/org/apache/kafka/streams/processor/internals/assignment/CustomStickyTaskAssignorTest.java +++ b/streams/src/test/java/org/apache/kafka/streams/processor/internals/assignment/CustomStickyTaskAssignorTest.java @@ -39,6 +39,8 @@ import java.util.ArrayList; import java.util.Collection; import java.util.Collections; +import java.util.HashMap; +import java.util.HashSet; import java.util.List; import java.util.Map; import java.util.Optional; @@ -69,6 +71,7 @@ import static org.apache.kafka.streams.processor.internals.assignment.TaskAssignmentUtilsTest.mkStreamState; import static org.apache.kafka.streams.processor.internals.assignment.TaskAssignmentUtilsTest.mkTaskInfo; import static org.apache.kafka.streams.processor.internals.assignment.TaskAssignmentUtilsTest.processId; +import static org.junit.jupiter.api.Assertions.fail; import static org.hamcrest.MatcherAssert.assertThat; import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.greaterThanOrEqualTo; @@ -913,4 +916,78 @@ private void assertActiveTaskTopicGroupIdsEvenlyDistributed(final Map tasks = buildStatelessTasks(numTasks); + final Set firstHalf = buildTaskIdRange(0, numTasks / 2); + + final Map streamStates = mkMap( + mkStreamState(1, threadsPerInstance, Optional.empty(), firstHalf, Set.of()), + mkStreamState(2, threadsPerInstance, Optional.empty()) + ); + + final Map assignments = + assign(streamStates, tasks, StreamsConfig.RACK_AWARE_ASSIGNMENT_STRATEGY_NONE); + + final Set instance1Tasks = activeTasks(assignments, 1); + final int retained = (int) firstHalf.stream().filter(instance1Tasks::contains).count(); + + assertThat("Instance should retain all of its fair share tasks, but only retained " + retained + " of " + firstHalf.size(), + retained, equalTo(firstHalf.size())); + } + + @Test + public void shouldConvergeWithinTwoRoundsWhenScalingUp() { + // Reporter's scenario: 450 tasks with 10+10=20 threads. + // 450/20 = 22.5, floor gives limit 220 per instance (10 task overflow), + // causing repeated reassignment across rounds. + final int numTasks = 450; + final int maxRounds = 2; + final Map tasks = buildStatelessTasks(numTasks); + + Set instance1Prev = buildTaskIdRange(0, numTasks); + Set instance2Prev = Set.of(); + + for (int round = 1; round <= maxRounds; round++) { + final Map streamStates = mkMap( + mkStreamState(1, 10, Optional.empty(), instance1Prev, Set.of()), + mkStreamState(2, 10, Optional.empty(), instance2Prev, Set.of()) + ); + + final Map assignments = + assign(streamStates, tasks, StreamsConfig.RACK_AWARE_ASSIGNMENT_STRATEGY_NONE); + + final Set newInstance1 = activeTasks(assignments, 1); + final Set newInstance2 = activeTasks(assignments, 2); + + if (newInstance1.equals(instance1Prev) && newInstance2.equals(instance2Prev)) { + return; // converged + } + instance1Prev = newInstance1; + instance2Prev = newInstance2; + } + fail("Assignment did not converge within " + maxRounds + " rounds"); + } + + private static Map buildStatelessTasks(final int count) { + final Map tasks = new HashMap<>(); + for (int i = 0; i < count; i++) { + final TaskId taskId = new TaskId(0, i); + tasks.put(taskId, mkTaskInfo(taskId, false).getValue()); + } + return tasks; + } + + private static Set buildTaskIdRange(final int from, final int to) { + final Set set = new HashSet<>(); + for (int i = from; i < to; i++) { + set.add(new TaskId(0, i)); + } + return set; + } }