Skip to content

Commit 58eeba4

Browse files
Fix Polling Behavior Flaky Test (#782) (#783)
Fixes #782
1 parent a1ec58c commit 58eeba4

File tree

3 files changed

+118
-75
lines changed

3 files changed

+118
-75
lines changed

spring-cloud-aws-sqs/src/main/java/io/awspring/cloud/sqs/MessageExecutionThread.java

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@
1515
*/
1616
package io.awspring.cloud.sqs;
1717

18+
import org.springframework.lang.Nullable;
19+
1820
/**
1921
* A {@link Thread} implementation for processing messages.
2022
* @author Tomaz Fernandes
@@ -30,7 +32,7 @@ public class MessageExecutionThread extends Thread {
3032
* @param runnable see {@link Thread} javadoc.
3133
* @param nextThreadName see {@link Thread} javadoc.
3234
*/
33-
public MessageExecutionThread(ThreadGroup threadGroup, Runnable runnable, String nextThreadName) {
35+
public MessageExecutionThread(@Nullable ThreadGroup threadGroup, Runnable runnable, String nextThreadName) {
3436
super(threadGroup, runnable, nextThreadName);
3537
}
3638

spring-cloud-aws-sqs/src/main/java/io/awspring/cloud/sqs/MessageExecutionThreadFactory.java

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,21 @@
2929
*/
3030
public class MessageExecutionThreadFactory extends CustomizableThreadFactory {
3131

32+
/**
33+
* Create a new MessageExecutionThreadFactory with default thread name prefix.
34+
*/
35+
public MessageExecutionThreadFactory() {
36+
super();
37+
}
38+
39+
/**
40+
* Create a new MessageExecutionThreadFactory with the given thread name prefix.
41+
* @param threadNamePrefix the prefix to use for the names of newly created threads
42+
*/
43+
public MessageExecutionThreadFactory(String threadNamePrefix) {
44+
super(threadNamePrefix);
45+
}
46+
3247
@Override
3348
public Thread createThread(Runnable runnable) {
3449
MessageExecutionThread thread = new MessageExecutionThread(getThreadGroup(), runnable, nextThreadName());

spring-cloud-aws-sqs/src/test/java/io/awspring/cloud/sqs/listener/source/AbstractPollingMessageSourceTests.java

Lines changed: 100 additions & 74 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@
3737
import java.util.concurrent.ThreadFactory;
3838
import java.util.concurrent.TimeUnit;
3939
import java.util.concurrent.atomic.AtomicBoolean;
40+
import java.util.concurrent.atomic.AtomicInteger;
4041
import org.assertj.core.api.InstanceOfAssertFactories;
4142
import org.junit.jupiter.api.Test;
4243
import org.slf4j.Logger;
@@ -53,9 +54,8 @@ class AbstractPollingMessageSourceTests {
5354

5455
private static final Logger logger = LoggerFactory.getLogger(AbstractPollingMessageSourceTests.class);
5556

56-
// @RepeatedTest(400)
5757
@Test
58-
void shouldAcquireAndReleaseFullPermits() throws Exception {
58+
void shouldAcquireAndReleaseFullPermits() {
5959
String testName = "shouldAcquireAndReleaseFullPermits";
6060

6161
SemaphoreBackPressureHandler backPressureHandler = SemaphoreBackPressureHandler.builder()
@@ -73,36 +73,47 @@ void shouldAcquireAndReleaseFullPermits() throws Exception {
7373

7474
@Override
7575
protected CompletableFuture<Collection<Message>> doPollForMessages(int messagesToRequest) {
76-
doSleep(100);
77-
// Since BackPressureMode.ALWAYS_POLL_MAX_MESSAGES, should always be 10.
78-
assertThat(messagesToRequest).isEqualTo(10);
79-
assertAvailablePermits(backPressureHandler, 0);
80-
boolean firstPoll = hasReceived.compareAndSet(false, true);
81-
if (firstPoll) {
82-
// No permits released yet, should be TM low
83-
assertThroughputMode(backPressureHandler, "low");
84-
}
85-
else if (hasMadeSecondPoll.compareAndSet(false, true)) {
86-
// Permits returned, should be high
87-
assertThroughputMode(backPressureHandler, "high");
88-
}
89-
else {
90-
// Already returned full permits, should be low
91-
assertThroughputMode(backPressureHandler, "low");
92-
}
93-
return CompletableFuture
94-
.supplyAsync(() -> firstPoll
76+
return CompletableFuture.supplyAsync(() -> {
77+
try {
78+
// Since BackPressureMode.ALWAYS_POLL_MAX_MESSAGES, should always be 10.
79+
assertThat(messagesToRequest).isEqualTo(10);
80+
assertAvailablePermits(backPressureHandler, 0);
81+
boolean firstPoll = hasReceived.compareAndSet(false, true);
82+
if (firstPoll) {
83+
logger.debug("First poll");
84+
// No permits released yet, should be TM low
85+
assertThroughputMode(backPressureHandler, "low");
86+
}
87+
else if (hasMadeSecondPoll.compareAndSet(false, true)) {
88+
logger.debug("Second poll");
89+
// Permits returned, should be high
90+
assertThroughputMode(backPressureHandler, "high");
91+
}
92+
else {
93+
logger.debug("Third poll");
94+
// Already returned full permits, should be low
95+
assertThroughputMode(backPressureHandler, "low");
96+
}
97+
return firstPoll
9598
? (Collection<Message>) List.of(Message.builder()
9699
.messageId(UUID.randomUUID().toString()).body("message").build())
97-
: Collections.<Message> emptyList(), threadPool)
98-
.whenComplete((v, t) -> pollingCounter.countDown());
100+
: Collections.<Message> emptyList();
101+
}
102+
catch (Throwable t) {
103+
logger.error("Error", t);
104+
throw new RuntimeException(t);
105+
}
106+
}, threadPool).whenComplete((v, t) -> {
107+
if (t == null) {
108+
pollingCounter.countDown();
109+
}
110+
});
99111
}
100112
};
101113

102114
source.setBackPressureHandler(backPressureHandler);
103115
source.setMessageSink((msgs, context) -> {
104116
assertAvailablePermits(backPressureHandler, 9);
105-
doSleep(500); // Longer than acquire timout + polling sleep
106117
msgs.forEach(msg -> context.runBackPressureReleaseCallback());
107118
return CompletableFuture.runAsync(processingCounter::countDown);
108119
});
@@ -112,20 +123,23 @@ else if (hasMadeSecondPoll.compareAndSet(false, true)) {
112123
source.setTaskExecutor(createTaskExecutor(testName));
113124
source.setAcknowledgementProcessor(getAcknowledgementProcessor());
114125
source.start();
115-
assertThat(pollingCounter.await(2, TimeUnit.SECONDS)).isTrue();
116-
assertThat(processingCounter.await(2, TimeUnit.SECONDS)).isTrue();
126+
assertThat(doAwait(pollingCounter)).isTrue();
127+
assertThat(doAwait(processingCounter)).isTrue();
117128
}
118129

119-
// @RepeatedTest(400)
130+
private static final AtomicInteger testCounter = new AtomicInteger();
131+
120132
@Test
121-
void shouldAcquireAndReleasePartialPermits() throws Exception {
133+
void shouldAcquireAndReleasePartialPermits() {
122134
String testName = "shouldAcquireAndReleasePartialPermits";
123135
SemaphoreBackPressureHandler backPressureHandler = SemaphoreBackPressureHandler.builder()
124-
.acquireTimeout(Duration.ofMillis(200)).batchSize(10).totalPermits(10)
136+
.acquireTimeout(Duration.ofMillis(150)).batchSize(10).totalPermits(10)
125137
.throughputConfiguration(BackPressureMode.AUTO).build();
126-
ExecutorService threadPool = Executors.newCachedThreadPool();
138+
ExecutorService threadPool = Executors
139+
.newCachedThreadPool(new MessageExecutionThreadFactory("test " + testCounter.incrementAndGet()));
127140
CountDownLatch pollingCounter = new CountDownLatch(4);
128141
CountDownLatch processingCounter = new CountDownLatch(1);
142+
CountDownLatch processingLatch = new CountDownLatch(1);
129143
AtomicBoolean hasThrownError = new AtomicBoolean(false);
130144

131145
AbstractPollingMessageSource<Object, Message> source = new AbstractPollingMessageSource<>() {
@@ -138,64 +152,67 @@ void shouldAcquireAndReleasePartialPermits() throws Exception {
138152

139153
@Override
140154
protected CompletableFuture<Collection<Message>> doPollForMessages(int messagesToRequest) {
141-
try {
142-
// Give it some time between returning empty and polling again
143-
doSleep(100);
144-
145-
// Will only be true the first time it sets hasReceived to true
146-
boolean shouldReturnMessage = hasReceived.compareAndSet(false, true);
147-
if (shouldReturnMessage) {
148-
// First poll, should have 10
149-
logger.debug("First poll - should request 10 messages");
150-
assertThat(messagesToRequest).isEqualTo(10);
151-
assertAvailablePermits(backPressureHandler, 0);
152-
// No permits have been released yet
153-
assertThroughputMode(backPressureHandler, "low");
154-
}
155-
else if (hasAcquired9.compareAndSet(false, true)) {
156-
// Second poll, should have 9
157-
logger.debug("Second poll - should request 9 messages");
158-
assertThat(messagesToRequest).isEqualTo(9);
159-
assertAvailablePermitsLessThanOrEqualTo(backPressureHandler, 1);
160-
// Has released 9 permits, should be TM HIGH
161-
assertThroughputMode(backPressureHandler, "high");
162-
}
163-
else {
164-
boolean thirdPoll = hasMadeThirdPoll.compareAndSet(false, true);
165-
// Third poll or later, should have 10 again
166-
logger.debug("Third poll - should request 10 messages");
167-
assertThat(messagesToRequest).isEqualTo(10);
168-
assertAvailablePermits(backPressureHandler, 0);
169-
if (thirdPoll) {
170-
// Hasn't yet returned a full batch, should be TM High
155+
return CompletableFuture.supplyAsync(() -> {
156+
try {
157+
// Give it some time between returning empty and polling again
158+
// doSleep(100);
159+
160+
// Will only be true the first time it sets hasReceived to true
161+
boolean shouldReturnMessage = hasReceived.compareAndSet(false, true);
162+
if (shouldReturnMessage) {
163+
// First poll, should have 10
164+
logger.debug("First poll - should request 10 messages");
165+
assertThat(messagesToRequest).isEqualTo(10);
166+
assertAvailablePermits(backPressureHandler, 0);
167+
// No permits have been released yet
168+
assertThroughputMode(backPressureHandler, "low");
169+
}
170+
else if (hasAcquired9.compareAndSet(false, true)) {
171+
// Second poll, should have 9
172+
logger.debug("Second poll - should request 9 messages");
173+
assertThat(messagesToRequest).isEqualTo(9);
174+
assertAvailablePermitsLessThanOrEqualTo(backPressureHandler, 1);
175+
// Has released 9 permits, should be TM HIGH
171176
assertThroughputMode(backPressureHandler, "high");
177+
processingLatch.countDown(); // Release processing now
172178
}
173179
else {
174-
// Has returned all permits in third poll
175-
assertThroughputMode(backPressureHandler, "low");
180+
boolean thirdPoll = hasMadeThirdPoll.compareAndSet(false, true);
181+
// Third poll or later, should have 10 again
182+
logger.debug("Third poll - should request 10 messages");
183+
assertThat(messagesToRequest).isEqualTo(10);
184+
assertAvailablePermits(backPressureHandler, 0);
185+
if (thirdPoll) {
186+
// Hasn't yet returned a full batch, should be TM High
187+
assertThroughputMode(backPressureHandler, "high");
188+
}
189+
else {
190+
// Has returned all permits in third poll
191+
assertThroughputMode(backPressureHandler, "low");
192+
}
176193
}
177-
}
178-
return CompletableFuture.supplyAsync(() -> {
179194
if (shouldReturnMessage) {
180195
logger.debug("shouldReturnMessage, returning one message");
181196
return (Collection<Message>) List.of(
182197
Message.builder().messageId(UUID.randomUUID().toString()).body("message").build());
183198
}
184199
logger.debug("should not return message, returning empty list");
185200
return Collections.<Message> emptyList();
186-
}, threadPool).whenComplete((v, t) -> pollingCounter.countDown());
187-
}
188-
catch (Error e) {
189-
hasThrownError.set(true);
190-
return CompletableFuture.failedFuture(new RuntimeException(e));
191-
}
201+
}
202+
catch (Error e) {
203+
hasThrownError.set(true);
204+
throw new RuntimeException("Error polling for messages", e);
205+
}
206+
}, threadPool).whenComplete((v, t) -> pollingCounter.countDown());
192207
}
193208
};
194209

195210
source.setBackPressureHandler(backPressureHandler);
196211
source.setMessageSink((msgs, context) -> {
212+
logger.debug("Processing {} messages", msgs.size());
197213
assertAvailablePermits(backPressureHandler, 9);
198-
doSleep(500); // Longer than acquire timout + polling sleep
214+
assertThat(doAwait(processingLatch)).isTrue();
215+
logger.debug("Finished processing {} messages", msgs.size());
199216
msgs.forEach(msg -> context.runBackPressureReleaseCallback());
200217
return CompletableFuture.completedFuture(null).thenRun(processingCounter::countDown);
201218
});
@@ -204,12 +221,22 @@ else if (hasAcquired9.compareAndSet(false, true)) {
204221
source.setTaskExecutor(createTaskExecutor(testName));
205222
source.setAcknowledgementProcessor(getAcknowledgementProcessor());
206223
source.start();
207-
assertThat(processingCounter.await(2, TimeUnit.SECONDS)).isTrue();
208-
assertThat(pollingCounter.await(2, TimeUnit.SECONDS)).isTrue();
224+
assertThat(doAwait(processingCounter)).isTrue();
225+
assertThat(doAwait(pollingCounter)).isTrue();
209226
source.stop();
210227
assertThat(hasThrownError.get()).isFalse();
211228
}
212229

230+
private static boolean doAwait(CountDownLatch processingLatch) {
231+
try {
232+
return processingLatch.await(4, TimeUnit.SECONDS);
233+
}
234+
catch (InterruptedException e) {
235+
Thread.currentThread().interrupt();
236+
throw new RuntimeException("Interrupted while waiting for latch", e);
237+
}
238+
}
239+
213240
private void assertThroughputMode(SemaphoreBackPressureHandler backPressureHandler, String expectedThroughputMode) {
214241
assertThat(ReflectionTestUtils.getField(backPressureHandler, "currentThroughputMode"))
215242
.extracting(Object::toString).extracting(String::toLowerCase)
@@ -243,7 +270,6 @@ protected TaskExecutor createTaskExecutor(String testName) {
243270
int poolSize = 10;
244271
executor.setMaxPoolSize(poolSize);
245272
executor.setCorePoolSize(10);
246-
// Necessary due to a small racing condition between releasing the permit and releasing the thread.
247273
executor.setQueueCapacity(poolSize);
248274
executor.setAllowCoreThreadTimeOut(true);
249275
executor.setThreadFactory(createThreadFactory(testName));

0 commit comments

Comments
 (0)