Skip to content

Add FilteringAdapter in SQS #1388

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 14 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 10 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -16,26 +16,22 @@
package io.awspring.cloud.sqs.config;

import io.awspring.cloud.sqs.ConfigUtils;
import io.awspring.cloud.sqs.listener.AbstractMessageListenerContainer;
import io.awspring.cloud.sqs.listener.AsyncComponentAdapters;
import io.awspring.cloud.sqs.listener.AsyncMessageListener;
import io.awspring.cloud.sqs.listener.ContainerComponentFactory;
import io.awspring.cloud.sqs.listener.ContainerOptions;
import io.awspring.cloud.sqs.listener.ContainerOptionsBuilder;
import io.awspring.cloud.sqs.listener.MessageListener;
import io.awspring.cloud.sqs.listener.MessageListenerContainer;
import io.awspring.cloud.sqs.listener.*;
import io.awspring.cloud.sqs.listener.acknowledgement.AcknowledgementResultCallback;
import io.awspring.cloud.sqs.listener.acknowledgement.AsyncAcknowledgementResultCallback;
import io.awspring.cloud.sqs.listener.errorhandler.AsyncErrorHandler;
import io.awspring.cloud.sqs.listener.errorhandler.ErrorHandler;
import io.awspring.cloud.sqs.listener.interceptor.AsyncMessageInterceptor;
import io.awspring.cloud.sqs.listener.interceptor.MessageInterceptor;
import io.awspring.cloud.sqs.support.filter.DefaultMessageFilter;
import io.awspring.cloud.sqs.support.filter.MessageFilter;
import org.springframework.messaging.Message;
import org.springframework.util.Assert;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.function.Consumer;
import org.springframework.messaging.Message;
import org.springframework.util.Assert;

/**
* Base implementation for a {@link MessageListenerContainerFactory}. Contains the components and
Expand Down Expand Up @@ -180,6 +176,7 @@ public C createContainer(Endpoint endpoint) {
B options = this.containerOptionsBuilder.createCopy();
configure(endpoint, options);
C container = createContainerInstance(endpoint, options.build());

endpoint.setupContainer(container);
configureContainer(container, endpoint);
return container;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,29 +17,24 @@

import io.awspring.cloud.sqs.ConfigUtils;
import io.awspring.cloud.sqs.annotation.SqsListener;
import io.awspring.cloud.sqs.listener.AsyncMessageListener;
import io.awspring.cloud.sqs.listener.ContainerComponentFactory;
import io.awspring.cloud.sqs.listener.ContainerOptions;
import io.awspring.cloud.sqs.listener.MessageListener;
import io.awspring.cloud.sqs.listener.SqsContainerOptions;
import io.awspring.cloud.sqs.listener.SqsContainerOptionsBuilder;
import io.awspring.cloud.sqs.listener.SqsMessageListenerContainer;
import io.awspring.cloud.sqs.listener.*;
import io.awspring.cloud.sqs.listener.acknowledgement.AcknowledgementResultCallback;
import io.awspring.cloud.sqs.listener.acknowledgement.AsyncAcknowledgementResultCallback;
import io.awspring.cloud.sqs.listener.errorhandler.AsyncErrorHandler;
import io.awspring.cloud.sqs.listener.errorhandler.ErrorHandler;
import io.awspring.cloud.sqs.listener.interceptor.AsyncMessageInterceptor;
import io.awspring.cloud.sqs.listener.interceptor.MessageInterceptor;
import java.util.ArrayList;
import java.util.Collection;
import java.util.function.Consumer;
import java.util.function.Supplier;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.messaging.Message;
import org.springframework.util.Assert;
import software.amazon.awssdk.services.sqs.SqsAsyncClient;

import java.util.ArrayList;
import java.util.Collection;
import java.util.function.Consumer;
import java.util.function.Supplier;

/**
* {@link MessageListenerContainerFactory} implementation for creating {@link SqsMessageListenerContainer} instances. A
* factory can be assigned to a {@link io.awspring.cloud.sqs.annotation.SqsListener @SqsListener} by using the
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,9 @@
import io.awspring.cloud.sqs.support.converter.MessagingMessageConverter;
import io.awspring.cloud.sqs.support.converter.SqsMessagingMessageConverter;
import java.time.Duration;

import io.awspring.cloud.sqs.support.filter.DefaultMessageFilter;
import io.awspring.cloud.sqs.support.filter.MessageFilter;
import org.springframework.core.task.TaskExecutor;
import org.springframework.lang.Nullable;
import org.springframework.retry.backoff.BackOffPolicy;
Expand Down Expand Up @@ -59,6 +62,9 @@ public abstract class AbstractContainerOptions<O extends ContainerOptions<O, B>,

private final AcknowledgementMode acknowledgementMode;

private final MessageFilter<?> messageFilter;


@Nullable
private final AcknowledgementOrdering acknowledgementOrdering;

Expand Down Expand Up @@ -92,6 +98,8 @@ protected AbstractContainerOptions(Builder<?, ?> builder) {
this.acknowledgementThreshold = builder.acknowledgementThreshold;
this.componentsTaskExecutor = builder.componentsTaskExecutor;
this.acknowledgementResultTaskExecutor = builder.acknowledgementResultTaskExecutor;
this.messageFilter = builder.messageFilter;

Assert.isTrue(this.maxMessagesPerPoll <= this.maxConcurrentMessages, String.format(
"messagesPerPoll should be less than or equal to maxConcurrentMessages. Values provided: %s and %s respectively",
this.maxMessagesPerPoll, this.maxConcurrentMessages));
Expand Down Expand Up @@ -164,6 +172,9 @@ public MessagingMessageConverter<?> getMessageConverter() {
return this.messageConverter;
}

@Override
public MessageFilter<?> getMessageFilter() {return this.messageFilter; }

@Nullable
@Override
public Duration getAcknowledgementInterval() {
Expand Down Expand Up @@ -244,6 +255,8 @@ protected abstract static class Builder<B extends ContainerOptionsBuilder<B, O>,

private AcknowledgementMode acknowledgementMode = DEFAULT_ACKNOWLEDGEMENT_MODE;

private MessageFilter<?> messageFilter = new DefaultMessageFilter<>();

@Nullable
private AcknowledgementOrdering acknowledgementOrdering;

Expand Down Expand Up @@ -280,6 +293,7 @@ protected Builder(AbstractContainerOptions<?, ?> options) {
this.acknowledgementThreshold = options.acknowledgementThreshold;
this.componentsTaskExecutor = options.componentsTaskExecutor;
this.acknowledgementResultTaskExecutor = options.acknowledgementResultTaskExecutor;
this.messageFilter = options.messageFilter;
}

@Override
Expand Down Expand Up @@ -400,6 +414,14 @@ public B messageConverter(MessagingMessageConverter<?> messageConverter) {
return self();
}

@Override
public B messageFilter(MessageFilter<?> messageFilter) {
Assert.notNull(messageFilter, "messageFilter cannot be null");
this.messageFilter = messageFilter;
return self();
}


@SuppressWarnings("unchecked")
private B self() {
return (B) this;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,18 +21,21 @@
import io.awspring.cloud.sqs.listener.errorhandler.ErrorHandler;
import io.awspring.cloud.sqs.listener.interceptor.AsyncMessageInterceptor;
import io.awspring.cloud.sqs.listener.interceptor.MessageInterceptor;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.function.Consumer;
import io.awspring.cloud.sqs.support.filter.DefaultMessageFilter;
import io.awspring.cloud.sqs.support.filter.MessageFilter;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.context.SmartLifecycle;
import org.springframework.lang.Nullable;
import org.springframework.messaging.Message;
import org.springframework.util.Assert;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.function.Consumer;

/**
* Base implementation for {@link MessageListenerContainer} with {@link SmartLifecycle} and component management
* capabilities.
Expand Down Expand Up @@ -132,7 +135,11 @@ public void addMessageInterceptor(AsyncMessageInterceptor<T> messageInterceptor)
@Override
public void setMessageListener(MessageListener<T> messageListener) {
Assert.notNull(messageListener, "messageListener cannot be null");
this.messageListener = AsyncComponentAdapters.adapt(messageListener);
if(containerOptions.getMessageFilter() instanceof DefaultMessageFilter) {
this.messageListener = AsyncComponentAdapters.adapt(messageListener);
} else {
this.messageListener = AsyncComponentAdapters.adaptFilter(messageListener, (MessageFilter<T>) containerOptions.getMessageFilter());
}
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.CompletionException;
import java.util.function.Supplier;

import io.awspring.cloud.sqs.support.filter.MessageFilter;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.core.task.TaskExecutor;
Expand Down Expand Up @@ -76,6 +78,10 @@ public static <T> AsyncMessageListener<T> adapt(MessageListener<T> messageListen
return new BlockingMessageListenerAdapter<>(messageListener);
}

public static <T> AsyncMessageListener<T> adaptFilter(MessageListener<T> messageListener, MessageFilter<T> messageFilter) {
return new FilteredMessageListenerAdapter<>(messageListener, messageFilter);
}

public static <T> AsyncAcknowledgementResultCallback<T> adapt(
AcknowledgementResultCallback<T> acknowledgementResultCallback) {
return new BlockingAcknowledgementResultCallbackAdapter<>(acknowledgementResultCallback);
Expand Down Expand Up @@ -214,6 +220,44 @@ public CompletableFuture<Void> onMessage(Collection<Message<T>> messages) {
}
}

private static class FilteredMessageListenerAdapter<T> extends AbstractThreadingComponentAdapter
implements AsyncMessageListener<T> {

private final MessageListener<T> filteredMessageListener;
private final MessageFilter<T> filter;

public FilteredMessageListenerAdapter(MessageListener<T> filteredMessageListener, MessageFilter<T> filter) {
this.filteredMessageListener = filteredMessageListener;
this.filter = filter;
}

@Override
public CompletableFuture<Void> onMessage(Message<T> message) {
if (filter.process(message)) {
return execute(() -> this.filteredMessageListener.onMessage(message));
}
else {
logger.debug("Message filtered out: {}", message.getPayload());
return CompletableFuture.completedFuture(null);
}
}

@Override
public CompletableFuture<Void> onMessage(Collection<Message<T>> messages) {
Collection<Message<T>> filteredMessages = messages.stream()
.filter(filter::process)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Neat pick again but please switch to for. Since size will be always maximum 10 lets switch to for loop to take benefit of performance.

.toList();

if (filteredMessages.isEmpty()) {
logger.debug("All messages were filtered out.");
return CompletableFuture.completedFuture(null);
}

return execute(() -> this.filteredMessageListener.onMessage(filteredMessages));
}
}


private static class BlockingErrorHandlerAdapter<T> extends AbstractThreadingComponentAdapter
implements AsyncErrorHandler<T> {

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@
import io.awspring.cloud.sqs.support.converter.MessagingMessageConverter;
import java.time.Duration;
import java.util.Collection;

import io.awspring.cloud.sqs.support.filter.MessageFilter;
import org.springframework.core.task.TaskExecutor;
import org.springframework.lang.Nullable;
import org.springframework.retry.backoff.BackOffPolicy;
Expand Down Expand Up @@ -139,6 +141,11 @@ default BackOffPolicy getPollBackOffPolicy() {
*/
MessagingMessageConverter<?> getMessageConverter();

/** Return the message filter applied before message processing.
* @return the message filter.
* */
MessageFilter<?> getMessageFilter();

/**
* Return the maximum interval between acknowledgements for batch acknowledgements.
* @return the interval.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@
import io.awspring.cloud.sqs.listener.acknowledgement.handler.AcknowledgementMode;
import io.awspring.cloud.sqs.support.converter.MessagingMessageConverter;
import java.time.Duration;

import io.awspring.cloud.sqs.support.filter.MessageFilter;
import org.springframework.core.task.TaskExecutor;
import org.springframework.retry.backoff.BackOffPolicy;

Expand Down Expand Up @@ -187,6 +189,14 @@ default B pollBackOffPolicy(BackOffPolicy pollBackOffPolicy) {
*/
B messageConverter(MessagingMessageConverter<?> messageConverter);

/**
* Set the {@link MessagingMessageConverter} for this container.
*
* @param messageFilter the message filter.
* @return this instance.
*/
B messageFilter(MessageFilter<?> messageFilter);

/**
* Create the {@link ContainerOptions} instance.
*
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
package io.awspring.cloud.sqs.support.filter;

import org.springframework.messaging.Message;

public class DefaultMessageFilter<T> implements MessageFilter<T> {
@Override
public boolean process(Message<T> message) {
return true;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
package io.awspring.cloud.sqs.support.filter;

import org.springframework.messaging.Message;

public interface MessageFilter<T> {
boolean process(Message<T> message);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

On a second thought, I think the interface should receive and return a Collection of messages rather than a single one.

Something like:

public interface MessageFilter<T> {
    Collection<Message<T>> filter(Collection<Message<T>> messages);
}

The reason is that if the filter includes some I/O (e.g. querying the DB to check if the users exist, or making an http request), that'll be more performant as a batch query than in a for loop, and if the user is receiving e.g. 500 messages in a batch that can make a sizeable difference.

If the user wants to filter individual messages, it's simple enough for them to iterate on the collection themselves.

What do you folks think?

}
Loading