diff --git a/.changes/next-release/bugfix-AWSSDKforJavav2-03a897a.json b/.changes/next-release/bugfix-AWSSDKforJavav2-03a897a.json new file mode 100644 index 000000000000..f98e0facec32 --- /dev/null +++ b/.changes/next-release/bugfix-AWSSDKforJavav2-03a897a.json @@ -0,0 +1,6 @@ +{ + "type": "bugfix", + "category": "AWS SDK for Java v2", + "contributor": "", + "description": "Fix an issue where `StackOverflowError` can occur when iterating over large pages from an async paginator. This can manifest as the publisher hanging/never reaching the end of the stream. Fixes [#6411](https://github.com/aws/aws-sdk-java-v2/issues/6411)." +} diff --git a/core/sdk-core/src/main/java/software/amazon/awssdk/core/internal/pagination/async/ItemsSubscription.java b/core/sdk-core/src/main/java/software/amazon/awssdk/core/internal/pagination/async/ItemsSubscription.java index 03ed8a4d1ea1..ec87fd6fc426 100644 --- a/core/sdk-core/src/main/java/software/amazon/awssdk/core/internal/pagination/async/ItemsSubscription.java +++ b/core/sdk-core/src/main/java/software/amazon/awssdk/core/internal/pagination/async/ItemsSubscription.java @@ -16,6 +16,8 @@ package software.amazon.awssdk.core.internal.pagination.async; import java.util.Iterator; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.atomic.AtomicBoolean; import java.util.function.Function; import org.reactivestreams.Subscription; import software.amazon.awssdk.annotations.SdkInternalApi; @@ -32,6 +34,8 @@ public final class ItemsSubscription extends PaginationSubscription { private final Function> getIteratorFunction; private volatile Iterator singlePageItemsIterator; + private final AtomicBoolean handlingRequests = new AtomicBoolean(); + private volatile boolean awaitingNewPage = false; private ItemsSubscription(BuilderImpl builder) { super(builder); @@ -47,61 +51,83 @@ public static Builder builder() { @Override protected void handleRequests() { - if (!hasMoreItems() && !hasNextPage()) { - completeSubscription(); + // Prevent recursion if we already invoked handleRequests + if (!handlingRequests.compareAndSet(false, true)) { return; } - synchronized (this) { - if (outstandingRequests.get() <= 0) { - stopTask(); - return; - } - } - - if (!isTerminated()) { - /** - * Current page is null only the first time the method is called. - * Once initialized, current page will never be null - */ - if (currentPage == null || (!hasMoreItems() && hasNextPage())) { - fetchNextPage(); - - } else if (hasMoreItems()) { - sendNextElement(); - - // All valid cases are covered above. Throw an exception if any combination is missed - } else { - throw new IllegalStateException("Execution should have not reached here"); + try { + while (true) { + if (!hasMoreItems() && !hasNextPage()) { + completeSubscription(); + return; + } + + synchronized (this) { + if (outstandingRequests.get() <= 0) { + stopTask(); + return; + } + } + + if (isTerminated()) { + return; + } + + if (shouldFetchNextPage()) { + awaitingNewPage = true; + fetchNextPage().whenComplete((r, e) -> { + if (e == null) { + awaitingNewPage = false; + handleRequests(); + } + // note: signaling onError if e != null is taken care of by fetchNextPage(). No need to do it here. + }); + } else if (hasMoreItems()) { + synchronized (this) { + if (outstandingRequests.get() <= 0) { + continue; + } + + subscriber.onNext(singlePageItemsIterator.next()); + outstandingRequests.getAndDecrement(); + } + } else { + // Outstanding demand AND no items in current page AND waiting for next page. Just return for now, and + // we'll handle demand when the new page arrives. + return; + } } + } finally { + handlingRequests.set(false); } } - private void fetchNextPage() { - nextPageFetcher.nextPage(currentPage) - .whenComplete(((response, error) -> { - if (response != null) { - currentPage = response; - singlePageItemsIterator = getIteratorFunction.apply(response); - sendNextElement(); - } - if (error != null) { - subscriber.onError(error); - cleanup(); - } - })); + private CompletableFuture fetchNextPage() { + return nextPageFetcher.nextPage(currentPage) + .whenComplete((response, error) -> { + if (response != null) { + currentPage = response; + singlePageItemsIterator = getIteratorFunction.apply(response); + } else if (error != null) { + subscriber.onError(error); + cleanup(); + } + }); } - /** - * Calls onNext and calls the recursive method. - */ - private void sendNextElement() { - if (singlePageItemsIterator.hasNext()) { - subscriber.onNext(singlePageItemsIterator.next()); - outstandingRequests.getAndDecrement(); + // Conditions when to fetch the next page: + // - We're NOT already waiting for a new page AND either + // - We still need to fetch the first page OR + // - We've exhausted the current page AND there is a next page available + private boolean shouldFetchNextPage() { + if (awaitingNewPage) { + return false; } - handleRequests(); + // Current page is null only the first time the method is called. + // Once initialized, current page will never be null. + return currentPage == null || (!hasMoreItems() && hasNextPage()); } private boolean hasMoreItems() { diff --git a/core/sdk-core/src/test/java/software/amazon/awssdk/core/pagination/async/PaginatedItemsPublisherTest.java b/core/sdk-core/src/test/java/software/amazon/awssdk/core/pagination/async/PaginatedItemsPublisherTest.java new file mode 100644 index 000000000000..fa818ba9f6b1 --- /dev/null +++ b/core/sdk-core/src/test/java/software/amazon/awssdk/core/pagination/async/PaginatedItemsPublisherTest.java @@ -0,0 +1,125 @@ +package software.amazon.awssdk.core.pagination.async; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +import java.util.Iterator; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicLong; +import java.util.function.Function; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.Timeout; +import software.amazon.awssdk.core.SdkResponse; + +public class PaginatedItemsPublisherTest { + @Test + @Timeout(value = 1, unit = TimeUnit.MINUTES) + void subscribe_largePage_doesNotFail() throws Exception { + int nItems = 100_000; + + Function> iteratorFn = resp -> + new Iterator() { + private int count = 0; + + @Override + public boolean hasNext() { + return count < nItems; + } + + @Override + public String next() { + ++count; + return "item"; + } + }; + + AsyncPageFetcher pageFetcher = new AsyncPageFetcher() { + @Override + public boolean hasNextPage(SdkResponse oldPage) { + return false; + } + + @Override + public CompletableFuture nextPage(SdkResponse oldPage) { + return CompletableFuture.completedFuture(mock(SdkResponse.class)); + } + }; + + PaginatedItemsPublisher publisher = PaginatedItemsPublisher.builder() + .isLastPage(false) + .nextPageFetcher(pageFetcher) + .iteratorFunction(iteratorFn) + .build(); + + AtomicLong counter = new AtomicLong(); + publisher.subscribe(i -> counter.incrementAndGet()).join(); + assertThat(counter.get()).isEqualTo(nItems); + } + + @Test + @Timeout(value = 1, unit = TimeUnit.MINUTES) + void subscribe_longStream_doesNotFail() throws Exception { + int nPages = 100_000; + int nItemsPerPage = 1; + Function> iteratorFn = resp -> + new Iterator() { + private int count = 0; + + @Override + public boolean hasNext() { + return count < nItemsPerPage; + } + + @Override + public String next() { + ++count; + return "item"; + } + }; + + AsyncPageFetcher pageFetcher = new AsyncPageFetcher() { + @Override + public boolean hasNextPage(TestResponse oldPage) { + return oldPage.pageNumber() < nPages - 1; + } + + @Override + public CompletableFuture nextPage(TestResponse oldPage) { + int nextPageNum; + if (oldPage == null) { + nextPageNum = 0; + } else { + nextPageNum = oldPage.pageNumber() + 1; + } + return CompletableFuture.completedFuture(createResponse(nextPageNum)); + } + }; + + PaginatedItemsPublisher publisher = PaginatedItemsPublisher.builder() + .isLastPage(false) + .nextPageFetcher(pageFetcher) + .iteratorFunction(iteratorFn) + .build(); + + AtomicLong counter = new AtomicLong(); + publisher.subscribe(i -> counter.incrementAndGet()).join(); + assertThat(counter.get()).isEqualTo(nPages * nItemsPerPage); + } + + private abstract class TestResponse extends SdkResponse { + + protected TestResponse(Builder builder) { + super(builder); + } + + abstract Integer pageNumber(); + } + + private static TestResponse createResponse(Integer pageNumber) { + TestResponse mock = mock(TestResponse.class); + when(mock.pageNumber()).thenReturn(pageNumber); + return mock; + } +} diff --git a/services/dynamodb/src/test/java/software/amazon/awssdk/services/dynamodb/PaginatorTest.java b/services/dynamodb/src/test/java/software/amazon/awssdk/services/dynamodb/PaginatorTest.java new file mode 100644 index 000000000000..36ab6edfacbd --- /dev/null +++ b/services/dynamodb/src/test/java/software/amazon/awssdk/services/dynamodb/PaginatorTest.java @@ -0,0 +1,123 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * A copy of the License is located at + * + * http://aws.amazon.com/apache2.0 + * + * or in the "license" file accompanying this file. This file is distributed + * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either + * express or implied. See the License for the specific language governing + * permissions and limitations under the License. + */ + +package software.amazon.awssdk.services.dynamodb; + +import static com.github.tomakehurst.wiremock.client.WireMock.aResponse; +import static com.github.tomakehurst.wiremock.client.WireMock.anyUrl; +import static org.assertj.core.api.Assertions.assertThat; + +import com.fasterxml.jackson.databind.JsonNode; +import com.fasterxml.jackson.databind.ObjectMapper; +import com.fasterxml.jackson.databind.node.ArrayNode; +import com.fasterxml.jackson.databind.node.ObjectNode; +import com.github.tomakehurst.wiremock.WireMockServer; +import com.github.tomakehurst.wiremock.client.WireMock; +import com.github.tomakehurst.wiremock.core.WireMockConfiguration; +import java.net.URI; +import java.util.concurrent.atomic.AtomicLong; +import org.junit.jupiter.api.AfterAll; +import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.api.Test; +import software.amazon.awssdk.auth.credentials.AnonymousCredentialsProvider; +import software.amazon.awssdk.regions.Region; +import software.amazon.awssdk.services.dynamodb.model.ScanRequest; +import software.amazon.awssdk.services.dynamodb.paginators.ScanIterable; +import software.amazon.awssdk.services.dynamodb.paginators.ScanPublisher; + +public class PaginatorTest { + private static final WireMockServer wireMock = new WireMockServer(WireMockConfiguration.wireMockConfig().dynamicPort()); + private static final ObjectMapper mapper = new ObjectMapper(); + private static DynamoDbAsyncClient ddbAsync; + private static DynamoDbClient ddb; + + @BeforeAll + static void setup() { + wireMock.start(); + + ddbAsync = DynamoDbAsyncClient.builder() + .region(Region.US_WEST_2) + .endpointOverride(URI.create("http://localhost:" + wireMock.port())) + .credentialsProvider(AnonymousCredentialsProvider.create()) + .build(); + + ddb = DynamoDbClient.builder() + .region(Region.US_WEST_2) + .endpointOverride(URI.create("http://localhost:" + wireMock.port())) + .credentialsProvider(AnonymousCredentialsProvider.create()) + .build(); + } + + @AfterAll + static void teardown() { + ddb.close(); + ddbAsync.close(); + wireMock.stop(); + } + + @Test + void scanPaginator_async_largePage_subscribe_succeeds() { + int nItems = 10_000; + wireMock.stubFor(WireMock.any(anyUrl()) + .willReturn(aResponse() + .withStatus(200) + .withJsonBody(createScanResponse(nItems)))); + + ScanPublisher publisher = ddbAsync.scanPaginator(ScanRequest.builder().build()); + + AtomicLong counter = new AtomicLong(); + publisher.items().subscribe(item -> counter.incrementAndGet()).join(); + assertThat(counter.get()).isEqualTo(nItems); + } + + @Test + void scanPaginator_sync_largePage_subscribe_succeeds() { + int nItems = 10_000; + wireMock.stubFor(WireMock.any(anyUrl()) + .willReturn(aResponse() + .withStatus(200) + .withJsonBody(createScanResponse(nItems)))); + + ScanIterable iterable = ddb.scanPaginator(ScanRequest.builder().build()); + + AtomicLong counter = new AtomicLong(); + iterable.items().forEach(item -> counter.incrementAndGet()); + assertThat(counter.get()).isEqualTo(nItems); + } + + private static JsonNode createScanResponse(int nItems) { + ObjectNode resp = mapper.createObjectNode(); + resp.set("Count", mapper.valueToTree(nItems)); + + ArrayNode items = mapper.createArrayNode(); + + for (int i = 0; i < nItems; i++) { + // { + // "id": { + // "N": 1 + // } + // } + ObjectNode item = mapper.createObjectNode(); + ObjectNode idNode = mapper.createObjectNode(); + idNode.put("N", mapper.valueToTree(i)); + item.set("id", idNode); + items.add(item); + } + + resp.set("Items", items); + + return resp; + } +}