Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -146,15 +146,14 @@ private static void assignActive(final ApplicationState applicationState,
final int totalCapacity = computeTotalProcessingThreads(clients);
final Set<TaskId> allTaskIds = applicationState.allTasks().keySet();
final int taskCount = allTaskIds.size();
final int activeTasksPerThread = taskCount / totalCapacity;
final Set<TaskId> unassigned = new HashSet<>(allTaskIds);

// first try and re-assign existing active tasks to clients that previously had
// the same active task
for (final TaskId taskId : assignmentState.previousActiveAssignment.keySet()) {
final ProcessId previousClientForTask = assignmentState.previousActiveAssignment.get(taskId);
if (allTaskIds.contains(taskId)) {
if (mustPreserveActiveTaskAssignment || assignmentState.hasRoomForActiveTask(previousClientForTask, activeTasksPerThread)) {
if (mustPreserveActiveTaskAssignment || assignmentState.hasRoomForActiveTask(previousClientForTask, taskCount, totalCapacity)) {
assignmentState.finalizeAssignment(taskId, previousClientForTask, AssignedTask.Type.ACTIVE);
unassigned.remove(taskId);
}
Expand All @@ -167,7 +166,7 @@ private static void assignActive(final ApplicationState applicationState,
final TaskId taskId = iterator.next();
final Set<ProcessId> previousClientsForStandbyTask = assignmentState.previousStandbyAssignment.getOrDefault(taskId, new HashSet<>());
for (final ProcessId client: previousClientsForStandbyTask) {
if (assignmentState.hasRoomForActiveTask(client, activeTasksPerThread)) {
if (assignmentState.hasRoomForActiveTask(client, taskCount, totalCapacity)) {
assignmentState.finalizeAssignment(taskId, client, AssignedTask.Type.ACTIVE);
iterator.remove();
break;
Expand Down Expand Up @@ -295,14 +294,15 @@ private void processOptimizedAssignments(final Map<ProcessId, KafkaStreamsAssign
this.newAssignments = optimizedAssignments;
}

private boolean hasRoomForActiveTask(final ProcessId processId, final int activeTasksPerThread) {
private boolean hasRoomForActiveTask(final ProcessId processId, final int taskCount, final int totalCapacity) {
final int capacity = clients.get(processId).numProcessingThreads();
final int newActiveTaskCount = newAssignments.computeIfAbsent(processId, k -> KafkaStreamsAssignment.of(processId, new HashSet<>()))
.tasks().values()
.stream().filter(assignedTask -> assignedTask.type() == AssignedTask.Type.ACTIVE)
.collect(Collectors.toSet())
.size();
return newActiveTaskCount < capacity * activeTasksPerThread;
final int instanceLimit = (taskCount * capacity + totalCapacity - 1) / totalCapacity;
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this needs a regression test that fails with the old quota math and passes with the new proportional ceiling.

On top of that, can we add a behavior test that simulates repeated StickyTaskAssignor.assign(...) rounds, and then assert that an uneven-capacity case converges in 2 rounds? That would cover the actual failure mode described in the PR, not just the arithmetic change.

return newActiveTaskCount < instanceLimit;
}

private ProcessId findBestClientForTask(final TaskId taskId, final Set<ProcessId> clientsWithin) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -913,4 +913,76 @@ private void assertActiveTaskTopicGroupIdsEvenlyDistributed(final Map<ProcessId,
assertThat(topicGroupIds, equalTo(asList(1, 2)));
}
}

// --- KAFKA-20198 regression tests ---

@Test
public void shouldRetainFairShareOfPreviousTasksWhenScalingUp() {
final int numTasks = 450;
final int threadsPerInstance = 10;
final Map<TaskId, TaskInfo> tasks = buildStatelessTasks(numTasks);
final Set<TaskId> firstHalf = buildTaskIdRange(0, numTasks / 2);

final Map<ProcessId, KafkaStreamsState> streamStates = mkMap(
mkStreamState(1, threadsPerInstance, Optional.empty(), firstHalf, Set.of()),
mkStreamState(2, threadsPerInstance, Optional.empty())
);

final Map<ProcessId, KafkaStreamsAssignment> assignments =
assign(streamStates, tasks, StreamsConfig.RACK_AWARE_ASSIGNMENT_STRATEGY_NONE);

final Set<TaskId> instance1Tasks = activeTasks(assignments, 1);
final int retained = (int) firstHalf.stream().filter(instance1Tasks::contains).count();

assertThat("Instance should retain all of its fair share tasks, but only retained " + retained + " of " + firstHalf.size(),
retained, equalTo(firstHalf.size()));
}

@Test
public void shouldConvergeWithinFewRoundsWhenScalingUp() {
final int numTasks = 450;
final int threadsPerInstance = 10;
final int maxRounds = 5;
final Map<TaskId, TaskInfo> tasks = buildStatelessTasks(numTasks);

Set<TaskId> instance1Prev = buildTaskIdRange(0, numTasks);
Set<TaskId> instance2Prev = Set.of();

for (int round = 1; round <= maxRounds; round++) {
final Map<ProcessId, KafkaStreamsState> streamStates = mkMap(
mkStreamState(1, threadsPerInstance, Optional.empty(), instance1Prev, Set.of()),
mkStreamState(2, threadsPerInstance, Optional.empty(), instance2Prev, Set.of())
);

final Map<ProcessId, KafkaStreamsAssignment> assignments =
assign(streamStates, tasks, StreamsConfig.RACK_AWARE_ASSIGNMENT_STRATEGY_NONE);

final Set<TaskId> newInstance1 = activeTasks(assignments, 1);
final Set<TaskId> newInstance2 = activeTasks(assignments, 2);

if (newInstance1.equals(instance1Prev) && newInstance2.equals(instance2Prev)) {
return; // converged
}
instance1Prev = newInstance1;
instance2Prev = newInstance2;
}
throw new AssertionError("Assignment did not converge within " + maxRounds + " rounds");
}

private static Map<TaskId, TaskInfo> buildStatelessTasks(final int count) {
final Map<TaskId, TaskInfo> tasks = new java.util.HashMap<>();
for (int i = 0; i < count; i++) {
final TaskId taskId = new TaskId(0, i);
tasks.put(taskId, mkTaskInfo(taskId, false).getValue());
}
return tasks;
}

private static Set<TaskId> buildTaskIdRange(final int from, final int to) {
final Set<TaskId> set = new java.util.HashSet<>();
for (int i = from; i < to; i++) {
set.add(new TaskId(0, i));
}
return set;
}
}