diff --git a/spring-cloud-aws-sqs/src/main/java/io/awspring/cloud/sqs/config/AbstractMessageListenerContainerFactory.java b/spring-cloud-aws-sqs/src/main/java/io/awspring/cloud/sqs/config/AbstractMessageListenerContainerFactory.java index 1de54cbae..09ce153d8 100644 --- a/spring-cloud-aws-sqs/src/main/java/io/awspring/cloud/sqs/config/AbstractMessageListenerContainerFactory.java +++ b/spring-cloud-aws-sqs/src/main/java/io/awspring/cloud/sqs/config/AbstractMessageListenerContainerFactory.java @@ -16,26 +16,20 @@ package io.awspring.cloud.sqs.config; import io.awspring.cloud.sqs.ConfigUtils; -import io.awspring.cloud.sqs.listener.AbstractMessageListenerContainer; -import io.awspring.cloud.sqs.listener.AsyncComponentAdapters; -import io.awspring.cloud.sqs.listener.AsyncMessageListener; -import io.awspring.cloud.sqs.listener.ContainerComponentFactory; -import io.awspring.cloud.sqs.listener.ContainerOptions; -import io.awspring.cloud.sqs.listener.ContainerOptionsBuilder; -import io.awspring.cloud.sqs.listener.MessageListener; -import io.awspring.cloud.sqs.listener.MessageListenerContainer; +import io.awspring.cloud.sqs.listener.*; import io.awspring.cloud.sqs.listener.acknowledgement.AcknowledgementResultCallback; import io.awspring.cloud.sqs.listener.acknowledgement.AsyncAcknowledgementResultCallback; import io.awspring.cloud.sqs.listener.errorhandler.AsyncErrorHandler; import io.awspring.cloud.sqs.listener.errorhandler.ErrorHandler; import io.awspring.cloud.sqs.listener.interceptor.AsyncMessageInterceptor; import io.awspring.cloud.sqs.listener.interceptor.MessageInterceptor; +import org.springframework.messaging.Message; +import org.springframework.util.Assert; + import java.util.ArrayList; import java.util.Arrays; import java.util.Collection; import java.util.function.Consumer; -import org.springframework.messaging.Message; -import org.springframework.util.Assert; /** * Base implementation for a {@link MessageListenerContainerFactory}. Contains the components and diff --git a/spring-cloud-aws-sqs/src/main/java/io/awspring/cloud/sqs/config/SqsMessageListenerContainerFactory.java b/spring-cloud-aws-sqs/src/main/java/io/awspring/cloud/sqs/config/SqsMessageListenerContainerFactory.java index e2f9f4902..768b46431 100644 --- a/spring-cloud-aws-sqs/src/main/java/io/awspring/cloud/sqs/config/SqsMessageListenerContainerFactory.java +++ b/spring-cloud-aws-sqs/src/main/java/io/awspring/cloud/sqs/config/SqsMessageListenerContainerFactory.java @@ -17,29 +17,24 @@ import io.awspring.cloud.sqs.ConfigUtils; import io.awspring.cloud.sqs.annotation.SqsListener; -import io.awspring.cloud.sqs.listener.AsyncMessageListener; -import io.awspring.cloud.sqs.listener.ContainerComponentFactory; -import io.awspring.cloud.sqs.listener.ContainerOptions; -import io.awspring.cloud.sqs.listener.MessageListener; -import io.awspring.cloud.sqs.listener.SqsContainerOptions; -import io.awspring.cloud.sqs.listener.SqsContainerOptionsBuilder; -import io.awspring.cloud.sqs.listener.SqsMessageListenerContainer; +import io.awspring.cloud.sqs.listener.*; import io.awspring.cloud.sqs.listener.acknowledgement.AcknowledgementResultCallback; import io.awspring.cloud.sqs.listener.acknowledgement.AsyncAcknowledgementResultCallback; import io.awspring.cloud.sqs.listener.errorhandler.AsyncErrorHandler; import io.awspring.cloud.sqs.listener.errorhandler.ErrorHandler; import io.awspring.cloud.sqs.listener.interceptor.AsyncMessageInterceptor; import io.awspring.cloud.sqs.listener.interceptor.MessageInterceptor; -import java.util.ArrayList; -import java.util.Collection; -import java.util.function.Consumer; -import java.util.function.Supplier; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import org.springframework.messaging.Message; import org.springframework.util.Assert; import software.amazon.awssdk.services.sqs.SqsAsyncClient; +import java.util.ArrayList; +import java.util.Collection; +import java.util.function.Consumer; +import java.util.function.Supplier; + /** * {@link MessageListenerContainerFactory} implementation for creating {@link SqsMessageListenerContainer} instances. A * factory can be assigned to a {@link io.awspring.cloud.sqs.annotation.SqsListener @SqsListener} by using the diff --git a/spring-cloud-aws-sqs/src/main/java/io/awspring/cloud/sqs/listener/AbstractContainerOptions.java b/spring-cloud-aws-sqs/src/main/java/io/awspring/cloud/sqs/listener/AbstractContainerOptions.java index 81f4eb3f2..e5adaa0fc 100644 --- a/spring-cloud-aws-sqs/src/main/java/io/awspring/cloud/sqs/listener/AbstractContainerOptions.java +++ b/spring-cloud-aws-sqs/src/main/java/io/awspring/cloud/sqs/listener/AbstractContainerOptions.java @@ -20,6 +20,9 @@ import io.awspring.cloud.sqs.support.converter.MessagingMessageConverter; import io.awspring.cloud.sqs.support.converter.SqsMessagingMessageConverter; import java.time.Duration; + +import io.awspring.cloud.sqs.support.filter.DefaultMessageFilter; +import io.awspring.cloud.sqs.support.filter.MessageFilter; import org.springframework.core.task.TaskExecutor; import org.springframework.lang.Nullable; import org.springframework.retry.backoff.BackOffPolicy; @@ -59,6 +62,8 @@ public abstract class AbstractContainerOptions, private final AcknowledgementMode acknowledgementMode; + private final MessageFilter messageFilter; + @Nullable private final AcknowledgementOrdering acknowledgementOrdering; @@ -92,6 +97,8 @@ protected AbstractContainerOptions(Builder builder) { this.acknowledgementThreshold = builder.acknowledgementThreshold; this.componentsTaskExecutor = builder.componentsTaskExecutor; this.acknowledgementResultTaskExecutor = builder.acknowledgementResultTaskExecutor; + this.messageFilter = builder.messageFilter; + Assert.isTrue(this.maxMessagesPerPoll <= this.maxConcurrentMessages, String.format( "messagesPerPoll should be less than or equal to maxConcurrentMessages. Values provided: %s and %s respectively", this.maxMessagesPerPoll, this.maxConcurrentMessages)); @@ -164,6 +171,11 @@ public MessagingMessageConverter getMessageConverter() { return this.messageConverter; } + @Override + public MessageFilter getMessageFilter() { + return this.messageFilter; + } + @Nullable @Override public Duration getAcknowledgementInterval() { @@ -244,6 +256,8 @@ protected abstract static class Builder, private AcknowledgementMode acknowledgementMode = DEFAULT_ACKNOWLEDGEMENT_MODE; + private MessageFilter messageFilter = new DefaultMessageFilter<>(); + @Nullable private AcknowledgementOrdering acknowledgementOrdering; @@ -280,6 +294,7 @@ protected Builder(AbstractContainerOptions options) { this.acknowledgementThreshold = options.acknowledgementThreshold; this.componentsTaskExecutor = options.componentsTaskExecutor; this.acknowledgementResultTaskExecutor = options.acknowledgementResultTaskExecutor; + this.messageFilter = options.messageFilter; } @Override @@ -400,6 +415,13 @@ public B messageConverter(MessagingMessageConverter messageConverter) { return self(); } + @Override + public B messageFilter(MessageFilter messageFilter) { + Assert.notNull(messageFilter, "messageFilter cannot be null"); + this.messageFilter = messageFilter; + return self(); + } + @SuppressWarnings("unchecked") private B self() { return (B) this; diff --git a/spring-cloud-aws-sqs/src/main/java/io/awspring/cloud/sqs/listener/AbstractMessageListenerContainer.java b/spring-cloud-aws-sqs/src/main/java/io/awspring/cloud/sqs/listener/AbstractMessageListenerContainer.java index 9566fbb7a..952844761 100644 --- a/spring-cloud-aws-sqs/src/main/java/io/awspring/cloud/sqs/listener/AbstractMessageListenerContainer.java +++ b/spring-cloud-aws-sqs/src/main/java/io/awspring/cloud/sqs/listener/AbstractMessageListenerContainer.java @@ -21,11 +21,6 @@ import io.awspring.cloud.sqs.listener.errorhandler.ErrorHandler; import io.awspring.cloud.sqs.listener.interceptor.AsyncMessageInterceptor; import io.awspring.cloud.sqs.listener.interceptor.MessageInterceptor; -import java.util.ArrayList; -import java.util.Arrays; -import java.util.Collection; -import java.util.Collections; -import java.util.function.Consumer; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import org.springframework.context.SmartLifecycle; @@ -33,6 +28,12 @@ import org.springframework.messaging.Message; import org.springframework.util.Assert; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collection; +import java.util.Collections; +import java.util.function.Consumer; + /** * Base implementation for {@link MessageListenerContainer} with {@link SmartLifecycle} and component management * capabilities. diff --git a/spring-cloud-aws-sqs/src/main/java/io/awspring/cloud/sqs/listener/AbstractPipelineMessageListenerContainer.java b/spring-cloud-aws-sqs/src/main/java/io/awspring/cloud/sqs/listener/AbstractPipelineMessageListenerContainer.java index 6808f647a..cffbda357 100644 --- a/spring-cloud-aws-sqs/src/main/java/io/awspring/cloud/sqs/listener/AbstractPipelineMessageListenerContainer.java +++ b/spring-cloud-aws-sqs/src/main/java/io/awspring/cloud/sqs/listener/AbstractPipelineMessageListenerContainer.java @@ -30,6 +30,7 @@ import io.awspring.cloud.sqs.listener.pipeline.MessageProcessingConfiguration; import io.awspring.cloud.sqs.listener.pipeline.MessageProcessingPipeline; import io.awspring.cloud.sqs.listener.pipeline.MessageProcessingPipelineBuilder; +import io.awspring.cloud.sqs.listener.sink.AbstractMessageProcessingPipelineSink; import io.awspring.cloud.sqs.listener.sink.MessageProcessingPipelineSink; import io.awspring.cloud.sqs.listener.sink.MessageSink; import io.awspring.cloud.sqs.listener.source.AcknowledgementProcessingMessageSource; @@ -42,6 +43,8 @@ import java.util.concurrent.ThreadFactory; import java.util.stream.Collectors; import java.util.stream.IntStream; + +import io.awspring.cloud.sqs.support.filter.MessageFilter; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import org.springframework.core.task.SimpleAsyncTaskExecutor; @@ -174,7 +177,10 @@ protected void configureMessageSink(MessageProcessingPipeline messageProcessi .acceptIfInstance(this.messageSink, TaskExecutorAware.class, teac -> teac.setTaskExecutor(getComponentsTaskExecutor())) .acceptIfInstance(this.messageSink, MessageProcessingPipelineSink.class, - mls -> mls.setMessagePipeline(messageProcessingPipeline)); + mls -> mls.setMessagePipeline(messageProcessingPipeline)) + .acceptIfInstance(this.messageSink, AbstractMessageProcessingPipelineSink.class, + s -> s.setMessageFilter(getContainerOptions().getMessageFilter())); + doConfigureMessageSink(this.messageSink); } diff --git a/spring-cloud-aws-sqs/src/main/java/io/awspring/cloud/sqs/listener/AsyncComponentAdapters.java b/spring-cloud-aws-sqs/src/main/java/io/awspring/cloud/sqs/listener/AsyncComponentAdapters.java index c2563002a..a11eca8bf 100644 --- a/spring-cloud-aws-sqs/src/main/java/io/awspring/cloud/sqs/listener/AsyncComponentAdapters.java +++ b/spring-cloud-aws-sqs/src/main/java/io/awspring/cloud/sqs/listener/AsyncComponentAdapters.java @@ -23,16 +23,17 @@ import io.awspring.cloud.sqs.listener.errorhandler.ErrorHandler; import io.awspring.cloud.sqs.listener.interceptor.AsyncMessageInterceptor; import io.awspring.cloud.sqs.listener.interceptor.MessageInterceptor; -import java.util.Collection; -import java.util.concurrent.CompletableFuture; -import java.util.concurrent.CompletionException; -import java.util.function.Supplier; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import org.springframework.core.task.TaskExecutor; import org.springframework.messaging.Message; import org.springframework.util.Assert; +import java.util.Collection; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.CompletionException; +import java.util.function.Supplier; + /** * Utility class for adapting blocking components to their asynchronous variants. * diff --git a/spring-cloud-aws-sqs/src/main/java/io/awspring/cloud/sqs/listener/ContainerOptions.java b/spring-cloud-aws-sqs/src/main/java/io/awspring/cloud/sqs/listener/ContainerOptions.java index ad7313cf6..c5390c237 100644 --- a/spring-cloud-aws-sqs/src/main/java/io/awspring/cloud/sqs/listener/ContainerOptions.java +++ b/spring-cloud-aws-sqs/src/main/java/io/awspring/cloud/sqs/listener/ContainerOptions.java @@ -20,6 +20,8 @@ import io.awspring.cloud.sqs.support.converter.MessagingMessageConverter; import java.time.Duration; import java.util.Collection; + +import io.awspring.cloud.sqs.support.filter.MessageFilter; import org.springframework.core.task.TaskExecutor; import org.springframework.lang.Nullable; import org.springframework.retry.backoff.BackOffPolicy; @@ -139,6 +141,11 @@ default BackOffPolicy getPollBackOffPolicy() { */ MessagingMessageConverter getMessageConverter(); + /** Return the message filter applied before message processing. + * @return the message filter. + */ + MessageFilter getMessageFilter(); + /** * Return the maximum interval between acknowledgements for batch acknowledgements. * @return the interval. diff --git a/spring-cloud-aws-sqs/src/main/java/io/awspring/cloud/sqs/listener/ContainerOptionsBuilder.java b/spring-cloud-aws-sqs/src/main/java/io/awspring/cloud/sqs/listener/ContainerOptionsBuilder.java index 9d03b7964..29780630a 100644 --- a/spring-cloud-aws-sqs/src/main/java/io/awspring/cloud/sqs/listener/ContainerOptionsBuilder.java +++ b/spring-cloud-aws-sqs/src/main/java/io/awspring/cloud/sqs/listener/ContainerOptionsBuilder.java @@ -19,6 +19,8 @@ import io.awspring.cloud.sqs.listener.acknowledgement.handler.AcknowledgementMode; import io.awspring.cloud.sqs.support.converter.MessagingMessageConverter; import java.time.Duration; + +import io.awspring.cloud.sqs.support.filter.MessageFilter; import org.springframework.core.task.TaskExecutor; import org.springframework.retry.backoff.BackOffPolicy; @@ -187,6 +189,14 @@ default B pollBackOffPolicy(BackOffPolicy pollBackOffPolicy) { */ B messageConverter(MessagingMessageConverter messageConverter); + /** + * Set the {@link MessagingMessageConverter} for this container. + * + * @param messageFilter the message filter. + * @return this instance. + */ + B messageFilter(MessageFilter messageFilter); + /** * Create the {@link ContainerOptions} instance. * diff --git a/spring-cloud-aws-sqs/src/main/java/io/awspring/cloud/sqs/listener/FifoSqsComponentFactory.java b/spring-cloud-aws-sqs/src/main/java/io/awspring/cloud/sqs/listener/FifoSqsComponentFactory.java index 3ed94b16a..2321a56ef 100644 --- a/spring-cloud-aws-sqs/src/main/java/io/awspring/cloud/sqs/listener/FifoSqsComponentFactory.java +++ b/spring-cloud-aws-sqs/src/main/java/io/awspring/cloud/sqs/listener/FifoSqsComponentFactory.java @@ -22,9 +22,7 @@ import io.awspring.cloud.sqs.listener.acknowledgement.AcknowledgementProcessor; import io.awspring.cloud.sqs.listener.acknowledgement.BatchingAcknowledgementProcessor; import io.awspring.cloud.sqs.listener.acknowledgement.ImmediateAcknowledgementProcessor; -import io.awspring.cloud.sqs.listener.sink.BatchMessageSink; -import io.awspring.cloud.sqs.listener.sink.MessageSink; -import io.awspring.cloud.sqs.listener.sink.OrderedMessageSink; +import io.awspring.cloud.sqs.listener.sink.*; import io.awspring.cloud.sqs.listener.sink.adapter.MessageGroupingSinkAdapter; import io.awspring.cloud.sqs.listener.sink.adapter.MessageVisibilityExtendingSinkAdapter; import io.awspring.cloud.sqs.listener.source.FifoSqsMessageSource; @@ -69,7 +67,7 @@ public MessageSource createMessageSource(SqsContainerOptions options) { @Override public MessageSink createMessageSink(SqsContainerOptions options) { - MessageSink deliverySink = createDeliverySink(options.getListenerMode()); + MessageSink deliverySink = createDeliverySink(options.getListenerMode(), options); MessageSink wrappedDeliverySink = maybeWrapWithVisibilityAdapter(deliverySink, options.getMessageVisibility()); return maybeWrapWithMessageGroupingAdapter(options, wrappedDeliverySink); @@ -84,7 +82,7 @@ private MessageSink maybeWrapWithMessageGroupingAdapter(SqsContainerOptions o } // @formatter:off - private MessageSink createDeliverySink(ListenerMode listenerMode) { + private MessageSink createDeliverySink(ListenerMode listenerMode, SqsContainerOptions options) { return ListenerMode.SINGLE_MESSAGE.equals(listenerMode) ? new OrderedMessageSink<>() : new BatchMessageSink<>(); diff --git a/spring-cloud-aws-sqs/src/main/java/io/awspring/cloud/sqs/listener/StandardSqsComponentFactory.java b/spring-cloud-aws-sqs/src/main/java/io/awspring/cloud/sqs/listener/StandardSqsComponentFactory.java index cba188c70..70fcc5fed 100644 --- a/spring-cloud-aws-sqs/src/main/java/io/awspring/cloud/sqs/listener/StandardSqsComponentFactory.java +++ b/spring-cloud-aws-sqs/src/main/java/io/awspring/cloud/sqs/listener/StandardSqsComponentFactory.java @@ -21,9 +21,7 @@ import io.awspring.cloud.sqs.listener.acknowledgement.AcknowledgementProcessor; import io.awspring.cloud.sqs.listener.acknowledgement.BatchingAcknowledgementProcessor; import io.awspring.cloud.sqs.listener.acknowledgement.ImmediateAcknowledgementProcessor; -import io.awspring.cloud.sqs.listener.sink.BatchMessageSink; -import io.awspring.cloud.sqs.listener.sink.FanOutMessageSink; -import io.awspring.cloud.sqs.listener.sink.MessageSink; +import io.awspring.cloud.sqs.listener.sink.*; import io.awspring.cloud.sqs.listener.source.MessageSource; import io.awspring.cloud.sqs.listener.source.StandardSqsMessageSource; import java.time.Duration; @@ -58,11 +56,10 @@ public MessageSource createMessageSource(SqsContainerOptions options) { @Override public MessageSink createMessageSink(SqsContainerOptions options) { return ListenerMode.SINGLE_MESSAGE.equals(options.getListenerMode()) - ? new FanOutMessageSink<>() - : new BatchMessageSink<>(); + ? new FanOutMessageSink<>() : new BatchMessageSink<>(); } - // @formatter:on + // @formatter:on @Override public AcknowledgementProcessor createAcknowledgementProcessor(SqsContainerOptions options) { validateAcknowledgementOrdering(options); diff --git a/spring-cloud-aws-sqs/src/main/java/io/awspring/cloud/sqs/listener/sink/AbstractMessageProcessingPipelineSink.java b/spring-cloud-aws-sqs/src/main/java/io/awspring/cloud/sqs/listener/sink/AbstractMessageProcessingPipelineSink.java index bac3c116d..79020afb2 100644 --- a/spring-cloud-aws-sqs/src/main/java/io/awspring/cloud/sqs/listener/sink/AbstractMessageProcessingPipelineSink.java +++ b/spring-cloud-aws-sqs/src/main/java/io/awspring/cloud/sqs/listener/sink/AbstractMessageProcessingPipelineSink.java @@ -25,6 +25,8 @@ import java.util.concurrent.CompletableFuture; import java.util.concurrent.Executor; import java.util.function.Supplier; + +import io.awspring.cloud.sqs.support.filter.MessageFilter; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import org.springframework.context.SmartLifecycle; @@ -56,6 +58,8 @@ public abstract class AbstractMessageProcessingPipelineSink private MessageProcessingPipeline messageProcessingPipeline; + protected MessageFilter messageFilter; + private String id; @Override @@ -70,6 +74,11 @@ public void setTaskExecutor(TaskExecutor taskExecutor) { this.taskExecutor = taskExecutor; } + public void setMessageFilter(MessageFilter messageFilter) { + Assert.notNull(messageFilter, "messageFilter must not be null."); + this.messageFilter = messageFilter; + } + @Override public CompletableFuture emit(Collection> messages, MessageProcessingContext context) { Assert.notNull(messages, "messages cannot be null"); @@ -87,6 +96,10 @@ public CompletableFuture emit(Collection> messages, MessageProc protected abstract CompletableFuture doEmit(Collection> messages, MessageProcessingContext context); + protected CompletableFuture>> filterAsync(Collection> messages) { + return CompletableFuture.supplyAsync(() -> this.messageFilter.process(messages), this.taskExecutor); + } + /** * Send the provided {@link Message} to the {@link TaskExecutor} as a unit of work. * @param message the message to be executed. diff --git a/spring-cloud-aws-sqs/src/main/java/io/awspring/cloud/sqs/listener/sink/BatchMessageSink.java b/spring-cloud-aws-sqs/src/main/java/io/awspring/cloud/sqs/listener/sink/BatchMessageSink.java index bca81afbd..b2fbfd14b 100644 --- a/spring-cloud-aws-sqs/src/main/java/io/awspring/cloud/sqs/listener/sink/BatchMessageSink.java +++ b/spring-cloud-aws-sqs/src/main/java/io/awspring/cloud/sqs/listener/sink/BatchMessageSink.java @@ -31,7 +31,8 @@ public class BatchMessageSink extends AbstractMessageProcessingPipelineSink doEmit(Collection> messages, MessageProcessingContext context) { - return execute(messages, context).exceptionally(t -> logError(t, messages)); + return filterAsync(messages).thenCompose( + filtered -> execute(filtered, context).exceptionally(t -> logError(t, messages))); } } diff --git a/spring-cloud-aws-sqs/src/main/java/io/awspring/cloud/sqs/listener/sink/FanOutMessageSink.java b/spring-cloud-aws-sqs/src/main/java/io/awspring/cloud/sqs/listener/sink/FanOutMessageSink.java index 83406f041..e07924d3d 100644 --- a/spring-cloud-aws-sqs/src/main/java/io/awspring/cloud/sqs/listener/sink/FanOutMessageSink.java +++ b/spring-cloud-aws-sqs/src/main/java/io/awspring/cloud/sqs/listener/sink/FanOutMessageSink.java @@ -38,9 +38,10 @@ public class FanOutMessageSink extends AbstractMessageProcessingPipelineSink< @Override protected CompletableFuture doEmit(Collection> messages, MessageProcessingContext context) { logger.trace("Emitting messages {}", MessageHeaderUtils.getId(messages)); - return CompletableFuture.allOf(messages.stream().map(msg -> execute(msg, context) + return filterAsync(messages) + .thenCompose(filtered -> CompletableFuture.allOf(filtered.stream().map(msg -> execute(msg, context) // Should log errors individually - no need to propagate upstream - .exceptionally(t -> logError(t, msg))).toArray(CompletableFuture[]::new)); + .exceptionally(t -> logError(t, msg))).toArray(CompletableFuture[]::new))); } } diff --git a/spring-cloud-aws-sqs/src/main/java/io/awspring/cloud/sqs/listener/sink/OrderedMessageSink.java b/spring-cloud-aws-sqs/src/main/java/io/awspring/cloud/sqs/listener/sink/OrderedMessageSink.java index c498d4467..5ac5a4f9f 100644 --- a/spring-cloud-aws-sqs/src/main/java/io/awspring/cloud/sqs/listener/sink/OrderedMessageSink.java +++ b/spring-cloud-aws-sqs/src/main/java/io/awspring/cloud/sqs/listener/sink/OrderedMessageSink.java @@ -40,7 +40,9 @@ public class OrderedMessageSink extends AbstractMessageProcessingPipelineSink @Override protected CompletableFuture doEmit(Collection> messages, MessageProcessingContext context) { logger.trace("Emitting messages {}", MessageHeaderUtils.getId(messages)); - CompletableFuture execution = messages.stream().reduce(CompletableFuture.completedFuture(null), + CompletableFuture execution = filterAsync(messages) + .thenCompose(filtered -> + filtered.stream().reduce(CompletableFuture.completedFuture(null), (resultFuture, msg) -> CompletableFutures.handleCompose(resultFuture, (v, t) -> { if (t == null) { return execute(msg, context).whenComplete(logIfError(msg)); @@ -48,7 +50,7 @@ protected CompletableFuture doEmit(Collection> messages, Messag // Release backpressure from subsequent interrupted executions in case of errors. context.runBackPressureReleaseCallback(); return CompletableFutures.failedFuture(t); - }), (a, b) -> a); + }), (a, b) -> a)); return execution.exceptionally(t -> null); } diff --git a/spring-cloud-aws-sqs/src/main/java/io/awspring/cloud/sqs/support/filter/DefaultMessageFilter.java b/spring-cloud-aws-sqs/src/main/java/io/awspring/cloud/sqs/support/filter/DefaultMessageFilter.java new file mode 100644 index 000000000..238ce1eba --- /dev/null +++ b/spring-cloud-aws-sqs/src/main/java/io/awspring/cloud/sqs/support/filter/DefaultMessageFilter.java @@ -0,0 +1,13 @@ +package io.awspring.cloud.sqs.support.filter; + +import org.springframework.messaging.Message; + +import java.util.Collection; + +public class DefaultMessageFilter implements MessageFilter { + + @Override + public Collection> process(Collection> message) { + return message; + } +} diff --git a/spring-cloud-aws-sqs/src/main/java/io/awspring/cloud/sqs/support/filter/MessageFilter.java b/spring-cloud-aws-sqs/src/main/java/io/awspring/cloud/sqs/support/filter/MessageFilter.java new file mode 100644 index 000000000..fe51b5251 --- /dev/null +++ b/spring-cloud-aws-sqs/src/main/java/io/awspring/cloud/sqs/support/filter/MessageFilter.java @@ -0,0 +1,10 @@ +package io.awspring.cloud.sqs.support.filter; + +import org.springframework.messaging.Message; + +import java.util.Collection; + +public interface MessageFilter { + + Collection> process(Collection> message); +} diff --git a/spring-cloud-aws-sqs/src/test/java/io/awspring/cloud/sqs/integration/SqsMessageBatchFilterIntegrationTests.java b/spring-cloud-aws-sqs/src/test/java/io/awspring/cloud/sqs/integration/SqsMessageBatchFilterIntegrationTests.java new file mode 100644 index 000000000..fb6d48bcf --- /dev/null +++ b/spring-cloud-aws-sqs/src/test/java/io/awspring/cloud/sqs/integration/SqsMessageBatchFilterIntegrationTests.java @@ -0,0 +1,124 @@ +package io.awspring.cloud.sqs.integration; + +import io.awspring.cloud.sqs.annotation.SqsListener; +import io.awspring.cloud.sqs.config.SqsBootstrapConfiguration; +import io.awspring.cloud.sqs.config.SqsMessageListenerContainerFactory; +import io.awspring.cloud.sqs.listener.ListenerMode; +import io.awspring.cloud.sqs.operations.SqsTemplate; +import io.awspring.cloud.sqs.support.filter.MessageFilter; +import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.api.Test; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.boot.test.context.SpringBootTest; +import org.springframework.context.annotation.Bean; +import org.springframework.context.annotation.Configuration; +import org.springframework.context.annotation.Import; +import org.springframework.messaging.Message; +import software.amazon.awssdk.services.sqs.SqsAsyncClient; + +import java.util.Collection; +import java.util.List; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.CopyOnWriteArrayList; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.TimeUnit; +import java.util.stream.Collectors; + +import static org.assertj.core.api.Assertions.assertThat; + +@SpringBootTest +public class SqsMessageBatchFilterIntegrationTests extends BaseSqsIntegrationTest { + + private static final Logger log = LoggerFactory.getLogger(SqsMessageBatchFilterIntegrationTests.class); + + private static final String FILTER_QUEUE_BATCH = "filter-queue-batch"; + private static final String BATCH_FACTORY = "batchFilteringFactory"; + + @Autowired SqsTemplate sqsTemplate; + @Autowired LatchContainer latch; + + record SampleRecord(String propertyOne, String propertyTwo) {} + + @BeforeAll + static void setupQueues() { + SqsAsyncClient client = createAsyncClient(); + createQueue(client, FILTER_QUEUE_BATCH).join(); + } + + @Test + void shouldDeliverAllowedInBatch() throws Exception { + sqsTemplate.send(FILTER_QUEUE_BATCH, new SampleRecord("Hello", "world1")); + sqsTemplate.send(FILTER_QUEUE_BATCH, new SampleRecord("NotHello", "world2")); + sqsTemplate.send(FILTER_QUEUE_BATCH, new SampleRecord("Hello", "world3")); + + assertThat(latch.pass.await(10, TimeUnit.SECONDS)).isTrue(); + + assertThat(latch.received).containsExactly("world1", "world3"); + } + + // ==== Config + @Import(SqsBootstrapConfiguration.class) + @Configuration + static class Config { + @Bean(name = BATCH_FACTORY) + SqsMessageListenerContainerFactory batchFactory() { + return SqsMessageListenerContainerFactory.builder() + .configure(o -> { + o.messageFilter(new AllowHelloOnlyFilter()); + o.listenerMode(ListenerMode.BATCH); + o.maxMessagesPerPoll(10); + }) + .sqsAsyncClientSupplier(BaseSqsIntegrationTest::createAsyncClient) + .build(); + } + + @Bean SqsTemplate sqsTemplate() { + return SqsTemplate.builder().sqsAsyncClient(BaseSqsIntegrationTest.createAsyncClient()).build(); + } + + @Bean FilteringBatchListener batchListener() { return new FilteringBatchListener(); } + @Bean LatchContainer latch() { return new LatchContainer(); } + } + + static class AllowHelloOnlyFilter implements MessageFilter { + @Override public Collection> process(Collection> msgs) { + return msgs.stream() + .filter(m -> { + SampleRecord p = m.getPayload(); + log.info("Filtering message: {}", p); + return "Hello".equals(p.propertyOne()); + }) + .collect(Collectors.toList()); + } + } + + static class FilteringBatchListener { + @Autowired LatchContainer latch; + + @SqsListener(queueNames = FILTER_QUEUE_BATCH, id = "filter-batch", factory = BATCH_FACTORY) + void listen(List batch) { + log.info("Received batch size={}", batch.size()); + latch.batchSizes.add(batch.size()); + + for (SampleRecord r : batch) { + if ("Hello".equals(r.propertyOne())) { + latch.received.add(r.propertyTwo()); + latch.pass.countDown(); + } else { + latch.block.countDown(); + } + } + } + } + + // ==== Latch + static class LatchContainer { + final CountDownLatch pass = new CountDownLatch(2); + final CountDownLatch block = new CountDownLatch(1); + + final List received = new CopyOnWriteArrayList<>(); + final List batchSizes = new CopyOnWriteArrayList<>(); + } +} diff --git a/spring-cloud-aws-sqs/src/test/java/io/awspring/cloud/sqs/integration/SqsMessageFifoFilterIntegrationTests.java b/spring-cloud-aws-sqs/src/test/java/io/awspring/cloud/sqs/integration/SqsMessageFifoFilterIntegrationTests.java new file mode 100644 index 000000000..9ace70146 --- /dev/null +++ b/spring-cloud-aws-sqs/src/test/java/io/awspring/cloud/sqs/integration/SqsMessageFifoFilterIntegrationTests.java @@ -0,0 +1,137 @@ +package io.awspring.cloud.sqs.integration; + +import io.awspring.cloud.sqs.annotation.SqsListener; +import io.awspring.cloud.sqs.config.SqsBootstrapConfiguration; +import io.awspring.cloud.sqs.config.SqsMessageListenerContainerFactory; +import io.awspring.cloud.sqs.listener.SqsHeaders; +import io.awspring.cloud.sqs.operations.SqsTemplate; +import io.awspring.cloud.sqs.support.filter.MessageFilter; +import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.api.Test; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.boot.test.context.SpringBootTest; +import org.springframework.context.ApplicationContext; +import org.springframework.context.annotation.Bean; +import org.springframework.context.annotation.Configuration; +import org.springframework.context.annotation.Import; +import org.springframework.messaging.Message; +import org.springframework.messaging.support.MessageBuilder; +import software.amazon.awssdk.services.sqs.SqsAsyncClient; + +import java.util.Collection; +import java.util.List; +import java.util.UUID; +import java.util.concurrent.CopyOnWriteArrayList; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.TimeUnit; +import java.util.stream.Collectors; + +import static org.assertj.core.api.Assertions.assertThat; + +@SpringBootTest +public class SqsMessageFifoFilterIntegrationTests extends BaseSqsIntegrationTest { + + private static final Logger log = LoggerFactory.getLogger(SqsMessageFifoFilterIntegrationTests.class); + + private static final String FILTER_QUEUE_FIFO = "filter-queue-test.fifo"; + private static final String FIFO_FACTORY = "fifoFilteringFactory"; + + @Autowired SqsTemplate sqsTemplate; + @Autowired LatchContainer latch; + @Autowired ApplicationContext applicationContext; + + record SampleRecord(String propertyOne, String propertyTwo) {} + + @BeforeAll + static void setupQueues() { + SqsAsyncClient client = createAsyncClient(); + createFifoQueue(client, FILTER_QUEUE_FIFO).join(); + } + + @Test + void shouldPreserveOrderForAllowedInFifo() throws Exception { + var m1 = MessageBuilder.withPayload(new SampleRecord("Hello", "A1")) + .setHeader(SqsHeaders.MessageSystemAttributes.SQS_MESSAGE_GROUP_ID_HEADER, "g1") + .setHeader(SqsHeaders.MessageSystemAttributes.SQS_MESSAGE_DEDUPLICATION_ID_HEADER, UUID.randomUUID().toString()) + .build(); + + var m2 = MessageBuilder.withPayload(new SampleRecord("NotHello", "A2")) + .setHeader(SqsHeaders.MessageSystemAttributes.SQS_MESSAGE_GROUP_ID_HEADER, "g1") + .setHeader(SqsHeaders.MessageSystemAttributes.SQS_MESSAGE_DEDUPLICATION_ID_HEADER, UUID.randomUUID().toString()) + .build(); + + var m3 = MessageBuilder.withPayload(new SampleRecord("Hello", "A3")) + .setHeader(SqsHeaders.MessageSystemAttributes.SQS_MESSAGE_GROUP_ID_HEADER, "g1") + .setHeader(SqsHeaders.MessageSystemAttributes.SQS_MESSAGE_DEDUPLICATION_ID_HEADER, UUID.randomUUID().toString()) + .build(); + + sqsTemplate.send(FILTER_QUEUE_FIFO, m1); + sqsTemplate.send(FILTER_QUEUE_FIFO, m2); + sqsTemplate.send(FILTER_QUEUE_FIFO, m3); + + // 통과 메시지 2건(A1, A3) 수신 대기 + assertThat(latch.pass.await(30, TimeUnit.SECONDS)).isTrue(); + + var ordered = applicationContext.getBean(FilteringOrderedListener.class); + assertThat(ordered.receivedOrder).containsExactly("A1", "A3"); + } + + // ==== Config + @Import(SqsBootstrapConfiguration.class) + @Configuration + static class Config { + @Bean(name = FIFO_FACTORY) + SqsMessageListenerContainerFactory fifoFactory() { + return SqsMessageListenerContainerFactory.builder() + .configure(o -> { + o.messageFilter(new AllowHelloOnlyFilter()); + o.maxMessagesPerPoll(10); // 선택 + }) + .sqsAsyncClientSupplier(BaseSqsIntegrationTest::createAsyncClient) + .build(); + } + + @Bean SqsTemplate sqsTemplate() { + return SqsTemplate.builder() + .sqsAsyncClient(BaseSqsIntegrationTest.createAsyncClient()) + .build(); + } + + @Bean FilteringOrderedListener orderedListener() { return new FilteringOrderedListener(); } + @Bean LatchContainer latch() { return new LatchContainer(); } + } + + // ==== Filter + static class AllowHelloOnlyFilter implements MessageFilter { + @Override public Collection> process(Collection> msgs) { + return msgs.stream() + .filter(m -> { + SampleRecord p = m.getPayload(); + log.info("Filtering message: {}", p); + return "Hello".equals(p.propertyOne()); + }) + .collect(Collectors.toList()); + } + } + + // ==== Listener + static class FilteringOrderedListener { + @Autowired LatchContainer latch; + final List receivedOrder = new CopyOnWriteArrayList<>(); + + @SqsListener(queueNames = FILTER_QUEUE_FIFO, id = "filter-fifo", factory = FIFO_FACTORY) + void listen(SampleRecord r) { + log.info("Received(fifo): {}", r); + receivedOrder.add(r.propertyTwo()); + latch.pass.countDown(); + } + } + + // ==== Latch + static class LatchContainer { + final CountDownLatch pass = new CountDownLatch(2); // A1, A3 두 건 + final CountDownLatch block = new CountDownLatch(1); + } +} diff --git a/spring-cloud-aws-sqs/src/test/java/io/awspring/cloud/sqs/integration/SqsMessageFilterIntegrationTests.java b/spring-cloud-aws-sqs/src/test/java/io/awspring/cloud/sqs/integration/SqsMessageFilterIntegrationTests.java new file mode 100644 index 000000000..0d883e908 --- /dev/null +++ b/spring-cloud-aws-sqs/src/test/java/io/awspring/cloud/sqs/integration/SqsMessageFilterIntegrationTests.java @@ -0,0 +1,147 @@ +package io.awspring.cloud.sqs.integration; + +import io.awspring.cloud.sqs.annotation.SqsListener; +import io.awspring.cloud.sqs.config.SqsBootstrapConfiguration; +import io.awspring.cloud.sqs.config.SqsMessageListenerContainerFactory; +import io.awspring.cloud.sqs.operations.SqsTemplate; +import io.awspring.cloud.sqs.support.filter.MessageFilter; +import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.api.Test; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.boot.test.context.SpringBootTest; +import org.springframework.context.annotation.Bean; +import org.springframework.context.annotation.Configuration; +import org.springframework.context.annotation.Import; +import org.springframework.messaging.Message; +import software.amazon.awssdk.services.sqs.SqsAsyncClient; + +import java.util.Collection; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.TimeUnit; +import java.util.stream.Collectors; + +import static org.assertj.core.api.AssertionsForClassTypes.assertThat; + +@SpringBootTest +public class SqsMessageFilterIntegrationTests extends BaseSqsIntegrationTest { + + private static final Logger logger = LoggerFactory.getLogger(SqsMessageFilterIntegrationTests.class); + private static final String FILTER_QUEUE_PASS = "filter-queue-pass"; + private static final String FILTER_QUEUE_BLOCK = "filter-queue-block"; + + private static final String FILTERING_FACTORY = "filteringFactory"; + + @Autowired + LatchContainer latchContainer; + + @Autowired + SqsTemplate sqsTemplate; + + @BeforeAll + static void setupQueues() { + SqsAsyncClient client = createAsyncClient(); + CompletableFuture.allOf( + createQueue(client, FILTER_QUEUE_PASS), + createQueue(client, FILTER_QUEUE_BLOCK) + ).join(); + } + + record SampleRecord(String propertyOne, String propertyTwo) {} + + @Test + void shouldReceiveMessageThatPassesProcess() throws Exception { + sqsTemplate.send(FILTER_QUEUE_PASS, new SampleRecord("Hello", "Accepted")); + assertThat(latchContainer.latchForPass.await(10, TimeUnit.SECONDS)).isTrue(); + } + + @Test + void shouldNotReceiveMessageThatFailsProcess() throws Exception { + sqsTemplate.send(FILTER_QUEUE_BLOCK, new SampleRecord("NotHello", "Rejected")); + assertThat(latchContainer.latchForBlock.await(10, TimeUnit.SECONDS)).isFalse(); + } + + // Configuration + @Import(SqsBootstrapConfiguration.class) + @Configuration + static class FilterTestConfig { + + @Bean(name = FILTERING_FACTORY) + public SqsMessageListenerContainerFactory messageFilterFactory() { + return SqsMessageListenerContainerFactory.builder() + .configure(options -> options.messageFilter(new AllowHelloOnlyFilter())) + .sqsAsyncClientSupplier(BaseSqsIntegrationTest::createAsyncClient) + .build(); + } + + @Bean + public SqsTemplate sqsTemplate() { + return SqsTemplate.builder() + .sqsAsyncClient(BaseSqsIntegrationTest.createAsyncClient()) + .build(); + } + + @Bean + public FilteringListenerPass filteringListenerPass() { + return new FilteringListenerPass(); + } + + @Bean + public FilteringListenerBlock filteringListenerBlock() { + return new FilteringListenerBlock(); + } + + @Bean + public LatchContainer latchContainer() { + return new LatchContainer(); + } + } + + // Sample Filter + static class AllowHelloOnlyFilter implements MessageFilter { + @Override + public Collection> process(Collection> messages) { + return messages.stream() + .filter(msg -> { + SampleRecord p = msg.getPayload(); + logger.info("Filtering message: {}", p); + return "Hello".equals(p.propertyOne()); + }) + .collect(Collectors.toList()); + } + } + + // Listener for PASS case + static class FilteringListenerPass { + + @Autowired + LatchContainer latchContainer; + + @SqsListener(queueNames = FILTER_QUEUE_PASS, id = "filter-pass", factory = FILTERING_FACTORY) + void listen(SampleRecord record) { + logger.info("Received (pass): {}", record); + latchContainer.latchForPass.countDown(); + } + } + + // Listener for BLOCK case + static class FilteringListenerBlock { + + @Autowired + LatchContainer latchContainer; + + @SqsListener(queueNames = FILTER_QUEUE_BLOCK, id = "filter-block", factory = FILTERING_FACTORY) + void listen(SampleRecord record) { + logger.info("Received (block): {}", record); + latchContainer.latchForBlock.countDown(); + } + } + + // Shared latch + static class LatchContainer { + final CountDownLatch latchForPass = new CountDownLatch(1); + final CountDownLatch latchForBlock = new CountDownLatch(1); + } +}