diff --git a/.changes/next-release/bugfix-AmazonS3-263fed5.json b/.changes/next-release/bugfix-AmazonS3-263fed5.json new file mode 100644 index 000000000000..e3e053c54ecc --- /dev/null +++ b/.changes/next-release/bugfix-AmazonS3-263fed5.json @@ -0,0 +1,6 @@ +{ + "type": "bugfix", + "category": "Amazon S3", + "contributor": "", + "description": "Fix a bug in the Java based multipart client where retryable errors from getObject may not be retried correctly." +} diff --git a/core/sdk-core/src/main/java/software/amazon/awssdk/core/internal/async/SplittingTransformer.java b/core/sdk-core/src/main/java/software/amazon/awssdk/core/internal/async/SplittingTransformer.java index 2c76bbc1d88f..753b21e43aab 100644 --- a/core/sdk-core/src/main/java/software/amazon/awssdk/core/internal/async/SplittingTransformer.java +++ b/core/sdk-core/src/main/java/software/amazon/awssdk/core/internal/async/SplittingTransformer.java @@ -18,6 +18,7 @@ import java.nio.ByteBuffer; import java.util.concurrent.CompletableFuture; import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicInteger; import java.util.concurrent.atomic.AtomicLong; import org.reactivestreams.Subscriber; import org.reactivestreams.Subscription; @@ -54,20 +55,10 @@ public class SplittingTransformer implements SdkPublisher upstreamResponseTransformer; - /** - * Set to true once {@code .prepare()} is called on the upstreamResponseTransformer - */ - private final AtomicBoolean preparedCalled = new AtomicBoolean(false); - - /** - * Set to true once {@code .onResponse()} is called on the upstreamResponseTransformer - */ - private final AtomicBoolean onResponseCalled = new AtomicBoolean(false); - /** * Set to true once {@code .onStream()} is called on the upstreamResponseTransformer */ - private final AtomicBoolean onStreamCalled = new AtomicBoolean(false); + private boolean onStreamCalled; /** * Set to true once {@code .cancel()} is called in the subscription of the downstream subscriber, or if the @@ -111,6 +102,17 @@ public class SplittingTransformer implements SdkPublisher upstreamFuture; + + /** + * Tracks the part number. Errors will only be retried for the first part. + */ + private final AtomicInteger partNumber = new AtomicInteger(0); + private SplittingTransformer(AsyncResponseTransformer upstreamResponseTransformer, Long maximumBufferSizeInBytes, CompletableFuture resultFuture) { @@ -198,7 +200,7 @@ private boolean doEmit() { } if (outstandingDemand.get() > 0) { demand = outstandingDemand.decrementAndGet(); - downstreamSubscriber.onNext(new IndividualTransformer()); + downstreamSubscriber.onNext(new IndividualTransformer(partNumber.incrementAndGet())); } } return false; @@ -216,7 +218,7 @@ private void handleSubscriptionCancel() { log.trace(() -> "downstreamSubscriber already null, skipping downstreamSubscriber.onComplete()"); return; } - if (!onStreamCalled.get()) { + if (!onStreamCalled) { // we never subscribe publisherToUpstream to the upstream, it would not complete downstreamSubscriber = null; return; @@ -230,6 +232,7 @@ private void handleSubscriptionCancel() { } else { log.trace(() -> "calling downstreamSubscriber.onComplete()"); downstreamSubscriber.onComplete(); + CompletableFutureUtils.forwardResultTo(upstreamFuture, resultFuture); } downstreamSubscriber = null; }); @@ -259,28 +262,27 @@ private void handleFutureCancel(Throwable e) { * body publisher. */ private class IndividualTransformer implements AsyncResponseTransformer { + private final int partNumber; private ResponseT response; private CompletableFuture individualFuture; + IndividualTransformer(int partNumber) { + this.partNumber = partNumber; + } + @Override public CompletableFuture prepare() { this.individualFuture = new CompletableFuture<>(); - if (preparedCalled.compareAndSet(false, true)) { + + if (partNumber == 1) { if (isCancelled.get()) { return individualFuture; } log.trace(() -> "calling prepare on the upstream transformer"); - CompletableFuture upstreamFuture = upstreamResponseTransformer.prepare(); - if (!resultFuture.isDone()) { - CompletableFutureUtils.forwardResultTo(upstreamFuture, resultFuture); - } + upstreamFuture = upstreamResponseTransformer.prepare(); + } - resultFuture.whenComplete((r, e) -> { - if (e == null) { - return; - } - individualFuture.completeExceptionally(e); - }); + individualFuture.whenComplete((r, e) -> { if (isCancelled.get()) { handleSubscriptionCancel(); @@ -291,7 +293,7 @@ public CompletableFuture prepare() { @Override public void onResponse(ResponseT response) { - if (onResponseCalled.compareAndSet(false, true)) { + if (partNumber == 1) { log.trace(() -> "calling onResponse on the upstream transformer"); upstreamResponseTransformer.onResponse(response); } @@ -304,7 +306,9 @@ public void onStream(SdkPublisher publisher) { return; } synchronized (cancelLock) { - if (onStreamCalled.compareAndSet(false, true)) { + if (partNumber == 1) { + CompletableFutureUtils.forwardResultTo(upstreamFuture, resultFuture); + onStreamCalled = true; log.trace(() -> "calling onStream on the upstream transformer"); upstreamResponseTransformer.onStream(upstreamSubscriber -> publisherToUpstream.subscribe( DelegatingBufferingSubscriber.builder() @@ -319,9 +323,19 @@ public void onStream(SdkPublisher publisher) { @Override public void exceptionOccurred(Throwable error) { - publisherToUpstream.error(error); - log.trace(() -> "calling exceptionOccurred on the upstream transformer"); - upstreamResponseTransformer.exceptionOccurred(error); + if (partNumber == 1) { + log.trace(() -> "calling exceptionOccurred on the upstream transformer"); + upstreamResponseTransformer.exceptionOccurred(error); + } + + // Invoking publisherToUpstream.error() essentially fails the request immediately. We should only call this if + // 1) The part number is greater than 1, since we want to retry errors on the first part OR 2) onStream() has + // already been invoked and data has started to be written + synchronized (cancelLock) { + if (partNumber > 1 || onStreamCalled) { + publisherToUpstream.error(error); + } + } } } diff --git a/core/sdk-core/src/test/java/software/amazon/awssdk/core/internal/async/IndividualPartSubscriberTckTest.java b/core/sdk-core/src/test/java/software/amazon/awssdk/core/internal/async/IndividualPartSubscriberTckTest.java index a72a3ab7aa1f..f5bcba85d70a 100644 --- a/core/sdk-core/src/test/java/software/amazon/awssdk/core/internal/async/IndividualPartSubscriberTckTest.java +++ b/core/sdk-core/src/test/java/software/amazon/awssdk/core/internal/async/IndividualPartSubscriberTckTest.java @@ -18,7 +18,6 @@ import java.nio.ByteBuffer; import java.util.concurrent.CompletableFuture; -import java.util.concurrent.atomic.AtomicInteger; import org.reactivestreams.Subscriber; import org.reactivestreams.Subscription; import org.reactivestreams.tck.SubscriberWhiteboxVerification; diff --git a/services/s3/pom.xml b/services/s3/pom.xml index d588c1985e65..57bf2245a8db 100644 --- a/services/s3/pom.xml +++ b/services/s3/pom.xml @@ -129,6 +129,12 @@ ${awsjavasdk.version} + + software.amazon.awssdk + retries + test + ${awsjavasdk.version} + commons-io commons-io diff --git a/services/s3/src/main/java/software/amazon/awssdk/services/s3/internal/multipart/DownloadObjectHelper.java b/services/s3/src/main/java/software/amazon/awssdk/services/s3/internal/multipart/DownloadObjectHelper.java index 2d6fadc5f505..57110b38e811 100644 --- a/services/s3/src/main/java/software/amazon/awssdk/services/s3/internal/multipart/DownloadObjectHelper.java +++ b/services/s3/src/main/java/software/amazon/awssdk/services/s3/internal/multipart/DownloadObjectHelper.java @@ -23,6 +23,7 @@ import software.amazon.awssdk.services.s3.S3AsyncClient; import software.amazon.awssdk.services.s3.model.GetObjectRequest; import software.amazon.awssdk.services.s3.model.GetObjectResponse; +import software.amazon.awssdk.utils.CompletableFutureUtils; import software.amazon.awssdk.utils.Logger; @SdkInternalApi @@ -49,6 +50,7 @@ public CompletableFuture downloadObject( .build()); MultipartDownloaderSubscriber subscriber = subscriber(getObjectRequest); split.publisher().subscribe(subscriber); + CompletableFutureUtils.forwardExceptionTo(subscriber.future(), split.resultFuture()); return split.resultFuture(); } diff --git a/services/s3/src/main/java/software/amazon/awssdk/services/s3/internal/multipart/MultipartDownloaderSubscriber.java b/services/s3/src/main/java/software/amazon/awssdk/services/s3/internal/multipart/MultipartDownloaderSubscriber.java index d369d0caff02..c84f55935d57 100644 --- a/services/s3/src/main/java/software/amazon/awssdk/services/s3/internal/multipart/MultipartDownloaderSubscriber.java +++ b/services/s3/src/main/java/software/amazon/awssdk/services/s3/internal/multipart/MultipartDownloaderSubscriber.java @@ -15,7 +15,9 @@ package software.amazon.awssdk.services.s3.internal.multipart; +import java.util.Queue; import java.util.concurrent.CompletableFuture; +import java.util.concurrent.ConcurrentLinkedQueue; import java.util.concurrent.atomic.AtomicInteger; import org.reactivestreams.Subscriber; import org.reactivestreams.Subscription; @@ -79,6 +81,8 @@ public class MultipartDownloaderSubscriber implements Subscriber> getObjectFutures = new ConcurrentLinkedQueue<>(); + public MultipartDownloaderSubscriber(S3AsyncClient s3, GetObjectRequest getObjectRequest) { this(s3, getObjectRequest, 0); } @@ -119,6 +123,7 @@ public void onNext(AsyncResponseTransformer "Sending GetObjectRequest for next part with partNumber=" + nextPartToGet); CompletableFuture getObjectFuture = s3.getObject(actualRequest, asyncResponseTransformer); + getObjectFutures.add(getObjectFuture); getObjectFuture.whenComplete((response, error) -> { if (error != null) { log.debug(() -> "Error encountered during GetObjectRequest with partNumber=" + nextPartToGet); @@ -166,6 +171,10 @@ private void requestMoreIfNeeded(GetObjectResponse response) { @Override public void onError(Throwable t) { + CompletableFuture partFuture; + while ((partFuture = getObjectFutures.poll()) != null) { + partFuture.cancel(true); + } future.completeExceptionally(t); } diff --git a/services/s3/src/main/java/software/amazon/awssdk/services/s3/internal/multipart/MultipartS3AsyncClient.java b/services/s3/src/main/java/software/amazon/awssdk/services/s3/internal/multipart/MultipartS3AsyncClient.java index 6142db7772f7..d638ca2f4b6a 100644 --- a/services/s3/src/main/java/software/amazon/awssdk/services/s3/internal/multipart/MultipartS3AsyncClient.java +++ b/services/s3/src/main/java/software/amazon/awssdk/services/s3/internal/multipart/MultipartS3AsyncClient.java @@ -38,9 +38,8 @@ import software.amazon.awssdk.utils.Validate; /** - * An {@link S3AsyncClient} that automatically converts PUT, COPY requests to their respective multipart call. CRC32 will be - * enabled for the PUT and COPY requests, unless the the checksum is specified or checksum validation is disabled. - * Note: GET is not yet supported. + * An {@link S3AsyncClient} that automatically converts PUT, COPY, and GET requests to their respective multipart call. CRC32 + * will be enabled for the requests, unless the checksum is specified or checksum validation is disabled. * * @see MultipartConfiguration */ diff --git a/services/s3/src/test/java/software/amazon/awssdk/services/s3/internal/multipart/MultipartDownloaderSubscriberWiremockTest.java b/services/s3/src/test/java/software/amazon/awssdk/services/s3/internal/multipart/MultipartDownloaderSubscriberWiremockTest.java index 1c6eb666a9c2..341febda47b1 100644 --- a/services/s3/src/test/java/software/amazon/awssdk/services/s3/internal/multipart/MultipartDownloaderSubscriberWiremockTest.java +++ b/services/s3/src/test/java/software/amazon/awssdk/services/s3/internal/multipart/MultipartDownloaderSubscriberWiremockTest.java @@ -98,32 +98,6 @@ void happyPath_shouldReceiveAllBodyPartInCorrectOrder(AsyncResponseTransform util.verifyCorrectAmountOfRequestsMade(amountOfPartToTest); } - @ParameterizedTest - @MethodSource("argumentsProvider") - void errorOnFirstRequest_shouldCompleteExceptionally(AsyncResponseTransformerTestSupplier supplier, - int amountOfPartToTest, - int partSize) { - stubFor(get(urlEqualTo(String.format("/%s/%s?partNumber=1", testBucket, testKey))).willReturn( - aResponse() - .withStatus(400) - .withBody("400test error message"))); - AsyncResponseTransformer transformer = supplier.transformer(); - AsyncResponseTransformer.SplitResult split = transformer.split( - SplittingTransformerConfiguration.builder() - .bufferSizeInBytes(1024 * 32L) - .build()); - Subscriber> subscriber = new MultipartDownloaderSubscriber( - s3AsyncClient, - GetObjectRequest.builder() - .bucket(testBucket) - .key(testKey) - .build()); - - split.publisher().subscribe(subscriber); - assertThatThrownBy(() -> split.resultFuture().join()) - .hasMessageContaining("test error message"); - } - @ParameterizedTest @MethodSource("argumentsProvider") void errorOnThirdRequest_shouldCompleteExceptionallyOnlyPartsGreaterThanTwo( diff --git a/services/s3/src/test/java/software/amazon/awssdk/services/s3/internal/multipart/S3MultipartClientGetObjectWiremockTest.java b/services/s3/src/test/java/software/amazon/awssdk/services/s3/internal/multipart/S3MultipartClientGetObjectWiremockTest.java new file mode 100644 index 000000000000..e45646e32061 --- /dev/null +++ b/services/s3/src/test/java/software/amazon/awssdk/services/s3/internal/multipart/S3MultipartClientGetObjectWiremockTest.java @@ -0,0 +1,542 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * A copy of the License is located at + * + * http://aws.amazon.com/apache2.0 + * + * or in the "license" file accompanying this file. This file is distributed + * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either + * express or implied. See the License for the specific language governing + * permissions and limitations under the License. + */ + +package software.amazon.awssdk.services.s3.internal.multipart; + +import static com.github.tomakehurst.wiremock.client.WireMock.aResponse; +import static com.github.tomakehurst.wiremock.client.WireMock.any; +import static com.github.tomakehurst.wiremock.client.WireMock.anyUrl; +import static com.github.tomakehurst.wiremock.client.WireMock.get; +import static com.github.tomakehurst.wiremock.client.WireMock.getRequestedFor; +import static com.github.tomakehurst.wiremock.client.WireMock.matching; +import static com.github.tomakehurst.wiremock.client.WireMock.stubFor; +import static com.github.tomakehurst.wiremock.client.WireMock.urlEqualTo; +import static com.github.tomakehurst.wiremock.client.WireMock.verify; +import static org.assertj.core.api.Assertions.assertThatThrownBy; +import static org.junit.jupiter.api.Assertions.assertArrayEquals; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNotEquals; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertTrue; + +import com.github.tomakehurst.wiremock.http.Fault; +import com.github.tomakehurst.wiremock.junit5.WireMockRuntimeInfo; +import com.github.tomakehurst.wiremock.junit5.WireMockTest; +import com.github.tomakehurst.wiremock.stubbing.Scenario; +import java.net.URI; +import java.nio.ByteBuffer; +import java.nio.charset.StandardCharsets; +import java.time.Duration; +import java.util.ArrayList; +import java.util.List; +import java.util.Random; +import java.util.UUID; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.CompletionException; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicBoolean; +import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.Timeout; +import org.reactivestreams.Subscriber; +import org.reactivestreams.Subscription; +import software.amazon.awssdk.auth.credentials.AwsBasicCredentials; +import software.amazon.awssdk.auth.credentials.StaticCredentialsProvider; +import software.amazon.awssdk.awscore.retry.AwsRetryStrategy; +import software.amazon.awssdk.core.ResponseBytes; +import software.amazon.awssdk.core.async.AsyncResponseTransformer; +import software.amazon.awssdk.core.async.SdkPublisher; +import software.amazon.awssdk.core.interceptor.Context; +import software.amazon.awssdk.core.interceptor.ExecutionAttributes; +import software.amazon.awssdk.core.interceptor.ExecutionInterceptor; +import software.amazon.awssdk.http.SdkHttpResponse; +import software.amazon.awssdk.http.nio.netty.NettyNioAsyncHttpClient; +import software.amazon.awssdk.regions.Region; +import software.amazon.awssdk.services.s3.S3AsyncClient; +import software.amazon.awssdk.services.s3.model.GetObjectRequest; +import software.amazon.awssdk.services.s3.model.GetObjectResponse; +import software.amazon.awssdk.services.s3.model.S3Exception; + +@WireMockTest +@Timeout(value = 30, unit = TimeUnit.SECONDS) +public class S3MultipartClientGetObjectWiremockTest { + private static final CapturingInterceptor capturingInterceptor = new CapturingInterceptor(); + public static final String BUCKET = "Example-Bucket"; + public static final String KEY = "Key"; + private static final int MAX_ATTEMPTS = 7; + private static final int TOTAL_PARTS = 3; + private static final int PART_SIZE = 1024; + private static final byte[] PART_1_DATA = new byte[PART_SIZE]; + private static final byte[] PART_2_DATA = new byte[PART_SIZE]; + private static final byte[] PART_3_DATA = new byte[PART_SIZE]; + private static byte[] expectedBody; + private S3AsyncClient multipartClient; + + @BeforeAll + public static void init() { + new Random().nextBytes(PART_1_DATA); + new Random().nextBytes(PART_2_DATA); + new Random().nextBytes(PART_3_DATA); + + expectedBody = new byte[TOTAL_PARTS * PART_SIZE]; + System.arraycopy(PART_1_DATA, 0, expectedBody, 0, PART_SIZE); + System.arraycopy(PART_2_DATA, 0, expectedBody, PART_SIZE, PART_SIZE); + System.arraycopy(PART_3_DATA, 0, expectedBody, 2 * PART_SIZE, PART_SIZE); + } + + @BeforeEach + public void setup(WireMockRuntimeInfo wm) { + capturingInterceptor.clear(); + multipartClient = S3AsyncClient.builder() + .region(Region.US_EAST_1) + .endpointOverride(URI.create(wm.getHttpBaseUrl())) + .multipartEnabled(true) + .httpClientBuilder(NettyNioAsyncHttpClient.builder().maxConcurrency(100).connectionAcquisitionTimeout(Duration.ofSeconds(100))) + .credentialsProvider(StaticCredentialsProvider.create(AwsBasicCredentials.create("key", "secret"))) + .overrideConfiguration( + o -> o.retryStrategy(AwsRetryStrategy.standardRetryStrategy().toBuilder() + .maxAttempts(MAX_ATTEMPTS) + .circuitBreakerEnabled(false) + .build()) + .addExecutionInterceptor(capturingInterceptor)) + .build(); + } + + @Test + public void getObject_concurrentCallsReturn200_shouldSucceed() { + List>> futures = new ArrayList<>(); + + int numRuns = 1000; + for (int i = 0; i < numRuns; i++) { + CompletableFuture> resp = mock200Response(multipartClient, i); + futures.add(resp); + } + + CompletableFuture.allOf(futures.toArray(new CompletableFuture[0])).join(); + } + + @Test + public void getObject_single500WithinMany200s_shouldRetrySuccessfully() { + List>> futures = new ArrayList<>(); + + int numRuns = 1000; + for (int i = 0; i < numRuns; i++) { + CompletableFuture> resp = mock200Response(multipartClient, i); + futures.add(resp); + } + + CompletableFuture> requestWithRetryableError = + mockRetryableErrorThen200Response(multipartClient, 1); + futures.add(requestWithRetryableError); + + for (int i = 0; i < numRuns; i++) { + CompletableFuture> resp = mock200Response(multipartClient, i + 1000); + futures.add(resp); + } + + CompletableFuture.allOf(futures.toArray(new CompletableFuture[0])).join(); + } + + @Test + public void getObject_concurrent503s_shouldRetrySuccessfully() { + List>> futures = new ArrayList<>(); + + int numRuns = 1000; + for (int i = 0; i < numRuns; i++) { + CompletableFuture> resp = mockRetryableErrorThen200Response(multipartClient, i); + futures.add(resp); + } + + CompletableFuture.allOf(futures.toArray(new CompletableFuture[0])).join(); + } + + @Test + public void getObject_5xxErrorResponses_shouldNotReuseInitialRequestId() { + String firstRequestId = UUID.randomUUID().toString(); + String secondRequestId = UUID.randomUUID().toString(); + + stubFor(any(anyUrl()) + .inScenario("errors") + .whenScenarioStateIs(Scenario.STARTED) + .willReturn(aResponse() + .withHeader("x-amz-request-id", firstRequestId) + .withStatus(503) + .withBody(internalErrorBody())) + .willSetStateTo("SecondAttempt")); + + stubFor(any(anyUrl()) + .inScenario("errors") + .whenScenarioStateIs("SecondAttempt") + .willReturn(aResponse() + .withHeader("x-amz-request-id", secondRequestId) + .withStatus(500))); + + assertThatThrownBy(() -> multipartClient.getObject(b -> b.bucket(BUCKET).key(KEY), + AsyncResponseTransformer.toBytes()).join()) + .isInstanceOf(CompletionException.class) + .hasCauseInstanceOf(S3Exception.class); + + + List responses = capturingInterceptor.getResponses(); + assertEquals(MAX_ATTEMPTS, responses.size(), () -> String.format("Expected exactly %s responses", MAX_ATTEMPTS)); + + String actualFirstRequestId = responses.get(0).firstMatchingHeader("x-amz-request-id").orElse(null); + String actualSecondRequestId = responses.get(1).firstMatchingHeader("x-amz-request-id").orElse(null); + + assertNotNull(actualFirstRequestId); + assertNotNull(actualSecondRequestId); + + assertNotEquals(actualFirstRequestId, actualSecondRequestId); + + assertEquals(firstRequestId, actualFirstRequestId); + assertEquals(secondRequestId, actualSecondRequestId); + + assertEquals(503, responses.get(0).statusCode()); + assertEquals(500, responses.get(1).statusCode()); + } + + @Test + public void multipartDownload_200Response_shouldSucceed() { + stub200SuccessPart1(); + stub200SuccessPart2(); + stub200SuccessPart3(); + + CompletableFuture> future = + multipartClient.getObject(GetObjectRequest.builder().bucket(BUCKET).key(KEY).build(), + AsyncResponseTransformer.toBytes()); + + ResponseBytes response = future.join(); + byte[] actualBody = response.asByteArray(); + assertArrayEquals(expectedBody, actualBody, "Downloaded body should match expected combined parts"); + + // Verify that all 3 parts were requested only once + verify(1, getRequestedFor(urlEqualTo(String.format("/%s/%s?partNumber=1", BUCKET, KEY)))); + verify(1, getRequestedFor(urlEqualTo(String.format("/%s/%s?partNumber=2", BUCKET, KEY)))); + verify(1, getRequestedFor(urlEqualTo(String.format("/%s/%s?partNumber=3", BUCKET, KEY)))); + } + + @Test + public void multipartDownload_503OnFirstPart_shouldRetrySuccessfully() { + // Stub Part 1 - 503 on first attempt, 200 on retry + String part1Scenario = "part1-retry"; + stubFor(get(urlEqualTo(String.format("/%s/%s?partNumber=1", BUCKET, KEY))) + .inScenario(part1Scenario) + .whenScenarioStateIs(Scenario.STARTED) + .willReturn(aResponse() + .withStatus(503) + .withHeader("x-amz-request-id", UUID.randomUUID().toString()) + .withBody("\n" + + "\n" + + " SlowDown\n" + + " Please reduce your request rate.\n" + + "")) + .willSetStateTo("retry-attempt")); + + stubFor(get(urlEqualTo(String.format("/%s/%s?partNumber=1", BUCKET, KEY))) + .inScenario(part1Scenario) + .whenScenarioStateIs("retry-attempt") + .willReturn(aResponse() + .withStatus(200) + .withHeader("x-amz-mp-parts-count", String.valueOf(TOTAL_PARTS)) + .withHeader("x-amz-request-id", UUID.randomUUID().toString()) + .withBody(PART_1_DATA))); + + stub200SuccessPart2(); + stub200SuccessPart3(); + + CompletableFuture> future = + multipartClient.getObject(GetObjectRequest.builder().bucket(BUCKET).key(KEY).build(), + AsyncResponseTransformer.toBytes()); + + ResponseBytes response = future.join(); + byte[] actualBody = response.asByteArray(); + assertArrayEquals(expectedBody, actualBody, "Downloaded body should match expected combined parts"); + + // Verify that part 1 was requested twice (initial 503 + retry) + verify(2, getRequestedFor(urlEqualTo(String.format("/%s/%s?partNumber=1", BUCKET, KEY)))); + verify(1, getRequestedFor(urlEqualTo(String.format("/%s/%s?partNumber=2", BUCKET, KEY)))); + verify(1, getRequestedFor(urlEqualTo(String.format("/%s/%s?partNumber=3", BUCKET, KEY)))); + } + + @Test + public void multipartDownload_503OnFirstPartAndSecondPart_shouldRetryFirstPartSuccessfullyAndFailOnSecondPart() { + // Stub Part 1 - 503 on first attempt, 200 on retry + stubFor(get(urlEqualTo(String.format("/%s/%s?partNumber=1", BUCKET, KEY))) + .inScenario("part1-retry") + .whenScenarioStateIs(Scenario.STARTED) + .willReturn(aResponse() + .withStatus(503) + .withHeader("x-amz-request-id", UUID.randomUUID().toString()) + .withBody(slowdownErrorBody())) + .willSetStateTo("retry-attempt")); + + stubFor(get(urlEqualTo(String.format("/%s/%s?partNumber=1", BUCKET, KEY))) + .inScenario("part1-retry") + .whenScenarioStateIs("retry-attempt") + .willReturn(aResponse() + .withStatus(200) + .withHeader("x-amz-mp-parts-count", String.valueOf(TOTAL_PARTS)) + .withHeader("x-amz-request-id", UUID.randomUUID().toString()) + .withBody(PART_1_DATA))); + + + // Stub Part 2 - 503 on first attempt, 200 on retry + stubFor(get(urlEqualTo(String.format("/%s/%s?partNumber=2", BUCKET, KEY))) + .inScenario("part2-retry") + .whenScenarioStateIs(Scenario.STARTED) + .willReturn(aResponse() + .withStatus(500) + .withHeader("x-amz-request-id", UUID.randomUUID().toString()) + .withBody(internalErrorBody())) + .willSetStateTo("retry-attempt")); + + stubFor(get(urlEqualTo(String.format("/%s/%s?partNumber=2", BUCKET, KEY))) + .inScenario("part2-retry") + .whenScenarioStateIs("retry-attempt") + .willReturn(aResponse() + .withStatus(200) + .withHeader("x-amz-mp-parts-count", String.valueOf(TOTAL_PARTS)) + .withHeader("x-amz-request-id", UUID.randomUUID().toString()) + .withBody(PART_2_DATA))); + + stub200SuccessPart3(); + + assertThatThrownBy(() -> multipartClient.getObject(b -> b.bucket(BUCKET).key(KEY), + AsyncResponseTransformer.toBytes()).join()) + .isInstanceOf(CompletionException.class) + .hasCauseInstanceOf(S3Exception.class) + .hasMessageContaining("We encountered an internal error. Please try again."); + + // Verify that part 1 was requested twice (initial 503 + retry) + verify(2, getRequestedFor(urlEqualTo(String.format("/%s/%s?partNumber=1", BUCKET, KEY)))); + // Verify that part 2 was requested once (no retry) + verify(1, getRequestedFor(urlEqualTo(String.format("/%s/%s?partNumber=2", BUCKET, KEY)))); + } + + @Test + public void getObject_iOError_shouldRetrySuccessfully() { + String requestId = UUID.randomUUID().toString(); + + stubFor(any(anyUrl()) + .inScenario("io-error") + .whenScenarioStateIs(Scenario.STARTED) + .willReturn(aResponse() + .withFault(Fault.CONNECTION_RESET_BY_PEER)) + .willSetStateTo("retry")); + + stubFor(any(anyUrl()) + .inScenario("io-error") + .whenScenarioStateIs("retry") + .willReturn(aResponse() + .withStatus(200) + .withHeader("x-amz-request-id", requestId) + .withBody("Hello World"))); + + ResponseBytes response = multipartClient.getObject(GetObjectRequest.builder() + .bucket(BUCKET) + .key(KEY) + .build(), + AsyncResponseTransformer.toBytes()).join(); + + assertArrayEquals("Hello World".getBytes(StandardCharsets.UTF_8), response.asByteArray()); + + verify(2, getRequestedFor(urlEqualTo("/" + BUCKET + "/" + KEY + "?partNumber=1"))); + + List responses = capturingInterceptor.getResponses(); + String finalRequestId = responses.get(responses.size() - 1) + .firstMatchingHeader("x-amz-request-id") + .orElse(null); + + assertEquals(requestId, finalRequestId); + } + + @Test + public void multipartDownload_errorDuringFirstPartAfterOnStream_shouldFailAndNotRetry() { + stub200SuccessPart1(); + stub200SuccessPart2(); + stub200SuccessPart3(); + + StreamingErrorTransformer failingTransformer = new StreamingErrorTransformer(); + + assertThatThrownBy(() -> multipartClient.getObject(b -> b.bucket(BUCKET).key(KEY), failingTransformer).join()) + .isInstanceOf(CompletionException.class) + .hasCauseInstanceOf(RuntimeException.class); + + assertTrue(failingTransformer.onStreamCalled.get()); + // Verify that the first part was requested only once and not retried + verify(1, getRequestedFor(urlEqualTo(String.format("/%s/%s?partNumber=1", BUCKET, KEY)))); + } + + private String errorBody(String errorCode, String errorMessage) { + return "\n" + + "\n" + + " " + errorCode + "\n" + + " " + errorMessage + "\n" + + ""; + } + + private String internalErrorBody() { + return errorBody("InternalError", "We encountered an internal error. Please try again."); + } + + private String slowdownErrorBody() { + return errorBody("SlowDown", "Please reduce your request rate."); + } + + private void stub200SuccessPart1() { + stubFor(get(urlEqualTo(String.format("/%s/%s?partNumber=1", BUCKET, KEY))) + .willReturn(aResponse() + .withHeader("x-amz-request-id", String.valueOf(UUID.randomUUID())) + .withHeader("x-amz-mp-parts-count", String.valueOf(TOTAL_PARTS)) + .withStatus(200).withBody(PART_1_DATA))); + } + + private void stub200SuccessPart2() { + stubFor(get(urlEqualTo(String.format("/%s/%s?partNumber=2", BUCKET, KEY))) + .willReturn(aResponse() + .withStatus(200) + .withHeader("x-amz-mp-parts-count", String.valueOf(TOTAL_PARTS)) + .withHeader("x-amz-request-id", UUID.randomUUID().toString()) + .withBody(PART_2_DATA))); + } + + private void stub200SuccessPart3() { + stubFor(get(urlEqualTo(String.format("/%s/%s?partNumber=3", BUCKET, KEY))) + .willReturn(aResponse() + .withStatus(200) + .withHeader("x-amz-mp-parts-count", String.valueOf(TOTAL_PARTS)) + .withHeader("x-amz-request-id", UUID.randomUUID().toString()) + .withBody(PART_3_DATA))); + } + + private CompletableFuture> mock200Response(S3AsyncClient s3Client, int runNumber) { + String runId = runNumber + " success"; + + stubFor(any(anyUrl()) + .withHeader("RunNum", matching(runId)) + .inScenario(runId) + .whenScenarioStateIs(Scenario.STARTED) + .willReturn(aResponse().withStatus(200) + .withHeader("x-amz-request-id", String.valueOf(UUID.randomUUID())) + .withBody("Hello World"))); + + return s3Client.getObject(r -> r.bucket(BUCKET).key("key") + .overrideConfiguration(c -> c.putHeader("RunNum", runId)), + AsyncResponseTransformer.toBytes()); + } + + private CompletableFuture> mockRetryableErrorThen200Response(S3AsyncClient s3Client, int runNumber) { + String runId = String.valueOf(runNumber); + + stubFor(any(anyUrl()) + .withHeader("RunNum", matching(runId)) + .inScenario(runId) + .whenScenarioStateIs(Scenario.STARTED) + .willReturn(aResponse() + .withHeader("x-amz-request-id", String.valueOf(UUID.randomUUID())) + .withStatus(500) + .withBody(internalErrorBody()) + ) + .willSetStateTo("SecondAttempt" + runId)); + + stubFor(any(anyUrl()) + .inScenario(runId) + .withHeader("RunNum", matching(runId)) + .whenScenarioStateIs("SecondAttempt" + runId) + .willReturn(aResponse().withStatus(200) + .withHeader("x-amz-request-id", String.valueOf(UUID.randomUUID())) + .withBody("Hello World"))); + + return s3Client.getObject(r -> r.bucket(BUCKET).key("key") + .overrideConfiguration(c -> c.putHeader("RunNum", runId)), + AsyncResponseTransformer.toBytes()); + } + + static class CapturingInterceptor implements ExecutionInterceptor { + private final List responses = new ArrayList<>(); + + @Override + public void afterTransmission(Context.AfterTransmission context, ExecutionAttributes executionAttributes) { + responses.add(context.httpResponse()); + } + + public List getResponses() { + return new ArrayList<>(responses); + } + + public void clear() { + responses.clear(); + } + } + + /** + * Custom AsyncResponseTransformer that simulates an error occurring after onStream() has been called + */ + private static final class StreamingErrorTransformer + implements AsyncResponseTransformer> { + + private final CompletableFuture> future = new CompletableFuture<>(); + private final AtomicBoolean errorThrown = new AtomicBoolean(); + private final AtomicBoolean onStreamCalled = new AtomicBoolean(); + + @Override + public CompletableFuture> prepare() { + return future; + } + + @Override + public void onResponse(GetObjectResponse response) { + // + } + + @Override + public void onStream(SdkPublisher publisher) { + onStreamCalled.set(true); + publisher.subscribe(new Subscriber() { + private Subscription subscription; + + @Override + public void onSubscribe(Subscription s) { + this.subscription = s; + s.request(1); + } + + @Override + public void onNext(ByteBuffer byteBuffer) { + if (errorThrown.compareAndSet(false, true)) { + future.completeExceptionally(new RuntimeException()); + subscription.cancel(); + } + } + + @Override + public void onError(Throwable t) { + future.completeExceptionally(t); + } + + @Override + public void onComplete() { + // + } + }); + } + + @Override + public void exceptionOccurred(Throwable throwable) { + future.completeExceptionally(throwable); + } + } +}