diff --git a/docs/src/main/asciidoc/sqs.adoc b/docs/src/main/asciidoc/sqs.adoc index 048c02a65..923b84627 100644 --- a/docs/src/main/asciidoc/sqs.adoc +++ b/docs/src/main/asciidoc/sqs.adoc @@ -762,7 +762,6 @@ NOTE: The same factory can be used to create both `single message` and `batch` c IMPORTANT: In case the same factory is shared by both delivery methods, any supplied `ErrorHandler`, `MessageInterceptor` or `MessageListener` should implement the proper methods. - ==== Container Options Each `MessageListenerContainer` can have a different set of options. @@ -1757,6 +1756,7 @@ If after the 5 seconds for `maxDelayBetweenPolls` 6 messages have been processed If the queue is depleted and a poll returns no messages, it'll enter `low throughput` mode again and perform only one poll at a time. ==== Configuring BackPressureMode +The default `BackPressureHandler` can be configured to optimize the polling behavior based on the application's throughput requirements. The following `BackPressureMode` values can be set in `SqsContainerOptions` to configure polling behavior: * `AUTO` - The default mode, as described in the previous section. @@ -1767,6 +1767,70 @@ Useful for really high throughput scenarios where the risk of making parallel po NOTE: The `AUTO` setting should be balanced for most use cases, including high throughput ones. +==== Advanced Backpressure management + +Even though the default `BackPressureHandler` should be enough for most use cases, there are scenarios where more fine-grained control over message consumption is required not to overwhelm downstream systems or exceed resource limits. +In such a case, it is necessary to replace the default `BackPressureHandler` with a custom one that implements the `BackPressureHandler` interface. +A `backPressureHandlerFactory` can be set in `SqsContainerOptions` to configure which `BackPressureHandler` to use. + +===== What is a BackPressureHandler? + +A `BackPressureHandler` is an interface that determines whether the container should apply backpressure (i.e., slow down or pause polling) based on the current state of the system. +It is invoked before each poll to SQS and can prevent polling or poll for fewer messages if certain conditions are met, e.g., too many inflight messages, custom resource constraints, etc. + +===== Creating a custom BackPressureHandler + +To implement a custom backpressure logic, the `BackPressureHandler` interface must be implemented. + +A `SqsMessageListenerContainer` can be configured to use the desired `BackPressureHandler` by setting the `backPressureHandlerFactory` on the `ContainerOptions`. + +```java +SqsMessageListenerContainer container = SqsMessageListenerContainer.builder() + .configure(options -> options + .backPressureHandlerFactory(containerOptions -> new CustomBackPressureHandler()) + // ... other options + ) + // ... other container settings ... + .build(); +``` + +===== Combining Multiple BackPressureHandlers + +If necessary, multiple `BackPressureHandler` can be combined by using the `CompositeBackPressureHandler`. +Each of the `BackPressureHandler` (which we'll call delegates) are chained in the order they are provided. +The first delegate will be requested the initial amount of permits and will return the number of permits it accepts to grant. +The second delegate will get that potentially reduced number of permits as a request and might in turn reduce it further. +The process continues until all delegates have been called or one of them returns 0, which will prevent the polling of messages from SQS. + +For example, to implement the `BackPressureMode.ALWAYS_POLL_MAX_MESSAGES` strategy, we can combine a concurrency limiter, an adaptative throughput handler, and a "full batch only" handler. +The resulting `CompositeBackPressureHandler` looks like this: + +```java +Duration maxIdleWaitTime = Duration.ofMillis(50L); +List backPressureHandlers = List.of( + BackPressureHandlerFactories.concurrencyLimiterBackPressureHandler(options), + BackPressureHandlerFactories.throughputBackPressureHandler(options), + BackPressureHandlerFactories.fullBatchBackPressureHandler(options) +); +CompositeBackPressureHandler backPressureHandler = BackPressureHandlerFactories.compositeBackPressureHandler( + options, maxIdleWaitTime, backPressureHandlers); +``` + +===== Built-in BackPressureHandlers + +Spring Cloud AWS provides several built-in `BackPressureHandler` implementations: + +- `ConcurrencyLimiterBackPressureHandler`: Limits the number of messages being processed concurrently. +- `ThroughputBackPressureHandler`: Switches between high and low throughput modes. In high throughput mode, multiple polls can be done in parallel. +In low throughput mode, only one poll is done at a time. +- `FullBatchBackPressureHandler`: Ensure polls will always be done with a full batch of messages, meaning that the number of messages polled will always be equal to `maxMessagesPerPoll` if possible or `0` if not possible. +This `FullBatchBackPressureHandler` must always be the last in the chain for it to work properly. + +The `BackPressureHandlerFactories` class provides factory methods to create these handlers easily. +These handlers can be used directly or combined with custom ones using the `CompositeBackPressureHandler` to fit the application's needs. + +Additionally, the `BackPressureHandlerFactories#adaptativeThroughputBackPressureHandler` factory method combines the `ConcurrencyLimiterBackPressureHandler`, `ThroughputBackPressureHandler`, and `FullBatchBackPressureHandler` as per the desired `BackPressureMode`. + === Blocking and Non-Blocking (Async) Components The SQS integration leverages the `CompletableFuture`-based async capabilities of `AWS SDK 2.0` to deliver a fully non-blocking infrastructure. diff --git a/spring-cloud-aws-sqs/src/main/java/io/awspring/cloud/sqs/listener/AbstractContainerOptions.java b/spring-cloud-aws-sqs/src/main/java/io/awspring/cloud/sqs/listener/AbstractContainerOptions.java index 2a120792f..8c9316eb7 100644 --- a/spring-cloud-aws-sqs/src/main/java/io/awspring/cloud/sqs/listener/AbstractContainerOptions.java +++ b/spring-cloud-aws-sqs/src/main/java/io/awspring/cloud/sqs/listener/AbstractContainerOptions.java @@ -32,6 +32,7 @@ * Base implementation for {@link ContainerOptions}. * * @author Tomaz Fernandes + * @author Loïc Rouchon * @since 3.0 */ public abstract class AbstractContainerOptions, B extends ContainerOptionsBuilder> @@ -55,6 +56,8 @@ public abstract class AbstractContainerOptions, private final BackPressureMode backPressureMode; + private final BackPressureHandlerFactory backPressureHandlerFactory; + private final ListenerMode listenerMode; private final MessagingMessageConverter messageConverter; @@ -90,6 +93,7 @@ protected AbstractContainerOptions(Builder builder) { this.listenerShutdownTimeout = builder.listenerShutdownTimeout; this.acknowledgementShutdownTimeout = builder.acknowledgementShutdownTimeout; this.backPressureMode = builder.backPressureMode; + this.backPressureHandlerFactory = builder.backPressureHandlerFactory; this.listenerMode = builder.listenerMode; this.messageConverter = builder.messageConverter; this.acknowledgementMode = builder.acknowledgementMode; @@ -162,6 +166,11 @@ public BackPressureMode getBackPressureMode() { return this.backPressureMode; } + @Override + public BackPressureHandlerFactory getBackPressureHandlerFactory() { + return this.backPressureHandlerFactory; + } + @Override public ListenerMode getListenerMode() { return this.listenerMode; @@ -232,6 +241,8 @@ protected abstract static class Builder, private static final BackPressureMode DEFAULT_THROUGHPUT_CONFIGURATION = BackPressureMode.AUTO; + private static final BackPressureHandlerFactory DEFAULT_BACKPRESSURE_FACTORY = BackPressureHandlerFactories::semaphoreBackPressureHandler; + private static final ListenerMode DEFAULT_MESSAGE_DELIVERY_STRATEGY = ListenerMode.SINGLE_MESSAGE; private static final MessagingMessageConverter DEFAULT_MESSAGE_CONVERTER = new SqsMessagingMessageConverter(); @@ -254,6 +265,8 @@ protected abstract static class Builder, private BackPressureMode backPressureMode = DEFAULT_THROUGHPUT_CONFIGURATION; + private BackPressureHandlerFactory backPressureHandlerFactory = DEFAULT_BACKPRESSURE_FACTORY; + private Duration listenerShutdownTimeout = DEFAULT_LISTENER_SHUTDOWN_TIMEOUT; private Duration acknowledgementShutdownTimeout = DEFAULT_ACKNOWLEDGEMENT_SHUTDOWN_TIMEOUT; @@ -296,6 +309,7 @@ protected Builder(AbstractContainerOptions options) { this.listenerShutdownTimeout = options.listenerShutdownTimeout; this.acknowledgementShutdownTimeout = options.acknowledgementShutdownTimeout; this.backPressureMode = options.backPressureMode; + this.backPressureHandlerFactory = options.backPressureHandlerFactory; this.listenerMode = options.listenerMode; this.messageConverter = options.messageConverter; this.acknowledgementMode = options.acknowledgementMode; @@ -390,6 +404,12 @@ public B backPressureMode(BackPressureMode backPressureMode) { return self(); } + @Override + public B backPressureHandlerFactory(BackPressureHandlerFactory backPressureHandlerFactory) { + this.backPressureHandlerFactory = backPressureHandlerFactory; + return self(); + } + @Override public B acknowledgementInterval(Duration acknowledgementInterval) { Assert.notNull(acknowledgementInterval, "acknowledgementInterval cannot be null"); diff --git a/spring-cloud-aws-sqs/src/main/java/io/awspring/cloud/sqs/listener/AbstractPipelineMessageListenerContainer.java b/spring-cloud-aws-sqs/src/main/java/io/awspring/cloud/sqs/listener/AbstractPipelineMessageListenerContainer.java index c5b0c19e8..66645f02b 100644 --- a/spring-cloud-aws-sqs/src/main/java/io/awspring/cloud/sqs/listener/AbstractPipelineMessageListenerContainer.java +++ b/spring-cloud-aws-sqs/src/main/java/io/awspring/cloud/sqs/listener/AbstractPipelineMessageListenerContainer.java @@ -230,10 +230,9 @@ private TaskExecutor validateCustomExecutor(TaskExecutor taskExecutor) { } protected BackPressureHandler createBackPressureHandler() { - return SemaphoreBackPressureHandler.builder().batchSize(getContainerOptions().getMaxMessagesPerPoll()) - .totalPermits(getContainerOptions().getMaxConcurrentMessages()) - .acquireTimeout(getContainerOptions().getMaxDelayBetweenPolls()) - .throughputConfiguration(getContainerOptions().getBackPressureMode()).build(); + O containerOptions = getContainerOptions(); + BackPressureHandlerFactory factory = containerOptions.getBackPressureHandlerFactory(); + return factory.createBackPressureHandler(containerOptions); } protected TaskExecutor createSourcesTaskExecutor() { diff --git a/spring-cloud-aws-sqs/src/main/java/io/awspring/cloud/sqs/listener/BackPressureHandler.java b/spring-cloud-aws-sqs/src/main/java/io/awspring/cloud/sqs/listener/BackPressureHandler.java index 1d76d6589..13214dac1 100644 --- a/spring-cloud-aws-sqs/src/main/java/io/awspring/cloud/sqs/listener/BackPressureHandler.java +++ b/spring-cloud-aws-sqs/src/main/java/io/awspring/cloud/sqs/listener/BackPressureHandler.java @@ -24,12 +24,13 @@ * semaphore-based, rate limiter-based, a mix of both, or any other. * * @author Tomaz Fernandes + * @author Loïc Rouchon * @since 3.0 */ public interface BackPressureHandler { /** - * Request a number of permits. Each obtained permit allows the + * Requests a number of permits. Each obtained permit allows the * {@link io.awspring.cloud.sqs.listener.source.MessageSource} to retrieve one message. * @param amount the amount of permits to request. * @return the amount of permits obtained. @@ -37,12 +38,41 @@ public interface BackPressureHandler { */ int request(int amount) throws InterruptedException; + /** + * Releases the specified amount of permits for processed messages. Each message that has been processed should + * release one permit, whether processing was successful or not. + *

+ * This method can be called in the following use cases: + *

    + *
  • {@link ReleaseReason#LIMITED}: all/some permits were not used because another BackPressureHandler has a lower + * permits limit and the difference in permits needs to be returned.
  • + *
  • {@link ReleaseReason#NONE_FETCHED}: none of the permits were actually used because no messages were retrieved + * from SQS. Permits need to be returned.
  • + *
  • {@link ReleaseReason#PARTIAL_FETCH}: some of the permits were used (some messages were retrieved from SQS). + * The unused ones need to be returned. The amount to be returned might be {@literal 0}, in which case it means all + * the permits will be used as the same number of messages were fetched from SQS.
  • + *
  • {@link ReleaseReason#PROCESSED}: a message processing finished, successfully or not.
  • + *
+ * @param amount the amount of permits to release. + * @param reason the reason why the permits were released. + */ + default void release(int amount, ReleaseReason reason) { + release(amount); + } + /** * Release the specified amount of permits. Each message that has been processed should release one permit, whether * processing was successful or not. * @param amount the amount of permits to release. + * + * @deprecated This method is deprecated and will not be called by the Spring Cloud AWS SQS listener anymore. + * Implement {@link #release(int, ReleaseReason)} instead. */ - void release(int amount); + @Deprecated + default void release(int amount) { + // Do not implement this method. It is not called anymore outside of backward compatibility use cases. + // Implement `#release(int amount, ReleaseReason reason)` instead. + } /** * Attempts to acquire all permits up to the specified timeout. If successful, means all permits were returned and @@ -52,4 +82,24 @@ public interface BackPressureHandler { */ boolean drain(Duration timeout); + enum ReleaseReason { + /** + * All/Some permits were not used because another BackPressureHandler has a lower permits limit and the permits + * difference need to be aligned across all handlers. + */ + LIMITED, + /** + * No messages were retrieved from SQS, so all permits need to be returned. + */ + NONE_FETCHED, + /** + * Some messages were fetched from SQS. Unused permits if any need to be returned. + */ + PARTIAL_FETCH, + /** + * The processing of one or more messages finished, successfully or not. + */ + PROCESSED; + } + } diff --git a/spring-cloud-aws-sqs/src/main/java/io/awspring/cloud/sqs/listener/BackPressureHandlerFactories.java b/spring-cloud-aws-sqs/src/main/java/io/awspring/cloud/sqs/listener/BackPressureHandlerFactories.java new file mode 100644 index 000000000..edb90105e --- /dev/null +++ b/spring-cloud-aws-sqs/src/main/java/io/awspring/cloud/sqs/listener/BackPressureHandlerFactories.java @@ -0,0 +1,175 @@ +/* + * Copyright 2013-2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.awspring.cloud.sqs.listener; + +import java.time.Duration; +import java.util.ArrayList; +import java.util.List; + +/** + * Spring Cloud AWS provides the following {@link BackPressureHandler} implementations: + *
    + *
  • {@link ConcurrencyLimiterBlockingBackPressureHandler}: Limits the maximum number of messages that can be * + * processed concurrently by the application.
  • * + *
  • {@link ThroughputBackPressureHandler}: Adapts the throughput dynamically between high and low modes in order to * + * reduce SQS pull costs when few messages are coming in.
  • * + *
  • {@link CompositeBackPressureHandler}: Allows combining multiple {@link BackPressureHandler} together and ensures + * * they cooperate.
  • * + *
+ *

+ * Below are a few examples of how common use cases can be achieved. Keep in mind you can always create your own * + * {@link BackPressureHandler} implementation and if needed combine it with the provided ones thanks to the * + * {@link CompositeBackPressureHandler}. * * + *

A BackPressureHandler limiting the max concurrency with high throughput

* * + * + *
{@code
+ * containerOptionsBuilder.backPressureHandlerFactory(containerOptions -> {
+ * 		return ConcurrencyLimiterBlockingBackPressureHandler.builder()
+ * 			.batchSize(containerOptions.getMaxMessagesPerPoll())
+ * 			.totalPermits(containerOptions.getMaxConcurrentMessages())
+ * 			.acquireTimeout(containerOptions.getMaxDelayBetweenPolls())
+ * 			.throughputConfiguration(BackPressureMode.FIXED_HIGH_THROUGHPUT)
+ * 			.build()
+ * }}
+ *

+ * * * + *

A BackPressureHandler limiting the max concurrency with dynamic throughput

* * + * + *
{@code
+ * containerOptionsBuilder.backPressureHandlerFactory(containerOptions -> {
+ * 		int batchSize = containerOptions.getMaxMessagesPerPoll();
+ * 		var concurrencyLimiterBlockingBackPressureHandler = ConcurrencyLimiterBlockingBackPressureHandler.builder()
+ * 			.batchSize(batchSize)
+ * 			.totalPermits(containerOptions.getMaxConcurrentMessages())
+ * 			.acquireTimeout(containerOptions.getMaxDelayBetweenPolls())
+ * 			.throughputConfiguration(BackPressureMode.AUTO)
+ * 			.build()
+ * 		var throughputBackPressureHandler = ThroughputBackPressureHandler.builder()
+ * 			.batchSize(batchSize)
+ * 			.build();
+ * 		return new CompositeBackPressureHandler(List.of(
+ * 				concurrencyLimiterBlockingBackPressureHandler,
+ * 				throughputBackPressureHandler
+ * 			),
+ * 			batchSize,
+ * 			standbyLimitPollingInterval
+ * 		);
+ * }}
+ * + * @author Loïc Rouchon + */ +public class BackPressureHandlerFactories { + + private BackPressureHandlerFactories() { + } + + /** + * Creates a new {@link SemaphoreBackPressureHandler} instance based on the provided {@link ContainerOptions}. + * + * @param options the container options. + * @return the created SemaphoreBackPressureHandler. + */ + public static BatchAwareBackPressureHandler semaphoreBackPressureHandler(ContainerOptions options) { + return SemaphoreBackPressureHandler.builder().batchSize(options.getMaxMessagesPerPoll()) + .totalPermits(options.getMaxConcurrentMessages()).acquireTimeout(options.getMaxDelayBetweenPolls()) + .throughputConfiguration(options.getBackPressureMode()).build(); + } + + /** + * Creates a new {@link BackPressureHandler} instance based on the provided {@link ContainerOptions} combining a + * {@link ConcurrencyLimiterBlockingBackPressureHandler}, a {@link ThroughputBackPressureHandler} and a + * {@link FullBatchBackPressureHandler}. The exact combination of depends on the given {@link ContainerOptions}. + * + * @param options the container options. + * @param maxIdleWaitTime the maximum amount of time to wait for a permit to be released in case no permits were + * obtained. + * @return the created SemaphoreBackPressureHandler. + */ + public static BatchAwareBackPressureHandler adaptativeThroughputBackPressureHandler(ContainerOptions options, + Duration maxIdleWaitTime) { + BackPressureMode backPressureMode = options.getBackPressureMode(); + + var concurrencyLimiterBlockingBackPressureHandler = concurrencyLimiterBackPressureHandler(options); + if (backPressureMode == BackPressureMode.FIXED_HIGH_THROUGHPUT) { + return concurrencyLimiterBlockingBackPressureHandler; + } + var backPressureHandlers = new ArrayList(); + backPressureHandlers.add(concurrencyLimiterBlockingBackPressureHandler); + + // The ThroughputBackPressureHandler should run second in the chain as it is non-blocking. + // Running it first would result in more polls as it would potentially limit the + // ConcurrencyLimiterBlockingBackPressureHandler to a lower amount of requested permits + // which means the ConcurrencyLimiterBlockingBackPressureHandler blocking behavior would + // not be optimally leveraged. + if (backPressureMode == BackPressureMode.AUTO + || backPressureMode == BackPressureMode.ALWAYS_POLL_MAX_MESSAGES) { + backPressureHandlers.add(throughputBackPressureHandler(options)); + } + + // The FullBatchBackPressureHandler should run last in the chain to ensure that a full batch is requested or not + if (backPressureMode == BackPressureMode.ALWAYS_POLL_MAX_MESSAGES) { + backPressureHandlers.add(fullBatchBackPressureHandler(options)); + } + return compositeBackPressureHandler(options, maxIdleWaitTime, backPressureHandlers); + } + + /** + * Creates a new {@link ConcurrencyLimiterBlockingBackPressureHandler} instance based on the provided + * {@link ContainerOptions}. + * + * @param options the container options. + * @return the created ConcurrencyLimiterBlockingBackPressureHandler. + */ + public static CompositeBackPressureHandler compositeBackPressureHandler(ContainerOptions options, + Duration maxIdleWaitTime, List backPressureHandlers) { + return CompositeBackPressureHandler.builder().batchSize(options.getMaxMessagesPerPoll()) + .noPermitsReturnedWaitTimeout(maxIdleWaitTime).backPressureHandlers(backPressureHandlers).build(); + } + + /** + * Creates a new {@link ConcurrencyLimiterBlockingBackPressureHandler} instance based on the provided + * {@link ContainerOptions}. + * + * @param options the container options. + * @return the created ConcurrencyLimiterBlockingBackPressureHandler. + */ + public static ConcurrencyLimiterBlockingBackPressureHandler concurrencyLimiterBackPressureHandler( + ContainerOptions options) { + return ConcurrencyLimiterBlockingBackPressureHandler.builder().batchSize(options.getMaxMessagesPerPoll()) + .totalPermits(options.getMaxConcurrentMessages()).acquireTimeout(options.getMaxDelayBetweenPolls()) + .build(); + } + + /** + * Creates a new {@link ThroughputBackPressureHandler} instance based on the provided {@link ContainerOptions}. + * + * @param options the container options. + * @return the created ThroughputBackPressureHandler. + */ + public static ThroughputBackPressureHandler throughputBackPressureHandler(ContainerOptions options) { + return ThroughputBackPressureHandler.builder().batchSize(options.getMaxMessagesPerPoll()).build(); + } + + /** + * Creates a new {@link FullBatchBackPressureHandler} instance based on the provided {@link ContainerOptions}. + * + * @param options the container options. + * @return the created FullBatchBackPressureHandler. + */ + public static FullBatchBackPressureHandler fullBatchBackPressureHandler(ContainerOptions options) { + return FullBatchBackPressureHandler.builder().batchSize(options.getMaxMessagesPerPoll()).build(); + } +} diff --git a/spring-cloud-aws-sqs/src/main/java/io/awspring/cloud/sqs/listener/BackPressureHandlerFactory.java b/spring-cloud-aws-sqs/src/main/java/io/awspring/cloud/sqs/listener/BackPressureHandlerFactory.java new file mode 100644 index 000000000..4c7e455aa --- /dev/null +++ b/spring-cloud-aws-sqs/src/main/java/io/awspring/cloud/sqs/listener/BackPressureHandlerFactory.java @@ -0,0 +1,42 @@ +/* + * Copyright 2013-2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.awspring.cloud.sqs.listener; + +/** + * Factory interface for creating {@link BackPressureHandler} instances to manage queue consumption backpressure. + *

+ * Implementations of this interface are responsible for producing a new {@link BackPressureHandler} for each container, + * configured according to the provided {@link ContainerOptions}. This ensures that internal resources (such as counters + * or semaphores) are not shared across containers, which could lead to unintended side effects. + *

+ * Default factory implementations can be found in the {@link BackPressureHandlerFactories} class. + * + * @author Loïc Rouchon + */ +public interface BackPressureHandlerFactory { + + /** + * Creates a new {@link BackPressureHandler} instance based on the provided {@link ContainerOptions}. + *

+ * NOTE: it is important for the factory to always return a new instance as otherwise it might + * result in a BackPressureHandler internal resources (counters, semaphores, ...) to be shared by multiple + * containers which is very likely not the desired behavior. + * + * @param containerOptions the container options to use for creating the BackPressureHandler. + * @return the created BackPressureHandler + */ + BackPressureHandler createBackPressureHandler(ContainerOptions containerOptions); +} diff --git a/spring-cloud-aws-sqs/src/main/java/io/awspring/cloud/sqs/listener/BatchAwareBackPressureHandler.java b/spring-cloud-aws-sqs/src/main/java/io/awspring/cloud/sqs/listener/BatchAwareBackPressureHandler.java index 51e12e0a0..685dbdcd6 100644 --- a/spring-cloud-aws-sqs/src/main/java/io/awspring/cloud/sqs/listener/BatchAwareBackPressureHandler.java +++ b/spring-cloud-aws-sqs/src/main/java/io/awspring/cloud/sqs/listener/BatchAwareBackPressureHandler.java @@ -20,6 +20,7 @@ * configured by the implementations. * * @author Tomaz Fernandes + * @author Loïc Rouchon * @since 3.0 */ public interface BatchAwareBackPressureHandler extends BackPressureHandler { @@ -35,13 +36,35 @@ public interface BatchAwareBackPressureHandler extends BackPressureHandler { * Release a batch of permits. This has the semantics of letting the {@link BackPressureHandler} know that all * permits from a batch are being released, in opposition to {@link #release(int)} in which any number of permits * can be specified. + * + * @deprecated This method is deprecated and will not be called by the Spring Cloud AWS SQS listener anymore. + * Implement {@link BackPressureHandler#release(int, ReleaseReason)} instead. */ - void releaseBatch(); + @Deprecated + default void releaseBatch() { + // Do not implement this method. It is not called anymore outside of backward compatibility use cases. + // Implement `#release(int amount, ReleaseReason reason)` instead. + } + + @Override + default void release(int amount, ReleaseReason reason) { + if (amount == getBatchSize() && reason == ReleaseReason.NONE_FETCHED) { + releaseBatch(); + } + else { + release(amount); + } + } /** * Return the configured batch size for this handler. * @return the batch size. + * + * @deprecated This method is deprecated and will not be used by the Spring Cloud AWS SQS listener anymore. */ - int getBatchSize(); + @Deprecated + default int getBatchSize() { + return 0; + } } diff --git a/spring-cloud-aws-sqs/src/main/java/io/awspring/cloud/sqs/listener/BlockingBackPressureHandler.java b/spring-cloud-aws-sqs/src/main/java/io/awspring/cloud/sqs/listener/BlockingBackPressureHandler.java new file mode 100644 index 000000000..d54c8b2ce --- /dev/null +++ b/spring-cloud-aws-sqs/src/main/java/io/awspring/cloud/sqs/listener/BlockingBackPressureHandler.java @@ -0,0 +1,26 @@ +/* + * Copyright 2013-2022 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.awspring.cloud.sqs.listener; + +/** + * Marker interface for a blocking {@link BackPressureHandler}. This handler is used to control the flow of messages in + * a blocking manner. + * + * @author Loïc Rouchon + */ +public interface BlockingBackPressureHandler extends BackPressureHandler { + +} diff --git a/spring-cloud-aws-sqs/src/main/java/io/awspring/cloud/sqs/listener/CompositeBackPressureHandler.java b/spring-cloud-aws-sqs/src/main/java/io/awspring/cloud/sqs/listener/CompositeBackPressureHandler.java new file mode 100644 index 000000000..ed94bbf50 --- /dev/null +++ b/spring-cloud-aws-sqs/src/main/java/io/awspring/cloud/sqs/listener/CompositeBackPressureHandler.java @@ -0,0 +1,207 @@ +/* + * Copyright 2013-2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.awspring.cloud.sqs.listener; + +import java.time.Duration; +import java.time.Instant; +import java.util.List; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.locks.Condition; +import java.util.concurrent.locks.ReentrantLock; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import org.springframework.util.Assert; + +/** + * Composite {@link BackPressureHandler} implementation that delegates the back-pressure handling to a list of + * {@link BackPressureHandler}s. + *

+ * This class is used to combine multiple back-pressure handlers into a single one. It allows for more complex + * back-pressure handling strategies by combining different implementations. + *

+ * The order in which the back-pressure handlers are registered in the {@link CompositeBackPressureHandler} is important + * as it will affect the blocking and limiting behaviour of the back-pressure handling. + *

+ * When {@link #request(int amount)} is called, the first back-pressure handler in the list is called with + * {@code amount} as the requested amount of permits. The returned amount of permits (which is less than or equal to the + * initial amount) is then passed to the next back-pressure handler in the list. This process of reducing the amount to + * request for the next handlers in the chain is called "limiting". This process continues until all back-pressure + * handlers have been called or {@literal 0} permits has been returned. + *

+ * Once the final amount of available permits have been computed, unused acquired permits on back-pressure handlers (due + * to later limiting happening in the chain) are released. + *

+ * If no permits were obtained, the {@link #request(int)} method will wait up to {@code noPermitsReturnedWaitTimeout} + * for a release of permits before returning. + * + * @author Loïc Rouchon + */ +public class CompositeBackPressureHandler implements BatchAwareBackPressureHandler, IdentifiableContainerComponent { + + private static final Logger logger = LoggerFactory.getLogger(CompositeBackPressureHandler.class); + + private String id; + + private final int batchSize; + + private final Duration noPermitsReturnedWaitTimeout; + + private final List backPressureHandlers; + + private final ReentrantLock noPermitsReturnedWaitLock = new ReentrantLock(); + + private final Condition permitsReleasedCondition = noPermitsReturnedWaitLock.newCondition(); + + private CompositeBackPressureHandler(Builder builder) { + this.batchSize = builder.batchSize; + this.noPermitsReturnedWaitTimeout = builder.noPermitsReturnedWaitTimeout; + this.backPressureHandlers = List.copyOf(builder.backPressureHandlers); + + } + + @Override + public void setId(String id) { + this.id = id; + backPressureHandlers.stream().filter(IdentifiableContainerComponent.class::isInstance) + .map(IdentifiableContainerComponent.class::cast) + .forEach(bph -> bph.setId(bph.getClass().getSimpleName() + "-" + id)); + } + + @Override + public String getId() { + return id; + } + + @Override + public int requestBatch() throws InterruptedException { + return request(batchSize); + } + + @Override + public int request(int amount) throws InterruptedException { + logger.debug("[{}] Requesting {} permits", this.id, amount); + int obtained = amount; + int[] obtainedPerBph = new int[backPressureHandlers.size()]; + for (int i = 0; i < backPressureHandlers.size() && obtained > 0; i++) { + obtainedPerBph[i] = backPressureHandlers.get(i).request(obtained); + obtained = Math.min(obtained, obtainedPerBph[i]); + } + for (int i = 0; i < backPressureHandlers.size(); i++) { + int obtainedForBph = obtainedPerBph[i]; + if (obtainedForBph > obtained) { + backPressureHandlers.get(i).release(obtainedForBph - obtained, ReleaseReason.LIMITED); + } + } + if (obtained == 0) { + waitForPermitsToBeReleased(); + } + logger.debug("[{}] Obtained {} permits ({} requested)", this.id, obtained, amount); + return obtained; + } + + @Override + public void release(int amount, ReleaseReason reason) { + logger.debug("[{}] Releasing {} permits ({})", this.id, amount, reason); + for (BackPressureHandler handler : backPressureHandlers) { + handler.release(amount, reason); + } + if (amount > 0) { + signalPermitsWereReleased(); + } + } + + /** + * Waits for permits to be released up to {@link #noPermitsReturnedWaitTimeout}. If no permits were released within + * the configured {@link #noPermitsReturnedWaitTimeout}, returns immediately. This allows {@link #request(int)} to + * return {@code 0} permits and will trigger another round of back-pressure handling. + * + * @throws InterruptedException if the Thread is interrupted while waiting for permits. + */ + @SuppressWarnings({ "java:S899" // we are not interested in the await return value here + }) + private void waitForPermitsToBeReleased() throws InterruptedException { + noPermitsReturnedWaitLock.lock(); + try { + logger.trace("[{}] No permits were obtained, waiting for a release up to {}", this.id, + noPermitsReturnedWaitTimeout); + permitsReleasedCondition.await(noPermitsReturnedWaitTimeout.toMillis(), TimeUnit.MILLISECONDS); + } + finally { + noPermitsReturnedWaitLock.unlock(); + } + } + + private void signalPermitsWereReleased() { + noPermitsReturnedWaitLock.lock(); + try { + permitsReleasedCondition.signal(); + } + finally { + noPermitsReturnedWaitLock.unlock(); + } + } + + @Override + public boolean drain(Duration timeout) { + logger.debug("[{}] Draining back-pressure handlers initiated", this.id); + boolean result = true; + Instant start = Instant.now(); + for (BackPressureHandler handler : backPressureHandlers) { + Duration remainingTimeout = maxDuration(timeout.minus(Duration.between(start, Instant.now())), + Duration.ZERO); + result &= handler.drain(remainingTimeout); + } + logger.debug("[{}] Draining back-pressure handlers completed", this.id); + return result; + } + + private static Duration maxDuration(Duration first, Duration second) { + return first.compareTo(second) > 0 ? first : second; + } + + public static Builder builder() { + return new Builder(); + } + + public static class Builder { + + private int batchSize; + private Duration noPermitsReturnedWaitTimeout; + private List backPressureHandlers; + + public Builder backPressureHandlers(List backPressureHandlers) { + this.backPressureHandlers = backPressureHandlers; + return this; + } + + public Builder batchSize(int batchSize) { + this.batchSize = batchSize; + return this; + } + + public Builder noPermitsReturnedWaitTimeout(Duration noPermitsReturnedWaitTimeout) { + this.noPermitsReturnedWaitTimeout = noPermitsReturnedWaitTimeout; + return this; + } + + public CompositeBackPressureHandler build() { + Assert.notNull(this.batchSize, "Missing configuration for batch size"); + Assert.notNull(this.noPermitsReturnedWaitTimeout, "Missing configuration for noPermitsReturnedWaitTimeout"); + Assert.noNullElements(this.backPressureHandlers, "backPressureHandlers must not be null"); + return new CompositeBackPressureHandler(this); + } + } +} diff --git a/spring-cloud-aws-sqs/src/main/java/io/awspring/cloud/sqs/listener/ConcurrencyLimiterBlockingBackPressureHandler.java b/spring-cloud-aws-sqs/src/main/java/io/awspring/cloud/sqs/listener/ConcurrencyLimiterBlockingBackPressureHandler.java new file mode 100644 index 000000000..38086bacd --- /dev/null +++ b/spring-cloud-aws-sqs/src/main/java/io/awspring/cloud/sqs/listener/ConcurrencyLimiterBlockingBackPressureHandler.java @@ -0,0 +1,154 @@ +/* + * Copyright 2013-2022 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.awspring.cloud.sqs.listener; + +import io.awspring.cloud.sqs.listener.source.PollingMessageSource; +import java.time.Duration; +import java.util.Arrays; +import java.util.concurrent.Semaphore; +import java.util.concurrent.TimeUnit; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import org.springframework.util.Assert; + +/** + * Blocking {@link BackPressureHandler} implementation that uses a {@link Semaphore} for handling the number of + * concurrent messages being processed. + * + * @see PollingMessageSource + * + * @author Loïc Rouchon + */ +public class ConcurrencyLimiterBlockingBackPressureHandler + implements BlockingBackPressureHandler, BatchAwareBackPressureHandler, IdentifiableContainerComponent { + + private static final Logger logger = LoggerFactory.getLogger(ConcurrencyLimiterBlockingBackPressureHandler.class); + + private final Semaphore semaphore; + + private final int batchSize; + + private final int totalPermits; + + private final Duration acquireTimeout; + + private String id = getClass().getSimpleName(); + + private ConcurrencyLimiterBlockingBackPressureHandler(Builder builder) { + this.batchSize = builder.batchSize; + this.totalPermits = builder.totalPermits; + this.acquireTimeout = builder.acquireTimeout; + logger.debug( + "ConcurrencyLimiterBlockingBackPressureHandler created with configuration " + + "totalPermits: {}, batchSize: {}, acquireTimeout: {}", + this.totalPermits, this.batchSize, this.acquireTimeout); + this.semaphore = new Semaphore(totalPermits); + } + + public static Builder builder() { + return new Builder(); + } + + @Override + public void setId(String id) { + this.id = id; + } + + @Override + public String getId() { + return this.id; + } + + @Override + public int requestBatch() throws InterruptedException { + return request(this.batchSize); + } + + @Override + public int request(int amount) throws InterruptedException { + int acquiredPermits = tryAcquire(amount, this.acquireTimeout); + if (acquiredPermits > 0) { + return acquiredPermits; + } + int availablePermits = Math.min(this.semaphore.availablePermits(), amount); + if (availablePermits > 0) { + return tryAcquire(availablePermits, this.acquireTimeout); + } + return 0; + } + + private int tryAcquire(int amount, Duration duration) throws InterruptedException { + if (this.semaphore.tryAcquire(amount, duration.toMillis(), TimeUnit.MILLISECONDS)) { + logger.debug("[{}] Acquired {} permits ({} / {} available)", this.id, amount, + this.semaphore.availablePermits(), this.totalPermits); + return amount; + } + return 0; + } + + @Override + public void release(int amount, ReleaseReason reason) { + this.semaphore.release(amount); + logger.debug("[{}] Released {} permits ({}) ({} / {} available)", this.id, amount, reason, + this.semaphore.availablePermits(), this.totalPermits); + } + + @Override + public boolean drain(Duration timeout) { + logger.debug("[{}] Waiting for up to {} for approx. {} permits to be released", this.id, timeout, + this.totalPermits - this.semaphore.availablePermits()); + try { + return tryAcquire(this.totalPermits, timeout) > 0; + } + catch (InterruptedException e) { + Thread.currentThread().interrupt(); + logger.debug("[{}] Draining interrupted", this.id); + return false; + } + } + + public static class Builder { + + private int batchSize; + + private int totalPermits; + + private Duration acquireTimeout; + + public Builder batchSize(int batchSize) { + this.batchSize = batchSize; + return this; + } + + public Builder totalPermits(int totalPermits) { + this.totalPermits = totalPermits; + return this; + } + + public Builder acquireTimeout(Duration acquireTimeout) { + this.acquireTimeout = acquireTimeout; + return this; + } + + public ConcurrencyLimiterBlockingBackPressureHandler build() { + Assert.noNullElements(Arrays.asList(this.batchSize, this.totalPermits, this.acquireTimeout), + "Missing configuration"); + Assert.isTrue(this.batchSize > 0, "The batch size must be greater than 0"); + Assert.isTrue(this.totalPermits >= this.batchSize, "Total permits must be greater than the batch size"); + return new ConcurrencyLimiterBlockingBackPressureHandler(this); + } + } +} diff --git a/spring-cloud-aws-sqs/src/main/java/io/awspring/cloud/sqs/listener/ContainerOptions.java b/spring-cloud-aws-sqs/src/main/java/io/awspring/cloud/sqs/listener/ContainerOptions.java index 838312ad2..e15461071 100644 --- a/spring-cloud-aws-sqs/src/main/java/io/awspring/cloud/sqs/listener/ContainerOptions.java +++ b/spring-cloud-aws-sqs/src/main/java/io/awspring/cloud/sqs/listener/ContainerOptions.java @@ -34,6 +34,7 @@ * original instance and the copy. * * @author Tomaz Fernandes + * @author Loïc Rouchon * @since 3.0 */ public interface ContainerOptions, B extends ContainerOptionsBuilder> { @@ -61,7 +62,7 @@ public interface ContainerOptions, B extends Co boolean isAutoStartup(); /** - * Set the maximum time the polling thread should wait for a full batch of permits to be available before trying to + * Sets the maximum time the polling thread should wait for a full batch of permits to be available before trying to * acquire a partial batch if so configured. A poll is only actually executed if at least one permit is available. * Default is 10 seconds. * @@ -129,6 +130,12 @@ default BackOffPolicy getPollBackOffPolicy() { */ BackPressureMode getBackPressureMode(); + /** + * Return the {@link BackPressureHandlerFactory} to create a {@link BackPressureHandler} for this container. + * @return the BackPressureHandlerFactory. + */ + BackPressureHandlerFactory getBackPressureHandlerFactory(); + /** * Return the {@link ListenerMode} mode for this container. * @return the listener mode. diff --git a/spring-cloud-aws-sqs/src/main/java/io/awspring/cloud/sqs/listener/ContainerOptionsBuilder.java b/spring-cloud-aws-sqs/src/main/java/io/awspring/cloud/sqs/listener/ContainerOptionsBuilder.java index 1e6bb38e7..1976cb70f 100644 --- a/spring-cloud-aws-sqs/src/main/java/io/awspring/cloud/sqs/listener/ContainerOptionsBuilder.java +++ b/spring-cloud-aws-sqs/src/main/java/io/awspring/cloud/sqs/listener/ContainerOptionsBuilder.java @@ -27,6 +27,9 @@ * A builder for creating a {@link ContainerOptions} instance. * @param the concrete {@link ContainerOptionsBuilder} type. * @param the concrete {@link ContainerOptions} type. + * + * @author Tomaz Fernandes + * @author Loïc Rouchon */ public interface ContainerOptionsBuilder, O extends ContainerOptions> { @@ -146,6 +149,16 @@ default B pollBackOffPolicy(BackOffPolicy pollBackOffPolicy) { */ B backPressureMode(BackPressureMode backPressureMode); + /** + * Sets the {@link BackPressureHandlerFactory} for this container. Default is + * {@code AbstractContainerOptions.DEFAULT_BACKPRESSURE_FACTORY} which results in a default + * {@link SemaphoreBackPressureHandler} to be instantiated. + * + * @param backPressureHandlerFactory the BackPressureHandler supplier. + * @return this instance. + */ + B backPressureHandlerFactory(BackPressureHandlerFactory backPressureHandlerFactory); + /** * Set the maximum interval between acknowledgements for batch acknowledgements. The default depends on the specific * {@link ContainerComponentFactory} implementation. diff --git a/spring-cloud-aws-sqs/src/main/java/io/awspring/cloud/sqs/listener/FullBatchBackPressureHandler.java b/spring-cloud-aws-sqs/src/main/java/io/awspring/cloud/sqs/listener/FullBatchBackPressureHandler.java new file mode 100644 index 000000000..1d4cbf854 --- /dev/null +++ b/spring-cloud-aws-sqs/src/main/java/io/awspring/cloud/sqs/listener/FullBatchBackPressureHandler.java @@ -0,0 +1,102 @@ +/* + * Copyright 2013-2022 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.awspring.cloud.sqs.listener; + +import io.awspring.cloud.sqs.listener.source.PollingMessageSource; +import java.time.Duration; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import org.springframework.util.Assert; + +/** + * Non-blocking {@link BackPressureHandler} implementation that ensures the exact batch size is requested. + *

+ * If the amount of permits being requested is not equal to the batch size, permits will be limited to {@literal 0}. For + * this limiting mechanism to work, the {@link FullBatchBackPressureHandler} must be used in combination with another + * {@link BackPressureHandler} and be the last one in the chain of the {@link CompositeBackPressureHandler} + * + * @see PollingMessageSource + * + * @author Loïc Rouchon + */ +public class FullBatchBackPressureHandler implements BatchAwareBackPressureHandler, IdentifiableContainerComponent { + + private static final Logger logger = LoggerFactory.getLogger(FullBatchBackPressureHandler.class); + + private final int batchSize; + + private String id = getClass().getSimpleName(); + + private FullBatchBackPressureHandler(Builder builder) { + this.batchSize = builder.batchSize; + logger.debug("FullBatchBackPressureHandler created with configuration: batchSize: {}", this.batchSize); + } + + public static Builder builder() { + return new Builder(); + } + + @Override + public void setId(String id) { + this.id = id; + } + + @Override + public String getId() { + return this.id; + } + + @Override + public int requestBatch() throws InterruptedException { + return request(this.batchSize); + } + + @Override + public int request(int amount) throws InterruptedException { + if (amount == batchSize) { + return amount; + } + logger.warn("[{}] Could not acquire a full batch ({} / {}), cancelling current poll", this.id, amount, + this.batchSize); + return 0; + } + + @Override + public void release(int amount, ReleaseReason reason) { + // NO-OP + } + + @Override + public boolean drain(Duration timeout) { + return true; + } + + public static class Builder { + + private int batchSize; + + public Builder batchSize(int batchSize) { + this.batchSize = batchSize; + return this; + } + + public FullBatchBackPressureHandler build() { + Assert.notNull(this.batchSize, "Missing configuration for batch size"); + Assert.isTrue(this.batchSize > 0, "The batch size must be greater than 0"); + return new FullBatchBackPressureHandler(this); + } + } +} diff --git a/spring-cloud-aws-sqs/src/main/java/io/awspring/cloud/sqs/listener/SemaphoreBackPressureHandler.java b/spring-cloud-aws-sqs/src/main/java/io/awspring/cloud/sqs/listener/SemaphoreBackPressureHandler.java index 310b64519..f682400a9 100644 --- a/spring-cloud-aws-sqs/src/main/java/io/awspring/cloud/sqs/listener/SemaphoreBackPressureHandler.java +++ b/spring-cloud-aws-sqs/src/main/java/io/awspring/cloud/sqs/listener/SemaphoreBackPressureHandler.java @@ -31,7 +31,8 @@ * @since 3.0 * @see io.awspring.cloud.sqs.listener.source.PollingMessageSource */ -public class SemaphoreBackPressureHandler implements BatchAwareBackPressureHandler, IdentifiableContainerComponent { +public class SemaphoreBackPressureHandler + implements BlockingBackPressureHandler, BatchAwareBackPressureHandler, IdentifiableContainerComponent { private static final Logger logger = LoggerFactory.getLogger(SemaphoreBackPressureHandler.class); diff --git a/spring-cloud-aws-sqs/src/main/java/io/awspring/cloud/sqs/listener/ThroughputBackPressureHandler.java b/spring-cloud-aws-sqs/src/main/java/io/awspring/cloud/sqs/listener/ThroughputBackPressureHandler.java new file mode 100644 index 000000000..e5b416eaa --- /dev/null +++ b/spring-cloud-aws-sqs/src/main/java/io/awspring/cloud/sqs/listener/ThroughputBackPressureHandler.java @@ -0,0 +1,157 @@ +/* + * Copyright 2013-2022 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.awspring.cloud.sqs.listener; + +import io.awspring.cloud.sqs.listener.source.PollingMessageSource; +import java.time.Duration; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicReference; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import org.springframework.util.Assert; + +/** + * Non-blocking {@link BackPressureHandler} implementation that uses a switch between high and low throughput modes. + *

+ * Throughput modes + *

    + *
  • In low-throughput mode, a single batch can be requested at a time. The number of permits that will be * delivered + * is the requested amount or 0 is a batch is already in-flight.
  • + *
  • In high-throughput mode, multiple batches can be requested at a time. The number of permits that will be + * delivered is the requested amount.
  • + *
+ *

+ * Throughput mode switch: The initial throughput mode is the low-throughput mode. If some messages are + * fetched, then the throughput mode is switched to high-throughput mode. If no messages are returned fetched by a poll, + * the throughput mode is switched back to low-throughput mode. + *

+ * This {@link BackPressureHandler} is designed to be used in combination with another {@link BackPressureHandler} like + * the {@link ConcurrencyLimiterBlockingBackPressureHandler} that will handle the maximum concurrency level within the + * application in a blocking way. + * + * @see PollingMessageSource + * + * @author Loïc Rouchon + */ +public class ThroughputBackPressureHandler implements BackPressureHandler, IdentifiableContainerComponent { + + private static final Logger logger = LoggerFactory.getLogger(ThroughputBackPressureHandler.class); + + private final AtomicReference currentThroughputMode = new AtomicReference<>( + CurrentThroughputMode.LOW); + + private final AtomicBoolean occupied = new AtomicBoolean(false); + + private final AtomicBoolean drained = new AtomicBoolean(false); + + private final int batchSize; + + private String id = getClass().getSimpleName(); + + private ThroughputBackPressureHandler(Builder builder) { + this.batchSize = builder.batchSize; + logger.debug("ThroughputBackPressureHandler created with batch size: {}", this.batchSize); + } + + public static Builder builder() { + return new Builder(); + } + + @Override + public void setId(String id) { + this.id = id; + } + + @Override + public String getId() { + return this.id; + } + + @Override + public int request(int amount) throws InterruptedException { + if (drained.get()) { + return 0; + } + CurrentThroughputMode throughputMode = this.currentThroughputMode.get(); + if (throughputMode == CurrentThroughputMode.LOW) { + if (this.occupied.get()) { + logger.debug("[{}] No permits acquired because a batch already being processed in low throughput mode", + this.id); + return 0; + } + this.occupied.set(true); + } + logger.debug("[{}] Acquired {} permits ({} mode)", this.id, amount, throughputMode); + return Math.min(amount, this.batchSize); + } + + @Override + public void release(int amount, ReleaseReason reason) { + if (drained.get()) { + return; + } + logger.debug("[{}] Releasing {} permits ({})", this.id, amount, reason); + switch (reason) { + case NONE_FETCHED -> { + this.occupied.compareAndSet(true, false); + updateThroughputMode(CurrentThroughputMode.HIGH, CurrentThroughputMode.LOW); + } + case PARTIAL_FETCH -> { + this.occupied.compareAndSet(true, false); + updateThroughputMode(CurrentThroughputMode.LOW, CurrentThroughputMode.HIGH); + } + case LIMITED, PROCESSED -> { + // No need to switch throughput mode + } + } + } + + private void updateThroughputMode(CurrentThroughputMode currentTarget, CurrentThroughputMode newTarget) { + if (this.currentThroughputMode.compareAndSet(currentTarget, newTarget)) { + logger.debug("[{}] throughput mode updated to {}", this.id, newTarget); + } + } + + @Override + public boolean drain(Duration timeout) { + logger.debug("[{}] Draining", this.id); + drained.set(true); + return true; + } + + private enum CurrentThroughputMode { + + HIGH, + + LOW; + + } + + public static class Builder { + + private int batchSize; + + public Builder batchSize(int batchSize) { + this.batchSize = batchSize; + return this; + } + + public ThroughputBackPressureHandler build() { + Assert.isTrue(this.batchSize > 0, "The batch size must be greater than 0"); + return new ThroughputBackPressureHandler(this); + } + } +} diff --git a/spring-cloud-aws-sqs/src/main/java/io/awspring/cloud/sqs/listener/source/AbstractPollingMessageSource.java b/spring-cloud-aws-sqs/src/main/java/io/awspring/cloud/sqs/listener/source/AbstractPollingMessageSource.java index e71dc4319..76d6bb8a6 100644 --- a/spring-cloud-aws-sqs/src/main/java/io/awspring/cloud/sqs/listener/source/AbstractPollingMessageSource.java +++ b/spring-cloud-aws-sqs/src/main/java/io/awspring/cloud/sqs/listener/source/AbstractPollingMessageSource.java @@ -17,6 +17,7 @@ import io.awspring.cloud.sqs.ConfigUtils; import io.awspring.cloud.sqs.listener.BackPressureHandler; +import io.awspring.cloud.sqs.listener.BackPressureHandler.ReleaseReason; import io.awspring.cloud.sqs.listener.BatchAwareBackPressureHandler; import io.awspring.cloud.sqs.listener.ContainerOptions; import io.awspring.cloud.sqs.listener.IdentifiableContainerComponent; @@ -57,6 +58,7 @@ * {@link MessageProcessingContext} and executed downstream when applicable. * * @author Tomaz Fernandes + * @author Loïc Rouchon * @since 3.0 */ public abstract class AbstractPollingMessageSource extends AbstractMessageConvertingMessageSource @@ -214,7 +216,7 @@ private void pollAndEmitMessages() { if (!isRunning()) { logger.debug("MessageSource was stopped after permits where acquired. Returning {} permits", acquiredPermits); - this.backPressureHandler.release(acquiredPermits); + this.backPressureHandler.release(acquiredPermits, ReleaseReason.NONE_FETCHED); continue; } // @formatter:off @@ -252,15 +254,12 @@ private void handlePollBackOff() { protected abstract CompletableFuture> doPollForMessages(int messagesToRequest); public Collection> releaseUnusedPermits(int permits, Collection> msgs) { - if (msgs.isEmpty() && permits == this.backPressureHandler.getBatchSize()) { - this.backPressureHandler.releaseBatch(); - logger.trace("Released batch of unused permits for queue {}", this.pollingEndpointName); - } - else { - int permitsToRelease = permits - msgs.size(); - this.backPressureHandler.release(permitsToRelease); - logger.trace("Released {} unused permits for queue {}", permitsToRelease, this.pollingEndpointName); - } + int polledMessages = msgs.size(); + int permitsToRelease = permits - polledMessages; + ReleaseReason releaseReason = polledMessages == 0 ? ReleaseReason.NONE_FETCHED : ReleaseReason.PARTIAL_FETCH; + this.backPressureHandler.release(permitsToRelease, releaseReason); + logger.trace("Released {} unused ({}) permits for queue {} (messages polled {})", permitsToRelease, + releaseReason, this.pollingEndpointName, polledMessages); return msgs; } @@ -285,7 +284,7 @@ protected AcknowledgementCallback getAcknowledgementCallback() { private void releaseBackPressure() { logger.debug("Releasing permit for queue {}", this.pollingEndpointName); - this.backPressureHandler.release(1); + this.backPressureHandler.release(1, ReleaseReason.PROCESSED); } private Void handleSinkException(Throwable t) { diff --git a/spring-cloud-aws-sqs/src/test/java/io/awspring/cloud/sqs/integration/SqsBackPressureIntegrationTests.java b/spring-cloud-aws-sqs/src/test/java/io/awspring/cloud/sqs/integration/SqsBackPressureIntegrationTests.java new file mode 100644 index 000000000..ae5c35828 --- /dev/null +++ b/spring-cloud-aws-sqs/src/test/java/io/awspring/cloud/sqs/integration/SqsBackPressureIntegrationTests.java @@ -0,0 +1,528 @@ +/* + * Copyright 2013-2022 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.awspring.cloud.sqs.integration; + +import static org.assertj.core.api.Assertions.assertThat; + +import io.awspring.cloud.sqs.config.SqsBootstrapConfiguration; +import io.awspring.cloud.sqs.listener.*; +import io.awspring.cloud.sqs.operations.SqsTemplate; +import java.nio.charset.StandardCharsets; +import java.nio.file.Files; +import java.nio.file.Path; +import java.nio.file.StandardOpenOption; +import java.time.Duration; +import java.time.Instant; +import java.util.List; +import java.util.Queue; +import java.util.Random; +import java.util.concurrent.ConcurrentLinkedQueue; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.Semaphore; +import java.util.concurrent.ThreadLocalRandom; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.function.IntUnaryOperator; +import java.util.stream.Collectors; +import java.util.stream.IntStream; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.CsvSource; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.boot.test.context.SpringBootTest; +import org.springframework.context.annotation.Bean; +import org.springframework.context.annotation.Configuration; +import org.springframework.context.annotation.Import; +import org.springframework.messaging.Message; +import org.springframework.messaging.support.MessageBuilder; + +/** + * Integration tests for SQS containers back pressure management. + * + * @author Loïc Rouchon + */ +@SpringBootTest +class SqsBackPressureIntegrationTests extends BaseSqsIntegrationTest { + + private static final Logger logger = LoggerFactory.getLogger(SqsBackPressureIntegrationTests.class); + + @Autowired + SqsTemplate sqsTemplate; + + static final class NonBlockingExternalConcurrencyLimiterBackPressureHandler implements BackPressureHandler { + private final AtomicInteger limit; + private final AtomicInteger inFlight = new AtomicInteger(0); + private final AtomicBoolean draining = new AtomicBoolean(false); + + NonBlockingExternalConcurrencyLimiterBackPressureHandler(int max) { + limit = new AtomicInteger(max); + } + + public void setLimit(int value) { + logger.info("adjusting limit from {} to {}", limit.get(), value); + limit.set(value); + } + + @Override + public int request(int amount) { + if (draining.get()) { + return 0; + } + int permits = Math.max(0, Math.min(limit.get() - inFlight.get(), amount)); + inFlight.addAndGet(permits); + return permits; + } + + @Override + public void release(int amount, ReleaseReason reason) { + inFlight.addAndGet(-amount); + } + + @Override + public boolean drain(Duration timeout) { + Duration drainingTimeout = Duration.ofSeconds(10L); + Duration drainingPollingIntervalCheck = Duration.ofMillis(50L); + draining.set(true); + limit.set(0); + Instant start = Instant.now(); + while (Duration.between(start, Instant.now()).compareTo(drainingTimeout) < 0) { + if (inFlight.get() == 0) { + return true; + } + sleep(drainingPollingIntervalCheck.toMillis()); + } + return false; + } + } + + @ParameterizedTest + @CsvSource({ "2,2", "4,4", "5,5", "20,5" }) + void staticBackPressureLimitShouldCapQueueProcessingCapacity(int staticLimit, int expectedMaxConcurrentRequests) + throws Exception { + AtomicInteger concurrentRequest = new AtomicInteger(); + AtomicInteger maxConcurrentRequest = new AtomicInteger(); + NonBlockingExternalConcurrencyLimiterBackPressureHandler limiter = new NonBlockingExternalConcurrencyLimiterBackPressureHandler( + staticLimit); + String queueName = "BACK_PRESSURE_LIMITER_QUEUE_NAME_STATIC_LIMIT_" + staticLimit; + IntStream.range(0, 10).forEach(index -> { + List> messages = create10Messages("staticBackPressureLimit" + staticLimit); + sqsTemplate.sendMany(queueName, messages); + }); + logger.debug("Sent 100 messages to queue {}", queueName); + var latch = new CountDownLatch(100); + var container = SqsMessageListenerContainer.builder().sqsAsyncClient(BaseSqsIntegrationTest.createAsyncClient()) + .queueNames(queueName) + .configure( + options -> options.maxMessagesPerPoll(5).maxConcurrentMessages(5) + .backPressureMode(BackPressureMode.AUTO).maxDelayBetweenPolls(Duration.ofSeconds(1)) + .pollTimeout(Duration.ofSeconds(1)) + .backPressureHandlerFactory(containerOptions -> BackPressureHandlerFactories + .compositeBackPressureHandler(containerOptions, Duration.ofMillis(50L), + List.of(limiter, BackPressureHandlerFactories + .concurrencyLimiterBackPressureHandler(containerOptions))))) + .messageListener(msg -> { + int concurrentRqs = concurrentRequest.incrementAndGet(); + maxConcurrentRequest.updateAndGet(max -> Math.max(max, concurrentRqs)); + sleep(50L); + logger.debug("concurrent rq {}, max concurrent rq {}, latch count {}", concurrentRequest.get(), + maxConcurrentRequest.get(), latch.getCount()); + latch.countDown(); + concurrentRequest.decrementAndGet(); + }).build(); + container.start(); + assertThat(latch.await(10, TimeUnit.SECONDS)).isTrue(); + assertThat(maxConcurrentRequest.get()).isLessThanOrEqualTo(expectedMaxConcurrentRequests); + container.stop(); + } + + @Test + void zeroBackPressureLimitShouldStopQueueProcessing() throws Exception { + AtomicInteger concurrentRequest = new AtomicInteger(); + AtomicInteger maxConcurrentRequest = new AtomicInteger(); + NonBlockingExternalConcurrencyLimiterBackPressureHandler limiter = new NonBlockingExternalConcurrencyLimiterBackPressureHandler( + 0); + String queueName = "BACK_PRESSURE_LIMITER_QUEUE_NAME_STATIC_LIMIT_0"; + IntStream.range(0, 10).forEach(index -> { + List> messages = create10Messages("staticBackPressureLimit0"); + sqsTemplate.sendMany(queueName, messages); + }); + logger.debug("Sent 100 messages to queue {}", queueName); + var latch = new CountDownLatch(100); + var container = SqsMessageListenerContainer.builder().sqsAsyncClient(BaseSqsIntegrationTest.createAsyncClient()) + .queueNames(queueName) + .configure( + options -> options.maxMessagesPerPoll(5).maxConcurrentMessages(5) + .backPressureMode(BackPressureMode.AUTO).maxDelayBetweenPolls(Duration.ofSeconds(1)) + .pollTimeout(Duration.ofSeconds(1)) + .backPressureHandlerFactory(containerOptions -> BackPressureHandlerFactories + .compositeBackPressureHandler(containerOptions, Duration.ofMillis(50L), + List.of(limiter, BackPressureHandlerFactories + .concurrencyLimiterBackPressureHandler(containerOptions))))) + .messageListener(msg -> { + int concurrentRqs = concurrentRequest.incrementAndGet(); + maxConcurrentRequest.updateAndGet(max -> Math.max(max, concurrentRqs)); + sleep(50L); + logger.debug("concurrent rq {}, max concurrent rq {}, latch count {}", concurrentRequest.get(), + maxConcurrentRequest.get(), latch.getCount()); + latch.countDown(); + concurrentRequest.decrementAndGet(); + }).build(); + container.start(); + assertThat(latch.await(2, TimeUnit.SECONDS)).isFalse(); + assertThat(maxConcurrentRequest.get()).isZero(); + assertThat(latch.getCount()).isEqualTo(100L); + container.stop(); + } + + @Test + void changeInBackPressureLimitShouldAdaptQueueProcessingCapacity() throws Exception { + AtomicInteger concurrentRequest = new AtomicInteger(); + AtomicInteger maxConcurrentRequest = new AtomicInteger(); + NonBlockingExternalConcurrencyLimiterBackPressureHandler limiter = new NonBlockingExternalConcurrencyLimiterBackPressureHandler( + 5); + String queueName = "BACK_PRESSURE_LIMITER_QUEUE_NAME_SYNC_ADAPTIVE_LIMIT"; + int nbMessages = 280; + IntStream.range(0, nbMessages / 10).forEach(index -> { + List> messages = create10Messages("syncAdaptiveBackPressureLimit"); + sqsTemplate.sendMany(queueName, messages); + }); + logger.debug("Sent {} messages to queue {}", nbMessages, queueName); + var latch = new CountDownLatch(nbMessages); + var controlSemaphore = new Semaphore(0); + var advanceSemaphore = new Semaphore(0); + var processingFailed = new AtomicBoolean(false); + var isDraining = new AtomicBoolean(false); + var container = SqsMessageListenerContainer.builder().sqsAsyncClient(BaseSqsIntegrationTest.createAsyncClient()) + .queueNames(queueName) + .configure( + options -> options.maxMessagesPerPoll(5).maxConcurrentMessages(5) + .backPressureMode(BackPressureMode.AUTO).maxDelayBetweenPolls(Duration.ofSeconds(1)) + .pollTimeout(Duration.ofSeconds(1)) + .backPressureHandlerFactory(containerOptions -> BackPressureHandlerFactories + .compositeBackPressureHandler(containerOptions, Duration.ofMillis(50L), + List.of(limiter, BackPressureHandlerFactories + .concurrencyLimiterBackPressureHandler(containerOptions))))) + .messageListener(msg -> { + try { + if (!controlSemaphore.tryAcquire(5, TimeUnit.SECONDS) && !isDraining.get()) { + processingFailed.set(true); + throw new IllegalStateException("Failed to wait for control semaphore"); + } + } + catch (InterruptedException e) { + if (!isDraining.get()) { + processingFailed.set(true); + Thread.currentThread().interrupt(); + throw new RuntimeException(e); + } + } + int concurrentRqs = concurrentRequest.incrementAndGet(); + maxConcurrentRequest.updateAndGet(max -> Math.max(max, concurrentRqs)); + latch.countDown(); + logger.debug("concurrent rq {}, max concurrent rq {}, latch count {}", concurrentRequest.get(), + maxConcurrentRequest.get(), latch.getCount()); + sleep(10L); + concurrentRequest.decrementAndGet(); + advanceSemaphore.release(); + }).build(); + class Controller { + private final Semaphore advanceSemaphore; + private final Semaphore controlSemaphore; + private final NonBlockingExternalConcurrencyLimiterBackPressureHandler limiter; + private final AtomicInteger maxConcurrentRequest; + private final AtomicBoolean processingFailed; + + Controller(Semaphore advanceSemaphore, Semaphore controlSemaphore, + NonBlockingExternalConcurrencyLimiterBackPressureHandler limiter, + AtomicInteger maxConcurrentRequest, AtomicBoolean processingFailed) { + this.advanceSemaphore = advanceSemaphore; + this.controlSemaphore = controlSemaphore; + this.limiter = limiter; + this.maxConcurrentRequest = maxConcurrentRequest; + this.processingFailed = processingFailed; + } + + public void updateLimit(int newLimit) { + limiter.setLimit(newLimit); + } + + void updateLimitAndWaitForReset(int newLimit) throws InterruptedException { + updateLimit(newLimit); + int atLeastTwoPollingCycles = 2 * 5; + controlSemaphore.release(atLeastTwoPollingCycles); + waitForAdvance(atLeastTwoPollingCycles); + maxConcurrentRequest.set(0); + } + + void advance(int permits) { + controlSemaphore.release(permits); + } + + void waitForAdvance(int permits) throws InterruptedException { + assertThat(advanceSemaphore.tryAcquire(permits, 5, TimeUnit.SECONDS)) + .withFailMessage(() -> "Waiting for %d permits timed out. Only %d permits available" + .formatted(permits, advanceSemaphore.availablePermits())) + .isTrue(); + assertThat(processingFailed.get()).isFalse(); + } + } + var controller = new Controller(advanceSemaphore, controlSemaphore, limiter, maxConcurrentRequest, + processingFailed); + try { + container.start(); + + controller.advance(50); + controller.waitForAdvance(50); + // not limiting queue processing capacity + assertThat(controller.maxConcurrentRequest.get()).isLessThanOrEqualTo(5); + controller.updateLimitAndWaitForReset(2); + controller.advance(50); + + controller.waitForAdvance(50); + // limiting queue processing capacity + assertThat(controller.maxConcurrentRequest.get()).isLessThanOrEqualTo(2); + controller.updateLimitAndWaitForReset(7); + controller.advance(50); + + controller.waitForAdvance(50); + // not limiting queue processing capacity + assertThat(controller.maxConcurrentRequest.get()).isLessThanOrEqualTo(5); + controller.updateLimitAndWaitForReset(3); + controller.advance(50); + sleep(10L); + limiter.setLimit(1); + sleep(10L); + limiter.setLimit(2); + sleep(10L); + limiter.setLimit(3); + + controller.waitForAdvance(50); + assertThat(controller.maxConcurrentRequest.get()).isLessThanOrEqualTo(3); + // stopping processing of the queue + controller.updateLimit(0); + controller.advance(50); + assertThat(advanceSemaphore.tryAcquire(10, 5, TimeUnit.SECONDS)) + .withFailMessage("Acquiring semaphore should have timed out as limit was set to 0").isFalse(); + + // resume queue processing + controller.updateLimit(6); + + controller.waitForAdvance(50); + assertThat(latch.await(10, TimeUnit.SECONDS)).isTrue(); + assertThat(controller.maxConcurrentRequest.get()).isLessThanOrEqualTo(5); + assertThat(processingFailed.get()).isFalse(); + } + finally { + isDraining.set(true); + container.stop(); + } + } + + static class EventsCsvWriter { + private final Queue events = new ConcurrentLinkedQueue<>(List.of("event,time,value")); + + void registerEvent(String event, int value) { + events.add("%s,%s,%d".formatted(event, Instant.now(), value)); + } + + void write(Path path) throws Exception { + Files.writeString(path, String.join("\n", events), StandardCharsets.UTF_8, StandardOpenOption.CREATE, + StandardOpenOption.TRUNCATE_EXISTING); + } + } + + static class StatisticsBphDecorator implements BatchAwareBackPressureHandler, IdentifiableContainerComponent { + private final BatchAwareBackPressureHandler delegate; + private final EventsCsvWriter eventCsv; + private String id; + + StatisticsBphDecorator(BatchAwareBackPressureHandler delegate, EventsCsvWriter eventsCsvWriter) { + this.delegate = delegate; + this.eventCsv = eventsCsvWriter; + } + + @Override + public int requestBatch() throws InterruptedException { + int permits = delegate.requestBatch(); + if (permits > 0) { + eventCsv.registerEvent("obtained_permits", permits); + } + return permits; + } + + @Override + public int request(int amount) throws InterruptedException { + int permits = delegate.request(amount); + if (permits > 0) { + eventCsv.registerEvent("obtained_permits", permits); + } + return permits; + } + + @Override + public void release(int amount, ReleaseReason reason) { + if (amount > 0) { + eventCsv.registerEvent("release_" + reason, amount); + } + delegate.release(amount, reason); + } + + @Override + public boolean drain(Duration timeout) { + eventCsv.registerEvent("drain", 1); + return delegate.drain(timeout); + } + + @Override + public void setId(String id) { + this.id = id; + if (delegate instanceof IdentifiableContainerComponent icc) { + icc.setId("delegate-" + id); + } + } + + @Override + public String getId() { + return id; + } + } + + /** + * This test simulates a progressive change in the back pressure limit. Unlike + * {@link #changeInBackPressureLimitShouldAdaptQueueProcessingCapacity()}, this test does not block message + * consumption while updating the limit. + *

+ * The limit is updated in a loop until all messages are consumed. The update follows a triangle wave pattern with a + * minimum of 0, a maximum of 15, and a period of 30 iterations. After each update of the limit, the test waits up + * to 10ms and samples the maximum number of concurrent messages that were processed since the update. This number + * can be higher than the defined limit during the adaptation period of the decreasing limit wave. For the + * increasing limit wave, it is usually lower due to the adaptation delay. In both cases, the maximum number of + * concurrent messages being processed rapidly converges toward the defined limit. + *

+ * The test passes if the sum of the sampled maximum number of concurrently processed messages is lower than the sum + * of the limits at those points in time. + */ + @Test + void unsynchronizedChangesInBackPressureLimitShouldAdaptQueueProcessingCapacity() throws Exception { + AtomicInteger concurrentRequest = new AtomicInteger(); + AtomicInteger maxConcurrentRequest = new AtomicInteger(); + NonBlockingExternalConcurrencyLimiterBackPressureHandler limiter = new NonBlockingExternalConcurrencyLimiterBackPressureHandler( + 0); + String queueName = "REACTIVE_BACK_PRESSURE_LIMITER_QUEUE_NAME_ADAPTIVE_LIMIT"; + int nbMessages = 1000; + Semaphore advanceSemaphore = new Semaphore(0); + IntStream.range(0, nbMessages / 10).forEach(index -> { + List> messages = create10Messages("reactAdaptiveBackPressureLimit"); + sqsTemplate.sendMany(queueName, messages); + }); + logger.debug("Sent {} messages to queue {}", nbMessages, queueName); + var latch = new CountDownLatch(nbMessages); + EventsCsvWriter eventsCsvWriter = new EventsCsvWriter(); + var container = SqsMessageListenerContainer.builder().sqsAsyncClient(BaseSqsIntegrationTest.createAsyncClient()) + .queueNames(queueName) + .configure( + options -> options.maxMessagesPerPoll(10).maxConcurrentMessages(10) + .backPressureMode(BackPressureMode.AUTO).maxDelayBetweenPolls(Duration.ofSeconds(1)) + .pollTimeout(Duration.ofSeconds(1)) + .backPressureHandlerFactory(containerOptions -> new StatisticsBphDecorator( + BackPressureHandlerFactories.compositeBackPressureHandler(containerOptions, + Duration.ofMillis(50L), + List.of(limiter, BackPressureHandlerFactories + .concurrencyLimiterBackPressureHandler(containerOptions))), + eventsCsvWriter))) + .messageListener(msg -> { + int currentConcurrentRq = concurrentRequest.incrementAndGet(); + maxConcurrentRequest.updateAndGet(max -> Math.max(max, currentConcurrentRq)); + sleep(ThreadLocalRandom.current().nextInt(10)); + latch.countDown(); + logger.debug("concurrent rq {}, max concurrent rq {}, latch count {}", concurrentRequest.get(), + maxConcurrentRequest.get(), latch.getCount()); + concurrentRequest.decrementAndGet(); + advanceSemaphore.release(); + }).build(); + IntUnaryOperator progressiveLimitChange = (int x) -> { + int period = 30; + int halfPeriod = period / 2; + if (x % period < halfPeriod) { + return (x % halfPeriod); + } + else { + return (halfPeriod - (x % halfPeriod)); + } + }; + try { + container.start(); + Random random = new Random(); + int limitsSum = 0; + int maxConcurrentRqSum = 0; + int changeLimitCount = 0; + while (latch.getCount() > 0 && changeLimitCount < nbMessages) { + changeLimitCount++; + int limit = progressiveLimitChange.applyAsInt(changeLimitCount); + int expectedMax = Math.min(10, limit); + limiter.setLimit(limit); + maxConcurrentRequest.set(0); + sleep(random.nextInt(20)); + int actualLimit = Math.min(10, limit); + int max = maxConcurrentRequest.get(); + if (max > 0) { + // Ignore iterations where nothing was polled (messages consumption slower than iteration) + limitsSum += actualLimit; + maxConcurrentRqSum += max; + } + eventsCsvWriter.registerEvent("max_concurrent_rq", max); + eventsCsvWriter.registerEvent("concurrent_rq", concurrentRequest.get()); + eventsCsvWriter.registerEvent("limit", limit); + eventsCsvWriter.registerEvent("in_flight", limiter.inFlight.get()); + eventsCsvWriter.registerEvent("expected_max", expectedMax); + eventsCsvWriter.registerEvent("max_minus_expected_max", max - expectedMax); + } + eventsCsvWriter.write(Path.of("target/stats-%s.csv".formatted(queueName))); + assertThat(maxConcurrentRqSum).isLessThanOrEqualTo(limitsSum); + assertThat(latch.await(10, TimeUnit.SECONDS)).isTrue(); + } + finally { + container.stop(); + } + } + + private static void sleep(long millis) { + try { + Thread.sleep(millis); + } + catch (InterruptedException e) { + Thread.currentThread().interrupt(); + } + } + + private List> create10Messages(String testName) { + return IntStream.range(0, 10).mapToObj(index -> testName + "-payload-" + index) + .map(payload -> MessageBuilder.withPayload(payload).build()).collect(Collectors.toList()); + } + + @Import(SqsBootstrapConfiguration.class) + @Configuration + static class SQSConfiguration { + + @Bean + SqsTemplate sqsTemplate() { + return SqsTemplate.builder().sqsAsyncClient(BaseSqsIntegrationTest.createAsyncClient()).build(); + } + } +} diff --git a/spring-cloud-aws-sqs/src/test/java/io/awspring/cloud/sqs/integration/SqsIntegrationTests.java b/spring-cloud-aws-sqs/src/test/java/io/awspring/cloud/sqs/integration/SqsIntegrationTests.java index 0a1157cea..95920deb9 100644 --- a/spring-cloud-aws-sqs/src/test/java/io/awspring/cloud/sqs/integration/SqsIntegrationTests.java +++ b/spring-cloud-aws-sqs/src/test/java/io/awspring/cloud/sqs/integration/SqsIntegrationTests.java @@ -392,6 +392,7 @@ void manuallyCreatesInactiveContainer() throws Exception { logger.debug("Sent message to queue {} with messageBody {}", MANUALLY_CREATE_INACTIVE_CONTAINER_QUEUE_NAME, messageBody); assertThat(latchContainer.manuallyInactiveCreatedContainerLatch.await(10, TimeUnit.SECONDS)).isTrue(); + inactiveMessageListenerContainer.stop(); } // @formatter:off diff --git a/spring-cloud-aws-sqs/src/test/java/io/awspring/cloud/sqs/listener/CompositeBackPressureHandlerTest.java b/spring-cloud-aws-sqs/src/test/java/io/awspring/cloud/sqs/listener/CompositeBackPressureHandlerTest.java new file mode 100644 index 000000000..b8b6ac47a --- /dev/null +++ b/spring-cloud-aws-sqs/src/test/java/io/awspring/cloud/sqs/listener/CompositeBackPressureHandlerTest.java @@ -0,0 +1,163 @@ +/* + * Copyright 2013-2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.awspring.cloud.sqs.listener; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.mockito.Mockito.*; + +import java.time.Duration; +import java.util.List; +import java.util.concurrent.atomic.AtomicInteger; +import org.jetbrains.annotations.NotNull; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; + +/** + * Tests for {@link CompositeBackPressureHandler}. + * + * @author Loïc Rouchon + */ +class CompositeBackPressureHandlerTest { + + private BackPressureHandler handler1; + private BackPressureHandler handler2; + + @BeforeEach + void setUp() { + handler1 = mock(BackPressureHandler.class); + handler2 = mock(BackPressureHandler.class); + } + + @Test + void request_shouldDelegateToHandlersAndReturnMinPermits() throws InterruptedException { + // given + CompositeBackPressureHandler compositeHandler = compositeHandlerBuilder() + .noPermitsReturnedWaitTimeout(Duration.ofSeconds(30)).backPressureHandlers(List.of(handler1, handler2)) + .build(); + when(handler1.request(5)).thenReturn(5); + when(handler2.request(5)).thenReturn(3); + // when + int permits = compositeHandler.request(5); + // then + assertThat(permits).isEqualTo(3); + verify(handler1).request(5); + verify(handler2).request(5); + } + + @Test + void release_shouldDelegateToHandlers() { + // given + CompositeBackPressureHandler compositeHandler = compositeHandlerBuilder() + .noPermitsReturnedWaitTimeout(Duration.ofSeconds(30)).backPressureHandlers(List.of(handler1, handler2)) + .build(); + // when + compositeHandler.release(2, BackPressureHandler.ReleaseReason.PROCESSED); + // then + verify(handler1).release(2, BackPressureHandler.ReleaseReason.PROCESSED); + verify(handler2).release(2, BackPressureHandler.ReleaseReason.PROCESSED); + } + + @Test + void request_shouldWaitIfNoPermitsAndTimeout() throws InterruptedException { + // given + CompositeBackPressureHandler compositeHandler = compositeHandlerBuilder() + .noPermitsReturnedWaitTimeout(Duration.ofSeconds(5)).backPressureHandlers(List.of(handler1, handler2)) + .build(); + when(handler1.request(5)).thenReturn(0); + when(handler2.request(5)).thenReturn(0); + // when + long start = System.nanoTime(); + int permits = compositeHandler.request(5); + Duration duration = Duration.ofNanos(System.nanoTime() - start); + // then + assertThat(permits).isZero(); + assertThat(duration).isGreaterThanOrEqualTo(Duration.ofSeconds(1L)); + } + + @Test + void request_shouldPassReducedPermitsToSubsequentHandlers() throws InterruptedException { + // given + CompositeBackPressureHandler compositeHandler = compositeHandlerBuilder() + .noPermitsReturnedWaitTimeout(Duration.ofSeconds(30)).backPressureHandlers(List.of(handler1, handler2)) + .build(); + when(handler1.request(10)).thenReturn(5); + when(handler2.request(5)).thenReturn(5); + // when + int permits = compositeHandler.request(10); + // then + assertThat(permits).isEqualTo(5); + verify(handler1).request(10); + verify(handler2).request(5); + } + + @Test + void request_whenLaterHandlerReturnsLessPermits_shouldReleaseDiffWithLimitedOnPreviousHandlers() + throws InterruptedException { + // given + BackPressureHandler handler3 = mock(BackPressureHandler.class); + CompositeBackPressureHandler compositeHandler = compositeHandlerBuilder() + .noPermitsReturnedWaitTimeout(Duration.ofMillis(50)) + .backPressureHandlers(List.of(handler1, handler2, handler3)).build(); + when(handler1.request(5)).thenReturn(4); + when(handler2.request(4)).thenReturn(2); + when(handler3.request(2)).thenReturn(1); + // when + int permits = compositeHandler.request(5); + // then + assertThat(permits).isEqualTo(1); + verify(handler1).request(5); + verify(handler2).request(4); + verify(handler3).request(2); + verify(handler1).release(3, BackPressureHandler.ReleaseReason.LIMITED); + verify(handler2).release(1, BackPressureHandler.ReleaseReason.LIMITED); + verify(handler3, never()).release(anyInt(), any()); + } + + @Test + void request_shouldUnblockWhenPermitsAreReleased() throws InterruptedException { + // given + CompositeBackPressureHandler compositeHandler = compositeHandlerBuilder() + .noPermitsReturnedWaitTimeout(Duration.ofSeconds(30)).backPressureHandlers(List.of(handler1, handler2)) + .build(); + when(handler1.request(5)).thenReturn(0, 5); + when(handler2.request(5)).thenReturn(5); + + AtomicInteger result = new AtomicInteger(-1); + Thread requester = new Thread(() -> { + try { + // when + result.set(compositeHandler.request(5)); + } + catch (InterruptedException e) { + Thread.currentThread().interrupt(); + } + }); + requester.start(); + Thread.sleep(200); // Ensure requester is waiting + assertThat(requester.isAlive()).isTrue(); + // when + compositeHandler.release(5, BackPressureHandler.ReleaseReason.PROCESSED); + requester.join(2000); + // then + assertThat(requester.isAlive()).isFalse(); + assertThat(result.get()).isZero(); + assertThat(compositeHandler.request(5)).isEqualTo(5); + } + + private static CompositeBackPressureHandler.@NotNull Builder compositeHandlerBuilder() { + return CompositeBackPressureHandler.builder().batchSize(5); + } +} diff --git a/spring-cloud-aws-sqs/src/test/java/io/awspring/cloud/sqs/listener/ConcurrencyLimiterBlockingBackPressureHandlerTest.java b/spring-cloud-aws-sqs/src/test/java/io/awspring/cloud/sqs/listener/ConcurrencyLimiterBlockingBackPressureHandlerTest.java new file mode 100644 index 000000000..f655069d0 --- /dev/null +++ b/spring-cloud-aws-sqs/src/test/java/io/awspring/cloud/sqs/listener/ConcurrencyLimiterBlockingBackPressureHandlerTest.java @@ -0,0 +1,61 @@ +/* + * Copyright 2013-2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.awspring.cloud.sqs.listener; + +import static org.assertj.core.api.Assertions.assertThat; + +import java.time.Duration; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; + +/** + * Tests for {@link ConcurrencyLimiterBlockingBackPressureHandler}. + * + * @author Loïc Rouchon + */ +class ConcurrencyLimiterBlockingBackPressureHandlerTest { + + private static final int BATCH_SIZE = 5; + private static final int TOTAL_PERMITS = 10; + + private ConcurrencyLimiterBlockingBackPressureHandler handler; + + @BeforeEach + void setUp() { + handler = ConcurrencyLimiterBlockingBackPressureHandler.builder().totalPermits(TOTAL_PERMITS) + .batchSize(BATCH_SIZE).acquireTimeout(Duration.ofMillis(100)).build(); + } + + @Test + void request_shouldAcquirePermits() throws InterruptedException { + // Requesting a first batch should acquire the permits + assertThat(handler.request(BATCH_SIZE)).isEqualTo(BATCH_SIZE); + // Requesting a second batch should acquire the remaining permits + assertThat(handler.request(BATCH_SIZE)).isEqualTo(BATCH_SIZE); + // No permits left + assertThat(handler.request(1)).isZero(); + } + + @Test + void release_shouldAllowFurtherRequests() throws InterruptedException { + // Given all permits are acquired + assertThat(handler.request(TOTAL_PERMITS)).isEqualTo(TOTAL_PERMITS); + assertThat(handler.request(1)).isZero(); + // When releasing some permits, new requests should be allowed + handler.release(3, BackPressureHandler.ReleaseReason.PROCESSED); + assertThat(handler.request(5)).isEqualTo(3); // Only 3 permits were released so far + } +} diff --git a/spring-cloud-aws-sqs/src/test/java/io/awspring/cloud/sqs/listener/FullBatchBackPressureHandlerTest.java b/spring-cloud-aws-sqs/src/test/java/io/awspring/cloud/sqs/listener/FullBatchBackPressureHandlerTest.java new file mode 100644 index 000000000..43df103e0 --- /dev/null +++ b/spring-cloud-aws-sqs/src/test/java/io/awspring/cloud/sqs/listener/FullBatchBackPressureHandlerTest.java @@ -0,0 +1,56 @@ +/* + * Copyright 2013-2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.awspring.cloud.sqs.listener; + +import static org.assertj.core.api.Assertions.assertThat; + +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; + +/** + * Tests for {@link FullBatchBackPressureHandler}. + * + * @author Loïc Rouchon + */ +class FullBatchBackPressureHandlerTest { + + private FullBatchBackPressureHandler handler; + + private final int batchSize = 10; + + @BeforeEach + void setUp() { + handler = FullBatchBackPressureHandler.builder().batchSize(batchSize).build(); + } + + @Test + void request_withExactBatchSize_shouldReturnBatchSize() throws InterruptedException { + assertThat(handler.request(batchSize)).isEqualTo(batchSize); + } + + @Test + void request_withNonBatchSize_shouldReturnZero() throws InterruptedException { + int permits = handler.request(batchSize - 1); + assertThat(permits).isZero(); + permits = handler.request(batchSize + 1); + assertThat(permits).isZero(); + } + + @Test + void requestBatch_shouldReturnBatchSize() throws InterruptedException { + assertThat(handler.requestBatch()).isEqualTo(batchSize); + } +} diff --git a/spring-cloud-aws-sqs/src/test/java/io/awspring/cloud/sqs/listener/ThroughputBackPressureHandlerTest.java b/spring-cloud-aws-sqs/src/test/java/io/awspring/cloud/sqs/listener/ThroughputBackPressureHandlerTest.java new file mode 100644 index 000000000..56bfc563c --- /dev/null +++ b/spring-cloud-aws-sqs/src/test/java/io/awspring/cloud/sqs/listener/ThroughputBackPressureHandlerTest.java @@ -0,0 +1,94 @@ +/* + * Copyright 2013-2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.awspring.cloud.sqs.listener; + +import static org.assertj.core.api.Assertions.assertThat; + +import java.time.Duration; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.CsvSource; + +/** + * Tests for {@link ThroughputBackPressureHandler}. + * + * @author Loïc Rouchon + */ +class ThroughputBackPressureHandlerTest { + private ThroughputBackPressureHandler handler; + + @BeforeEach + void setUp() { + handler = new ThroughputBackPressureHandler.Builder().batchSize(5).build(); + } + + @ParameterizedTest + @CsvSource({ "4,4", "5,5", "6,5", }) + void amountIsCappedAtBatchSize(int requestedAmount, int expectedPermits) throws InterruptedException { + assertThat(handler.request(requestedAmount)).isEqualTo(expectedPermits); + } + + @ParameterizedTest + @CsvSource({ "LIMITED,0", "PROCESSED,0", "NONE_FETCHED,5", "PARTIAL_FETCH,5", }) + void lowThroughputMode_shouldReturnZeroUntilRelease(BackPressureHandler.ReleaseReason releaseReason, + int expectedPermitsAfterRelease) throws InterruptedException { + // Given a first batch + int batchSize = 5; + assertThat(handler.request(batchSize)).isEqualTo(batchSize); + // When a second batch is requested, it should return zero permits (because low throughput mode) + assertThat(handler.request(batchSize)).isZero(); + // When a batch is requested after a release, the expected permits should be + // returned depending on the release reason + handler.release(1, releaseReason); + assertThat(handler.request(batchSize)).isEqualTo(expectedPermitsAfterRelease); + } + + @Test + void highThroughputMode_shouldAllowMultipleConcurrentRequests() throws InterruptedException { + // Given a first batch with polled messages + int batchSize = 5; + assertThat(handler.request(batchSize)).isEqualTo(batchSize); + handler.release(0, BackPressureHandler.ReleaseReason.PARTIAL_FETCH); // switch to HIGH + // Then subsequent requests should return the same batch size + // because we are in high throughput mode + assertThat(handler.request(batchSize)).isEqualTo(batchSize); + assertThat(handler.request(batchSize)).isEqualTo(batchSize); + handler.release(0, BackPressureHandler.ReleaseReason.PARTIAL_FETCH); + handler.release(0, BackPressureHandler.ReleaseReason.PARTIAL_FETCH); + // When a fetch returns no messages, throughput mode should switch to LOW + assertThat(handler.request(batchSize)).isEqualTo(batchSize); + handler.release(5, BackPressureHandler.ReleaseReason.NONE_FETCHED); + assertThat(handler.request(batchSize)).isEqualTo(batchSize); + // And subsequent requests should return zero permits until the current batch finishes with NONE_FETCHED + assertThat(handler.request(batchSize)).isZero(); + assertThat(handler.request(batchSize)).isZero(); + handler.release(5, BackPressureHandler.ReleaseReason.NONE_FETCHED); + assertThat(handler.request(batchSize)).isEqualTo(5); + // or until it (the current batch) finishes with PARTIAL_FETCH + assertThat(handler.request(batchSize)).isZero(); + assertThat(handler.request(batchSize)).isZero(); + handler.release(3, BackPressureHandler.ReleaseReason.PARTIAL_FETCH); + assertThat(handler.request(batchSize)).isEqualTo(5); + } + + @Test + void drain_shouldSetDrainedAndReturnTrue() throws InterruptedException { + boolean result = handler.drain(Duration.ofSeconds(1)); + assertThat(result).isTrue(); + assertThat(handler.request(5)).isZero(); + } +} diff --git a/spring-cloud-aws-sqs/src/test/java/io/awspring/cloud/sqs/listener/source/AbstractPollingMessageSourceTests.java b/spring-cloud-aws-sqs/src/test/java/io/awspring/cloud/sqs/listener/source/AbstractPollingMessageSourceTests.java index b03b308c6..f9371450f 100644 --- a/spring-cloud-aws-sqs/src/test/java/io/awspring/cloud/sqs/listener/source/AbstractPollingMessageSourceTests.java +++ b/spring-cloud-aws-sqs/src/test/java/io/awspring/cloud/sqs/listener/source/AbstractPollingMessageSourceTests.java @@ -23,34 +23,23 @@ import static org.mockito.Mockito.times; import io.awspring.cloud.sqs.MessageExecutionThreadFactory; -import io.awspring.cloud.sqs.listener.BackPressureMode; -import io.awspring.cloud.sqs.listener.SemaphoreBackPressureHandler; -import io.awspring.cloud.sqs.listener.SqsContainerOptions; +import io.awspring.cloud.sqs.listener.*; import io.awspring.cloud.sqs.listener.acknowledgement.AcknowledgementCallback; import io.awspring.cloud.sqs.listener.acknowledgement.AcknowledgementProcessor; import io.awspring.cloud.sqs.support.converter.MessageConversionContext; import io.awspring.cloud.sqs.support.converter.SqsMessagingMessageConverter; import java.time.Duration; -import java.util.Collection; -import java.util.Collections; -import java.util.List; -import java.util.UUID; -import java.util.concurrent.CompletableFuture; -import java.util.concurrent.CountDownLatch; -import java.util.concurrent.ExecutorService; -import java.util.concurrent.Executors; -import java.util.concurrent.Semaphore; -import java.util.concurrent.ThreadFactory; -import java.util.concurrent.TimeUnit; +import java.util.*; +import java.util.concurrent.*; import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicInteger; +import java.util.concurrent.atomic.AtomicReference; import java.util.stream.IntStream; import org.assertj.core.api.InstanceOfAssertFactories; import org.awaitility.Awaitility; import org.junit.jupiter.api.Test; import org.slf4j.Logger; import org.slf4j.LoggerFactory; -import org.springframework.core.task.TaskExecutor; import org.springframework.lang.Nullable; import org.springframework.retry.backoff.BackOffContext; import org.springframework.retry.backoff.BackOffPolicy; @@ -60,6 +49,7 @@ /** * @author Tomaz Fernandes + * @author Loïc Rouchon */ class AbstractPollingMessageSourceTests { @@ -68,10 +58,12 @@ class AbstractPollingMessageSourceTests { @Test void shouldAcquireAndReleaseFullPermits() { String testName = "shouldAcquireAndReleaseFullPermits"; + SqsContainerOptions options = SqsContainerOptions.builder().maxMessagesPerPoll(10).maxConcurrentMessages(10) + .backPressureMode(BackPressureMode.ALWAYS_POLL_MAX_MESSAGES) + .maxDelayBetweenPolls(Duration.ofMillis(200)).listenerShutdownTimeout(Duration.ZERO).build(); + BackPressureHandler backPressureHandler = BackPressureHandlerFactories + .adaptativeThroughputBackPressureHandler(options, Duration.ofMillis(100L)); - SemaphoreBackPressureHandler backPressureHandler = SemaphoreBackPressureHandler.builder() - .acquireTimeout(Duration.ofMillis(200)).batchSize(10).totalPermits(10) - .throughputConfiguration(BackPressureMode.ALWAYS_POLL_MAX_MESSAGES).build(); ExecutorService threadPool = Executors.newCachedThreadPool(); CountDownLatch pollingCounter = new CountDownLatch(3); CountDownLatch processingCounter = new CountDownLatch(1); @@ -80,8 +72,6 @@ void shouldAcquireAndReleaseFullPermits() { private final AtomicBoolean hasReceived = new AtomicBoolean(false); - private final AtomicBoolean hasMadeSecondPoll = new AtomicBoolean(false); - @Override protected CompletableFuture> doPollForMessages(int messagesToRequest) { return CompletableFuture.supplyAsync(() -> { @@ -90,52 +80,128 @@ protected CompletableFuture> doPollForMessages(int messagesT assertThat(messagesToRequest).isEqualTo(10); assertAvailablePermits(backPressureHandler, 0); boolean firstPoll = hasReceived.compareAndSet(false, true); - if (firstPoll) { - logger.debug("First poll"); - // No permits released yet, should be TM low + return firstPoll + ? (Collection) List.of(Message.builder() + .messageId(UUID.randomUUID().toString()).body("message").build()) + : Collections. emptyList(); + } + catch (Throwable t) { + logger.error("Error", t); + throw new RuntimeException(t); + } + }, threadPool).whenComplete((v, t) -> { + if (t == null) { + pollingCounter.countDown(); + } + }); + } + }; + + source.setBackPressureHandler(backPressureHandler); + source.setMessageSink((msgs, context) -> { + assertAvailablePermits(backPressureHandler, 9); + msgs.forEach(msg -> context.runBackPressureReleaseCallback()); + return CompletableFuture.runAsync(processingCounter::countDown); + }); + + source.setId(testName + " source"); + source.configure(options); + source.setTaskExecutor(createTaskExecutor(testName)); + source.setAcknowledgementProcessor(getNoOpsAcknowledgementProcessor()); + try { + source.start(); + assertThat(doAwait(pollingCounter)).isTrue(); + assertThat(doAwait(processingCounter)).isTrue(); + } + finally { + source.stop(); + threadPool.shutdownNow(); + } + } + + @Test + void shouldAdaptThroughputMode() { + String testName = "shouldAdaptThroughputMode"; + SqsContainerOptions options = SqsContainerOptions.builder().maxMessagesPerPoll(10).maxConcurrentMessages(10) + .backPressureMode(BackPressureMode.ALWAYS_POLL_MAX_MESSAGES) + .maxDelayBetweenPolls(Duration.ofMillis(150)).listenerShutdownTimeout(Duration.ZERO).build(); + BackPressureHandler backPressureHandler = BackPressureHandlerFactories + .adaptativeThroughputBackPressureHandler(options, Duration.ofMillis(100L)); + + ExecutorService threadPool = Executors.newCachedThreadPool(); + CountDownLatch pollingCounter = new CountDownLatch(3); + CountDownLatch processingCounter = new CountDownLatch(1); + Collection errors = new ConcurrentLinkedQueue<>(); + + AbstractPollingMessageSource source = new AbstractPollingMessageSource<>() { + + private final AtomicInteger pollAttemptCounter = new AtomicInteger(0); + + @Override + protected CompletableFuture> doPollForMessages(int messagesToRequest) { + return CompletableFuture.supplyAsync(() -> { + try { + int pollAttempt = pollAttemptCounter.incrementAndGet(); + logger.warn("Poll attempt {}", pollAttempt); + if (pollAttempt == 1) { + // Initial poll; throughput mode should be low assertThroughputMode(backPressureHandler, "low"); + // Since no permits were acquired yet, should be 10 + assertThat(messagesToRequest).isEqualTo(10); + return (Collection) List.of( + Message.builder().messageId(UUID.randomUUID().toString()).body("message").build()); } - else if (hasMadeSecondPoll.compareAndSet(false, true)) { - logger.debug("Second poll"); - // Permits returned, should be high + else if (pollAttempt == 2) { + // Messages returned in the previous poll; throughput mode should be high assertThroughputMode(backPressureHandler, "high"); + // Since throughput mode is high, should be 10 + assertThat(messagesToRequest).isEqualTo(10); + return Collections. emptyList(); } else { - logger.debug("Third poll"); - // Already returned full permits, should be low + // No Messages returned in the previous poll; throughput mode should be low assertThroughputMode(backPressureHandler, "low"); + return Collections. emptyList(); } - return firstPoll - ? (Collection) List.of(Message.builder() - .messageId(UUID.randomUUID().toString()).body("message").build()) - : Collections. emptyList(); } catch (Throwable t) { - logger.error("Error", t); + logger.error("Error (not expecting it)", t); + errors.add(t); throw new RuntimeException(t); } }, threadPool).whenComplete((v, t) -> { if (t == null) { + logger.warn("Polling succeeded", t); pollingCounter.countDown(); } + else { + logger.warn("Polling failed with error", t); + errors.add(t); + } }); } }; source.setBackPressureHandler(backPressureHandler); source.setMessageSink((msgs, context) -> { - assertAvailablePermits(backPressureHandler, 9); msgs.forEach(msg -> context.runBackPressureReleaseCallback()); return CompletableFuture.runAsync(processingCounter::countDown); }); source.setId(testName + " source"); - source.configure(SqsContainerOptions.builder().build()); + source.configure(options); source.setTaskExecutor(createTaskExecutor(testName)); source.setAcknowledgementProcessor(getNoOpsAcknowledgementProcessor()); - source.start(); - assertThat(doAwait(pollingCounter)).isTrue(); - assertThat(doAwait(processingCounter)).isTrue(); + try { + source.start(); + assertThat(doAwait(pollingCounter)).isTrue(); + assertThat(doAwait(processingCounter)).isTrue(); + assertThat(errors).isEmpty(); + } + finally { + source.stop(); + threadPool.shutdownNow(); + } } private static final AtomicInteger testCounter = new AtomicInteger(); @@ -143,9 +209,11 @@ else if (hasMadeSecondPoll.compareAndSet(false, true)) { @Test void shouldAcquireAndReleasePartialPermits() { String testName = "shouldAcquireAndReleasePartialPermits"; - SemaphoreBackPressureHandler backPressureHandler = SemaphoreBackPressureHandler.builder() - .acquireTimeout(Duration.ofMillis(150)).batchSize(10).totalPermits(10) - .throughputConfiguration(BackPressureMode.AUTO).build(); + SqsContainerOptions options = SqsContainerOptions.builder().maxMessagesPerPoll(10).maxConcurrentMessages(10) + .backPressureMode(BackPressureMode.AUTO).maxDelayBetweenPolls(Duration.ofMillis(150)) + .listenerShutdownTimeout(Duration.ZERO).build(); + BackPressureHandler backPressureHandler = BackPressureHandlerFactories + .adaptativeThroughputBackPressureHandler(options, Duration.ofMillis(200L)); ExecutorService threadPool = Executors .newCachedThreadPool(new MessageExecutionThreadFactory("test " + testCounter.incrementAndGet())); CountDownLatch pollingCounter = new CountDownLatch(4); @@ -155,60 +223,34 @@ void shouldAcquireAndReleasePartialPermits() { AbstractPollingMessageSource source = new AbstractPollingMessageSource<>() { - private final AtomicBoolean hasReceived = new AtomicBoolean(false); - - private final AtomicBoolean hasAcquired9 = new AtomicBoolean(false); - - private final AtomicBoolean hasMadeThirdPoll = new AtomicBoolean(false); + private final AtomicInteger pollAttemptCounter = new AtomicInteger(0); @Override protected CompletableFuture> doPollForMessages(int messagesToRequest) { return CompletableFuture.supplyAsync(() -> { try { - // Give it some time between returning empty and polling again - // doSleep(100); - - // Will only be true the first time it sets hasReceived to true - boolean shouldReturnMessage = hasReceived.compareAndSet(false, true); - if (shouldReturnMessage) { + int pollAttempt = pollAttemptCounter.incrementAndGet(); + if (pollAttempt == 1) { // First poll, should have 10 logger.debug("First poll - should request 10 messages"); assertThat(messagesToRequest).isEqualTo(10); - assertAvailablePermits(backPressureHandler, 0); - // No permits have been released yet - assertThroughputMode(backPressureHandler, "low"); + Message message = Message.builder().messageId(UUID.randomUUID().toString()).body("message") + .build(); + return (Collection) List.of(message); } - else if (hasAcquired9.compareAndSet(false, true)) { + else if (pollAttempt == 2) { // Second poll, should have 9 logger.debug("Second poll - should request 9 messages"); assertThat(messagesToRequest).isEqualTo(9); - assertAvailablePermitsLessThanOrEqualTo(backPressureHandler, 1); - // Has released 9 permits, should be TM HIGH - assertThroughputMode(backPressureHandler, "high"); processingLatch.countDown(); // Release processing now + return Collections. emptyList(); } else { - boolean thirdPoll = hasMadeThirdPoll.compareAndSet(false, true); // Third poll or later, should have 10 again - logger.debug("Third poll - should request 10 messages"); + logger.debug("Third (or later) poll - should request 10 messages"); assertThat(messagesToRequest).isEqualTo(10); - assertAvailablePermits(backPressureHandler, 0); - if (thirdPoll) { - // Hasn't yet returned a full batch, should be TM High - assertThroughputMode(backPressureHandler, "high"); - } - else { - // Has returned all permits in third poll - assertThroughputMode(backPressureHandler, "low"); - } - } - if (shouldReturnMessage) { - logger.debug("shouldReturnMessage, returning one message"); - return (Collection) List.of( - Message.builder().messageId(UUID.randomUUID().toString()).body("message").build()); + return Collections. emptyList(); } - logger.debug("should not return message, returning empty list"); - return Collections. emptyList(); } catch (Error e) { hasThrownError.set(true); @@ -228,27 +270,26 @@ else if (hasAcquired9.compareAndSet(false, true)) { return CompletableFuture.completedFuture(null).thenRun(processingCounter::countDown); }); source.setId(testName + " source"); - source.configure(SqsContainerOptions.builder().build()); + source.configure(options); source.setTaskExecutor(createTaskExecutor(testName)); source.setAcknowledgementProcessor(getNoOpsAcknowledgementProcessor()); - source.start(); - assertThat(doAwait(processingCounter)).isTrue(); - assertThat(doAwait(pollingCounter)).isTrue(); - source.stop(); - assertThat(hasThrownError.get()).isFalse(); + try { + source.start(); + assertThat(doAwait(processingCounter)).isTrue(); + assertThat(doAwait(pollingCounter)).isTrue(); + assertThat(hasThrownError.get()).isFalse(); + } + finally { + threadPool.shutdownNow(); + source.stop(); + } } @Test void shouldReleasePermitsOnConversionErrors() { String testName = "shouldReleasePermitsOnConversionErrors"; - SemaphoreBackPressureHandler backPressureHandler = SemaphoreBackPressureHandler.builder() - .acquireTimeout(Duration.ofMillis(150)).batchSize(10).totalPermits(10) - .throughputConfiguration(BackPressureMode.AUTO).build(); AtomicInteger convertedMessages = new AtomicInteger(0); - AtomicInteger messagesInSink = new AtomicInteger(0); - AtomicBoolean hasFailed = new AtomicBoolean(false); - var converter = new SqsMessagingMessageConverter() { @Override public org.springframework.messaging.Message toMessagingMessage(Message source, @@ -262,6 +303,16 @@ public org.springframework.messaging.Message toMessagingMessage(Message sourc } }; + SqsContainerOptions options = SqsContainerOptions.builder().maxMessagesPerPoll(10).maxConcurrentMessages(10) + .backPressureMode(BackPressureMode.ALWAYS_POLL_MAX_MESSAGES) + .maxDelayBetweenPolls(Duration.ofMillis(150)).messageConverter(converter) + .listenerShutdownTimeout(Duration.ZERO).build(); + BackPressureHandler backPressureHandler = BackPressureHandlerFactories + .adaptativeThroughputBackPressureHandler(options, Duration.ofMillis(100L)); + + AtomicInteger messagesInSink = new AtomicInteger(0); + AtomicBoolean hasFailed = new AtomicBoolean(false); + AbstractPollingMessageSource source = new AbstractPollingMessageSource<>() { @Override @@ -288,25 +339,37 @@ private Collection create10Messages() { return CompletableFuture.completedFuture(null); }); source.setId(testName + " source"); - source.configure(SqsContainerOptions.builder().messageConverter(converter).build()); + source.configure(options); source.setPollingEndpointName("shouldReleasePermitsOnConversionErrors-queue"); - source.setTaskExecutor(createTaskExecutor(testName)); + ThreadPoolTaskExecutor taskExecutor = createTaskExecutor(testName); + source.setTaskExecutor(taskExecutor); source.setAcknowledgementProcessor(getNoOpsAcknowledgementProcessor()); - source.start(); - Awaitility.waitAtMost(Duration.ofSeconds(10)).until(() -> convertedMessages.get() == 30); - assertThat(hasFailed).isFalse(); - assertThat(messagesInSink).hasValue(27); - source.stop(); + try { + source.start(); + Awaitility.waitAtMost(Duration.ofSeconds(10)).until(() -> convertedMessages.get() == 30); + assertThat(hasFailed).isFalse(); + assertThat(messagesInSink).hasValue(27); + } + finally { + source.stop(); + taskExecutor.shutdown(); + } } @Test void shouldBackOffIfPollingThrowsAnError() { - var testName = "shouldBackOffIfPollingThrowsAnError"; - var backPressureHandler = SemaphoreBackPressureHandler.builder().acquireTimeout(Duration.ofMillis(200)) - .batchSize(10).totalPermits(40).throughputConfiguration(BackPressureMode.ALWAYS_POLL_MAX_MESSAGES) - .build(); + var policy = mock(BackOffPolicy.class); + var backOffContext = mock(BackOffContext.class); + given(policy.start(null)).willReturn(backOffContext); + SqsContainerOptions options = SqsContainerOptions.builder().maxMessagesPerPoll(10).maxConcurrentMessages(40) + .backPressureMode(BackPressureMode.ALWAYS_POLL_MAX_MESSAGES) + .maxDelayBetweenPolls(Duration.ofMillis(200)).pollBackOffPolicy(policy) + .listenerShutdownTimeout(Duration.ZERO).build(); + BackPressureHandler backPressureHandler = BackPressureHandlerFactories + .adaptativeThroughputBackPressureHandler(options, Duration.ofMillis(100L)); + var currentPoll = new AtomicInteger(0); var waitThirdPollLatch = new CountDownLatch(4); @@ -333,24 +396,26 @@ else if (currentPoll.compareAndSet(2, 3)) { } }; - var policy = mock(BackOffPolicy.class); - var backOffContext = mock(BackOffContext.class); - given(policy.start(null)).willReturn(backOffContext); - source.setBackPressureHandler(backPressureHandler); source.setMessageSink((msgs, context) -> CompletableFuture.completedFuture(null)); source.setId(testName + " source"); - source.configure(SqsContainerOptions.builder().pollBackOffPolicy(policy).build()); + source.configure(options); - source.setTaskExecutor(createTaskExecutor(testName)); + ThreadPoolTaskExecutor taskExecutor = createTaskExecutor(testName); + source.setTaskExecutor(taskExecutor); source.setAcknowledgementProcessor(getNoOpsAcknowledgementProcessor()); - source.start(); - - doAwait(waitThirdPollLatch); + try { + source.start(); - then(policy).should().start(null); - then(policy).should(times(2)).backOff(backOffContext); + doAwait(waitThirdPollLatch); + then(policy).should().start(null); + then(policy).should(times(2)).backOff(backOffContext); + } + finally { + source.stop(); + taskExecutor.shutdown(); + } } private static boolean doAwait(CountDownLatch processingLatch) { @@ -363,24 +428,45 @@ private static boolean doAwait(CountDownLatch processingLatch) { } } - private void assertThroughputMode(SemaphoreBackPressureHandler backPressureHandler, String expectedThroughputMode) { - assertThat(ReflectionTestUtils.getField(backPressureHandler, "currentThroughputMode")) - .extracting(Object::toString).extracting(String::toLowerCase) + private void assertThroughputMode(BackPressureHandler backPressureHandler, String expectedThroughputMode) { + var bph = extractBackPressureHandler(backPressureHandler, ThroughputBackPressureHandler.class); + assertThat(getThroughputModeValue(bph, "currentThroughputMode")) .isEqualTo(expectedThroughputMode.toLowerCase()); } - private void assertAvailablePermits(SemaphoreBackPressureHandler backPressureHandler, int expectedPermits) { - assertThat(ReflectionTestUtils.getField(backPressureHandler, "semaphore")).asInstanceOf(type(Semaphore.class)) + private static String getThroughputModeValue(ThroughputBackPressureHandler bph, String targetThroughputMode) { + return ((AtomicReference) ReflectionTestUtils.getField(bph, targetThroughputMode)).get().toString() + .toLowerCase(Locale.ROOT); + } + + private void assertAvailablePermits(BackPressureHandler backPressureHandler, int expectedPermits) { + var bph = extractBackPressureHandler(backPressureHandler, ConcurrencyLimiterBlockingBackPressureHandler.class); + assertThat(ReflectionTestUtils.getField(bph, "semaphore")).asInstanceOf(type(Semaphore.class)) .extracting(Semaphore::availablePermits).isEqualTo(expectedPermits); } - private void assertAvailablePermitsLessThanOrEqualTo(SemaphoreBackPressureHandler backPressureHandler, + private void assertAvailablePermitsLessThanOrEqualTo(BackPressureHandler backPressureHandler, int maxExpectedPermits) { - assertThat(ReflectionTestUtils.getField(backPressureHandler, "semaphore")).asInstanceOf(type(Semaphore.class)) + var bph = extractBackPressureHandler(backPressureHandler, ConcurrencyLimiterBlockingBackPressureHandler.class); + assertThat(ReflectionTestUtils.getField(bph, "semaphore")).asInstanceOf(type(Semaphore.class)) .extracting(Semaphore::availablePermits).asInstanceOf(InstanceOfAssertFactories.INTEGER) .isLessThanOrEqualTo(maxExpectedPermits); } + private T extractBackPressureHandler(BackPressureHandler bph, Class type) { + if (type.isInstance(bph)) { + return type.cast(bph); + } + if (bph instanceof CompositeBackPressureHandler cbph) { + List backPressureHandlers = (List) ReflectionTestUtils + .getField(cbph, "backPressureHandlers"); + return extractBackPressureHandler( + backPressureHandlers.stream().filter(type::isInstance).map(type::cast).findFirst().orElseThrow(), + type); + } + throw new NoSuchElementException("%s not found in %s".formatted(type.getSimpleName(), bph)); + } + // Used to slow down tests while developing private void doSleep(int time) { try { @@ -392,7 +478,7 @@ private void doSleep(int time) { } } - protected TaskExecutor createTaskExecutor(String testName) { + protected ThreadPoolTaskExecutor createTaskExecutor(String testName) { ThreadPoolTaskExecutor executor = new ThreadPoolTaskExecutor(); int poolSize = 10; executor.setMaxPoolSize(poolSize); diff --git a/spring-cloud-aws-sqs/src/test/java/io/awspring/cloud/sqs/listener/source/SemaphoreBackPressureHandlerAbstractPollingMessageSourceTests.java b/spring-cloud-aws-sqs/src/test/java/io/awspring/cloud/sqs/listener/source/SemaphoreBackPressureHandlerAbstractPollingMessageSourceTests.java new file mode 100644 index 000000000..d5d3597c0 --- /dev/null +++ b/spring-cloud-aws-sqs/src/test/java/io/awspring/cloud/sqs/listener/source/SemaphoreBackPressureHandlerAbstractPollingMessageSourceTests.java @@ -0,0 +1,446 @@ +/* + * Copyright 2013-2023 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.awspring.cloud.sqs.listener.source; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.InstanceOfAssertFactories.type; +import static org.mockito.BDDMockito.given; +import static org.mockito.BDDMockito.then; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.times; + +import io.awspring.cloud.sqs.MessageExecutionThreadFactory; +import io.awspring.cloud.sqs.listener.*; +import io.awspring.cloud.sqs.listener.acknowledgement.AcknowledgementCallback; +import io.awspring.cloud.sqs.listener.acknowledgement.AcknowledgementProcessor; +import io.awspring.cloud.sqs.support.converter.MessageConversionContext; +import io.awspring.cloud.sqs.support.converter.SqsMessagingMessageConverter; +import java.time.Duration; +import java.util.Collection; +import java.util.Collections; +import java.util.List; +import java.util.UUID; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import java.util.concurrent.Semaphore; +import java.util.concurrent.ThreadFactory; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.stream.IntStream; +import org.assertj.core.api.InstanceOfAssertFactories; +import org.awaitility.Awaitility; +import org.junit.jupiter.api.Test; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import org.springframework.core.task.TaskExecutor; +import org.springframework.lang.Nullable; +import org.springframework.retry.backoff.BackOffContext; +import org.springframework.retry.backoff.BackOffPolicy; +import org.springframework.scheduling.concurrent.ThreadPoolTaskExecutor; +import org.springframework.test.util.ReflectionTestUtils; +import software.amazon.awssdk.services.sqs.model.Message; + +/** + * @author Tomaz Fernandes + * @author Loïc Rouchon + */ +class SemaphoreBackPressureHandlerAbstractPollingMessageSourceTests { + + private static final Logger logger = LoggerFactory.getLogger(AbstractPollingMessageSourceTests.class); + + @Test + void shouldAcquireAndReleaseFullPermits() { + String testName = "shouldAcquireAndReleaseFullPermits"; + BackPressureHandler backPressureHandler = BackPressureHandlerFactories + .semaphoreBackPressureHandler(SqsContainerOptions.builder().maxMessagesPerPoll(10) + .maxConcurrentMessages(10).backPressureMode(BackPressureMode.ALWAYS_POLL_MAX_MESSAGES) + .maxDelayBetweenPolls(Duration.ofMillis(200)).build()); + + ExecutorService threadPool = Executors.newCachedThreadPool(); + CountDownLatch pollingCounter = new CountDownLatch(3); + CountDownLatch processingCounter = new CountDownLatch(1); + + AbstractPollingMessageSource source = new AbstractPollingMessageSource<>() { + + private final AtomicBoolean hasReceived = new AtomicBoolean(false); + + private final AtomicBoolean hasMadeSecondPoll = new AtomicBoolean(false); + + @Override + protected CompletableFuture> doPollForMessages(int messagesToRequest) { + return CompletableFuture.supplyAsync(() -> { + try { + // Since BackPressureMode.ALWAYS_POLL_MAX_MESSAGES, should always be 10. + assertThat(messagesToRequest).isEqualTo(10); + assertAvailablePermits(backPressureHandler, 0); + boolean firstPoll = hasReceived.compareAndSet(false, true); + if (firstPoll) { + logger.debug("First poll"); + // No permits released yet, should be TM low + assertThroughputMode(backPressureHandler, "low"); + } + else if (hasMadeSecondPoll.compareAndSet(false, true)) { + logger.debug("Second poll"); + // Permits returned, should be high + assertThroughputMode(backPressureHandler, "high"); + } + else { + logger.debug("Third poll"); + // Already returned full permits, should be low + assertThroughputMode(backPressureHandler, "low"); + } + return firstPoll + ? (Collection) List.of(Message.builder() + .messageId(UUID.randomUUID().toString()).body("message").build()) + : Collections. emptyList(); + } + catch (Throwable t) { + logger.error("Error", t); + throw new RuntimeException(t); + } + }, threadPool).whenComplete((v, t) -> { + if (t == null) { + pollingCounter.countDown(); + } + }); + } + }; + + source.setBackPressureHandler(backPressureHandler); + source.setMessageSink((msgs, context) -> { + assertAvailablePermits(backPressureHandler, 9); + msgs.forEach(msg -> context.runBackPressureReleaseCallback()); + return CompletableFuture.runAsync(processingCounter::countDown); + }); + + source.setId(testName + " source"); + source.configure(SqsContainerOptions.builder().build()); + source.setTaskExecutor(createTaskExecutor(testName)); + source.setAcknowledgementProcessor(getNoOpsAcknowledgementProcessor()); + source.start(); + assertThat(doAwait(pollingCounter)).isTrue(); + assertThat(doAwait(processingCounter)).isTrue(); + } + + private static final AtomicInteger testCounter = new AtomicInteger(); + + @Test + void shouldAcquireAndReleasePartialPermits() { + String testName = "shouldAcquireAndReleasePartialPermits"; + BackPressureHandler backPressureHandler = BackPressureHandlerFactories.semaphoreBackPressureHandler( + SqsContainerOptions.builder().maxMessagesPerPoll(10).maxConcurrentMessages(10) + .backPressureMode(BackPressureMode.AUTO).maxDelayBetweenPolls(Duration.ofMillis(150)).build()); + + ExecutorService threadPool = Executors + .newCachedThreadPool(new MessageExecutionThreadFactory("test " + testCounter.incrementAndGet())); + CountDownLatch pollingCounter = new CountDownLatch(4); + CountDownLatch processingCounter = new CountDownLatch(1); + CountDownLatch processingLatch = new CountDownLatch(1); + AtomicBoolean hasThrownError = new AtomicBoolean(false); + + AbstractPollingMessageSource source = new AbstractPollingMessageSource<>() { + + private final AtomicBoolean hasReceived = new AtomicBoolean(false); + + private final AtomicBoolean hasAcquired9 = new AtomicBoolean(false); + + private final AtomicBoolean hasMadeThirdPoll = new AtomicBoolean(false); + + @Override + protected CompletableFuture> doPollForMessages(int messagesToRequest) { + return CompletableFuture.supplyAsync(() -> { + try { + // Give it some time between returning empty and polling again + // doSleep(100); + + // Will only be true the first time it sets hasReceived to true + boolean shouldReturnMessage = hasReceived.compareAndSet(false, true); + if (shouldReturnMessage) { + // First poll, should have 10 + logger.debug("First poll - should request 10 messages"); + assertThat(messagesToRequest).isEqualTo(10); + assertAvailablePermits(backPressureHandler, 0); + // No permits have been released yet + assertThroughputMode(backPressureHandler, "low"); + } + else if (hasAcquired9.compareAndSet(false, true)) { + // Second poll, should have 9 + logger.debug("Second poll - should request 9 messages"); + assertThat(messagesToRequest).isEqualTo(9); + assertAvailablePermitsLessThanOrEqualTo(backPressureHandler, 1); + // Has released 9 permits, should be TM HIGH + assertThroughputMode(backPressureHandler, "high"); + processingLatch.countDown(); // Release processing now + } + else { + boolean thirdPoll = hasMadeThirdPoll.compareAndSet(false, true); + // Third poll or later, should have 10 again + logger.debug("Third poll - should request 10 messages"); + assertThat(messagesToRequest).isEqualTo(10); + assertAvailablePermits(backPressureHandler, 0); + if (thirdPoll) { + // Hasn't yet returned a full batch, should be TM High + assertThroughputMode(backPressureHandler, "high"); + } + else { + // Has returned all permits in third poll + assertThroughputMode(backPressureHandler, "low"); + } + } + if (shouldReturnMessage) { + logger.debug("shouldReturnMessage, returning one message"); + return (Collection) List.of( + Message.builder().messageId(UUID.randomUUID().toString()).body("message").build()); + } + logger.debug("should not return message, returning empty list"); + return Collections. emptyList(); + } + catch (Error e) { + hasThrownError.set(true); + throw new RuntimeException("Error polling for messages", e); + } + }, threadPool).whenComplete((v, t) -> pollingCounter.countDown()); + } + }; + + source.setBackPressureHandler(backPressureHandler); + source.setMessageSink((msgs, context) -> { + logger.debug("Processing {} messages", msgs.size()); + assertAvailablePermits(backPressureHandler, 9); + assertThat(doAwait(processingLatch)).isTrue(); + logger.debug("Finished processing {} messages", msgs.size()); + msgs.forEach(msg -> context.runBackPressureReleaseCallback()); + return CompletableFuture.completedFuture(null).thenRun(processingCounter::countDown); + }); + source.setId(testName + " source"); + source.configure(SqsContainerOptions.builder().build()); + source.setTaskExecutor(createTaskExecutor(testName)); + source.setAcknowledgementProcessor(getNoOpsAcknowledgementProcessor()); + source.start(); + assertThat(doAwait(processingCounter)).isTrue(); + assertThat(doAwait(pollingCounter)).isTrue(); + source.stop(); + assertThat(hasThrownError.get()).isFalse(); + } + + @Test + void shouldReleasePermitsOnConversionErrors() { + String testName = "shouldReleasePermitsOnConversionErrors"; + BackPressureHandler backPressureHandler = BackPressureHandlerFactories.semaphoreBackPressureHandler( + SqsContainerOptions.builder().maxMessagesPerPoll(10).maxConcurrentMessages(10) + .backPressureMode(BackPressureMode.AUTO).maxDelayBetweenPolls(Duration.ofMillis(150)).build()); + + AtomicInteger convertedMessages = new AtomicInteger(0); + AtomicInteger messagesInSink = new AtomicInteger(0); + AtomicBoolean hasFailed = new AtomicBoolean(false); + + var converter = new SqsMessagingMessageConverter() { + @Override + public org.springframework.messaging.Message toMessagingMessage(Message source, + @Nullable MessageConversionContext context) { + var converted = convertedMessages.incrementAndGet(); + logger.trace("Messages converted: {}", converted); + if (converted % 9 == 0) { + throw new RuntimeException("Expected error"); + } + return super.toMessagingMessage(source, context); + } + }; + + AbstractPollingMessageSource source = new AbstractPollingMessageSource<>() { + + @Override + protected CompletableFuture> doPollForMessages(int messagesToRequest) { + if (messagesToRequest != 10) { + logger.error("Expected 10 messages to requesst, received {}", messagesToRequest); + hasFailed.set(true); + } + return convertedMessages.get() < 30 ? CompletableFuture.completedFuture(create10Messages()) + : CompletableFuture.completedFuture(List.of()); + } + + private Collection create10Messages() { + return IntStream.range(0, 10).mapToObj( + index -> Message.builder().messageId(UUID.randomUUID().toString()).body("test-message").build()) + .toList(); + } + }; + + source.setBackPressureHandler(backPressureHandler); + source.setMessageSink((msgs, context) -> { + msgs.forEach(message -> messagesInSink.incrementAndGet()); + msgs.forEach(msg -> context.runBackPressureReleaseCallback()); + return CompletableFuture.completedFuture(null); + }); + source.setId(testName + " source"); + source.configure(SqsContainerOptions.builder().messageConverter(converter).build()); + source.setPollingEndpointName("shouldReleasePermitsOnConversionErrors-queue"); + source.setTaskExecutor(createTaskExecutor(testName)); + source.setAcknowledgementProcessor(getNoOpsAcknowledgementProcessor()); + source.start(); + Awaitility.waitAtMost(Duration.ofSeconds(10)).until(() -> convertedMessages.get() == 30); + assertThat(hasFailed).isFalse(); + assertThat(messagesInSink).hasValue(27); + source.stop(); + } + + @Test + void shouldBackOffIfPollingThrowsAnError() { + var testName = "shouldBackOffIfPollingThrowsAnError"; + BackPressureHandler backPressureHandler = BackPressureHandlerFactories + .semaphoreBackPressureHandler(SqsContainerOptions.builder().maxMessagesPerPoll(10) + .maxConcurrentMessages(40).backPressureMode(BackPressureMode.ALWAYS_POLL_MAX_MESSAGES) + .maxDelayBetweenPolls(Duration.ofMillis(200)).build()); + + var currentPoll = new AtomicInteger(0); + var waitThirdPollLatch = new CountDownLatch(4); + + AbstractPollingMessageSource source = new AbstractPollingMessageSource<>() { + @Override + protected CompletableFuture> doPollForMessages(int messagesToRequest) { + waitThirdPollLatch.countDown(); + if (currentPoll.compareAndSet(0, 1)) { + logger.debug("First poll - returning empty list"); + return CompletableFuture.completedFuture(List.of()); + } + else if (currentPoll.compareAndSet(1, 2)) { + logger.debug("Second poll - returning error"); + return CompletableFuture.failedFuture(new RuntimeException("Expected exception on second poll")); + } + else if (currentPoll.compareAndSet(2, 3)) { + logger.debug("Third poll - returning error"); + return CompletableFuture.failedFuture(new RuntimeException("Expected exception on third poll")); + } + else { + logger.debug("Fourth poll - returning empty list"); + return CompletableFuture.completedFuture(List.of()); + } + } + }; + + var policy = mock(BackOffPolicy.class); + var backOffContext = mock(BackOffContext.class); + given(policy.start(null)).willReturn(backOffContext); + + source.setBackPressureHandler(backPressureHandler); + source.setMessageSink((msgs, context) -> CompletableFuture.completedFuture(null)); + source.setId(testName + " source"); + source.configure(SqsContainerOptions.builder().pollBackOffPolicy(policy).build()); + + source.setTaskExecutor(createTaskExecutor(testName)); + source.setAcknowledgementProcessor(getNoOpsAcknowledgementProcessor()); + source.start(); + + doAwait(waitThirdPollLatch); + + then(policy).should().start(null); + then(policy).should(times(2)).backOff(backOffContext); + + } + + private static boolean doAwait(CountDownLatch processingLatch) { + try { + return processingLatch.await(4, TimeUnit.SECONDS); + } + catch (InterruptedException e) { + Thread.currentThread().interrupt(); + throw new RuntimeException("Interrupted while waiting for latch", e); + } + } + + private void assertThroughputMode(BackPressureHandler backPressureHandler, String expectedThroughputMode) { + assertThat(ReflectionTestUtils.getField(backPressureHandler, "currentThroughputMode")) + .extracting(Object::toString).extracting(String::toLowerCase) + .isEqualTo(expectedThroughputMode.toLowerCase()); + } + + private void assertAvailablePermits(BackPressureHandler backPressureHandler, int expectedPermits) { + assertThat(ReflectionTestUtils.getField(backPressureHandler, "semaphore")).asInstanceOf(type(Semaphore.class)) + .extracting(Semaphore::availablePermits).isEqualTo(expectedPermits); + } + + private void assertAvailablePermitsLessThanOrEqualTo(BackPressureHandler backPressureHandler, + int maxExpectedPermits) { + assertThat(ReflectionTestUtils.getField(backPressureHandler, "semaphore")).asInstanceOf(type(Semaphore.class)) + .extracting(Semaphore::availablePermits).asInstanceOf(InstanceOfAssertFactories.INTEGER) + .isLessThanOrEqualTo(maxExpectedPermits); + } + + // Used to slow down tests while developing + private void doSleep(int time) { + try { + Thread.sleep(time); + } + catch (InterruptedException e) { + Thread.currentThread().interrupt(); + throw new RuntimeException(e); + } + } + + protected TaskExecutor createTaskExecutor(String testName) { + ThreadPoolTaskExecutor executor = new ThreadPoolTaskExecutor(); + int poolSize = 10; + executor.setMaxPoolSize(poolSize); + executor.setCorePoolSize(10); + executor.setQueueCapacity(poolSize); + executor.setAllowCoreThreadTimeOut(true); + executor.setThreadFactory(createThreadFactory(testName)); + executor.afterPropertiesSet(); + return executor; + } + + protected ThreadFactory createThreadFactory(String testName) { + MessageExecutionThreadFactory threadFactory = new MessageExecutionThreadFactory(); + threadFactory.setThreadNamePrefix(testName + "-thread" + "-"); + return threadFactory; + } + + private AcknowledgementProcessor getNoOpsAcknowledgementProcessor() { + return new AcknowledgementProcessor<>() { + @Override + public AcknowledgementCallback getAcknowledgementCallback() { + return new AcknowledgementCallback<>() { + }; + } + + @Override + public void setId(String id) { + } + + @Override + public String getId() { + return "test processor"; + } + + @Override + public void start() { + } + + @Override + public void stop() { + } + + @Override + public boolean isRunning() { + return false; + } + }; + } + +}