diff --git a/.changes/next-release/feature-AmazonS3-7c1530f.json b/.changes/next-release/feature-AmazonS3-7c1530f.json new file mode 100644 index 000000000000..0d6ff52b5325 --- /dev/null +++ b/.changes/next-release/feature-AmazonS3-7c1530f.json @@ -0,0 +1,6 @@ +{ + "type": "feature", + "category": "Amazon S3", + "contributor": "", + "description": "Add retry support for Java based S3 multipart client download to Byte array" +} diff --git a/core/sdk-core/src/main/java/software/amazon/awssdk/core/internal/async/ByteArrayAsyncResponseTransformer.java b/core/sdk-core/src/main/java/software/amazon/awssdk/core/internal/async/ByteArrayAsyncResponseTransformer.java index d1103ea2a2de..79838f739d4a 100644 --- a/core/sdk-core/src/main/java/software/amazon/awssdk/core/internal/async/ByteArrayAsyncResponseTransformer.java +++ b/core/sdk-core/src/main/java/software/amazon/awssdk/core/internal/async/ByteArrayAsyncResponseTransformer.java @@ -24,6 +24,7 @@ import org.reactivestreams.Subscription; import software.amazon.awssdk.annotations.SdkInternalApi; import software.amazon.awssdk.core.ResponseBytes; +import software.amazon.awssdk.core.SplittingTransformerConfiguration; import software.amazon.awssdk.core.async.AsyncResponseTransformer; import software.amazon.awssdk.core.async.SdkPublisher; import software.amazon.awssdk.utils.BinaryUtils; @@ -65,6 +66,17 @@ public void exceptionOccurred(Throwable throwable) { cf.completeExceptionally(throwable); } + @Override + public SplitResult> split(SplittingTransformerConfiguration splitConfig) { + CompletableFuture> future = new CompletableFuture<>(); + SdkPublisher> transformer = + new ByteArraySplittingTransformer<>(this, future); + return AsyncResponseTransformer.SplitResult.>builder() + .publisher(transformer) + .resultFuture(future) + .build(); + } + @Override public String name() { return TransformerType.BYTES.getName(); diff --git a/core/sdk-core/src/main/java/software/amazon/awssdk/core/internal/async/ByteArraySplittingTransformer.java b/core/sdk-core/src/main/java/software/amazon/awssdk/core/internal/async/ByteArraySplittingTransformer.java new file mode 100644 index 000000000000..7d3b3a249067 --- /dev/null +++ b/core/sdk-core/src/main/java/software/amazon/awssdk/core/internal/async/ByteArraySplittingTransformer.java @@ -0,0 +1,239 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * A copy of the License is located at + * + * http://aws.amazon.com/apache2.0 + * + * or in the "license" file accompanying this file. This file is distributed + * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either + * express or implied. See the License for the specific language governing + * permissions and limitations under the License. + */ + +package software.amazon.awssdk.core.internal.async; + +import java.nio.ByteBuffer; +import java.util.ArrayList; +import java.util.List; +import java.util.Map; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.concurrent.atomic.AtomicLong; +import java.util.concurrent.atomic.AtomicReference; +import org.reactivestreams.Subscriber; +import org.reactivestreams.Subscription; +import software.amazon.awssdk.annotations.SdkInternalApi; +import software.amazon.awssdk.core.ResponseBytes; +import software.amazon.awssdk.core.async.AsyncResponseTransformer; +import software.amazon.awssdk.core.async.SdkPublisher; +import software.amazon.awssdk.core.exception.SdkClientException; +import software.amazon.awssdk.utils.CompletableFutureUtils; +import software.amazon.awssdk.utils.Logger; +import software.amazon.awssdk.utils.async.SimplePublisher; + +@SdkInternalApi +public class ByteArraySplittingTransformer implements SdkPublisher> { + private static final Logger log = Logger.loggerFor(ByteArraySplittingTransformer.class); + private final AsyncResponseTransformer> upstreamResponseTransformer; + private final CompletableFuture> resultFuture; + private Subscriber> downstreamSubscriber; + private final AtomicInteger onNextSignalsSent = new AtomicInteger(0); + private final AtomicReference responseT = new AtomicReference<>(); + + private final SimplePublisher publisherToUpstream = new SimplePublisher<>(); + /** + * The amount requested by the downstream subscriber that is still left to fulfill. Updated when the + * {@link Subscription#request(long) request} method is called on the downstream subscriber's subscription. Corresponds to the + * number of {@code AsyncResponseTransformer} that will be published to the downstream subscriber. + */ + private final AtomicLong outstandingDemand = new AtomicLong(0); + + /** + * This flag stops the current thread from publishing transformers while another thread is already publishing. + */ + private final AtomicBoolean emitting = new AtomicBoolean(false); + + private final Object lock = new Object(); + + /** + * Set to true once {@code .onStream()} is called on the upstreamResponseTransformer + */ + private boolean onStreamCalled; + + /** + * Set to true once {@code .cancel()} is called in the subscription of the downstream subscriber, or if the + * {@code resultFuture} is cancelled. + */ + private final AtomicBoolean isCancelled = new AtomicBoolean(false); + + private final Map buffers; + + public ByteArraySplittingTransformer(AsyncResponseTransformer> + upstreamResponseTransformer, + CompletableFuture> resultFuture) { + this.upstreamResponseTransformer = upstreamResponseTransformer; + this.resultFuture = resultFuture; + this.buffers = new ConcurrentHashMap<>(); + } + + @Override + public void subscribe(Subscriber> subscriber) { + this.downstreamSubscriber = subscriber; + subscriber.onSubscribe(new DownstreamSubscription()); + } + + private final class DownstreamSubscription implements Subscription { + + @Override + public void request(long n) { + if (n <= 0) { + downstreamSubscriber.onError(new IllegalArgumentException("Amount requested must be positive")); + return; + } + long newDemand = outstandingDemand.updateAndGet(current -> { + if (Long.MAX_VALUE - current < n) { + return Long.MAX_VALUE; + } + return current + n; + }); + log.trace(() -> String.format("new outstanding demand: %s", newDemand)); + emit(); + } + + @Override + public void cancel() { + log.trace(() -> String.format("received cancel signal. Current cancel state is 'isCancelled=%s'", isCancelled.get())); + if (isCancelled.compareAndSet(false, true)) { + handleSubscriptionCancel(); + } + } + } + + private void emit() { + do { + if (!emitting.compareAndSet(false, true)) { + return; + } + try { + if (doEmit()) { + return; + } + } finally { + emitting.compareAndSet(true, false); + } + } while (outstandingDemand.get() > 0); + } + + private boolean doEmit() { + long demand = outstandingDemand.get(); + + while (demand > 0) { + if (isCancelled.get()) { + return true; + } + if (outstandingDemand.get() > 0) { + demand = outstandingDemand.decrementAndGet(); + downstreamSubscriber.onNext(new IndividualTransformer(onNextSignalsSent.incrementAndGet())); + } + } + return false; + } + + /** + * Handle the {@code .cancel()} signal received from the downstream subscription. Data that is being sent to the upstream + * transformer need to finish processing before we complete. One typical use case for this is completing the multipart + * download, the subscriber having reached the final part will signal that it doesn't need more + * {@link AsyncResponseTransformer}s by calling {@code .cancel()} on the subscription. + */ + private void handleSubscriptionCancel() { + synchronized (lock) { + if (downstreamSubscriber == null) { + log.trace(() -> "downstreamSubscriber already null, skipping downstreamSubscriber.onComplete()"); + return; + } + if (!onStreamCalled) { + // we never subscribe publisherToUpstream to the upstream, it would not complete + downstreamSubscriber = null; + return; + } + + // if result future is already complete (likely by exception propagation), skip. + if (resultFuture.isDone()) { + return; + } + + CompletableFuture> upstreamPrepareFuture = upstreamResponseTransformer.prepare(); + CompletableFutureUtils.forwardResultTo(upstreamPrepareFuture, resultFuture); + + upstreamResponseTransformer.onResponse(responseT.get()); + + try { + buffers.keySet().stream().sorted().forEach(index -> { + publisherToUpstream.send(buffers.get(index)).exceptionally(ex -> { + resultFuture.completeExceptionally(SdkClientException.create("unexpected error occurred", ex)); + return null; + }); + }); + + publisherToUpstream.complete().exceptionally(ex -> { + resultFuture.completeExceptionally(SdkClientException.create("unexpected error occurred", ex)); + return null; + }); + upstreamResponseTransformer.onStream(SdkPublisher.adapt(publisherToUpstream)); + + } catch (Throwable throwable) { + resultFuture.completeExceptionally(SdkClientException.create("unexpected error occurred", throwable)); + } + } + } + + private final class IndividualTransformer implements AsyncResponseTransformer { + private final int onNextCount; + private final ByteArrayAsyncResponseTransformer delegate = new ByteArrayAsyncResponseTransformer<>(); + + private CompletableFuture future; + private final List>> delegatePrepareFutures = new ArrayList<>(); + + private IndividualTransformer(int onNextCount) { + this.onNextCount = onNextCount; + } + + @Override + public CompletableFuture prepare() { + future = new CompletableFuture<>(); + CompletableFuture> prepare = delegate.prepare(); + CompletableFutureUtils.forwardExceptionTo(prepare, future); + delegatePrepareFutures.add(prepare); + return prepare.thenApply(responseTResponseBytes -> { + buffers.put(onNextCount, responseTResponseBytes.asByteBuffer()); + return responseTResponseBytes.response(); + }); + } + + @Override + public void onResponse(ResponseT response) { + responseT.compareAndSet(null, response); + delegate.onResponse(response); + } + + @Override + public void onStream(SdkPublisher publisher) { + delegate.onStream(publisher); + synchronized (lock) { + if (!onStreamCalled) { + onStreamCalled = true; + } + } + } + + @Override + public void exceptionOccurred(Throwable error) { + delegate.exceptionOccurred(error); + } + } +} diff --git a/services/s3/src/main/java/software/amazon/awssdk/services/s3/internal/multipart/DownloadObjectHelper.java b/services/s3/src/main/java/software/amazon/awssdk/services/s3/internal/multipart/DownloadObjectHelper.java index 2d6fadc5f505..e09a065b29be 100644 --- a/services/s3/src/main/java/software/amazon/awssdk/services/s3/internal/multipart/DownloadObjectHelper.java +++ b/services/s3/src/main/java/software/amazon/awssdk/services/s3/internal/multipart/DownloadObjectHelper.java @@ -23,6 +23,7 @@ import software.amazon.awssdk.services.s3.S3AsyncClient; import software.amazon.awssdk.services.s3.model.GetObjectRequest; import software.amazon.awssdk.services.s3.model.GetObjectResponse; +import software.amazon.awssdk.utils.CompletableFutureUtils; import software.amazon.awssdk.utils.Logger; @SdkInternalApi @@ -49,7 +50,9 @@ public CompletableFuture downloadObject( .build()); MultipartDownloaderSubscriber subscriber = subscriber(getObjectRequest); split.publisher().subscribe(subscriber); - return split.resultFuture(); + CompletableFuture splitFuture = split.resultFuture(); + CompletableFutureUtils.forwardExceptionTo(subscriber.future(), splitFuture); + return splitFuture; } private MultipartDownloaderSubscriber subscriber(GetObjectRequest getObjectRequest) { diff --git a/services/s3/src/main/java/software/amazon/awssdk/services/s3/multipart/MultipartConfiguration.java b/services/s3/src/main/java/software/amazon/awssdk/services/s3/multipart/MultipartConfiguration.java index 9b074d6244f3..bb96fa9bcb23 100644 --- a/services/s3/src/main/java/software/amazon/awssdk/services/s3/multipart/MultipartConfiguration.java +++ b/services/s3/src/main/java/software/amazon/awssdk/services/s3/multipart/MultipartConfiguration.java @@ -18,6 +18,7 @@ import software.amazon.awssdk.annotations.SdkPublicApi; import software.amazon.awssdk.core.async.AsyncRequestBody; import software.amazon.awssdk.core.async.AsyncResponseTransformer; +import software.amazon.awssdk.core.internal.async.ByteArrayAsyncResponseTransformer; import software.amazon.awssdk.services.s3.S3AsyncClient; import software.amazon.awssdk.services.s3.S3AsyncClientBuilder; import software.amazon.awssdk.services.s3.model.CopyObjectRequest; @@ -28,15 +29,21 @@ import software.amazon.awssdk.utils.builder.ToCopyableBuilder; /** - * Class that hold configuration properties related to multipart operation for a {@link S3AsyncClient}. Passing this class to the - * {@link S3AsyncClientBuilder#multipartConfiguration(MultipartConfiguration)} will enable automatic conversion of - * {@link S3AsyncClient#getObject(GetObjectRequest, AsyncResponseTransformer)}, - * {@link S3AsyncClient#putObject(PutObjectRequest, AsyncRequestBody)} and - * {@link S3AsyncClient#copyObject(CopyObjectRequest)} to their respective multipart operation. + * Class that holds configuration properties related to multipart operations for a {@link S3AsyncClient}. Passing this class to + * the {@link S3AsyncClientBuilder#multipartConfiguration(MultipartConfiguration)} will enable automatic conversion of the + * following operations to their respective multipart variants: + *
    + *
  • {@link S3AsyncClient#getObject(GetObjectRequest, AsyncResponseTransformer)}, + *
  • {@link S3AsyncClient#putObject(PutObjectRequest, AsyncRequestBody)} + *
  • {@link S3AsyncClient#copyObject(CopyObjectRequest)} + *
*

- * Note that multipart download fetch individual part of the object using {@link GetObjectRequest#partNumber() part number}, this - * means it will only download multiple parts if the - * object itself was uploaded as a {@link S3AsyncClient#createMultipartUpload(CreateMultipartUploadRequest) multipart object} + * Note that multipart download fetches individual parts of the object using {@link GetObjectRequest#partNumber() PartNumber}. + * This means the S3 client will only download multiple parts if the object itself was uploaded as a + * {@link S3AsyncClient#createMultipartUpload(CreateMultipartUploadRequest) multipart object} + *

+ * When performing multipart download, retry is only supported for downloading to byte array, i.e., when providing a + * {@link ByteArrayAsyncResponseTransformer} */ @SdkPublicApi public final class MultipartConfiguration implements ToCopyableBuilder { @@ -83,6 +90,10 @@ public Long minimumPartSizeInBytes() { /** * The maximum memory, in bytes, that the SDK will use to buffer requests content into memory. + *

+ * This setting is not supported and will be ignored when downloading to a byte array, i.e., when providing a + * {@link ByteArrayAsyncResponseTransformer}. + * * @return the value of the configured maximum memory usage. */ public Long apiCallBufferSizeInBytes() { @@ -152,6 +163,9 @@ public interface Builder extends CopyableBuilder * Default value: If not specified, the SDK will use the equivalent of four parts worth of memory, so 32 Mib by default. + *

+ * This setting is not supported and will be ignored when downloading to a byte array, i.e., when providing a + * {@link ByteArrayAsyncResponseTransformer}. * * @param apiCallBufferSizeInBytes the value of the maximum memory usage. * @return an instance of this builder. @@ -170,20 +184,24 @@ private static class DefaultMultipartConfigBuilder implements Builder { private Long minimumPartSizeInBytes; private Long apiCallBufferSizeInBytes; + @Override public Builder thresholdInBytes(Long thresholdInBytes) { this.thresholdInBytes = thresholdInBytes; return this; } + @Override public Long thresholdInBytes() { return this.thresholdInBytes; } + @Override public Builder minimumPartSizeInBytes(Long minimumPartSizeInBytes) { this.minimumPartSizeInBytes = minimumPartSizeInBytes; return this; } + @Override public Long minimumPartSizeInBytes() { return this.minimumPartSizeInBytes; } diff --git a/services/s3/src/main/resources/codegen-resources/customization.config b/services/s3/src/main/resources/codegen-resources/customization.config index 8acd43509eab..44484c5ff0c2 100644 --- a/services/s3/src/main/resources/codegen-resources/customization.config +++ b/services/s3/src/main/resources/codegen-resources/customization.config @@ -289,8 +289,8 @@ "useS3ExpressSessionAuth": true, "multipartCustomization": { "multipartConfigurationClass": "software.amazon.awssdk.services.s3.multipart.MultipartConfiguration", - "multipartConfigMethodDoc": "Configuration for multipart operation of this client.", - "multipartEnableMethodDoc": "Enables automatic conversion of GET, PUT and COPY methods to their equivalent multipart operation. CRC32 checksum will be enabled for PUT, unless the checksum is specified or checksum validation is disabled.", + "multipartConfigMethodDoc": "Configuration for multipart operation of this client.

When performing multipart download, retry is only supported for downloading to byte array, i.e., when providing a {@code ByteArrayAsyncResponseTransformer}", + "multipartEnableMethodDoc": "Enables automatic conversion of GET, PUT and COPY methods to their equivalent multipart operation. CRC32 checksum will be enabled for PUT, unless the checksum is specified or checksum validation is disabled.

When performing multipart download, retry is only supported for downloading to byte array, i.e., when providing a {@code ByteArrayAsyncResponseTransformer}", "contextParamEnabledKey": "S3AsyncClientDecorator.MULTIPART_ENABLED_KEY", "contextParamConfigKey": "S3AsyncClientDecorator.MULTIPART_CONFIGURATION_KEY" }, diff --git a/services/s3/src/test/java/software/amazon/awssdk/services/s3/internal/multipart/MultipartDownloaderSubscriberWiremockTest.java b/services/s3/src/test/java/software/amazon/awssdk/services/s3/internal/multipart/MultipartDownloaderSubscriberWiremockTest.java deleted file mode 100644 index 1c6eb666a9c2..000000000000 --- a/services/s3/src/test/java/software/amazon/awssdk/services/s3/internal/multipart/MultipartDownloaderSubscriberWiremockTest.java +++ /dev/null @@ -1,182 +0,0 @@ -/* - * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"). - * You may not use this file except in compliance with the License. - * A copy of the License is located at - * - * http://aws.amazon.com/apache2.0 - * - * or in the "license" file accompanying this file. This file is distributed - * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either - * express or implied. See the License for the specific language governing - * permissions and limitations under the License. - */ - -package software.amazon.awssdk.services.s3.internal.multipart; - -import static com.github.tomakehurst.wiremock.client.WireMock.aResponse; -import static com.github.tomakehurst.wiremock.client.WireMock.get; -import static com.github.tomakehurst.wiremock.client.WireMock.stubFor; -import static com.github.tomakehurst.wiremock.client.WireMock.urlEqualTo; -import static org.assertj.core.api.Assertions.assertThatThrownBy; -import static org.junit.jupiter.api.Assertions.assertArrayEquals; -import static org.junit.jupiter.api.Assertions.assertNotNull; -import static org.junit.jupiter.params.provider.Arguments.arguments; -import static software.amazon.awssdk.services.s3.internal.multipart.MultipartDownloadTestUtil.transformersSuppliers; - -import com.github.tomakehurst.wiremock.junit5.WireMockRuntimeInfo; -import com.github.tomakehurst.wiremock.junit5.WireMockTest; -import java.net.URI; -import java.util.Arrays; -import java.util.List; -import java.util.UUID; -import java.util.stream.Stream; -import org.junit.jupiter.api.BeforeEach; -import org.junit.jupiter.params.ParameterizedTest; -import org.junit.jupiter.params.provider.Arguments; -import org.junit.jupiter.params.provider.MethodSource; -import org.reactivestreams.Subscriber; -import software.amazon.awssdk.auth.credentials.AwsBasicCredentials; -import software.amazon.awssdk.auth.credentials.StaticCredentialsProvider; -import software.amazon.awssdk.core.SplittingTransformerConfiguration; -import software.amazon.awssdk.core.async.AsyncResponseTransformer; -import software.amazon.awssdk.regions.Region; -import software.amazon.awssdk.services.s3.S3AsyncClient; -import software.amazon.awssdk.services.s3.S3Configuration; -import software.amazon.awssdk.services.s3.model.GetObjectRequest; -import software.amazon.awssdk.services.s3.model.GetObjectResponse; -import software.amazon.awssdk.services.s3.utils.AsyncResponseTransformerTestSupplier; -import software.amazon.awssdk.utils.Pair; - -@WireMockTest -class MultipartDownloaderSubscriberWiremockTest { - - private final String testBucket = "test-bucket"; - private final String testKey = "test-key"; - - private S3AsyncClient s3AsyncClient; - private MultipartDownloadTestUtil util; - - @BeforeEach - public void init(WireMockRuntimeInfo wiremock) { - s3AsyncClient = S3AsyncClient.builder() - .credentialsProvider(StaticCredentialsProvider.create( - AwsBasicCredentials.create("key", "secret"))) - .region(Region.US_WEST_2) - .endpointOverride(URI.create("http://localhost:" + wiremock.getHttpPort())) - .serviceConfiguration(S3Configuration.builder() - .pathStyleAccessEnabled(true) - .build()) - .build(); - util = new MultipartDownloadTestUtil(testBucket, testKey, UUID.randomUUID().toString()); - } - - @ParameterizedTest - @MethodSource("argumentsProvider") - void happyPath_shouldReceiveAllBodyPartInCorrectOrder(AsyncResponseTransformerTestSupplier supplier, - int amountOfPartToTest, - int partSize) { - byte[] expectedBody = util.stubAllParts(testBucket, testKey, amountOfPartToTest, partSize); - AsyncResponseTransformer transformer = supplier.transformer(); - AsyncResponseTransformer.SplitResult split = transformer.split( - SplittingTransformerConfiguration.builder() - .bufferSizeInBytes(1024 * 32L) - .build()); - Subscriber> subscriber = new MultipartDownloaderSubscriber( - s3AsyncClient, - GetObjectRequest.builder() - .bucket(testBucket) - .key(testKey) - .build()); - - split.publisher().subscribe(subscriber); - T response = split.resultFuture().join(); - - byte[] body = supplier.body(response); - assertArrayEquals(expectedBody, body); - util.verifyCorrectAmountOfRequestsMade(amountOfPartToTest); - } - - @ParameterizedTest - @MethodSource("argumentsProvider") - void errorOnFirstRequest_shouldCompleteExceptionally(AsyncResponseTransformerTestSupplier supplier, - int amountOfPartToTest, - int partSize) { - stubFor(get(urlEqualTo(String.format("/%s/%s?partNumber=1", testBucket, testKey))).willReturn( - aResponse() - .withStatus(400) - .withBody("400test error message"))); - AsyncResponseTransformer transformer = supplier.transformer(); - AsyncResponseTransformer.SplitResult split = transformer.split( - SplittingTransformerConfiguration.builder() - .bufferSizeInBytes(1024 * 32L) - .build()); - Subscriber> subscriber = new MultipartDownloaderSubscriber( - s3AsyncClient, - GetObjectRequest.builder() - .bucket(testBucket) - .key(testKey) - .build()); - - split.publisher().subscribe(subscriber); - assertThatThrownBy(() -> split.resultFuture().join()) - .hasMessageContaining("test error message"); - } - - @ParameterizedTest - @MethodSource("argumentsProvider") - void errorOnThirdRequest_shouldCompleteExceptionallyOnlyPartsGreaterThanTwo( - AsyncResponseTransformerTestSupplier supplier, - int amountOfPartToTest, - int partSize) { - util.stubForPart(testBucket, testKey, 1, 3, partSize); - util.stubForPart(testBucket, testKey, 2, 3, partSize); - stubFor(get(urlEqualTo(String.format("/%s/%s?partNumber=3", testBucket, testKey))).willReturn( - aResponse() - .withStatus(400) - .withBody("400test error message"))); - AsyncResponseTransformer transformer = supplier.transformer(); - AsyncResponseTransformer.SplitResult split = transformer.split( - SplittingTransformerConfiguration.builder() - .bufferSizeInBytes(1024 * 32L) - .build()); - Subscriber> subscriber = new MultipartDownloaderSubscriber( - s3AsyncClient, - GetObjectRequest.builder() - .bucket(testBucket) - .key(testKey) - .build()); - - if (partSize > 1) { - split.publisher().subscribe(subscriber); - assertThatThrownBy(() -> { - T res = split.resultFuture().join(); - supplier.body(res); - }).hasMessageContaining("test error message"); - } else { - T res = split.resultFuture().join(); - assertNotNull(supplier.body(res)); - } - } - - private static Stream argumentsProvider() { - // amount of part, individual part size - List> partSizes = Arrays.asList( - Pair.of(4, 16), - Pair.of(1, 1024), - Pair.of(31, 1243), - Pair.of(16, 16 * 1024), - Pair.of(1, 1024 * 1024), - Pair.of(4, 1024 * 1024), - Pair.of(1, 4 * 1024 * 1024), - Pair.of(4, 6 * 1024 * 1024), - Pair.of(7, 5 * 3752) - ); - - Stream.Builder sb = Stream.builder(); - transformersSuppliers().forEach(tr -> partSizes.forEach(p -> sb.accept(arguments(tr, p.left(), p.right())))); - return sb.build(); - } - -} \ No newline at end of file diff --git a/services/s3/src/test/java/software/amazon/awssdk/services/s3/internal/multipart/S3MultipartClientGetObjectToBytesWiremockTest.java b/services/s3/src/test/java/software/amazon/awssdk/services/s3/internal/multipart/S3MultipartClientGetObjectToBytesWiremockTest.java new file mode 100644 index 000000000000..2a29d385999c --- /dev/null +++ b/services/s3/src/test/java/software/amazon/awssdk/services/s3/internal/multipart/S3MultipartClientGetObjectToBytesWiremockTest.java @@ -0,0 +1,495 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * A copy of the License is located at + * + * http://aws.amazon.com/apache2.0 + * + * or in the "license" file accompanying this file. This file is distributed + * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either + * express or implied. See the License for the specific language governing + * permissions and limitations under the License. + */ + +package software.amazon.awssdk.services.s3.internal.multipart; + +import static com.github.tomakehurst.wiremock.client.WireMock.aResponse; +import static com.github.tomakehurst.wiremock.client.WireMock.any; +import static com.github.tomakehurst.wiremock.client.WireMock.anyUrl; +import static com.github.tomakehurst.wiremock.client.WireMock.get; +import static com.github.tomakehurst.wiremock.client.WireMock.getRequestedFor; +import static com.github.tomakehurst.wiremock.client.WireMock.matching; +import static com.github.tomakehurst.wiremock.client.WireMock.stubFor; +import static com.github.tomakehurst.wiremock.client.WireMock.urlEqualTo; +import static com.github.tomakehurst.wiremock.client.WireMock.verify; +import static org.assertj.core.api.Assertions.assertThatThrownBy; +import static org.junit.jupiter.api.Assertions.assertArrayEquals; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNotEquals; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static software.amazon.awssdk.services.s3.internal.multipart.utils.MultipartDownloadTestUtils.internalErrorBody; +import static software.amazon.awssdk.services.s3.internal.multipart.utils.MultipartDownloadTestUtils.slowdownErrorBody; + +import com.github.tomakehurst.wiremock.http.Fault; +import com.github.tomakehurst.wiremock.junit5.WireMockRuntimeInfo; +import com.github.tomakehurst.wiremock.junit5.WireMockTest; +import com.github.tomakehurst.wiremock.stubbing.Scenario; +import java.net.URI; +import java.nio.charset.StandardCharsets; +import java.time.Duration; +import java.util.ArrayList; +import java.util.List; +import java.util.Random; +import java.util.UUID; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.CompletionException; +import java.util.concurrent.TimeUnit; +import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.Timeout; +import software.amazon.awssdk.auth.credentials.AwsBasicCredentials; +import software.amazon.awssdk.auth.credentials.StaticCredentialsProvider; +import software.amazon.awssdk.core.ResponseBytes; +import software.amazon.awssdk.core.async.AsyncResponseTransformer; +import software.amazon.awssdk.core.exception.SdkClientException; +import software.amazon.awssdk.core.interceptor.Context; +import software.amazon.awssdk.core.interceptor.ExecutionAttributes; +import software.amazon.awssdk.core.interceptor.ExecutionInterceptor; +import software.amazon.awssdk.http.SdkHttpResponse; +import software.amazon.awssdk.http.nio.netty.NettyNioAsyncHttpClient; +import software.amazon.awssdk.regions.Region; +import software.amazon.awssdk.services.s3.S3AsyncClient; +import software.amazon.awssdk.services.s3.model.GetObjectRequest; +import software.amazon.awssdk.services.s3.model.GetObjectResponse; +import software.amazon.awssdk.services.s3.model.S3Exception; + +@WireMockTest +@Timeout(value = 30, unit = TimeUnit.SECONDS) +public class S3MultipartClientGetObjectToBytesWiremockTest { + private static final CapturingInterceptor capturingInterceptor = new CapturingInterceptor(); + private static final String BUCKET = "Example-Bucket"; + private static final String KEY = "Key"; + private static final int MAX_ATTEMPTS = 7; + private static final int TOTAL_PARTS = 3; + private static final int PART_SIZE = 1024; + private static final byte[] PART_1_DATA = new byte[PART_SIZE]; + private static final byte[] PART_2_DATA = new byte[PART_SIZE]; + private static final byte[] PART_3_DATA = new byte[PART_SIZE]; + private static byte[] expectedBody; + private S3AsyncClient multipartClient; + + @BeforeAll + public static void init() { + new Random().nextBytes(PART_1_DATA); + new Random().nextBytes(PART_2_DATA); + new Random().nextBytes(PART_3_DATA); + + expectedBody = new byte[TOTAL_PARTS * PART_SIZE]; + System.arraycopy(PART_1_DATA, 0, expectedBody, 0, PART_SIZE); + System.arraycopy(PART_2_DATA, 0, expectedBody, PART_SIZE, PART_SIZE); + System.arraycopy(PART_3_DATA, 0, expectedBody, 2 * PART_SIZE, PART_SIZE); + } + + @BeforeEach + public void setup(WireMockRuntimeInfo wm) { + capturingInterceptor.clear(); + multipartClient = S3AsyncClient.builder() + .region(Region.US_EAST_1) + .endpointOverride(URI.create(wm.getHttpBaseUrl())) + .multipartEnabled(true) + .httpClientBuilder(NettyNioAsyncHttpClient.builder().maxConcurrency(100).connectionAcquisitionTimeout(Duration.ofSeconds(100))) + .credentialsProvider(StaticCredentialsProvider.create(AwsBasicCredentials.create("key", "secret"))) + .overrideConfiguration( + o -> o.retryStrategy(b -> b.maxAttempts(MAX_ATTEMPTS)) + .addExecutionInterceptor(capturingInterceptor)) + .build(); + } + + @Test + public void getObject_concurrentCallsReturn200_shouldSucceed() { + List>> futures = new ArrayList<>(); + + int numRuns = 1000; + for (int i = 0; i < numRuns; i++) { + CompletableFuture> resp = mock200Response(multipartClient, i); + futures.add(resp); + } + + CompletableFuture.allOf(futures.toArray(new CompletableFuture[0])).join(); + } + + @Test + public void getObject_single500WithinMany200s_shouldRetrySuccessfully() { + List>> futures = new ArrayList<>(); + + int numRuns = 1000; + for (int i = 0; i < numRuns; i++) { + CompletableFuture> resp = mock200Response(multipartClient, i); + futures.add(resp); + } + + CompletableFuture> requestWithRetryableError = + mockRetryableErrorThen200Response(multipartClient, 1); + futures.add(requestWithRetryableError); + + for (int i = 0; i < numRuns; i++) { + CompletableFuture> resp = mock200Response(multipartClient, i + 1000); + futures.add(resp); + } + + CompletableFuture.allOf(futures.toArray(new CompletableFuture[0])).join(); + } + + @Test + public void getObject_concurrent503s_shouldRetrySuccessfully() { + List>> futures = new ArrayList<>(); + + int numRuns = 100; + for (int i = 0; i < numRuns; i++) { + CompletableFuture> resp = mockRetryableErrorThen200Response(multipartClient, i); + futures.add(resp); + } + + CompletableFuture.allOf(futures.toArray(new CompletableFuture[0])).join(); + } + + @Test + public void getObject_5xxErrorResponses_shouldNotReuseInitialRequestId() { + String firstRequestId = UUID.randomUUID().toString(); + String secondRequestId = UUID.randomUUID().toString(); + + stubFor(any(anyUrl()) + .inScenario("errors") + .whenScenarioStateIs(Scenario.STARTED) + .willReturn(aResponse() + .withHeader("x-amz-request-id", firstRequestId) + .withStatus(503) + .withBody(internalErrorBody())) + .willSetStateTo("SecondAttempt")); + + stubFor(any(anyUrl()) + .inScenario("errors") + .whenScenarioStateIs("SecondAttempt") + .willReturn(aResponse() + .withHeader("x-amz-request-id", secondRequestId) + .withStatus(500))); + + assertThatThrownBy(() -> multipartClient.getObject(b -> b.bucket(BUCKET).key(KEY), + AsyncResponseTransformer.toBytes()).join()) + .isInstanceOf(CompletionException.class) + .hasCauseInstanceOf(S3Exception.class); + + + List responses = capturingInterceptor.getResponses(); + assertEquals(MAX_ATTEMPTS, responses.size()); + + String actualFirstRequestId = responses.get(0).firstMatchingHeader("x-amz-request-id").orElse(null); + String actualSecondRequestId = responses.get(1).firstMatchingHeader("x-amz-request-id").orElse(null); + + assertNotNull(actualFirstRequestId); + assertNotNull(actualSecondRequestId); + + assertNotEquals(actualFirstRequestId, actualSecondRequestId); + + assertEquals(firstRequestId, actualFirstRequestId); + assertEquals(secondRequestId, actualSecondRequestId); + + assertEquals(503, responses.get(0).statusCode()); + assertEquals(500, responses.get(1).statusCode()); + + verify(MAX_ATTEMPTS, getRequestedFor(urlEqualTo(String.format("/%s/%s?partNumber=1", BUCKET, KEY)))); + verify(0, getRequestedFor(urlEqualTo(String.format("/%s/%s?partNumber=2", BUCKET, KEY)))); + verify(0, getRequestedFor(urlEqualTo(String.format("/%s/%s?partNumber=3", BUCKET, KEY)))); + } + + + @Test + public void multipartDownload_200Response_shouldSucceed() { + stub200SuccessPart1(); + stub200SuccessPart2(); + stub200SuccessPart3(); + + CompletableFuture> future = + multipartClient.getObject(GetObjectRequest.builder().bucket(BUCKET).key(KEY).build(), + AsyncResponseTransformer.toBytes()); + + ResponseBytes response = future.join(); + byte[] actualBody = response.asByteArray(); + assertArrayEquals(expectedBody, actualBody); + + // Verify that all 3 parts were requested only once + verify(1, getRequestedFor(urlEqualTo(String.format("/%s/%s?partNumber=1", BUCKET, KEY)))); + verify(1, getRequestedFor(urlEqualTo(String.format("/%s/%s?partNumber=2", BUCKET, KEY)))); + verify(1, getRequestedFor(urlEqualTo(String.format("/%s/%s?partNumber=3", BUCKET, KEY)))); + } + + @Test + public void multipartDownload_secondPartNonRetryableError_shouldFail() { + stub200SuccessPart1(); + stubError(2, internalErrorBody()); + + CompletableFuture> future = + multipartClient.getObject(GetObjectRequest.builder().bucket(BUCKET).key(KEY).build(), + AsyncResponseTransformer.toBytes()); + + assertThatThrownBy(future::join).hasCauseInstanceOf(S3Exception.class) + .hasMessageContaining("We encountered an internal error. Please try again. (Service: S3, Status Code: 500"); + + verify(1, getRequestedFor(urlEqualTo(String.format("/%s/%s?partNumber=1", BUCKET, KEY)))); + verify(MAX_ATTEMPTS, getRequestedFor(urlEqualTo(String.format("/%s/%s?partNumber=2", BUCKET, KEY)))); + verify(0, getRequestedFor(urlEqualTo(String.format("/%s/%s?partNumber=3", BUCKET, KEY)))); + } + + @Test + public void multipartDownload_503OnFirstPart_shouldRetrySuccessfully() { + // Stub Part 1 - 503 on first attempt, 200 on retry + String part1Scenario = "part1-retry"; + stubFor(get(urlEqualTo(String.format("/%s/%s?partNumber=1", BUCKET, KEY))) + .inScenario(part1Scenario) + .whenScenarioStateIs(Scenario.STARTED) + .willReturn(aResponse() + .withStatus(503) + .withHeader("x-amz-request-id", UUID.randomUUID().toString()) + .withBody("\n" + + "\n" + + " SlowDown\n" + + " Please reduce your request rate.\n" + + "")) + .willSetStateTo("retry-attempt")); + + stubFor(get(urlEqualTo(String.format("/%s/%s?partNumber=1", BUCKET, KEY))) + .inScenario(part1Scenario) + .whenScenarioStateIs("retry-attempt") + .willReturn(aResponse() + .withStatus(200) + .withHeader("x-amz-mp-parts-count", String.valueOf(TOTAL_PARTS)) + .withHeader("x-amz-request-id", UUID.randomUUID().toString()) + .withBody(PART_1_DATA))); + + stub200SuccessPart2(); + stub200SuccessPart3(); + + CompletableFuture> future = + multipartClient.getObject(GetObjectRequest.builder().bucket(BUCKET).key(KEY).build(), + AsyncResponseTransformer.toBytes()); + + ResponseBytes response = future.join(); + byte[] actualBody = response.asByteArray(); + assertArrayEquals(expectedBody, actualBody, "Downloaded body should match expected combined parts"); + + // Verify that part 1 was requested twice (initial 503 + retry) + verify(2, getRequestedFor(urlEqualTo(String.format("/%s/%s?partNumber=1", BUCKET, KEY)))); + verify(1, getRequestedFor(urlEqualTo(String.format("/%s/%s?partNumber=2", BUCKET, KEY)))); + verify(1, getRequestedFor(urlEqualTo(String.format("/%s/%s?partNumber=3", BUCKET, KEY)))); + } + + @Test + public void multipartDownload_503OnFirstPartAndSecondPart_shouldRetrySuccessfully() { + // Stub Part 1 - 503 on first attempt, 200 on retry + stubFor(get(urlEqualTo(String.format("/%s/%s?partNumber=1", BUCKET, KEY))) + .inScenario("part1-retry") + .whenScenarioStateIs(Scenario.STARTED) + .willReturn(aResponse() + .withStatus(503) + .withHeader("x-amz-request-id", UUID.randomUUID().toString()) + .withBody(slowdownErrorBody())) + .willSetStateTo("retry-attempt")); + + stubFor(get(urlEqualTo(String.format("/%s/%s?partNumber=1", BUCKET, KEY))) + .inScenario("part1-retry") + .whenScenarioStateIs("retry-attempt") + .willReturn(aResponse() + .withStatus(200) + .withHeader("x-amz-mp-parts-count", String.valueOf(TOTAL_PARTS)) + .withHeader("x-amz-request-id", UUID.randomUUID().toString()) + .withBody(PART_1_DATA))); + + + // Stub Part 2 - 503 on first attempt, 200 on retry + stubFor(get(urlEqualTo(String.format("/%s/%s?partNumber=2", BUCKET, KEY))) + .inScenario("part2-retry") + .whenScenarioStateIs(Scenario.STARTED) + .willReturn(aResponse() + .withStatus(500) + .withHeader("x-amz-request-id", UUID.randomUUID().toString()) + .withBody(internalErrorBody())) + .willSetStateTo("retry-attempt")); + + stubFor(get(urlEqualTo(String.format("/%s/%s?partNumber=2", BUCKET, KEY))) + .inScenario("part2-retry") + .whenScenarioStateIs("retry-attempt") + .willReturn(aResponse() + .withStatus(200) + .withHeader("x-amz-mp-parts-count", String.valueOf(TOTAL_PARTS)) + .withHeader("x-amz-request-id", UUID.randomUUID().toString()) + .withBody(PART_2_DATA))); + + stub200SuccessPart3(); + ResponseBytes response = multipartClient.getObject(b -> b.bucket(BUCKET).key(KEY), + AsyncResponseTransformer.toBytes()).join(); + + byte[] actualBody = response.asByteArray(); + assertArrayEquals(expectedBody, actualBody); + + // Verify that part 1 was requested twice (initial 503 + retry) + verify(2, getRequestedFor(urlEqualTo(String.format("/%s/%s?partNumber=1", BUCKET, KEY)))); + // Verify that part 2 was requested once (no retry) + verify(2, getRequestedFor(urlEqualTo(String.format("/%s/%s?partNumber=2", BUCKET, KEY)))); + verify(1, getRequestedFor(urlEqualTo(String.format("/%s/%s?partNumber=3", BUCKET, KEY)))); + } + + @Test + public void getObject_ioExceptionOnly_shouldExhaustRetriesAndFail() { + stubIoError(1); + stub200SuccessPart2(); + stub200SuccessPart3(); + assertThatThrownBy(() -> multipartClient.getObject(b -> b.bucket(BUCKET).key(KEY), + AsyncResponseTransformer.toBytes()).join()) + .isInstanceOf(CompletionException.class) + .hasCauseInstanceOf(SdkClientException.class); + + verify(MAX_ATTEMPTS, getRequestedFor(urlEqualTo(String.format("/%s/%s?partNumber=1", BUCKET, KEY)))); + verify(0, getRequestedFor(urlEqualTo(String.format("/%s/%s?partNumber=2", BUCKET, KEY)))); + verify(0, getRequestedFor(urlEqualTo(String.format("/%s/%s?partNumber=3", BUCKET, KEY)))); + } + + @Test + public void getObject_iOErrorThen200Response_shouldRetrySuccessfully() { + String requestId = UUID.randomUUID().toString(); + + stubFor(any(anyUrl()) + .inScenario("io-error") + .whenScenarioStateIs(Scenario.STARTED) + .willReturn(aResponse() + .withFault(Fault.CONNECTION_RESET_BY_PEER)) + .willSetStateTo("retry")); + + stubFor(any(anyUrl()) + .inScenario("io-error") + .whenScenarioStateIs("retry") + .willReturn(aResponse() + .withStatus(200) + .withHeader("x-amz-request-id", requestId) + .withBody("Hello World"))); + + ResponseBytes response = multipartClient.getObject(GetObjectRequest.builder() + .bucket(BUCKET) + .key(KEY) + .build(), + AsyncResponseTransformer.toBytes()).join(); + + assertArrayEquals("Hello World".getBytes(StandardCharsets.UTF_8), response.asByteArray()); + + verify(2, getRequestedFor(urlEqualTo("/" + BUCKET + "/" + KEY + "?partNumber=1"))); + + List responses = capturingInterceptor.getResponses(); + String finalRequestId = responses.get(responses.size() - 1) + .firstMatchingHeader("x-amz-request-id") + .orElse(null); + + assertEquals(requestId, finalRequestId); + } + + private void stubError(int partNumber, String errorBody) { + stubFor(get(urlEqualTo(String.format("/%s/%s?partNumber=%d", BUCKET, KEY, partNumber))) + .willReturn(aResponse() + .withHeader("x-amz-request-id", String.valueOf(UUID.randomUUID())) + .withHeader("x-amz-mp-parts-count", String.valueOf(TOTAL_PARTS)) + .withStatus(500).withBody(errorBody))); + } + + private void stubIoError(int partNumber) { + stubFor(get(urlEqualTo(String.format("/%s/%s?partNumber=%d", BUCKET, KEY, partNumber))) + .willReturn(aResponse() + .withFault(Fault.CONNECTION_RESET_BY_PEER))); + } + + private void stub200SuccessPart1() { + stubFor(get(urlEqualTo(String.format("/%s/%s?partNumber=1", BUCKET, KEY))) + .willReturn(aResponse() + .withHeader("x-amz-request-id", String.valueOf(UUID.randomUUID())) + .withHeader("x-amz-mp-parts-count", String.valueOf(TOTAL_PARTS)) + .withStatus(200).withBody(PART_1_DATA))); + } + + private void stub200SuccessPart2() { + stubFor(get(urlEqualTo(String.format("/%s/%s?partNumber=2", BUCKET, KEY))) + .willReturn(aResponse() + .withStatus(200) + .withHeader("x-amz-mp-parts-count", String.valueOf(TOTAL_PARTS)) + .withHeader("x-amz-request-id", UUID.randomUUID().toString()) + .withBody(PART_2_DATA))); + } + + private void stub200SuccessPart3() { + stubFor(get(urlEqualTo(String.format("/%s/%s?partNumber=3", BUCKET, KEY))) + .willReturn(aResponse() + .withStatus(200) + .withHeader("x-amz-mp-parts-count", String.valueOf(TOTAL_PARTS)) + .withHeader("x-amz-request-id", UUID.randomUUID().toString()) + .withBody(PART_3_DATA))); + } + + private CompletableFuture> mock200Response(S3AsyncClient s3Client, int runNumber) { + String runId = runNumber + " success"; + + stubFor(any(anyUrl()) + .withHeader("RunNum", matching(runId)) + .inScenario(runId) + .whenScenarioStateIs(Scenario.STARTED) + .willReturn(aResponse().withStatus(200) + .withHeader("x-amz-request-id", String.valueOf(UUID.randomUUID())) + .withBody("Hello World"))); + + return s3Client.getObject(r -> r.bucket(BUCKET).key("key") + .overrideConfiguration(c -> c.putHeader("RunNum", runId)), + AsyncResponseTransformer.toBytes()); + } + + private CompletableFuture> mockRetryableErrorThen200Response(S3AsyncClient s3Client, int runNumber) { + String runId = String.valueOf(runNumber); + + stubFor(any(anyUrl()) + .withHeader("RunNum", matching(runId)) + .inScenario(runId) + .whenScenarioStateIs(Scenario.STARTED) + .willReturn(aResponse() + .withHeader("x-amz-request-id", String.valueOf(UUID.randomUUID())) + .withStatus(500) + .withBody(internalErrorBody()) + ) + .willSetStateTo("SecondAttempt" + runId)); + + stubFor(any(anyUrl()) + .inScenario(runId) + .withHeader("RunNum", matching(runId)) + .whenScenarioStateIs("SecondAttempt" + runId) + .willReturn(aResponse().withStatus(200) + .withHeader("x-amz-request-id", String.valueOf(UUID.randomUUID())) + .withBody("Hello World"))); + + return s3Client.getObject(r -> r.bucket(BUCKET).key("key") + .overrideConfiguration(c -> c.putHeader("RunNum", runId)), + AsyncResponseTransformer.toBytes()); + } + + static class CapturingInterceptor implements ExecutionInterceptor { + private final List responses = new ArrayList<>(); + + @Override + public void afterTransmission(Context.AfterTransmission context, ExecutionAttributes executionAttributes) { + responses.add(context.httpResponse()); + } + + public List getResponses() { + return new ArrayList<>(responses); + } + + public void clear() { + responses.clear(); + } + } +} diff --git a/services/s3/src/test/java/software/amazon/awssdk/services/s3/internal/multipart/S3MultipartClientGetObjectWiremockTest.java b/services/s3/src/test/java/software/amazon/awssdk/services/s3/internal/multipart/S3MultipartClientGetObjectWiremockTest.java index 5869b1a82733..fb93ab8ec0ce 100644 --- a/services/s3/src/test/java/software/amazon/awssdk/services/s3/internal/multipart/S3MultipartClientGetObjectWiremockTest.java +++ b/services/s3/src/test/java/software/amazon/awssdk/services/s3/internal/multipart/S3MultipartClientGetObjectWiremockTest.java @@ -25,18 +25,22 @@ import static com.github.tomakehurst.wiremock.client.WireMock.urlEqualTo; import static com.github.tomakehurst.wiremock.client.WireMock.verify; import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; import static org.assertj.core.api.Assertions.fail; +import static org.junit.jupiter.api.Assertions.assertArrayEquals; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.params.provider.Arguments.arguments; +import static software.amazon.awssdk.services.s3.internal.multipart.utils.MultipartDownloadTestUtils.internalErrorBody; +import static software.amazon.awssdk.services.s3.internal.multipart.utils.MultipartDownloadTestUtils.transformersSuppliers; +import com.github.tomakehurst.wiremock.http.Fault; import com.github.tomakehurst.wiremock.junit5.WireMockRuntimeInfo; import com.github.tomakehurst.wiremock.junit5.WireMockTest; import com.github.tomakehurst.wiremock.stubbing.Scenario; -import java.io.IOException; -import java.io.UncheckedIOException; import java.net.URI; -import java.nio.file.Files; -import java.nio.file.Path; import java.time.Duration; import java.util.ArrayList; +import java.util.Arrays; import java.util.List; import java.util.UUID; import java.util.concurrent.CompletableFuture; @@ -46,26 +50,40 @@ import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Timeout; import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.Arguments; import org.junit.jupiter.params.provider.MethodSource; import software.amazon.awssdk.auth.credentials.AwsBasicCredentials; import software.amazon.awssdk.auth.credentials.StaticCredentialsProvider; +import software.amazon.awssdk.core.SplittingTransformerConfiguration; import software.amazon.awssdk.core.async.AsyncResponseTransformer; +import software.amazon.awssdk.core.internal.async.ByteArrayAsyncResponseTransformer; +import software.amazon.awssdk.core.internal.async.FileAsyncResponseTransformer; +import software.amazon.awssdk.core.internal.async.InputStreamResponseTransformer; +import software.amazon.awssdk.core.internal.async.PublisherAsyncResponseTransformer; import software.amazon.awssdk.http.nio.netty.NettyNioAsyncHttpClient; import software.amazon.awssdk.regions.Region; import software.amazon.awssdk.services.s3.S3AsyncClient; +import software.amazon.awssdk.services.s3.internal.multipart.utils.MultipartDownloadTestUtils; import software.amazon.awssdk.services.s3.model.GetObjectResponse; import software.amazon.awssdk.services.s3.model.S3Exception; +import software.amazon.awssdk.services.s3.utils.AsyncResponseTransformerTestSupplier; +import software.amazon.awssdk.utils.Pair; @WireMockTest @Timeout(value = 45, unit = TimeUnit.SECONDS) public class S3MultipartClientGetObjectWiremockTest { private static final String BUCKET = "Example-Bucket"; private static final String KEY = "Key"; - private static int fileCounter = 0; + private static final int MAX_ATTEMPTS = 3; private S3AsyncClient multipartClient; + private MultipartDownloadTestUtils util; @BeforeEach - public void setup(WireMockRuntimeInfo wm) { + public void setup(WireMockRuntimeInfo wm) + { + wm.getWireMock().resetRequests(); + wm.getWireMock().resetScenarios(); + wm.getWireMock().resetMappings(); multipartClient = S3AsyncClient.builder() .region(Region.US_EAST_1) .endpointOverride(URI.create(wm.getHttpBaseUrl())) @@ -74,40 +92,96 @@ public void setup(WireMockRuntimeInfo wm) { .maxConcurrency(100) .connectionAcquisitionTimeout(Duration.ofSeconds(60))) .credentialsProvider(StaticCredentialsProvider.create(AwsBasicCredentials.create("key", "secret"))) + .overrideConfiguration(o -> o.retryStrategy(b -> b.maxAttempts(MAX_ATTEMPTS))) .build(); + util = new MultipartDownloadTestUtils(BUCKET, KEY, UUID.randomUUID().toString()); } - private static Stream responseTransformerFactories() { - return Stream.of( - AsyncResponseTransformer::toBytes, - AsyncResponseTransformer::toBlockingInputStream, - AsyncResponseTransformer::toPublisher, - () -> { - try { - Path tempDir = Files.createTempDirectory("s3-test"); - Path tempFile = tempDir.resolve("testFile" + fileCounter + ".txt"); - fileCounter++; - tempFile.toFile().deleteOnExit(); - return AsyncResponseTransformer.toFile(tempFile); - } catch (IOException e) { - throw new UncheckedIOException(e); - } - } - ); + @ParameterizedTest + @MethodSource("partSizeAndTransformerParams") + void happyPath_shouldReceiveAllBodyPartInCorrectOrder(AsyncResponseTransformerTestSupplier supplier, + int amountOfPartToTest, + int partSize) { + byte[] expectedBody = util.stubAllParts(BUCKET, KEY, amountOfPartToTest, partSize); + AsyncResponseTransformer transformer = supplier.transformer(); + + T response = multipartClient.getObject(b -> b.bucket(BUCKET).key(KEY), transformer).join(); + + byte[] body = supplier.body(response); + assertArrayEquals(expectedBody, body); + util.verifyCorrectAmountOfRequestsMade(amountOfPartToTest); } - interface TransformerFactory { - AsyncResponseTransformer create(); + @ParameterizedTest + @MethodSource("partSizeAndTransformerParams") + void nonRetryableErrorOnThirdPart_shouldCompleteExceptionallyOnlyPartsGreaterThanTwo( + AsyncResponseTransformerTestSupplier supplier, + int amountOfPartToTest, + int partSize) { + util.stubForPart(BUCKET, KEY, 1, 3, partSize); + util.stubForPart(BUCKET, KEY, 2, 3, partSize); + stubFor(get(urlEqualTo(String.format("/%s/%s?partNumber=3", BUCKET, KEY))).willReturn( + aResponse() + .withStatus(400) + .withBody("400test error message"))); + AsyncResponseTransformer transformer = supplier.transformer(); + AsyncResponseTransformer.SplitResult split = transformer.split( + SplittingTransformerConfiguration.builder() + .bufferSizeInBytes(1024 * 32L) + .build()); + + if (partSize > 1) { + assertThatThrownBy(() -> { + T res = multipartClient.getObject(b -> b.bucket(BUCKET).key(KEY), transformer).join(); + supplier.body(res); + }).hasMessageContaining("test error message"); + } else { + T res = split.resultFuture().join(); + assertNotNull(supplier.body(res)); + } } @ParameterizedTest - @MethodSource("responseTransformerFactories") - public void getObject_single500WithinMany200s_shouldNotRetryError(TransformerFactory transformerFactory) { + @MethodSource("responseTransformers") + void nonRetryableErrorOnFirstPart_shouldFail(AsyncResponseTransformerTestSupplier supplier) { + stubFor(get(urlEqualTo(String.format("/%s/%s?partNumber=1", BUCKET, KEY))).willReturn( + aResponse() + .withStatus(400) + .withBody("400test error message"))); + AsyncResponseTransformer transformer = supplier.transformer(); + + assertThatThrownBy(() -> multipartClient.getObject(b -> b.bucket(BUCKET).key(KEY), transformer).join()) + .hasMessageContaining("test error message"); + } + + @ParameterizedTest + @MethodSource("responseTransformers") + public void ioError_shouldFailAndNotRetry() { + stubFor(get(urlEqualTo(String.format("/%s/%s?partNumber=1", BUCKET, KEY))) + .willReturn(aResponse() + .withFault(Fault.CONNECTION_RESET_BY_PEER))); + + assertThatThrownBy(() -> multipartClient.getObject(b -> b.bucket(BUCKET).key(KEY), + AsyncResponseTransformer.toBlockingInputStream()).join()) + .satisfiesAnyOf( + throwable -> assertThat(throwable) + .hasMessageContaining("The connection was closed during the request"), + + throwable -> assertThat(throwable) + .hasMessageContaining("Connection reset") + ); + + verify(1, getRequestedFor(urlEqualTo(String.format("/%s/%s?partNumber=1", BUCKET, KEY)))); + } + + @ParameterizedTest + @MethodSource("responseTransformers") + public void getObject_single500WithinMany200s_shouldNotRetryError(AsyncResponseTransformerTestSupplier transformerSupplier) { List> futures = new ArrayList<>(); - int numRuns = 100; + int numRuns = 50; for (int i = 0; i < numRuns; i++) { - CompletableFuture resp = mock200Response(multipartClient, i, transformerFactory); + CompletableFuture resp = mock200Response(multipartClient, i, transformerSupplier); futures.add(resp); } @@ -130,11 +204,11 @@ public void getObject_single500WithinMany200s_shouldNotRetryError(TransformerFac .withBody("Hello World"))); CompletableFuture requestWithRetryableError = - multipartClient.getObject(r -> r.bucket(BUCKET).key(errorKey), transformerFactory.create()); + multipartClient.getObject(r -> r.bucket(BUCKET).key(errorKey), transformerSupplier.transformer()); futures.add(requestWithRetryableError); for (int i = 0; i < numRuns; i++) { - CompletableFuture resp = mock200Response(multipartClient, i + 1000, transformerFactory); + CompletableFuture resp = mock200Response(multipartClient, i + 1000, transformerSupplier); futures.add(resp); } @@ -148,7 +222,42 @@ public void getObject_single500WithinMany200s_shouldNotRetryError(TransformerFac verify(1, getRequestedFor(urlEqualTo(String.format("/%s/%s?partNumber=1", BUCKET, errorKey)))); } - private CompletableFuture mock200Response(S3AsyncClient s3Client, int runNumber, TransformerFactory transformerFactory) { + private static Stream partSizeAndTransformerParams() { + // amount of part, individual part size + List> partSizes = Arrays.asList( + Pair.of(4, 16), + Pair.of(1, 1024), + Pair.of(31, 1243), + Pair.of(16, 16 * 1024), + Pair.of(1, 1024 * 1024), + Pair.of(4, 1024 * 1024), + Pair.of(1, 4 * 1024 * 1024), + Pair.of(4, 6 * 1024 * 1024), + Pair.of(7, 5 * 3752) + ); + + Stream.Builder sb = Stream.builder(); + transformersSuppliers().forEach(tr -> partSizes.forEach(p -> sb.accept(arguments(tr, p.left(), p.right())))); + return sb.build(); + } + + + /** + * Testing {@link PublisherAsyncResponseTransformer}, {@link InputStreamResponseTransformer}, and + * {@link FileAsyncResponseTransformer} + *

+ * + * Retry for multipart download is supported for {@link ByteArrayAsyncResponseTransformer}, tested in + * {@link S3MultipartClientGetObjectToBytesWiremockTest}. + */ + private static Stream> responseTransformers() { + return Stream.of(new AsyncResponseTransformerTestSupplier.InputStreamArtSupplier(), + new AsyncResponseTransformerTestSupplier.PublisherArtSupplier(), + new AsyncResponseTransformerTestSupplier.FileArtSupplier()); + } + + private CompletableFuture mock200Response(S3AsyncClient s3Client, int runNumber, + AsyncResponseTransformerTestSupplier transformerSupplier) { String runId = runNumber + " success"; stubFor(any(anyUrl()) @@ -161,18 +270,6 @@ private CompletableFuture mock200Response(S3AsyncClient s3Client, int runNumb return s3Client.getObject(r -> r.bucket(BUCKET).key(KEY) .overrideConfiguration(c -> c.putHeader("RunNum", runId)), - transformerFactory.create()); - } - - private String errorBody(String errorCode, String errorMessage) { - return "\n" - + "\n" - + " " + errorCode + "\n" - + " " + errorMessage + "\n" - + ""; - } - - private String internalErrorBody() { - return errorBody("InternalError", "We encountered an internal error. Please try again."); + transformerSupplier.transformer()); } } diff --git a/services/s3/src/test/java/software/amazon/awssdk/services/s3/internal/multipart/UploadObjectHelperTest.java b/services/s3/src/test/java/software/amazon/awssdk/services/s3/internal/multipart/UploadObjectHelperTest.java index 284a392086df..20ae807dc334 100644 --- a/services/s3/src/test/java/software/amazon/awssdk/services/s3/internal/multipart/UploadObjectHelperTest.java +++ b/services/s3/src/test/java/software/amazon/awssdk/services/s3/internal/multipart/UploadObjectHelperTest.java @@ -24,10 +24,10 @@ import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; -import static software.amazon.awssdk.services.s3.internal.multipart.MpuTestUtils.s3ResumeToken; -import static software.amazon.awssdk.services.s3.internal.multipart.MpuTestUtils.stubSuccessfulCompleteMultipartCall; -import static software.amazon.awssdk.services.s3.internal.multipart.MpuTestUtils.stubSuccessfulCreateMultipartCall; -import static software.amazon.awssdk.services.s3.internal.multipart.MpuTestUtils.stubSuccessfulUploadPartCalls; +import static software.amazon.awssdk.services.s3.internal.multipart.utils.MultipartUploadTestUtils.s3ResumeToken; +import static software.amazon.awssdk.services.s3.internal.multipart.utils.MultipartUploadTestUtils.stubSuccessfulCompleteMultipartCall; +import static software.amazon.awssdk.services.s3.internal.multipart.utils.MultipartUploadTestUtils.stubSuccessfulCreateMultipartCall; +import static software.amazon.awssdk.services.s3.internal.multipart.utils.MultipartUploadTestUtils.stubSuccessfulUploadPartCalls; import static software.amazon.awssdk.services.s3.multipart.S3MultipartExecutionAttribute.PAUSE_OBSERVABLE; import static software.amazon.awssdk.services.s3.multipart.S3MultipartExecutionAttribute.RESUME_TOKEN; diff --git a/services/s3/src/test/java/software/amazon/awssdk/services/s3/internal/multipart/UploadWithUnknownContentLengthHelperTest.java b/services/s3/src/test/java/software/amazon/awssdk/services/s3/internal/multipart/UploadWithUnknownContentLengthHelperTest.java index 972f0b86241a..90c5bfed038c 100644 --- a/services/s3/src/test/java/software/amazon/awssdk/services/s3/internal/multipart/UploadWithUnknownContentLengthHelperTest.java +++ b/services/s3/src/test/java/software/amazon/awssdk/services/s3/internal/multipart/UploadWithUnknownContentLengthHelperTest.java @@ -22,15 +22,13 @@ import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; -import static software.amazon.awssdk.services.s3.internal.multipart.MpuTestUtils.stubSuccessfulCompleteMultipartCall; -import static software.amazon.awssdk.services.s3.internal.multipart.MpuTestUtils.stubSuccessfulCreateMultipartCall; -import static software.amazon.awssdk.services.s3.internal.multipart.MpuTestUtils.stubSuccessfulUploadPartCalls; +import static software.amazon.awssdk.services.s3.internal.multipart.utils.MultipartUploadTestUtils.stubSuccessfulCompleteMultipartCall; +import static software.amazon.awssdk.services.s3.internal.multipart.utils.MultipartUploadTestUtils.stubSuccessfulCreateMultipartCall; +import static software.amazon.awssdk.services.s3.internal.multipart.utils.MultipartUploadTestUtils.stubSuccessfulUploadPartCalls; -import java.io.ByteArrayInputStream; import java.io.FileInputStream; import java.io.FileNotFoundException; import java.io.IOException; -import java.io.InputStream; import java.util.List; import java.util.Optional; import java.util.concurrent.CompletableFuture; @@ -52,12 +50,10 @@ import software.amazon.awssdk.services.s3.S3AsyncClient; import software.amazon.awssdk.services.s3.model.CompleteMultipartUploadRequest; import software.amazon.awssdk.services.s3.model.CompletedPart; -import software.amazon.awssdk.services.s3.model.CreateMultipartUploadRequest; import software.amazon.awssdk.services.s3.model.PutObjectRequest; import software.amazon.awssdk.services.s3.model.PutObjectResponse; import software.amazon.awssdk.services.s3.model.UploadPartRequest; import software.amazon.awssdk.testutils.RandomTempFile; -import software.amazon.awssdk.utils.StringInputStream; public class UploadWithUnknownContentLengthHelperTest { private static final String BUCKET = "bucket"; diff --git a/services/s3/src/test/java/software/amazon/awssdk/services/s3/internal/multipart/MultipartDownloadTestUtil.java b/services/s3/src/test/java/software/amazon/awssdk/services/s3/internal/multipart/utils/MultipartDownloadTestUtils.java similarity index 73% rename from services/s3/src/test/java/software/amazon/awssdk/services/s3/internal/multipart/MultipartDownloadTestUtil.java rename to services/s3/src/test/java/software/amazon/awssdk/services/s3/internal/multipart/utils/MultipartDownloadTestUtils.java index 708972b6b0d7..ac667912a33f 100644 --- a/services/s3/src/test/java/software/amazon/awssdk/services/s3/internal/multipart/MultipartDownloadTestUtil.java +++ b/services/s3/src/test/java/software/amazon/awssdk/services/s3/internal/multipart/utils/MultipartDownloadTestUtils.java @@ -13,7 +13,7 @@ * permissions and limitations under the License. */ -package software.amazon.awssdk.services.s3.internal.multipart; +package software.amazon.awssdk.services.s3.internal.multipart.utils; import static com.github.tomakehurst.wiremock.client.WireMock.aResponse; import static com.github.tomakehurst.wiremock.client.WireMock.get; @@ -28,18 +28,14 @@ import java.util.Random; import software.amazon.awssdk.services.s3.utils.AsyncResponseTransformerTestSupplier; -public class MultipartDownloadTestUtil { +public class MultipartDownloadTestUtils { - private static final String RETRY_SCENARIO = "retry"; - private static final String SUCCESS_STATE = "success"; - private static final String FAILED_STATE = "failed"; + private final String testBucket; + private final String testKey; + private final String eTag; + private final Random random = new Random(); - private String testBucket; - private String testKey; - private String eTag; - private Random random = new Random(); - - public MultipartDownloadTestUtil(String testBucket, String testKey, String eTag) { + public MultipartDownloadTestUtils(String testBucket, String testKey, String eTag) { this.testBucket = testBucket; this.testKey = testKey; this.eTag = eTag; @@ -82,17 +78,19 @@ public void verifyCorrectAmountOfRequestsMade(int amountOfPartToTest) { verify(0, getRequestedFor(urlMatching(String.format(urlTemplate, amountOfPartToTest + 1)))); } - public byte[] stubForPartSuccess(int part, int totalPart, int partSize) { - byte[] body = new byte[partSize]; - random.nextBytes(body); - stubFor(get(urlEqualTo(String.format("/%s/%s?partNumber=%d", testBucket, testKey, part))) - .inScenario(RETRY_SCENARIO) - .whenScenarioStateIs(SUCCESS_STATE) - .willReturn( - aResponse() - .withHeader("x-amz-mp-parts-count", totalPart + "") - .withHeader("ETag", eTag) - .withBody(body))); - return body; + public static String errorBody(String errorCode, String errorMessage) { + return "\n" + + "\n" + + " " + errorCode + "\n" + + " " + errorMessage + "\n" + + ""; + } + + public static String internalErrorBody() { + return errorBody("InternalError", "We encountered an internal error. Please try again."); + } + + public static String slowdownErrorBody() { + return errorBody("SlowDown", "Please reduce your request rate."); } } diff --git a/services/s3/src/test/java/software/amazon/awssdk/services/s3/internal/multipart/MpuTestUtils.java b/services/s3/src/test/java/software/amazon/awssdk/services/s3/internal/multipart/utils/MultipartUploadTestUtils.java similarity index 96% rename from services/s3/src/test/java/software/amazon/awssdk/services/s3/internal/multipart/MpuTestUtils.java rename to services/s3/src/test/java/software/amazon/awssdk/services/s3/internal/multipart/utils/MultipartUploadTestUtils.java index 23fe07ab2743..9a97bf70c1e8 100644 --- a/services/s3/src/test/java/software/amazon/awssdk/services/s3/internal/multipart/MpuTestUtils.java +++ b/services/s3/src/test/java/software/amazon/awssdk/services/s3/internal/multipart/utils/MultipartUploadTestUtils.java @@ -13,7 +13,7 @@ * permissions and limitations under the License. */ -package software.amazon.awssdk.services.s3.internal.multipart; +package software.amazon.awssdk.services.s3.internal.multipart.utils; import static org.mockito.ArgumentMatchers.any; import static org.mockito.Mockito.when; @@ -33,9 +33,9 @@ import software.amazon.awssdk.services.s3.model.UploadPartResponse; import software.amazon.awssdk.services.s3.multipart.S3ResumeToken; -public final class MpuTestUtils { +public final class MultipartUploadTestUtils { - private MpuTestUtils() { + private MultipartUploadTestUtils() { } public static void stubSuccessfulHeadObjectCall(long contentLength, S3AsyncClient s3AsyncClient) {