diff --git a/spring-cloud-aws-sqs/src/main/java/io/awspring/cloud/sqs/annotation/AbstractListenerAnnotationBeanPostProcessor.java b/spring-cloud-aws-sqs/src/main/java/io/awspring/cloud/sqs/annotation/AbstractListenerAnnotationBeanPostProcessor.java index dbab79ff8..e9841ae4d 100644 --- a/spring-cloud-aws-sqs/src/main/java/io/awspring/cloud/sqs/annotation/AbstractListenerAnnotationBeanPostProcessor.java +++ b/spring-cloud-aws-sqs/src/main/java/io/awspring/cloud/sqs/annotation/AbstractListenerAnnotationBeanPostProcessor.java @@ -35,6 +35,7 @@ import java.util.HashSet; import java.util.List; import java.util.Map; +import java.util.Objects; import java.util.concurrent.atomic.AtomicInteger; import java.util.stream.Collectors; import java.util.stream.Stream; @@ -125,6 +126,7 @@ protected void detectAnnotationsAndRegisterEndpoints(Object bean, Class targe } annotatedMethods.entrySet().stream() .map(entry -> createAndConfigureEndpoint(bean, entry.getKey(), entry.getValue())) + .filter(Objects::nonNull) .forEach(this.endpointRegistrar::registerEndpoint); } diff --git a/spring-cloud-aws-sqs/src/main/java/io/awspring/cloud/sqs/annotation/SqsListenerAnnotationBeanPostProcessor.java b/spring-cloud-aws-sqs/src/main/java/io/awspring/cloud/sqs/annotation/SqsListenerAnnotationBeanPostProcessor.java index 711dd289b..56a6c40ef 100644 --- a/spring-cloud-aws-sqs/src/main/java/io/awspring/cloud/sqs/annotation/SqsListenerAnnotationBeanPostProcessor.java +++ b/spring-cloud-aws-sqs/src/main/java/io/awspring/cloud/sqs/annotation/SqsListenerAnnotationBeanPostProcessor.java @@ -29,7 +29,10 @@ import java.util.ArrayList; import java.util.Arrays; import java.util.Collection; +import java.util.Collections; import java.util.List; +import java.util.Optional; + import org.springframework.messaging.converter.MessageConverter; import org.springframework.messaging.handler.invocation.HandlerMethodArgumentResolver; @@ -44,12 +47,23 @@ public class SqsListenerAnnotationBeanPostProcessor extends AbstractListenerAnno private static final String GENERATED_ID_PREFIX = "io.awspring.cloud.sqs.sqsListenerEndpointContainer#"; + private final List filters; + + public SqsListenerAnnotationBeanPostProcessor(Optional> filters) { + this.filters = filters.orElseGet(Collections::emptyList); + } + @Override protected Class getAnnotationClass() { return SqsListener.class; } + @Override protected Endpoint createEndpoint(SqsListener sqsListenerAnnotation) { + if (filters.stream().anyMatch(f -> !f.createEndpoint(sqsListenerAnnotation))) { + return null; + } + return SqsEndpoint.builder().queueNames(resolveEndpointNames(sqsListenerAnnotation.value())) .factoryBeanName(resolveAsString(sqsListenerAnnotation.factory(), "factory")) .id(getEndpointId(sqsListenerAnnotation.id())) diff --git a/spring-cloud-aws-sqs/src/main/java/io/awspring/cloud/sqs/annotation/SqsListenerFilter.java b/spring-cloud-aws-sqs/src/main/java/io/awspring/cloud/sqs/annotation/SqsListenerFilter.java new file mode 100644 index 000000000..e4831aee3 --- /dev/null +++ b/spring-cloud-aws-sqs/src/main/java/io/awspring/cloud/sqs/annotation/SqsListenerFilter.java @@ -0,0 +1,10 @@ +package io.awspring.cloud.sqs.annotation; + +/** + * Predicate interface to filter {@link SqsListener} annotations during bean post-processing. + */ +@FunctionalInterface +public interface SqsListenerFilter { + + boolean createEndpoint(SqsListener annotation); +} diff --git a/spring-cloud-aws-sqs/src/main/java/io/awspring/cloud/sqs/config/EndpointRegistrar.java b/spring-cloud-aws-sqs/src/main/java/io/awspring/cloud/sqs/config/EndpointRegistrar.java index 48e90df52..4e814058d 100644 --- a/spring-cloud-aws-sqs/src/main/java/io/awspring/cloud/sqs/config/EndpointRegistrar.java +++ b/spring-cloud-aws-sqs/src/main/java/io/awspring/cloud/sqs/config/EndpointRegistrar.java @@ -206,7 +206,9 @@ public void setBeanFactory(BeanFactory beanFactory) throws BeansException { * @param endpoint the endpoint. */ public void registerEndpoint(Endpoint endpoint) { - this.endpoints.add(endpoint); + if (endpoint != null) { + this.endpoints.add(endpoint); + } } @Override diff --git a/spring-cloud-aws-sqs/src/main/java/io/awspring/cloud/sqs/config/SqsBootstrapConfiguration.java b/spring-cloud-aws-sqs/src/main/java/io/awspring/cloud/sqs/config/SqsBootstrapConfiguration.java index b6adb689d..d3383eb9e 100644 --- a/spring-cloud-aws-sqs/src/main/java/io/awspring/cloud/sqs/config/SqsBootstrapConfiguration.java +++ b/spring-cloud-aws-sqs/src/main/java/io/awspring/cloud/sqs/config/SqsBootstrapConfiguration.java @@ -35,7 +35,7 @@ public class SqsBootstrapConfiguration implements ImportBeanDefinitionRegistrar public void registerBeanDefinitions(AnnotationMetadata importingClassMetadata, BeanDefinitionRegistry registry) { if (!registry.containsBeanDefinition(SqsBeanNames.SQS_LISTENER_ANNOTATION_BEAN_POST_PROCESSOR_BEAN_NAME)) { registry.registerBeanDefinition(SqsBeanNames.SQS_LISTENER_ANNOTATION_BEAN_POST_PROCESSOR_BEAN_NAME, - new RootBeanDefinition(SqsListenerAnnotationBeanPostProcessor.class)); + new RootBeanDefinition(SqsListenerAnnotationBeanPostProcessor.class, RootBeanDefinition.AUTOWIRE_BY_TYPE, true)); } if (!registry.containsBeanDefinition(SqsBeanNames.ENDPOINT_REGISTRY_BEAN_NAME)) { diff --git a/spring-cloud-aws-sqs/src/test/java/io/awspring/cloud/sqs/annotation/SqsListenerAnnotationBeanPostProcessorTests.java b/spring-cloud-aws-sqs/src/test/java/io/awspring/cloud/sqs/annotation/SqsListenerAnnotationBeanPostProcessorTests.java index 13f40a4a5..7e092f216 100644 --- a/spring-cloud-aws-sqs/src/test/java/io/awspring/cloud/sqs/annotation/SqsListenerAnnotationBeanPostProcessorTests.java +++ b/spring-cloud-aws-sqs/src/test/java/io/awspring/cloud/sqs/annotation/SqsListenerAnnotationBeanPostProcessorTests.java @@ -19,9 +19,12 @@ import static org.assertj.core.api.Assertions.assertThatThrownBy; import static org.assertj.core.api.InstanceOfAssertFactories.list; import static org.assertj.core.api.InstanceOfAssertFactories.type; +import static org.mockito.ArgumentMatchers.isNotNull; import static org.mockito.BDDMockito.given; import static org.mockito.BDDMockito.then; import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.verifyNoInteractions; import static org.mockito.Mockito.when; import com.fasterxml.jackson.databind.ObjectMapper; @@ -37,6 +40,8 @@ import java.util.ArrayList; import java.util.Collections; import java.util.List; +import java.util.Optional; + import org.junit.jupiter.api.Test; import org.mockito.ArgumentCaptor; import org.springframework.beans.factory.ListableBeanFactory; @@ -100,7 +105,7 @@ public void registerEndpoint(Endpoint endpoint) { super.registerEndpoint(endpoint); } }; - SqsListenerAnnotationBeanPostProcessor processor = new SqsListenerAnnotationBeanPostProcessor() { + SqsListenerAnnotationBeanPostProcessor processor = new SqsListenerAnnotationBeanPostProcessor(Optional.empty()) { @Override protected EndpointRegistrar createEndpointRegistrar() { return registrar; @@ -150,7 +155,7 @@ void shouldChangeListenerRegistryBeanName() { EndpointRegistrar registrar = new EndpointRegistrar(); - SqsListenerAnnotationBeanPostProcessor processor = new SqsListenerAnnotationBeanPostProcessor() { + SqsListenerAnnotationBeanPostProcessor processor = new SqsListenerAnnotationBeanPostProcessor(Optional.empty()) { @Override protected EndpointRegistrar createEndpointRegistrar() { return registrar; @@ -183,7 +188,7 @@ void shouldThrowIfFactoryBeanNotFound() { when(beanFactory.containsBean(EndpointRegistrar.DEFAULT_LISTENER_CONTAINER_FACTORY_BEAN_NAME)) .thenReturn(false); - SqsListenerAnnotationBeanPostProcessor processor = new SqsListenerAnnotationBeanPostProcessor(); + SqsListenerAnnotationBeanPostProcessor processor = new SqsListenerAnnotationBeanPostProcessor(Optional.empty()); Listener bean = new Listener(); StringValueResolver valueResolver = mock(StringValueResolver.class); @@ -206,7 +211,7 @@ void shouldResolveListOfQueuesFromSPEL() { SqsQueueNameReader sqsQueueNameReader = new SqsQueueNameReader(); beanFactory.registerSingleton("sqsQueueNameReader", sqsQueueNameReader); - SqsListenerAnnotationBeanPostProcessor processor = new SqsListenerAnnotationBeanPostProcessor(); + SqsListenerAnnotationBeanPostProcessor processor = new SqsListenerAnnotationBeanPostProcessor(Optional.empty()); ManyQueuesListener bean = new ManyQueuesListener(sqsQueueNameReader); processor.setBeanFactory(beanFactory); @@ -222,6 +227,45 @@ void shouldResolveListOfQueuesFromSPEL() { } + @Test + void shouldApplyFiltersThatPrevent() { + EndpointRegistrar registrar = mock(EndpointRegistrar.class); + List filters = List.of( + annotation -> true, // Passes all annotations + annotation -> false // Denies all annotations + ); + + SqsListenerAnnotationBeanPostProcessor processor = new SqsListenerAnnotationBeanPostProcessor(Optional.of(filters)) { + @Override + protected EndpointRegistrar createEndpointRegistrar() { + return registrar; + } + }; + + Listener bean = new Listener(); + processor.postProcessAfterInitialization(bean, "listener"); + + verifyNoInteractions(registrar); + } + + @Test + void shouldApplyFiltersThatAllow() { + EndpointRegistrar registrar = mock(EndpointRegistrar.class); + List filters = List.of(annotation -> true); // Passes all annotations + + SqsListenerAnnotationBeanPostProcessor processor = new SqsListenerAnnotationBeanPostProcessor(Optional.of(filters)) { + @Override + protected EndpointRegistrar createEndpointRegistrar() { + return registrar; + } + }; + + Listener bean = new Listener(); + processor.postProcessAfterInitialization(bean, "listener"); + + verify(registrar).registerEndpoint(isNotNull()); + } + static class Listener { @SqsListener("myQueue")