Skip to content

Commit 38745bc

Browse files
authored
fix: SQS FIFO visibility timeout extension fails due to Message object mutation (#1430) (#1432)
1 parent 1c2556a commit 38745bc

File tree

2 files changed

+82
-15
lines changed

2 files changed

+82
-15
lines changed

spring-cloud-aws-sqs/src/main/java/io/awspring/cloud/sqs/listener/sink/adapter/MessageVisibilityExtendingSinkAdapter.java

Lines changed: 16 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -22,9 +22,9 @@
2222
import io.awspring.cloud.sqs.listener.interceptor.AsyncMessageInterceptor;
2323
import io.awspring.cloud.sqs.listener.sink.MessageSink;
2424
import java.time.Duration;
25-
import java.util.ArrayList;
2625
import java.util.Collection;
2726
import java.util.Collections;
27+
import java.util.Map;
2828
import java.util.UUID;
2929
import java.util.concurrent.CompletableFuture;
3030
import java.util.stream.Collectors;
@@ -114,41 +114,42 @@ private void logResult(Collection<Message<T>> messages, Throwable t) {
114114
}
115115

116116
private class OriginalBatchMessageVisibilityExtendingInterceptor implements AsyncMessageInterceptor<T> {
117-
118-
private final Collection<Message<T>> originalMessageBatchCopy;
119-
117+
private final Map<String, Message<T>> originalMessageBatchMap;
120118
private final int initialBatchSize;
121119

122120
private OriginalBatchMessageVisibilityExtendingInterceptor(Collection<Message<T>> originalMessageBatch) {
123-
this.originalMessageBatchCopy = Collections.synchronizedCollection(new ArrayList<>(originalMessageBatch));
124121
this.initialBatchSize = originalMessageBatch.size();
122+
this.originalMessageBatchMap = Collections.synchronizedMap(originalMessageBatch.stream()
123+
.collect(Collectors.toMap(MessageHeaderUtils::getId, message -> message)));
125124
}
126125

127-
// @formatter:off
128126
@Override
129127
public CompletableFuture<Message<T>> intercept(Message<T> message) {
130-
return originalMessageBatchCopy.size() == initialBatchSize
131-
? CompletableFuture.completedFuture(message)
132-
: changeVisibility(this.originalMessageBatchCopy).thenApply(response -> message);
128+
if (this.originalMessageBatchMap.size() == this.initialBatchSize) {
129+
return CompletableFuture.completedFuture(message);
130+
}
131+
132+
return changeVisibility(this.originalMessageBatchMap.values()).thenApply(response -> message);
133133
}
134134

135135
@Override
136136
public CompletableFuture<Collection<Message<T>>> intercept(Collection<Message<T>> messages) {
137-
return originalMessageBatchCopy.size() == initialBatchSize
138-
? CompletableFuture.completedFuture(messages)
139-
: changeVisibility(this.originalMessageBatchCopy).thenApply(response -> messages);
137+
if (this.originalMessageBatchMap.size() == this.initialBatchSize) {
138+
return CompletableFuture.completedFuture(messages);
139+
}
140+
141+
return changeVisibility(this.originalMessageBatchMap.values()).thenApply(response -> messages);
140142
}
141-
// @formatter:on
142143

143144
@Override
144145
public CompletableFuture<Void> afterProcessing(Collection<Message<T>> messages, Throwable t) {
145-
this.originalMessageBatchCopy.removeAll(messages);
146+
messages.forEach(message -> this.originalMessageBatchMap.remove(MessageHeaderUtils.getId(message)));
146147
return CompletableFuture.completedFuture(null);
147148
}
148149

149150
@Override
150151
public CompletableFuture<Void> afterProcessing(Message<T> message, Throwable t) {
151-
this.originalMessageBatchCopy.remove(message);
152+
this.originalMessageBatchMap.remove(MessageHeaderUtils.getId(message));
152153
return CompletableFuture.completedFuture(null);
153154
}
154155

spring-cloud-aws-sqs/src/test/java/io/awspring/cloud/sqs/integration/SqsFifoIntegrationTests.java

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,9 @@
1919
import static java.util.stream.Collectors.toList;
2020
import static org.assertj.core.api.Assertions.assertThat;
2121
import static org.awaitility.Awaitility.await;
22+
import static org.mockito.ArgumentMatchers.any;
23+
import static org.mockito.Mockito.doAnswer;
24+
import static org.mockito.Mockito.spy;
2225

2326
import com.fasterxml.jackson.databind.ObjectMapper;
2427
import io.awspring.cloud.sqs.CompletableFutures;
@@ -81,6 +84,8 @@
8184
import org.springframework.util.Assert;
8285
import org.springframework.util.StopWatch;
8386
import software.amazon.awssdk.services.sqs.SqsAsyncClient;
87+
import software.amazon.awssdk.services.sqs.model.ChangeMessageVisibilityBatchRequest;
88+
import software.amazon.awssdk.services.sqs.model.ChangeMessageVisibilityBatchRequestEntry;
8489
import software.amazon.awssdk.services.sqs.model.QueueAttributeName;
8590

8691
/**
@@ -116,7 +121,10 @@ class SqsFifoIntegrationTests extends BaseSqsIntegrationTest {
116121

117122
static final String OBSERVES_MESSAGE_FIFO_QUEUE_NAME = "observes_fifo_message_test_queue.fifo";
118123

124+
static final String FIFO_VISIBILITY_TIMEOUT_EXTENSION_QUEUE_NAME = "fifo_visibility_timeout_extension_test_queue.fifo";
125+
119126
private static final String ERROR_ON_ACK_FACTORY = "errorOnAckFactory";
127+
private static final String VISIBILITY_TIMEOUT_EXTENSION_FACTORY = "visibilityTimeoutExtensionFactory";
120128

121129
@Autowired
122130
LatchContainer latchContainer;
@@ -165,6 +173,7 @@ static void beforeTests() {
165173
createFifoQueue(client, FIFO_MANUALLY_CREATE_FACTORY_QUEUE_NAME),
166174
createFifoQueue(client, FIFO_MANUALLY_CREATE_BATCH_CONTAINER_QUEUE_NAME),
167175
createFifoQueue(client, OBSERVES_MESSAGE_FIFO_QUEUE_NAME),
176+
createFifoQueue(client, FIFO_VISIBILITY_TIMEOUT_EXTENSION_QUEUE_NAME, getVisibilityAttribute("5")),
168177
createFifoQueue(client, FIFO_MANUALLY_CREATE_BATCH_FACTORY_QUEUE_NAME)).join();
169178
}
170179

@@ -460,6 +469,26 @@ public void onMessage(Collection<Message<String>> messages) {
460469

461470
}
462471

472+
@Test
473+
void visibilityTimeoutExtensionWorksForFifoBatch() throws Exception {
474+
final int messageCount = this.settings.messagesPerMessageGroup;
475+
// There will be messageCount - 1 requests to change visibility, before each message except the first one
476+
latchContainer.visibilityTimeoutExtensionLatch = new CountDownLatch(messageCount - 1);
477+
478+
List<String> values = IntStream.range(0, messageCount).mapToObj(String::valueOf).toList();
479+
String messageGroupId = UUID.randomUUID().toString();
480+
sqsTemplate.sendMany(FIFO_VISIBILITY_TIMEOUT_EXTENSION_QUEUE_NAME,
481+
createMessagesFromValues(messageGroupId, values));
482+
483+
assertThat(latchContainer.visibilityTimeoutExtensionLatch.await(settings.latchTimeoutSeconds, TimeUnit.SECONDS))
484+
.isTrue();
485+
List<Integer> expectedRequestEntryCounts = IntStream.range(1, messageCount).map(i -> messageCount - i).boxed()
486+
.toList();
487+
assertThat(messagesContainer.visibilityTimeoutExtensionBatchRequests).as(
488+
"Number of entries in each ChangeMessageVisibilityBatchRequest should decrease by 1 on each message")
489+
.extracting(List::size).containsExactlyElementsOf(expectedRequestEntryCounts);
490+
}
491+
463492
@Test
464493
void manuallyCreatesContainer() throws Exception {
465494
List<String> values = IntStream.range(0, this.settings.messagesPerTest).mapToObj(String::valueOf)
@@ -657,6 +686,13 @@ void listen(List<Message<String>> messages) {
657686

658687
}
659688

689+
static class VisibilityTimeoutExtensionListener {
690+
@SqsListener(queueNames = FIFO_VISIBILITY_TIMEOUT_EXTENSION_QUEUE_NAME, messageVisibilitySeconds = "5", factory = VISIBILITY_TIMEOUT_EXTENSION_FACTORY)
691+
void listen(String message) {
692+
logger.debug("Processing message: {}", message);
693+
}
694+
}
695+
660696
static class LatchContainer {
661697

662698
Settings settings;
@@ -678,6 +714,7 @@ static class LatchContainer {
678714
CountDownLatch stopsProcessingOnAckErrorHasThrown;
679715
CountDownLatch receivesBatchManyGroupsLatch;
680716
CountDownLatch receivesFifoBatchGroupingStrategyMultipleGroupsInSameBatchLatch;
717+
CountDownLatch visibilityTimeoutExtensionLatch;
681718

682719
LatchContainer(Settings settings) {
683720
this.settings = settings;
@@ -698,6 +735,7 @@ static class LatchContainer {
698735
this.receivesFifoBatchGroupingStrategyMultipleGroupsInSameBatchLatch = new CountDownLatch(1);
699736
this.stopsProcessingOnAckErrorHasThrown = new CountDownLatch(1);
700737
this.observesFifoMessageLatch = new CountDownLatch(1);
738+
this.visibilityTimeoutExtensionLatch = new CountDownLatch(1);
701739
}
702740

703741
}
@@ -711,6 +749,7 @@ static class MessagesContainer {
711749
List<String> manuallyCreatedBatchFactoryMessages = Collections.synchronizedList(new ArrayList<>());
712750
List<String> stopsProcessingOnAckErrorBeforeThrown = Collections.synchronizedList(new ArrayList<>());
713751
List<String> stopsProcessingOnAckErrorAfterThrown = Collections.synchronizedList(new ArrayList<>());
752+
List<List<String>> visibilityTimeoutExtensionBatchRequests = Collections.synchronizedList(new ArrayList<>());
714753

715754
}
716755

@@ -820,6 +859,28 @@ private void handleResult(Message<String> message) {
820859
return factory;
821860
}
822861

862+
@Bean(VISIBILITY_TIMEOUT_EXTENSION_FACTORY)
863+
SqsMessageListenerContainerFactory<String> visibilityTrackingSqsListenerContainerFactory() {
864+
SqsAsyncClient spyAsyncClient = spy(createAsyncClient());
865+
866+
doAnswer(invocation -> {
867+
ChangeMessageVisibilityBatchRequest request = invocation.getArgument(0);
868+
messagesContainer.visibilityTimeoutExtensionBatchRequests.add(request.entries().stream().map(ChangeMessageVisibilityBatchRequestEntry::receiptHandle).toList());
869+
latchContainer.visibilityTimeoutExtensionLatch.countDown();
870+
871+
return invocation.callRealMethod();
872+
}).when(spyAsyncClient).changeMessageVisibilityBatch(any(ChangeMessageVisibilityBatchRequest.class));
873+
874+
SqsMessageListenerContainerFactory<String> factory = new SqsMessageListenerContainerFactory<>();
875+
factory.configure(options -> options
876+
.maxConcurrentMessages(10)
877+
.acknowledgementThreshold(10)
878+
.acknowledgementOrdering(AcknowledgementOrdering.ORDERED_BY_GROUP)
879+
.messageVisibility(Duration.ofSeconds(5)));
880+
factory.setSqsAsyncClientSupplier(() -> spyAsyncClient);
881+
return factory;
882+
}
883+
823884
@Bean
824885
public MessageListenerContainer<String> manuallyCreatedContainer() {
825886
SqsMessageListenerContainer<String> container = new SqsMessageListenerContainer<>(createAsyncClient());
@@ -931,6 +992,11 @@ ReceivesBatchesFromManyGroupsListener receiveBatchesFromManyGroupsListener() {
931992
return new ReceivesBatchesFromManyGroupsListener();
932993
}
933994

995+
@Bean
996+
VisibilityTimeoutExtensionListener visibilityTimeoutExtensionListener() {
997+
return new VisibilityTimeoutExtensionListener();
998+
}
999+
9341000
@Bean
9351001
ObservesFifoMessageListener observesFifoMessageListener() {
9361002
return new ObservesFifoMessageListener();

0 commit comments

Comments
 (0)