Skip to content

Commit 1f7631c

Browse files
authored
MINOR: Fix StreamsRebalanceListenerInvoker (#20575)
StreamsRebalanceListenerInvoker was implemented to match the behavior of ConsumerRebalanceListenerInvoker, however StreamsRebalanceListener has a subtly different interface than ConsumerRebalanceListener - it does not throw exceptions, but returns it as an Optional. In the interest of consistency, this change fixes this mismatch by changing the StreamsRebalanceListener interface to behave more like the ConsumerRebalanceListener - throwing exceptions directly. In another minor fix, the StreamsRebalanceListenerInvoker is changed to simply skip callback execution instead of throwing an IllegalStateException when no streamRebalanceListener is defined. This can happen when the consumer is closed before Consumer.subscribe is called. Reviewers: Lianet Magrans <[email protected]>, Matthias J. Sax <[email protected]>
1 parent 0a48361 commit 1f7631c

File tree

8 files changed

+66
-117
lines changed

8 files changed

+66
-117
lines changed

clients/src/main/java/org/apache/kafka/clients/consumer/internals/StreamsRebalanceListener.java

Lines changed: 4 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@
1616
*/
1717
package org.apache.kafka.clients.consumer.internals;
1818

19-
import java.util.Optional;
2019
import java.util.Set;
2120

2221
/**
@@ -28,22 +27,18 @@ public interface StreamsRebalanceListener {
2827
* Called when tasks are revoked from a stream thread.
2928
*
3029
* @param tasks The tasks to be revoked.
31-
* @return The exception thrown during the callback, if any.
3230
*/
33-
Optional<Exception> onTasksRevoked(final Set<StreamsRebalanceData.TaskId> tasks);
31+
void onTasksRevoked(final Set<StreamsRebalanceData.TaskId> tasks);
3432

3533
/**
3634
* Called when tasks are assigned from a stream thread.
3735
*
3836
* @param assignment The tasks assigned.
39-
* @return The exception thrown during the callback, if any.
4037
*/
41-
Optional<Exception> onTasksAssigned(final StreamsRebalanceData.Assignment assignment);
38+
void onTasksAssigned(final StreamsRebalanceData.Assignment assignment);
4239

4340
/**
44-
* Called when a stream thread loses all assigned tasks.
45-
*
46-
* @return The exception thrown during the callback, if any.
41+
* Called when a stream thread loses all assigned tasks
4742
*/
48-
Optional<Exception> onAllTasksLost();
43+
void onAllTasksLost();
4944
}

clients/src/main/java/org/apache/kafka/clients/consumer/internals/StreamsRebalanceListenerInvoker.java

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -51,14 +51,14 @@ public void setRebalanceListener(StreamsRebalanceListener streamsRebalanceListen
5151

5252
public Exception invokeAllTasksRevoked() {
5353
if (listener.isEmpty()) {
54-
throw new IllegalStateException("StreamsRebalanceListener is not defined");
54+
return null;
5555
}
5656
return invokeTasksRevoked(streamsRebalanceData.reconciledAssignment().activeTasks());
5757
}
5858

5959
public Exception invokeTasksAssigned(final StreamsRebalanceData.Assignment assignment) {
6060
if (listener.isEmpty()) {
61-
throw new IllegalStateException("StreamsRebalanceListener is not defined");
61+
return null;
6262
}
6363
log.info("Invoking tasks assigned callback for new assignment: {}", assignment);
6464
try {
@@ -78,7 +78,7 @@ public Exception invokeTasksAssigned(final StreamsRebalanceData.Assignment assig
7878

7979
public Exception invokeTasksRevoked(final Set<StreamsRebalanceData.TaskId> tasks) {
8080
if (listener.isEmpty()) {
81-
throw new IllegalStateException("StreamsRebalanceListener is not defined");
81+
return null;
8282
}
8383
log.info("Invoking task revoked callback for revoked active tasks {}", tasks);
8484
try {
@@ -98,7 +98,7 @@ public Exception invokeTasksRevoked(final Set<StreamsRebalanceData.TaskId> tasks
9898

9999
public Exception invokeAllTasksLost() {
100100
if (listener.isEmpty()) {
101-
throw new IllegalStateException("StreamsRebalanceListener is not defined");
101+
return null;
102102
}
103103
log.info("Invoking tasks lost callback for all tasks");
104104
try {

clients/src/test/java/org/apache/kafka/clients/consumer/internals/AsyncKafkaConsumerTest.java

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2218,7 +2218,6 @@ public void testCloseInvokesStreamsRebalanceListenerOnTasksRevokedWhenMemberEpoc
22182218
try (final MockedStatic<RequestManagers> requestManagers = mockStatic(RequestManagers.class)) {
22192219
consumer = newConsumerWithStreamRebalanceData(requiredConsumerConfigAndGroupId(groupId), streamsRebalanceData);
22202220
StreamsRebalanceListener mockStreamsListener = mock(StreamsRebalanceListener.class);
2221-
when(mockStreamsListener.onTasksRevoked(any())).thenReturn(Optional.empty());
22222221
consumer.subscribe(singletonList("topic"), mockStreamsListener);
22232222
final MemberStateListener groupMetadataUpdateListener = captureGroupMetadataUpdateListener(requestManagers);
22242223
final int memberEpoch = 42;
@@ -2239,7 +2238,6 @@ public void testCloseInvokesStreamsRebalanceListenerOnAllTasksLostWhenMemberEpoc
22392238
try (final MockedStatic<RequestManagers> requestManagers = mockStatic(RequestManagers.class)) {
22402239
consumer = newConsumerWithStreamRebalanceData(requiredConsumerConfigAndGroupId(groupId), streamsRebalanceData);
22412240
StreamsRebalanceListener mockStreamsListener = mock(StreamsRebalanceListener.class);
2242-
when(mockStreamsListener.onAllTasksLost()).thenReturn(Optional.empty());
22432241
consumer.subscribe(singletonList("topic"), mockStreamsListener);
22442242
final MemberStateListener groupMetadataUpdateListener = captureGroupMetadataUpdateListener(requestManagers);
22452243
final int memberEpoch = 0;

clients/src/test/java/org/apache/kafka/clients/consumer/internals/StreamsRebalanceListenerInvokerTest.java

Lines changed: 8 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,6 @@
2828
import org.mockito.junit.jupiter.MockitoSettings;
2929
import org.mockito.quality.Strictness;
3030

31-
import java.util.Optional;
3231
import java.util.Set;
3332

3433
import static org.junit.jupiter.api.Assertions.assertEquals;
@@ -73,7 +72,6 @@ public void testSetRebalanceListenerOverwritesExisting() {
7372

7473
StreamsRebalanceData.Assignment mockAssignment = createMockAssignment();
7574
when(streamsRebalanceData.reconciledAssignment()).thenReturn(mockAssignment);
76-
when(secondListener.onTasksRevoked(any())).thenReturn(Optional.empty());
7775

7876
// Set first listener
7977
invoker.setRebalanceListener(firstListener);
@@ -89,21 +87,10 @@ public void testSetRebalanceListenerOverwritesExisting() {
8987

9088
@Test
9189
public void testInvokeMethodsWithNoListener() {
92-
IllegalStateException exception1 = assertThrows(IllegalStateException.class,
93-
() -> invoker.invokeAllTasksRevoked());
94-
assertEquals("StreamsRebalanceListener is not defined", exception1.getMessage());
95-
96-
IllegalStateException exception2 = assertThrows(IllegalStateException.class,
97-
() -> invoker.invokeTasksAssigned(createMockAssignment()));
98-
assertEquals("StreamsRebalanceListener is not defined", exception2.getMessage());
99-
100-
IllegalStateException exception3 = assertThrows(IllegalStateException.class,
101-
() -> invoker.invokeTasksRevoked(createMockTasks()));
102-
assertEquals("StreamsRebalanceListener is not defined", exception3.getMessage());
103-
104-
IllegalStateException exception4 = assertThrows(IllegalStateException.class,
105-
() -> invoker.invokeAllTasksLost());
106-
assertEquals("StreamsRebalanceListener is not defined", exception4.getMessage());
90+
assertNull(invoker.invokeAllTasksRevoked());
91+
assertNull(invoker.invokeTasksAssigned(createMockAssignment()));
92+
assertNull(invoker.invokeTasksRevoked(createMockTasks()));
93+
assertNull(invoker.invokeAllTasksLost());
10794
}
10895

10996
@Test
@@ -112,8 +99,7 @@ public void testInvokeAllTasksRevokedWithListener() {
11299

113100
StreamsRebalanceData.Assignment mockAssignment = createMockAssignment();
114101
when(streamsRebalanceData.reconciledAssignment()).thenReturn(mockAssignment);
115-
when(mockListener.onTasksRevoked(any())).thenReturn(Optional.empty());
116-
102+
117103
Exception result = invoker.invokeAllTasksRevoked();
118104

119105
assertNull(result);
@@ -124,8 +110,7 @@ public void testInvokeAllTasksRevokedWithListener() {
124110
public void testInvokeTasksAssignedWithListener() {
125111
invoker.setRebalanceListener(mockListener);
126112
StreamsRebalanceData.Assignment assignment = createMockAssignment();
127-
when(mockListener.onTasksAssigned(assignment)).thenReturn(Optional.empty());
128-
113+
129114
Exception result = invoker.invokeTasksAssigned(assignment);
130115

131116
assertNull(result);
@@ -177,8 +162,7 @@ public void testInvokeTasksAssignedWithOtherException() {
177162
public void testInvokeTasksRevokedWithListener() {
178163
invoker.setRebalanceListener(mockListener);
179164
Set<StreamsRebalanceData.TaskId> tasks = createMockTasks();
180-
when(mockListener.onTasksRevoked(tasks)).thenReturn(Optional.empty());
181-
165+
182166
Exception result = invoker.invokeTasksRevoked(tasks);
183167

184168
assertNull(result);
@@ -229,8 +213,7 @@ public void testInvokeTasksRevokedWithOtherException() {
229213
@Test
230214
public void testInvokeAllTasksLostWithListener() {
231215
invoker.setRebalanceListener(mockListener);
232-
when(mockListener.onAllTasksLost()).thenReturn(Optional.empty());
233-
216+
234217
Exception result = invoker.invokeAllTasksLost();
235218

236219
assertNull(result);

core/src/test/scala/integration/kafka/api/AuthorizerIntegrationTest.scala

Lines changed: 3 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -3912,14 +3912,9 @@ class AuthorizerIntegrationTest extends AbstractAuthorizerIntegrationTest {
39123912
consumer.subscribe(
39133913
if (topicAsSourceTopic || topicAsRepartitionSourceTopic) util.Set.of(sourceTopic, topic) else util.Set.of(sourceTopic),
39143914
new StreamsRebalanceListener {
3915-
override def onTasksRevoked(tasks: util.Set[StreamsRebalanceData.TaskId]): Optional[Exception] =
3916-
Optional.empty()
3917-
3918-
override def onTasksAssigned(assignment: StreamsRebalanceData.Assignment): Optional[Exception] =
3919-
Optional.empty()
3920-
3921-
override def onAllTasksLost(): Optional[Exception] =
3922-
Optional.empty()
3915+
override def onTasksRevoked(tasks: util.Set[StreamsRebalanceData.TaskId]): Unit = ()
3916+
override def onTasksAssigned(assignment: StreamsRebalanceData.Assignment): Unit = ()
3917+
override def onAllTasksLost(): Unit = ()
39233918
}
39243919
)
39253920
consumer.poll(Duration.ofMillis(500L))

core/src/test/scala/integration/kafka/api/IntegrationTestHarness.scala

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -272,13 +272,9 @@ abstract class IntegrationTestHarness extends KafkaServerTestHarness {
272272
)
273273
consumer.subscribe(util.Set.of(inputTopic),
274274
new StreamsRebalanceListener {
275-
override def onTasksRevoked(tasks: util.Set[StreamsRebalanceData.TaskId]): Optional[Exception] =
276-
Optional.empty()
277-
override def onTasksAssigned(assignment: StreamsRebalanceData.Assignment): Optional[Exception] = {
278-
Optional.empty()
279-
}
280-
override def onAllTasksLost(): Optional[Exception] =
281-
Optional.empty()
275+
override def onTasksRevoked(tasks: util.Set[StreamsRebalanceData.TaskId]): Unit = ()
276+
override def onTasksAssigned(assignment: StreamsRebalanceData.Assignment): Unit = ()
277+
override def onAllTasksLost(): Unit = ()
282278
})
283279
consumer
284280
}

streams/src/main/java/org/apache/kafka/streams/processor/internals/DefaultStreamsRebalanceListener.java

Lines changed: 28 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,6 @@
2626

2727
import java.util.Collection;
2828
import java.util.Map;
29-
import java.util.Optional;
3029
import java.util.Set;
3130
import java.util.stream.Collectors;
3231
import java.util.stream.Stream;
@@ -52,59 +51,44 @@ public DefaultStreamsRebalanceListener(final Logger log,
5251
}
5352

5453
@Override
55-
public Optional<Exception> onTasksRevoked(final Set<StreamsRebalanceData.TaskId> tasks) {
56-
try {
57-
final Map<TaskId, Set<TopicPartition>> activeTasksToRevokeWithPartitions =
58-
pairWithTopicPartitions(tasks.stream());
59-
final Set<TopicPartition> partitionsToRevoke = activeTasksToRevokeWithPartitions.values().stream()
60-
.flatMap(Collection::stream)
61-
.collect(Collectors.toSet());
54+
public void onTasksRevoked(final Set<StreamsRebalanceData.TaskId> tasks) {
55+
final Map<TaskId, Set<TopicPartition>> activeTasksToRevokeWithPartitions =
56+
pairWithTopicPartitions(tasks.stream());
57+
final Set<TopicPartition> partitionsToRevoke = activeTasksToRevokeWithPartitions.values().stream()
58+
.flatMap(Collection::stream)
59+
.collect(Collectors.toSet());
6260

63-
final long start = time.milliseconds();
64-
try {
65-
log.info("Revoking active tasks {}.", tasks);
66-
taskManager.handleRevocation(partitionsToRevoke);
67-
} finally {
68-
log.info("partition revocation took {} ms.", time.milliseconds() - start);
69-
}
70-
if (streamThread.state() != StreamThread.State.PENDING_SHUTDOWN) {
71-
streamThread.setState(StreamThread.State.PARTITIONS_REVOKED);
72-
}
73-
} catch (final Exception exception) {
74-
return Optional.of(exception);
61+
final long start = time.milliseconds();
62+
try {
63+
log.info("Revoking active tasks {}.", tasks);
64+
taskManager.handleRevocation(partitionsToRevoke);
65+
} finally {
66+
log.info("partition revocation took {} ms.", time.milliseconds() - start);
67+
}
68+
if (streamThread.state() != StreamThread.State.PENDING_SHUTDOWN) {
69+
streamThread.setState(StreamThread.State.PARTITIONS_REVOKED);
7570
}
76-
return Optional.empty();
7771
}
7872

7973
@Override
80-
public Optional<Exception> onTasksAssigned(final StreamsRebalanceData.Assignment assignment) {
81-
try {
82-
final Map<TaskId, Set<TopicPartition>> activeTasksWithPartitions =
83-
pairWithTopicPartitions(assignment.activeTasks().stream());
84-
final Map<TaskId, Set<TopicPartition>> standbyTasksWithPartitions =
85-
pairWithTopicPartitions(Stream.concat(assignment.standbyTasks().stream(), assignment.warmupTasks().stream()));
74+
public void onTasksAssigned(final StreamsRebalanceData.Assignment assignment) {
75+
final Map<TaskId, Set<TopicPartition>> activeTasksWithPartitions =
76+
pairWithTopicPartitions(assignment.activeTasks().stream());
77+
final Map<TaskId, Set<TopicPartition>> standbyTasksWithPartitions =
78+
pairWithTopicPartitions(Stream.concat(assignment.standbyTasks().stream(), assignment.warmupTasks().stream()));
8679

87-
log.info("Processing new assignment {} from Streams Rebalance Protocol", assignment);
80+
log.info("Processing new assignment {} from Streams Rebalance Protocol", assignment);
8881

89-
taskManager.handleAssignment(activeTasksWithPartitions, standbyTasksWithPartitions);
90-
streamThread.setState(StreamThread.State.PARTITIONS_ASSIGNED);
91-
taskManager.handleRebalanceComplete();
92-
streamsRebalanceData.setReconciledAssignment(assignment);
93-
} catch (final Exception exception) {
94-
return Optional.of(exception);
95-
}
96-
return Optional.empty();
82+
taskManager.handleAssignment(activeTasksWithPartitions, standbyTasksWithPartitions);
83+
streamThread.setState(StreamThread.State.PARTITIONS_ASSIGNED);
84+
taskManager.handleRebalanceComplete();
85+
streamsRebalanceData.setReconciledAssignment(assignment);
9786
}
9887

9988
@Override
100-
public Optional<Exception> onAllTasksLost() {
101-
try {
102-
taskManager.handleLostAll();
103-
streamsRebalanceData.setReconciledAssignment(StreamsRebalanceData.Assignment.EMPTY);
104-
} catch (final Exception exception) {
105-
return Optional.of(exception);
106-
}
107-
return Optional.empty();
89+
public void onAllTasksLost() {
90+
taskManager.handleLostAll();
91+
streamsRebalanceData.setReconciledAssignment(StreamsRebalanceData.Assignment.EMPTY);
10892
}
10993

11094
private Map<TaskId, Set<TopicPartition>> pairWithTopicPartitions(final Stream<StreamsRebalanceData.TaskId> taskIdStream) {

streams/src/test/java/org/apache/kafka/streams/processor/internals/DefaultStreamsRebalanceListenerTest.java

Lines changed: 16 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -32,8 +32,9 @@
3232
import java.util.Set;
3333
import java.util.UUID;
3434

35+
import static org.junit.jupiter.api.Assertions.assertDoesNotThrow;
3536
import static org.junit.jupiter.api.Assertions.assertEquals;
36-
import static org.junit.jupiter.api.Assertions.assertTrue;
37+
import static org.junit.jupiter.api.Assertions.assertThrows;
3738
import static org.mockito.ArgumentMatchers.any;
3839
import static org.mockito.Mockito.doThrow;
3940
import static org.mockito.Mockito.inOrder;
@@ -84,11 +85,9 @@ void testOnTasksRevoked(final StreamThread.State state) {
8485
));
8586
when(streamThread.state()).thenReturn(state);
8687

87-
final Optional<Exception> result = defaultStreamsRebalanceListener.onTasksRevoked(
88+
assertDoesNotThrow(() -> defaultStreamsRebalanceListener.onTasksRevoked(
8889
Set.of(new StreamsRebalanceData.TaskId("1", 0))
89-
);
90-
91-
assertTrue(result.isEmpty());
90+
));
9291

9392
final InOrder inOrder = inOrder(taskManager, streamThread);
9493
inOrder.verify(taskManager).handleRevocation(
@@ -109,9 +108,9 @@ void testOnTasksRevokedWithException() {
109108

110109
createRebalanceListenerWithRebalanceData(new StreamsRebalanceData(UUID.randomUUID(), Optional.empty(), Map.of(), Map.of()));
111110

112-
final Optional<Exception> result = defaultStreamsRebalanceListener.onTasksRevoked(Set.of());
111+
final Exception actualException = assertThrows(RuntimeException.class, () -> defaultStreamsRebalanceListener.onTasksRevoked(Set.of()));
113112

114-
assertTrue(result.isPresent());
113+
assertEquals(actualException, exception);
115114
verify(taskManager).handleRevocation(any());
116115
verify(streamThread, never()).setState(any());
117116
}
@@ -153,9 +152,7 @@ void testOnTasksAssigned() {
153152
Set.of(new StreamsRebalanceData.TaskId("3", 0))
154153
);
155154

156-
final Optional<Exception> result = defaultStreamsRebalanceListener.onTasksAssigned(assignment);
157-
158-
assertTrue(result.isEmpty());
155+
assertDoesNotThrow(() -> defaultStreamsRebalanceListener.onTasksAssigned(assignment));
159156

160157
final InOrder inOrder = inOrder(taskManager, streamThread, streamsRebalanceData);
161158
inOrder.verify(taskManager).handleAssignment(
@@ -179,11 +176,11 @@ void testOnTasksAssignedWithException() {
179176
when(streamsRebalanceData.subtopologies()).thenReturn(Map.of());
180177
createRebalanceListenerWithRebalanceData(streamsRebalanceData);
181178

182-
final Optional<Exception> result = defaultStreamsRebalanceListener.onTasksAssigned(
179+
final Exception actualException = assertThrows(RuntimeException.class, () -> defaultStreamsRebalanceListener.onTasksAssigned(
183180
new StreamsRebalanceData.Assignment(Set.of(), Set.of(), Set.of())
184-
);
185-
assertTrue(result.isPresent());
186-
assertEquals(exception, result.get());
181+
));
182+
183+
assertEquals(exception, actualException);
187184
verify(taskManager).handleAssignment(any(), any());
188185
verify(streamThread, never()).setState(StreamThread.State.PARTITIONS_ASSIGNED);
189186
verify(taskManager, never()).handleRebalanceComplete();
@@ -196,7 +193,7 @@ void testOnAllTasksLost() {
196193
when(streamsRebalanceData.subtopologies()).thenReturn(Map.of());
197194
createRebalanceListenerWithRebalanceData(streamsRebalanceData);
198195

199-
assertTrue(defaultStreamsRebalanceListener.onAllTasksLost().isEmpty());
196+
assertDoesNotThrow(() -> defaultStreamsRebalanceListener.onAllTasksLost());
200197

201198
final InOrder inOrder = inOrder(taskManager, streamsRebalanceData);
202199
inOrder.verify(taskManager).handleLostAll();
@@ -211,9 +208,10 @@ void testOnAllTasksLostWithException() {
211208
final StreamsRebalanceData streamsRebalanceData = mock(StreamsRebalanceData.class);
212209
when(streamsRebalanceData.subtopologies()).thenReturn(Map.of());
213210
createRebalanceListenerWithRebalanceData(streamsRebalanceData);
214-
final Optional<Exception> result = defaultStreamsRebalanceListener.onAllTasksLost();
215-
assertTrue(result.isPresent());
216-
assertEquals(exception, result.get());
211+
212+
final Exception actualException = assertThrows(RuntimeException.class, () -> defaultStreamsRebalanceListener.onAllTasksLost());
213+
214+
assertEquals(exception, actualException);
217215
verify(taskManager).handleLostAll();
218216
verify(streamsRebalanceData, never()).setReconciledAssignment(any());
219217
}

0 commit comments

Comments
 (0)