Skip to content

Add FilteringAdapter in SQS #1388

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 14 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -16,26 +16,20 @@
package io.awspring.cloud.sqs.config;

import io.awspring.cloud.sqs.ConfigUtils;
import io.awspring.cloud.sqs.listener.AbstractMessageListenerContainer;
import io.awspring.cloud.sqs.listener.AsyncComponentAdapters;
import io.awspring.cloud.sqs.listener.AsyncMessageListener;
import io.awspring.cloud.sqs.listener.ContainerComponentFactory;
import io.awspring.cloud.sqs.listener.ContainerOptions;
import io.awspring.cloud.sqs.listener.ContainerOptionsBuilder;
import io.awspring.cloud.sqs.listener.MessageListener;
import io.awspring.cloud.sqs.listener.MessageListenerContainer;
import io.awspring.cloud.sqs.listener.*;
import io.awspring.cloud.sqs.listener.acknowledgement.AcknowledgementResultCallback;
import io.awspring.cloud.sqs.listener.acknowledgement.AsyncAcknowledgementResultCallback;
import io.awspring.cloud.sqs.listener.errorhandler.AsyncErrorHandler;
import io.awspring.cloud.sqs.listener.errorhandler.ErrorHandler;
import io.awspring.cloud.sqs.listener.interceptor.AsyncMessageInterceptor;
import io.awspring.cloud.sqs.listener.interceptor.MessageInterceptor;
import org.springframework.messaging.Message;
import org.springframework.util.Assert;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.function.Consumer;
import org.springframework.messaging.Message;
import org.springframework.util.Assert;

/**
* Base implementation for a {@link MessageListenerContainerFactory}. Contains the components and
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,29 +17,24 @@

import io.awspring.cloud.sqs.ConfigUtils;
import io.awspring.cloud.sqs.annotation.SqsListener;
import io.awspring.cloud.sqs.listener.AsyncMessageListener;
import io.awspring.cloud.sqs.listener.ContainerComponentFactory;
import io.awspring.cloud.sqs.listener.ContainerOptions;
import io.awspring.cloud.sqs.listener.MessageListener;
import io.awspring.cloud.sqs.listener.SqsContainerOptions;
import io.awspring.cloud.sqs.listener.SqsContainerOptionsBuilder;
import io.awspring.cloud.sqs.listener.SqsMessageListenerContainer;
import io.awspring.cloud.sqs.listener.*;
import io.awspring.cloud.sqs.listener.acknowledgement.AcknowledgementResultCallback;
import io.awspring.cloud.sqs.listener.acknowledgement.AsyncAcknowledgementResultCallback;
import io.awspring.cloud.sqs.listener.errorhandler.AsyncErrorHandler;
import io.awspring.cloud.sqs.listener.errorhandler.ErrorHandler;
import io.awspring.cloud.sqs.listener.interceptor.AsyncMessageInterceptor;
import io.awspring.cloud.sqs.listener.interceptor.MessageInterceptor;
import java.util.ArrayList;
import java.util.Collection;
import java.util.function.Consumer;
import java.util.function.Supplier;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.messaging.Message;
import org.springframework.util.Assert;
import software.amazon.awssdk.services.sqs.SqsAsyncClient;

import java.util.ArrayList;
import java.util.Collection;
import java.util.function.Consumer;
import java.util.function.Supplier;

/**
* {@link MessageListenerContainerFactory} implementation for creating {@link SqsMessageListenerContainer} instances. A
* factory can be assigned to a {@link io.awspring.cloud.sqs.annotation.SqsListener @SqsListener} by using the
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,9 @@
import io.awspring.cloud.sqs.support.converter.MessagingMessageConverter;
import io.awspring.cloud.sqs.support.converter.SqsMessagingMessageConverter;
import java.time.Duration;

import io.awspring.cloud.sqs.support.filter.DefaultMessageFilter;
import io.awspring.cloud.sqs.support.filter.MessageFilter;
import org.springframework.core.task.TaskExecutor;
import org.springframework.lang.Nullable;
import org.springframework.retry.backoff.BackOffPolicy;
Expand Down Expand Up @@ -59,6 +62,8 @@ public abstract class AbstractContainerOptions<O extends ContainerOptions<O, B>,

private final AcknowledgementMode acknowledgementMode;

private final MessageFilter<?> messageFilter;

@Nullable
private final AcknowledgementOrdering acknowledgementOrdering;

Expand Down Expand Up @@ -92,6 +97,8 @@ protected AbstractContainerOptions(Builder<?, ?> builder) {
this.acknowledgementThreshold = builder.acknowledgementThreshold;
this.componentsTaskExecutor = builder.componentsTaskExecutor;
this.acknowledgementResultTaskExecutor = builder.acknowledgementResultTaskExecutor;
this.messageFilter = builder.messageFilter;

Assert.isTrue(this.maxMessagesPerPoll <= this.maxConcurrentMessages, String.format(
"messagesPerPoll should be less than or equal to maxConcurrentMessages. Values provided: %s and %s respectively",
this.maxMessagesPerPoll, this.maxConcurrentMessages));
Expand Down Expand Up @@ -164,6 +171,11 @@ public MessagingMessageConverter<?> getMessageConverter() {
return this.messageConverter;
}

@Override
public MessageFilter<?> getMessageFilter() {
return this.messageFilter;
}

@Nullable
@Override
public Duration getAcknowledgementInterval() {
Expand Down Expand Up @@ -244,6 +256,8 @@ protected abstract static class Builder<B extends ContainerOptionsBuilder<B, O>,

private AcknowledgementMode acknowledgementMode = DEFAULT_ACKNOWLEDGEMENT_MODE;

private MessageFilter<?> messageFilter = new DefaultMessageFilter<>();

@Nullable
private AcknowledgementOrdering acknowledgementOrdering;

Expand Down Expand Up @@ -280,6 +294,7 @@ protected Builder(AbstractContainerOptions<?, ?> options) {
this.acknowledgementThreshold = options.acknowledgementThreshold;
this.componentsTaskExecutor = options.componentsTaskExecutor;
this.acknowledgementResultTaskExecutor = options.acknowledgementResultTaskExecutor;
this.messageFilter = options.messageFilter;
}

@Override
Expand Down Expand Up @@ -400,6 +415,13 @@ public B messageConverter(MessagingMessageConverter<?> messageConverter) {
return self();
}

@Override
public B messageFilter(MessageFilter<?> messageFilter) {
Assert.notNull(messageFilter, "messageFilter cannot be null");
this.messageFilter = messageFilter;
return self();
}

@SuppressWarnings("unchecked")
private B self() {
return (B) this;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,18 +21,19 @@
import io.awspring.cloud.sqs.listener.errorhandler.ErrorHandler;
import io.awspring.cloud.sqs.listener.interceptor.AsyncMessageInterceptor;
import io.awspring.cloud.sqs.listener.interceptor.MessageInterceptor;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.function.Consumer;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.context.SmartLifecycle;
import org.springframework.lang.Nullable;
import org.springframework.messaging.Message;
import org.springframework.util.Assert;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.function.Consumer;

/**
* Base implementation for {@link MessageListenerContainer} with {@link SmartLifecycle} and component management
* capabilities.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
import io.awspring.cloud.sqs.listener.pipeline.MessageProcessingConfiguration;
import io.awspring.cloud.sqs.listener.pipeline.MessageProcessingPipeline;
import io.awspring.cloud.sqs.listener.pipeline.MessageProcessingPipelineBuilder;
import io.awspring.cloud.sqs.listener.sink.AbstractMessageProcessingPipelineSink;
import io.awspring.cloud.sqs.listener.sink.MessageProcessingPipelineSink;
import io.awspring.cloud.sqs.listener.sink.MessageSink;
import io.awspring.cloud.sqs.listener.source.AcknowledgementProcessingMessageSource;
Expand All @@ -42,6 +43,8 @@
import java.util.concurrent.ThreadFactory;
import java.util.stream.Collectors;
import java.util.stream.IntStream;

import io.awspring.cloud.sqs.support.filter.MessageFilter;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.core.task.SimpleAsyncTaskExecutor;
Expand Down Expand Up @@ -174,7 +177,10 @@ protected void configureMessageSink(MessageProcessingPipeline<T> messageProcessi
.acceptIfInstance(this.messageSink, TaskExecutorAware.class,
teac -> teac.setTaskExecutor(getComponentsTaskExecutor()))
.acceptIfInstance(this.messageSink, MessageProcessingPipelineSink.class,
mls -> mls.setMessagePipeline(messageProcessingPipeline));
mls -> mls.setMessagePipeline(messageProcessingPipeline))
.acceptIfInstance(this.messageSink, AbstractMessageProcessingPipelineSink.class,
s -> s.setMessageFilter(getContainerOptions().getMessageFilter()));

doConfigureMessageSink(this.messageSink);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,16 +23,17 @@
import io.awspring.cloud.sqs.listener.errorhandler.ErrorHandler;
import io.awspring.cloud.sqs.listener.interceptor.AsyncMessageInterceptor;
import io.awspring.cloud.sqs.listener.interceptor.MessageInterceptor;
import java.util.Collection;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.CompletionException;
import java.util.function.Supplier;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.core.task.TaskExecutor;
import org.springframework.messaging.Message;
import org.springframework.util.Assert;

import java.util.Collection;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.CompletionException;
import java.util.function.Supplier;

/**
* Utility class for adapting blocking components to their asynchronous variants.
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@
import io.awspring.cloud.sqs.support.converter.MessagingMessageConverter;
import java.time.Duration;
import java.util.Collection;

import io.awspring.cloud.sqs.support.filter.MessageFilter;
import org.springframework.core.task.TaskExecutor;
import org.springframework.lang.Nullable;
import org.springframework.retry.backoff.BackOffPolicy;
Expand Down Expand Up @@ -139,6 +141,11 @@ default BackOffPolicy getPollBackOffPolicy() {
*/
MessagingMessageConverter<?> getMessageConverter();

/** Return the message filter applied before message processing.
* @return the message filter.
*/
MessageFilter<?> getMessageFilter();

/**
* Return the maximum interval between acknowledgements for batch acknowledgements.
* @return the interval.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@
import io.awspring.cloud.sqs.listener.acknowledgement.handler.AcknowledgementMode;
import io.awspring.cloud.sqs.support.converter.MessagingMessageConverter;
import java.time.Duration;

import io.awspring.cloud.sqs.support.filter.MessageFilter;
import org.springframework.core.task.TaskExecutor;
import org.springframework.retry.backoff.BackOffPolicy;

Expand Down Expand Up @@ -187,6 +189,14 @@ default B pollBackOffPolicy(BackOffPolicy pollBackOffPolicy) {
*/
B messageConverter(MessagingMessageConverter<?> messageConverter);

/**
* Set the {@link MessagingMessageConverter} for this container.
*
* @param messageFilter the message filter.
* @return this instance.
*/
B messageFilter(MessageFilter<?> messageFilter);

/**
* Create the {@link ContainerOptions} instance.
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,7 @@
import io.awspring.cloud.sqs.listener.acknowledgement.AcknowledgementProcessor;
import io.awspring.cloud.sqs.listener.acknowledgement.BatchingAcknowledgementProcessor;
import io.awspring.cloud.sqs.listener.acknowledgement.ImmediateAcknowledgementProcessor;
import io.awspring.cloud.sqs.listener.sink.BatchMessageSink;
import io.awspring.cloud.sqs.listener.sink.MessageSink;
import io.awspring.cloud.sqs.listener.sink.OrderedMessageSink;
import io.awspring.cloud.sqs.listener.sink.*;
import io.awspring.cloud.sqs.listener.sink.adapter.MessageGroupingSinkAdapter;
import io.awspring.cloud.sqs.listener.sink.adapter.MessageVisibilityExtendingSinkAdapter;
import io.awspring.cloud.sqs.listener.source.FifoSqsMessageSource;
Expand Down Expand Up @@ -69,7 +67,7 @@ public MessageSource<T> createMessageSource(SqsContainerOptions options) {

@Override
public MessageSink<T> createMessageSink(SqsContainerOptions options) {
MessageSink<T> deliverySink = createDeliverySink(options.getListenerMode());
MessageSink<T> deliverySink = createDeliverySink(options.getListenerMode(), options);
MessageSink<T> wrappedDeliverySink = maybeWrapWithVisibilityAdapter(deliverySink,
options.getMessageVisibility());
return maybeWrapWithMessageGroupingAdapter(options, wrappedDeliverySink);
Expand All @@ -84,7 +82,7 @@ private MessageSink<T> maybeWrapWithMessageGroupingAdapter(SqsContainerOptions o
}

// @formatter:off
private MessageSink<T> createDeliverySink(ListenerMode listenerMode) {
private MessageSink<T> createDeliverySink(ListenerMode listenerMode, SqsContainerOptions options) {
return ListenerMode.SINGLE_MESSAGE.equals(listenerMode)
? new OrderedMessageSink<>()
: new BatchMessageSink<>();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,7 @@
import io.awspring.cloud.sqs.listener.acknowledgement.AcknowledgementProcessor;
import io.awspring.cloud.sqs.listener.acknowledgement.BatchingAcknowledgementProcessor;
import io.awspring.cloud.sqs.listener.acknowledgement.ImmediateAcknowledgementProcessor;
import io.awspring.cloud.sqs.listener.sink.BatchMessageSink;
import io.awspring.cloud.sqs.listener.sink.FanOutMessageSink;
import io.awspring.cloud.sqs.listener.sink.MessageSink;
import io.awspring.cloud.sqs.listener.sink.*;
import io.awspring.cloud.sqs.listener.source.MessageSource;
import io.awspring.cloud.sqs.listener.source.StandardSqsMessageSource;
import java.time.Duration;
Expand Down Expand Up @@ -58,11 +56,10 @@ public MessageSource<T> createMessageSource(SqsContainerOptions options) {
@Override
public MessageSink<T> createMessageSink(SqsContainerOptions options) {
return ListenerMode.SINGLE_MESSAGE.equals(options.getListenerMode())
? new FanOutMessageSink<>()
: new BatchMessageSink<>();
? new FanOutMessageSink<>() : new BatchMessageSink<>();
}
// @formatter:on

// @formatter:on
@Override
public AcknowledgementProcessor<T> createAcknowledgementProcessor(SqsContainerOptions options) {
validateAcknowledgementOrdering(options);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.Executor;
import java.util.function.Supplier;

import io.awspring.cloud.sqs.support.filter.MessageFilter;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.context.SmartLifecycle;
Expand Down Expand Up @@ -56,6 +58,8 @@ public abstract class AbstractMessageProcessingPipelineSink<T>

private MessageProcessingPipeline<T> messageProcessingPipeline;

protected MessageFilter<T> messageFilter;

private String id;

@Override
Expand All @@ -70,6 +74,11 @@ public void setTaskExecutor(TaskExecutor taskExecutor) {
this.taskExecutor = taskExecutor;
}

public void setMessageFilter(MessageFilter<T> messageFilter) {
Assert.notNull(messageFilter, "messageFilter must not be null.");
this.messageFilter = messageFilter;
}

@Override
public CompletableFuture<Void> emit(Collection<Message<T>> messages, MessageProcessingContext<T> context) {
Assert.notNull(messages, "messages cannot be null");
Expand All @@ -87,6 +96,10 @@ public CompletableFuture<Void> emit(Collection<Message<T>> messages, MessageProc
protected abstract CompletableFuture<Void> doEmit(Collection<Message<T>> messages,
MessageProcessingContext<T> context);

protected CompletableFuture<Collection<Message<T>>> filterAsync(Collection<Message<T>> messages) {
return CompletableFuture.supplyAsync(() -> this.messageFilter.process(messages), this.taskExecutor);
}

/**
* Send the provided {@link Message} to the {@link TaskExecutor} as a unit of work.
* @param message the message to be executed.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,8 @@ public class BatchMessageSink<T> extends AbstractMessageProcessingPipelineSink<T

@Override
protected CompletableFuture<Void> doEmit(Collection<Message<T>> messages, MessageProcessingContext<T> context) {
return execute(messages, context).exceptionally(t -> logError(t, messages));
return filterAsync(messages).thenCompose(
filtered -> execute(filtered, context).exceptionally(t -> logError(t, messages)));
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -38,9 +38,10 @@ public class FanOutMessageSink<T> extends AbstractMessageProcessingPipelineSink<
@Override
protected CompletableFuture<Void> doEmit(Collection<Message<T>> messages, MessageProcessingContext<T> context) {
logger.trace("Emitting messages {}", MessageHeaderUtils.getId(messages));
return CompletableFuture.allOf(messages.stream().map(msg -> execute(msg, context)
return filterAsync(messages)
.thenCompose(filtered -> CompletableFuture.allOf(filtered.stream().map(msg -> execute(msg, context)
// Should log errors individually - no need to propagate upstream
.exceptionally(t -> logError(t, msg))).toArray(CompletableFuture[]::new));
.exceptionally(t -> logError(t, msg))).toArray(CompletableFuture[]::new)));
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -40,15 +40,17 @@ public class OrderedMessageSink<T> extends AbstractMessageProcessingPipelineSink
@Override
protected CompletableFuture<Void> doEmit(Collection<Message<T>> messages, MessageProcessingContext<T> context) {
logger.trace("Emitting messages {}", MessageHeaderUtils.getId(messages));
CompletableFuture<Void> execution = messages.stream().reduce(CompletableFuture.completedFuture(null),
CompletableFuture<Void> execution = filterAsync(messages)
.thenCompose(filtered ->
filtered.stream().reduce(CompletableFuture.completedFuture(null),
(resultFuture, msg) -> CompletableFutures.handleCompose(resultFuture, (v, t) -> {
if (t == null) {
return execute(msg, context).whenComplete(logIfError(msg));
}
// Release backpressure from subsequent interrupted executions in case of errors.
context.runBackPressureReleaseCallback();
return CompletableFutures.failedFuture(t);
}), (a, b) -> a);
}), (a, b) -> a));
return execution.exceptionally(t -> null);
}

Expand Down
Loading