Skip to content
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
import io.awspring.cloud.sqs.operations.SqsTemplate;
import io.awspring.cloud.sqs.operations.SqsTemplateBuilder;
import io.awspring.cloud.sqs.support.converter.MessagingMessageConverter;
import io.awspring.cloud.sqs.support.converter.SqsHeaderMapper;
import io.awspring.cloud.sqs.support.converter.SqsMessagingMessageConverter;
import io.awspring.cloud.sqs.support.observation.SqsListenerObservation;
import io.awspring.cloud.sqs.support.observation.SqsTemplateObservation;
Expand All @@ -59,6 +60,7 @@
* @author Maciej Walkowiak
* @author Wei Jiang
* @author Dongha Kim
* @author Jeongmin Kim
* @since 3.0
*/
@AutoConfiguration
Expand Down Expand Up @@ -146,7 +148,13 @@ private void setMapperToConverter(MessagingMessageConverter<?> messagingMessageC
@ConditionalOnMissingBean
@Bean
public MessagingMessageConverter<Message> messageConverter() {
return new SqsMessagingMessageConverter();
SqsMessagingMessageConverter converter = new SqsMessagingMessageConverter();

SqsHeaderMapper headerMapper = new SqsHeaderMapper();
headerMapper.setConvertMessageIdToUuid(this.sqsProperties.getConvertMessageIdToUuid());
converter.setHeaderMapper(headerMapper);

return converter;
}

private void configureProperties(SqsContainerOptionsBuilder options) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
*
* @author Tomaz Fernandes
* @author Wei Jiang
* @author Jeongmin Kim
* @since 3.0
*/
@ConfigurationProperties(prefix = SqsProperties.PREFIX)
Expand All @@ -51,6 +52,16 @@ public void setListener(Listener listener) {

private Boolean observationEnabled = false;

private Boolean convertMessageIdToUuid = true;

public Boolean getConvertMessageIdToUuid() {
return convertMessageIdToUuid;
}

public void setConvertMessageIdToUuid(Boolean convertMessageIdToUuid) {
this.convertMessageIdToUuid = convertMessageIdToUuid;
}

/**
* Return the strategy to use if the queue is not found.
* @return the {@link QueueNotFoundStrategy}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
*/
package io.awspring.cloud.sqs;

import io.awspring.cloud.sqs.listener.SqsHeaders;
import io.awspring.cloud.sqs.support.converter.MessagingMessageHeaders;
import java.util.Collection;
import java.util.Map;
Expand All @@ -30,6 +31,7 @@
* Utility class for extracting {@link MessageHeaders} from a {@link Message}.
*
* @author Tomaz Fernandes
* @author Jeongmin Kim
* @since 3.0
*/
public class MessageHeaderUtils {
Expand Down Expand Up @@ -150,4 +152,22 @@ public static <T> Message<T> removeHeaderIfPresent(Message<T> message, String ke
return new GenericMessage<>(message.getPayload(), newHeaders);
}

/**
* Return the AWS message ID, falling back to Spring message ID if not present.
* @param message the message.
* @return the AWS ID or Spring ID.
*/
public static String getAwsMessageId(Message<?> message) {
String awsMessageId = message.getHeaders().get(SqsHeaders.SQS_AWS_MESSAGE_ID_HEADER, String.class);
return awsMessageId != null ? awsMessageId : getId(message);
}

/**
* Return the messages' AWS ID as a concatenated {@link String}.
* @param messages the messages.
* @return the AWS IDs.
*/
public static <T> String getAwsMessageId(Collection<Message<T>> messages) {
return messages.stream().map(MessageHeaderUtils::getAwsMessageId).collect(Collectors.joining("; "));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
* {@link org.springframework.messaging.MessageHeaders#get} or
* {@link org.springframework.messaging.handler.annotation.Header} parameter annotations.
* @author Tomaz Fernandes
* @author jeongmin Kim
* @since 3.0
* @see io.awspring.cloud.sqs.support.converter.SqsHeaderMapper
*/
Expand Down Expand Up @@ -84,6 +85,11 @@ private SqsHeaders() {
*/
public static final String SQS_DEFAULT_TYPE_HEADER = "JavaType";

/**
* Header for the original AWS MessageId when not using UUID conversion.
*/
public static final String SQS_AWS_MESSAGE_ID_HEADER = SQS_HEADER_PREFIX + "AWSMessageId";

public static class MessageSystemAttributes {

private MessageSystemAttributes() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@
*
* @author Tomaz Fernandes
* @author Zhong Xi Lu
* @author Jeongmin Kim
*
* @since 3.0
*/
Expand Down Expand Up @@ -355,7 +356,7 @@ protected <T> CompletableFuture<SendResult.Batch<T>> doSendBatchAsync(String end
logger.debug("Sending messages {} to endpoint {}", messages, endpointName);
return createSendMessageBatchRequest(endpointName, messages).thenCompose(this.sqsAsyncClient::sendMessageBatch)
.thenApply(response -> createSendResultBatch(response, endpointName,
originalMessages.stream().collect(Collectors.toMap(MessageHeaderUtils::getId, msg -> msg))));
originalMessages.stream().collect(Collectors.toMap(MessageHeaderUtils::getAwsMessageId, msg -> msg))));
}

private <T> SendResult.Batch<T> createSendResultBatch(SendMessageBatchResponse response, String endpointName,
Expand Down Expand Up @@ -526,7 +527,7 @@ private Map<String, Object> addMissingFifoReceiveHeaders(Map<String, Object> hea
private CompletableFuture<Void> deleteMessages(String endpointName,
Collection<org.springframework.messaging.Message<?>> messages) {
logger.trace("Acknowledging in queue {} messages {}", endpointName,
MessageHeaderUtils.getId(addTypeToMessages(messages)));
MessageHeaderUtils.getAwsMessageId(addTypeToMessages(messages)));
return getQueueAttributes(endpointName)
.thenCompose(attributes -> this.sqsAsyncClient.deleteMessageBatch(DeleteMessageBatchRequest.builder()
.queueUrl(attributes.getQueueUrl()).entries(createDeleteMessageEntries(messages)).build()))
Expand All @@ -545,7 +546,7 @@ private Collection<org.springframework.messaging.Message<?>> getFailedAckMessage
DeleteMessageBatchResponse response, Collection<org.springframework.messaging.Message<?>> messages,
String endpointName) {
return response.failed().stream().map(BatchResultErrorEntry::id)
.map(id -> messages.stream().filter(msg -> MessageHeaderUtils.getId(msg).equals(id)).findFirst()
.map(id -> messages.stream().filter(msg -> MessageHeaderUtils.getAwsMessageId(msg).equals(id)).findFirst()
.orElseThrow(() -> new SqsAcknowledgementException(
"Could not correlate ids for acknowledgement failure", Collections.emptyList(),
messages, endpointName)))
Expand All @@ -556,7 +557,7 @@ private Collection<org.springframework.messaging.Message<?>> getSuccessfulAckMes
DeleteMessageBatchResponse response, Collection<org.springframework.messaging.Message<?>> messages,
String endpointName) {
return response.successful().stream().map(DeleteMessageBatchResultEntry::id)
.map(id -> messages.stream().filter(msg -> MessageHeaderUtils.getId(msg).equals(id)).findFirst()
.map(id -> messages.stream().filter(msg -> MessageHeaderUtils.getAwsMessageId(msg).equals(id)).findFirst()
.orElseThrow(() -> new SqsAcknowledgementException(
"Could not correlate ids for acknowledgement failure", Collections.emptyList(),
messages, endpointName)))
Expand All @@ -574,22 +575,22 @@ private void logAcknowledgement(String endpointName, Collection<org.springframew
DeleteMessageBatchResponse response, @Nullable Throwable t) {
if (t != null) {
logger.error("Error acknowledging in queue {} messages {}", endpointName,
MessageHeaderUtils.getId(addTypeToMessages(messages)));
MessageHeaderUtils.getAwsMessageId(addTypeToMessages(messages)));
}
else if (!response.failed().isEmpty()) {
logger.warn("Some messages could not be acknowledged in queue {}: {}", endpointName,
response.failed().stream().map(BatchResultErrorEntry::id).toList());
}
else {
logger.trace("Acknowledged messages in queue {}: {}", endpointName,
MessageHeaderUtils.getId(addTypeToMessages(messages)));
MessageHeaderUtils.getAwsMessageId(addTypeToMessages(messages)));
}
}

private Collection<DeleteMessageBatchRequestEntry> createDeleteMessageEntries(
Collection<org.springframework.messaging.Message<?>> messages) {
return messages.stream()
.map(message -> DeleteMessageBatchRequestEntry.builder().id(MessageHeaderUtils.getId(message))
.map(message -> DeleteMessageBatchRequestEntry.builder().id(MessageHeaderUtils.getAwsMessageId(message))
.receiptHandle(
MessageHeaderUtils.getHeaderAsString(message, SqsHeaders.SQS_RECEIPT_HANDLE_HEADER))
.build())
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@
* @author Tomaz Fernandes
* @author Alain Sahli
* @author Maciej Walkowiak
* @author Jeongmin Kim
*
* @since 3.0
* @see SqsMessagingMessageConverter
Expand All @@ -61,12 +62,18 @@ public class SqsHeaderMapper implements ContextAwareHeaderMapper<Message> {
private BiFunction<Message, MessageHeaderAccessor, MessageHeaders> additionalHeadersFunction = ((message,
accessor) -> accessor.toMessageHeaders());

private boolean convertMessageIdToUuid = true;

public void setAdditionalHeadersFunction(
BiFunction<Message, MessageHeaderAccessor, MessageHeaders> headerFunction) {
Assert.notNull(headerFunction, "headerFunction cannot be null");
this.additionalHeadersFunction = headerFunction;
}

public void setConvertMessageIdToUuid(boolean convertMessageIdToUuid) {
this.convertMessageIdToUuid = convertMessageIdToUuid;
}

@Override
public Message fromHeaders(MessageHeaders headers) {
Message.Builder builder = Message.builder();
Expand Down Expand Up @@ -156,9 +163,27 @@ public MessageHeaders toHeaders(Message source) {
accessor.copyHeadersIfAbsent(getMessageAttributesAsHeaders(source));
accessor.copyHeadersIfAbsent(createDefaultHeaders(source));
accessor.copyHeadersIfAbsent(createAdditionalHeaders(source));
MessageHeaders messageHeaders = accessor.toMessageHeaders();
logger.trace("Mapped headers {} for message {}", messageHeaders, source.messageId());
return new MessagingMessageHeaders(messageHeaders, UUID.fromString(source.messageId()));

if (convertMessageIdToUuid && isValidUuid(source.messageId())) {
MessageHeaders messageHeaders = accessor.toMessageHeaders();
logger.trace("Mapped headers {} for message {}", messageHeaders, source.messageId());
return new MessagingMessageHeaders(messageHeaders, UUID.fromString(source.messageId()));
} else {
accessor.setHeader(SqsHeaders.SQS_AWS_MESSAGE_ID_HEADER, source.messageId());
MessageHeaders messageHeaders = accessor.toMessageHeaders();
logger.trace("Mapped headers {} for message {}", messageHeaders, source.messageId());
return new MessagingMessageHeaders(messageHeaders, UUID.randomUUID());
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Passing a random UUID does not feel right. Perhaps MessagingMessageHeaders can have id set as String instead of UUID? cc @tomazfernandes

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the suggestion! Updated to generate UUID from Message ID instead of random UUID.

}

}

private boolean isValidUuid(String messageId) {
try {
UUID.fromString(messageId);
return true;
} catch (IllegalArgumentException e) {
return false;
}
}

private MessageHeaders createAdditionalHeaders(Message source) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,19 @@

import static org.assertj.core.api.Assertions.assertThat;

import io.awspring.cloud.sqs.listener.SqsHeaders;
import org.junit.jupiter.api.Test;
import org.springframework.messaging.Message;
import org.springframework.messaging.support.MessageBuilder;

import java.util.Collection;
import java.util.List;

/**
* Tests for {@link MessageHeaderUtils}.
*
* @author Tomaz Fernandes
* @author Jeongmin Kim
*/
class MessageHeaderUtilsTest {

Expand Down Expand Up @@ -93,4 +98,54 @@ void shouldPreserveOtherHeaders() {
assertThat(result.getHeaders().get("another-header")).isEqualTo("another-value");
assertThat(result.getHeaders().size()).isEqualTo(message.getHeaders().size() - 1);
}

@Test
void shouldReturnAwsMessageIdWhenHeaderPresent() {
// given
String awsMessageId = "92898073-7bd6a160-5797b060-54a7e539";
Message<String> message = MessageBuilder.withPayload("test-payload")
.setHeader(SqsHeaders.SQS_AWS_MESSAGE_ID_HEADER, awsMessageId)
.build();

// when
String result = MessageHeaderUtils.getAwsMessageId(message);

// then
assertThat(result).isEqualTo(awsMessageId);
}

@Test
void shouldFallbackToSpringMessageIdWhenAwsHeaderNotPresent() {
// given
Message<String> message = MessageBuilder.withPayload("test-payload").build();
String expectedId = message.getHeaders().getId().toString();

// when
String result = MessageHeaderUtils.getAwsMessageId(message);

// then
assertThat(result).isEqualTo(expectedId);
}

@Test
void shouldConcatenateAwsMessageIdsFromCollection() {
// given
String awsMessageId1 = "aws-id-1";
String awsMessageId2 = "aws-id-2";

Message<String> message1 = MessageBuilder.withPayload("payload1")
.setHeader(SqsHeaders.SQS_AWS_MESSAGE_ID_HEADER, awsMessageId1)
.build();
Message<String> message2 = MessageBuilder.withPayload("payload2")
.setHeader(SqsHeaders.SQS_AWS_MESSAGE_ID_HEADER, awsMessageId2)
.build();

Collection<Message<String>> messages = List.of(message1, message2);

// when
String result = MessageHeaderUtils.getAwsMessageId(messages);

// then
assertThat(result).isEqualTo("aws-id-1; aws-id-2");
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
*
* @author Tomaz Fernandes
* @author Maciej Walkowiak
* @author Jeongmin Kim
*/
class SqsHeaderMapperTests {

Expand Down Expand Up @@ -177,6 +178,35 @@ void createsMessageWithNumberHeader(String value, String type, Number expected)
assertThat(headers.get(headerName)).isEqualTo(expected);
}

@Test
void shouldConvertUuidMessageIdWhenConvertMessageIdToUuidIsTrue() {
SqsHeaderMapper mapper = new SqsHeaderMapper();
mapper.setConvertMessageIdToUuid(true);
String uuidMessageId = "550e8400-e29b-41d4-a716-446655440000";
Message message = Message.builder()
.body("payload")
.messageId(uuidMessageId)
.build();
MessageHeaders headers = mapper.toHeaders(message);
assertThat(headers.getId()).isEqualTo(UUID.fromString(uuidMessageId));
assertThat(headers.get(SqsHeaders.SQS_AWS_MESSAGE_ID_HEADER)).isNull();
}

@Test
void shouldStoreAwsMessageIdInHeaderWhenConvertMessageIdToUuidIsFalse() {
SqsHeaderMapper mapper = new SqsHeaderMapper();
mapper.setConvertMessageIdToUuid(false);
String nonUuidMessageId = "92898073-7bd6a160-5797b060-54a7e539";
Message message = Message.builder()
.body("payload")
.messageId(nonUuidMessageId)
.build();
MessageHeaders headers = mapper.toHeaders(message);
assertThat(headers.get(SqsHeaders.SQS_AWS_MESSAGE_ID_HEADER)).isEqualTo(nonUuidMessageId);
assertThat(headers.getId()).isNotEqualTo(nonUuidMessageId);
assertThat(headers.getId()).isNotNull();
}

private static Stream<Arguments> validArguments() {
return Stream.of(Arguments.of("10", "Number", BigDecimal.valueOf(10)),
Arguments.of("3", "Number.byte", (byte) 3), Arguments.of("3", "Number.Byte", (byte) 3),
Expand Down