diff --git a/.changes/next-release/bugfix-AmazonS3-6522f77.json b/.changes/next-release/bugfix-AmazonS3-6522f77.json new file mode 100644 index 000000000000..3f4b2cb0c6ff --- /dev/null +++ b/.changes/next-release/bugfix-AmazonS3-6522f77.json @@ -0,0 +1,6 @@ +{ + "type": "bugfix", + "category": "Amazon S3", + "contributor": "", + "description": "Add additional validations for multipart upload operations in the Java multipart S3 client." +} diff --git a/bom-internal/pom.xml b/bom-internal/pom.xml index 70c0248c0df8..8bceb73e63f9 100644 --- a/bom-internal/pom.xml +++ b/bom-internal/pom.xml @@ -235,6 +235,12 @@ ${rxjava.version} test + + io.reactivex.rxjava3 + rxjava + ${rxjava3.version} + test + commons-lang3 org.apache.commons diff --git a/pom.xml b/pom.xml index 43aedc0ee458..ac9378780d2f 100644 --- a/pom.xml +++ b/pom.xml @@ -124,6 +124,7 @@ 3.10.0 3.5.101 2.2.21 + 3.1.5 1.17.1 1.37 0.38.1 diff --git a/services/s3/pom.xml b/services/s3/pom.xml index f0ca2e658780..b8b76dfba495 100644 --- a/services/s3/pom.xml +++ b/services/s3/pom.xml @@ -230,5 +230,10 @@ jimfs test + + io.reactivex.rxjava3 + rxjava + test + diff --git a/services/s3/src/main/java/software/amazon/awssdk/services/s3/internal/multipart/KnownContentLengthAsyncRequestBodySubscriber.java b/services/s3/src/main/java/software/amazon/awssdk/services/s3/internal/multipart/KnownContentLengthAsyncRequestBodySubscriber.java index 7a27e60e31dc..93bc0dfeb6f8 100644 --- a/services/s3/src/main/java/software/amazon/awssdk/services/s3/internal/multipart/KnownContentLengthAsyncRequestBodySubscriber.java +++ b/services/s3/src/main/java/software/amazon/awssdk/services/s3/internal/multipart/KnownContentLengthAsyncRequestBodySubscriber.java @@ -15,11 +15,14 @@ package software.amazon.awssdk.services.s3.internal.multipart; +import static software.amazon.awssdk.services.s3.internal.multipart.MultipartUploadHelper.contentLengthMismatchForPart; +import static software.amazon.awssdk.services.s3.internal.multipart.MultipartUploadHelper.partNumMismatch; import static software.amazon.awssdk.services.s3.multipart.S3MultipartExecutionAttribute.JAVA_PROGRESS_LISTENER; import java.util.Collection; import java.util.HashMap; import java.util.Map; +import java.util.Optional; import java.util.concurrent.CompletableFuture; import java.util.concurrent.ConcurrentLinkedQueue; import java.util.concurrent.atomic.AtomicBoolean; @@ -32,6 +35,7 @@ import software.amazon.awssdk.annotations.SdkInternalApi; import software.amazon.awssdk.core.async.AsyncRequestBody; import software.amazon.awssdk.core.async.listener.PublisherListener; +import software.amazon.awssdk.core.exception.SdkClientException; import software.amazon.awssdk.services.s3.model.CompleteMultipartUploadResponse; import software.amazon.awssdk.services.s3.model.CompletedPart; import software.amazon.awssdk.services.s3.model.PutObjectRequest; @@ -54,10 +58,10 @@ public class KnownContentLengthAsyncRequestBodySubscriber implements Subscriber< private final AtomicBoolean failureActionInitiated = new AtomicBoolean(false); private final AtomicInteger partNumber = new AtomicInteger(1); private final MultipartUploadHelper multipartUploadHelper; - private final long contentLength; + private final long totalSize; private final long partSize; - private final int partCount; - private final int numExistingParts; + private final int expectedNumParts; + private final int existingNumParts; private final String uploadId; private final Collection> futures = new ConcurrentLinkedQueue<>(); private final PutObjectRequest putObjectRequest; @@ -77,25 +81,21 @@ public class KnownContentLengthAsyncRequestBodySubscriber implements Subscriber< KnownContentLengthAsyncRequestBodySubscriber(MpuRequestContext mpuRequestContext, CompletableFuture returnFuture, MultipartUploadHelper multipartUploadHelper) { - this.contentLength = mpuRequestContext.contentLength(); + this.totalSize = mpuRequestContext.contentLength(); this.partSize = mpuRequestContext.partSize(); - this.partCount = determinePartCount(contentLength, partSize); + this.expectedNumParts = mpuRequestContext.expectedNumParts(); this.putObjectRequest = mpuRequestContext.request().left(); this.returnFuture = returnFuture; this.uploadId = mpuRequestContext.uploadId(); this.existingParts = mpuRequestContext.existingParts() == null ? new HashMap<>() : mpuRequestContext.existingParts(); - this.numExistingParts = NumericUtils.saturatedCast(mpuRequestContext.numPartsCompleted()); - this.completedParts = new AtomicReferenceArray<>(partCount); + this.existingNumParts = NumericUtils.saturatedCast(mpuRequestContext.numPartsCompleted()); + this.completedParts = new AtomicReferenceArray<>(expectedNumParts); this.multipartUploadHelper = multipartUploadHelper; this.progressListener = putObjectRequest.overrideConfiguration().map(c -> c.executionAttributes() .getAttribute(JAVA_PROGRESS_LISTENER)) .orElseGet(PublisherListener::noOp); } - private int determinePartCount(long contentLength, long partSize) { - return (int) Math.ceil(contentLength / (double) partSize); - } - public S3ResumeToken pause() { isPaused = true; @@ -119,8 +119,8 @@ public S3ResumeToken pause() { return S3ResumeToken.builder() .uploadId(uploadId) .partSize(partSize) - .totalNumParts((long) partCount) - .numPartsCompleted(numPartsCompleted + numExistingParts) + .totalNumParts((long) expectedNumParts) + .numPartsCompleted(numPartsCompleted + existingNumParts) .build(); } @@ -145,21 +145,32 @@ public void onSubscribe(Subscription s) { @Override public void onNext(AsyncRequestBody asyncRequestBody) { - if (isPaused) { + if (isPaused || isDone) { return; } - if (existingParts.containsKey(partNumber.get())) { - partNumber.getAndIncrement(); + int currentPartNum = partNumber.getAndIncrement(); + if (existingParts.containsKey(currentPartNum)) { asyncRequestBody.subscribe(new CancelledSubscriber<>()); subscription.request(1); asyncRequestBody.contentLength().ifPresent(progressListener::subscriberOnNext); return; } + Optional sdkClientException = validatePart(asyncRequestBody, currentPartNum); + if (sdkClientException.isPresent()) { + multipartUploadHelper.failRequestsElegantly(futures, + sdkClientException.get(), + uploadId, + returnFuture, + putObjectRequest); + subscription.cancel(); + return; + } + asyncRequestBodyInFlight.incrementAndGet(); UploadPartRequest uploadRequest = SdkPojoConversionUtils.toUploadPartRequest(putObjectRequest, - partNumber.getAndIncrement(), + currentPartNum, uploadId); Consumer completedPartConsumer = completedPart -> completedParts.set(completedPart.partNumber() - 1, @@ -179,6 +190,39 @@ public void onNext(AsyncRequestBody asyncRequestBody) { subscription.request(1); } + private Optional validatePart(AsyncRequestBody asyncRequestBody, int currentPartNum) { + if (!asyncRequestBody.contentLength().isPresent()) { + return Optional.of(MultipartUploadHelper.contentLengthMissingForPart(currentPartNum)); + } + + Long currentPartSize = asyncRequestBody.contentLength().get(); + + if (currentPartNum > expectedNumParts) { + return Optional.of(partNumMismatch(expectedNumParts, currentPartNum)); + } + + if (currentPartNum == expectedNumParts) { + return validateLastPartSize(currentPartSize); + } + + if (currentPartSize != partSize) { + return Optional.of(contentLengthMismatchForPart(partSize, currentPartSize)); + } + return Optional.empty(); + } + + private Optional validateLastPartSize(Long currentPartSize) { + long remainder = totalSize % partSize; + long expectedLastPartSize = remainder == 0 ? partSize : remainder; + if (currentPartSize != expectedLastPartSize) { + return Optional.of( + SdkClientException.create("Content length of the last part must be equal to the " + + "expected last part size. Expected: " + expectedLastPartSize + + ", Actual: " + currentPartSize)); + } + return Optional.empty(); + } + private boolean shouldFailRequest() { return failureActionInitiated.compareAndSet(false, true) && !isPaused; } @@ -187,6 +231,7 @@ private boolean shouldFailRequest() { public void onError(Throwable t) { log.debug(() -> "Received onError ", t); if (failureActionInitiated.compareAndSet(false, true)) { + isDone = true; multipartUploadHelper.failRequestsElegantly(futures, t, uploadId, returnFuture, putObjectRequest); } } @@ -203,6 +248,7 @@ public void onComplete() { private void completeMultipartUploadIfFinished(int requestsInFlight) { if (isDone && requestsInFlight == 0 && completedMultipartInitiated.compareAndSet(false, true)) { CompletedPart[] parts; + if (existingParts.isEmpty()) { parts = IntStream.range(0, completedParts.length()) @@ -212,15 +258,23 @@ private void completeMultipartUploadIfFinished(int requestsInFlight) { // List of CompletedParts needs to be in ascending order parts = mergeCompletedParts(); } + + int actualNumParts = partNumber.get() - 1; + if (actualNumParts != expectedNumParts) { + SdkClientException exception = partNumMismatch(expectedNumParts, actualNumParts); + multipartUploadHelper.failRequestsElegantly(futures, exception, uploadId, returnFuture, putObjectRequest); + return; + } + completeMpuFuture = multipartUploadHelper.completeMultipartUpload(returnFuture, uploadId, parts, putObjectRequest, - contentLength); + totalSize); } } private CompletedPart[] mergeCompletedParts() { - CompletedPart[] merged = new CompletedPart[partCount]; + CompletedPart[] merged = new CompletedPart[expectedNumParts]; int currPart = 1; - while (currPart < partCount + 1) { + while (currPart < expectedNumParts + 1) { CompletedPart completedPart = existingParts.containsKey(currPart) ? existingParts.get(currPart) : completedParts.get(currPart - 1); merged[currPart - 1] = completedPart; diff --git a/services/s3/src/main/java/software/amazon/awssdk/services/s3/internal/multipart/MpuRequestContext.java b/services/s3/src/main/java/software/amazon/awssdk/services/s3/internal/multipart/MpuRequestContext.java index b9d47f6b6c71..d0965f2ecc0e 100644 --- a/services/s3/src/main/java/software/amazon/awssdk/services/s3/internal/multipart/MpuRequestContext.java +++ b/services/s3/src/main/java/software/amazon/awssdk/services/s3/internal/multipart/MpuRequestContext.java @@ -22,6 +22,7 @@ import software.amazon.awssdk.services.s3.model.CompletedPart; import software.amazon.awssdk.services.s3.model.PutObjectRequest; import software.amazon.awssdk.utils.Pair; +import software.amazon.awssdk.utils.Validate; @SdkInternalApi public class MpuRequestContext { @@ -32,6 +33,7 @@ public class MpuRequestContext { private final Long numPartsCompleted; private final String uploadId; private final Map existingParts; + private final int expectedNumParts; protected MpuRequestContext(Builder builder) { this.request = builder.request; @@ -40,6 +42,8 @@ protected MpuRequestContext(Builder builder) { this.uploadId = builder.uploadId; this.existingParts = builder.existingParts; this.numPartsCompleted = builder.numPartsCompleted; + this.expectedNumParts = Validate.paramNotNull(builder.expectedNumParts, + "expectedNumParts"); } public static Builder builder() { @@ -56,9 +60,13 @@ public boolean equals(Object o) { } MpuRequestContext that = (MpuRequestContext) o; - return Objects.equals(request, that.request) && Objects.equals(contentLength, that.contentLength) - && Objects.equals(partSize, that.partSize) && Objects.equals(numPartsCompleted, that.numPartsCompleted) - && Objects.equals(uploadId, that.uploadId) && Objects.equals(existingParts, that.existingParts); + return expectedNumParts == that.expectedNumParts + && Objects.equals(request, that.request) + && Objects.equals(contentLength, that.contentLength) + && Objects.equals(partSize, that.partSize) + && Objects.equals(numPartsCompleted, that.numPartsCompleted) + && Objects.equals(uploadId, that.uploadId) + && Objects.equals(existingParts, that.existingParts); } @Override @@ -69,6 +77,7 @@ public int hashCode() { result = 31 * result + (contentLength != null ? contentLength.hashCode() : 0); result = 31 * result + (partSize != null ? partSize.hashCode() : 0); result = 31 * result + (numPartsCompleted != null ? numPartsCompleted.hashCode() : 0); + result = 31 * result + expectedNumParts; return result; } @@ -92,6 +101,10 @@ public String uploadId() { return uploadId; } + public int expectedNumParts() { + return expectedNumParts; + } + public Map existingParts() { return existingParts; } @@ -103,6 +116,7 @@ public static final class Builder { private Long numPartsCompleted; private String uploadId; private Map existingParts; + private Integer expectedNumParts; private Builder() { } @@ -137,6 +151,11 @@ public Builder existingParts(Map existingParts) { return this; } + public Builder expectedNumParts(Integer expectedNumParts) { + this.expectedNumParts = expectedNumParts; + return this; + } + public MpuRequestContext build() { return new MpuRequestContext(this); } diff --git a/services/s3/src/main/java/software/amazon/awssdk/services/s3/internal/multipart/MultipartUploadHelper.java b/services/s3/src/main/java/software/amazon/awssdk/services/s3/internal/multipart/MultipartUploadHelper.java index e752a40e6262..d25d5b6fa7fa 100644 --- a/services/s3/src/main/java/software/amazon/awssdk/services/s3/internal/multipart/MultipartUploadHelper.java +++ b/services/s3/src/main/java/software/amazon/awssdk/services/s3/internal/multipart/MultipartUploadHelper.java @@ -25,6 +25,7 @@ import software.amazon.awssdk.annotations.SdkInternalApi; import software.amazon.awssdk.core.async.AsyncRequestBody; import software.amazon.awssdk.core.async.listener.PublisherListener; +import software.amazon.awssdk.core.exception.SdkClientException; import software.amazon.awssdk.services.s3.S3AsyncClient; import software.amazon.awssdk.services.s3.model.CompleteMultipartUploadResponse; import software.amazon.awssdk.services.s3.model.CompletedPart; @@ -47,18 +48,15 @@ public final class MultipartUploadHelper { private static final Logger log = Logger.loggerFor(MultipartUploadHelper.class); private final S3AsyncClient s3AsyncClient; - private final long partSizeInBytes; private final GenericMultipartHelper genericMultipartHelper; private final long maxMemoryUsageInBytes; private final long multipartUploadThresholdInBytes; public MultipartUploadHelper(S3AsyncClient s3AsyncClient, - long partSizeInBytes, long multipartUploadThresholdInBytes, long maxMemoryUsageInBytes) { this.s3AsyncClient = s3AsyncClient; - this.partSizeInBytes = partSizeInBytes; this.genericMultipartHelper = new GenericMultipartHelper<>(s3AsyncClient, SdkPojoConversionUtils::toAbortMultipartUploadRequest, SdkPojoConversionUtils::toPutObjectResponse); @@ -123,11 +121,18 @@ void failRequestsElegantly(Collection> futures, String uploadId, CompletableFuture returnFuture, PutObjectRequest putObjectRequest) { - genericMultipartHelper.handleException(returnFuture, () -> "Failed to send multipart upload requests", t); - if (uploadId != null) { - genericMultipartHelper.cleanUpParts(uploadId, toAbortMultipartUploadRequest(putObjectRequest)); + + try { + genericMultipartHelper.handleException(returnFuture, () -> "Failed to send multipart upload requests", t); + if (uploadId != null) { + genericMultipartHelper.cleanUpParts(uploadId, toAbortMultipartUploadRequest(putObjectRequest)); + } + cancelingOtherOngoingRequests(futures, t); + } catch (Throwable throwable) { + returnFuture.completeExceptionally(SdkClientException.create("Unexpected error occurred while handling the upstream " + + "exception.", throwable)); } - cancelingOtherOngoingRequests(futures, t); + } static void cancelingOtherOngoingRequests(Collection> futures, Throwable t) { @@ -152,4 +157,22 @@ void uploadInOneChunk(PutObjectRequest putObjectRequest, CompletableFutureUtils.forwardExceptionTo(returnFuture, putObjectResponseCompletableFuture); CompletableFutureUtils.forwardResultTo(putObjectResponseCompletableFuture, returnFuture); } + + static SdkClientException contentLengthMissingForPart(int currentPartNum) { + return SdkClientException.create("Content length is missing on the AsyncRequestBody for part number " + currentPartNum); + } + + static SdkClientException contentLengthMismatchForPart(long expected, long actual) { + return SdkClientException.create(String.format("Content length must not be greater than " + + "part size. Expected: %d, Actual: %d", + expected, + actual)); + } + + static SdkClientException partNumMismatch(int expectedNumParts, int actualNumParts) { + return SdkClientException.create(String.format("The number of parts divided is " + + "not equal to the expected number of " + + "parts. Expected: %d, Actual: %d", + expectedNumParts, actualNumParts)); + } } diff --git a/services/s3/src/main/java/software/amazon/awssdk/services/s3/internal/multipart/UploadWithKnownContentLengthHelper.java b/services/s3/src/main/java/software/amazon/awssdk/services/s3/internal/multipart/UploadWithKnownContentLengthHelper.java index cd4ebc4a88b9..04690677c92b 100644 --- a/services/s3/src/main/java/software/amazon/awssdk/services/s3/internal/multipart/UploadWithKnownContentLengthHelper.java +++ b/services/s3/src/main/java/software/amazon/awssdk/services/s3/internal/multipart/UploadWithKnownContentLengthHelper.java @@ -62,7 +62,7 @@ public UploadWithKnownContentLengthHelper(S3AsyncClient s3AsyncClient, SdkPojoConversionUtils::toPutObjectResponse); this.maxMemoryUsageInBytes = maxMemoryUsageInBytes; this.multipartUploadThresholdInBytes = multipartUploadThresholdInBytes; - this.multipartUploadHelper = new MultipartUploadHelper(s3AsyncClient, partSizeInBytes, multipartUploadThresholdInBytes, + this.multipartUploadHelper = new MultipartUploadHelper(s3AsyncClient, multipartUploadThresholdInBytes, maxMemoryUsageInBytes); } @@ -137,6 +137,7 @@ private void uploadFromBeginning(Pair reques .partSize(partSize) .uploadId(uploadId) .numPartsCompleted(numPartsCompleted) + .expectedNumParts(partCount) .build(); splitAndSubscribe(mpuRequestContext, returnFuture); @@ -170,6 +171,7 @@ private void resumePausedUpload(ResumeRequestContext resumeContext) { .partSize(resumeToken.partSize()) .uploadId(uploadId) .existingParts(existingParts) + .expectedNumParts(Math.toIntExact(resumeToken.totalNumParts())) .numPartsCompleted(resumeToken.numPartsCompleted()) .build(); diff --git a/services/s3/src/main/java/software/amazon/awssdk/services/s3/internal/multipart/UploadWithUnknownContentLengthHelper.java b/services/s3/src/main/java/software/amazon/awssdk/services/s3/internal/multipart/UploadWithUnknownContentLengthHelper.java index 745e7a9d3981..520625ad90b0 100644 --- a/services/s3/src/main/java/software/amazon/awssdk/services/s3/internal/multipart/UploadWithUnknownContentLengthHelper.java +++ b/services/s3/src/main/java/software/amazon/awssdk/services/s3/internal/multipart/UploadWithUnknownContentLengthHelper.java @@ -16,6 +16,8 @@ package software.amazon.awssdk.services.s3.internal.multipart; +import static software.amazon.awssdk.services.s3.internal.multipart.MultipartUploadHelper.contentLengthMismatchForPart; +import static software.amazon.awssdk.services.s3.internal.multipart.MultipartUploadHelper.contentLengthMissingForPart; import static software.amazon.awssdk.services.s3.multipart.S3MultipartExecutionAttribute.JAVA_PROGRESS_LISTENER; import java.util.Collection; @@ -71,7 +73,7 @@ public UploadWithUnknownContentLengthHelper(S3AsyncClient s3AsyncClient, SdkPojoConversionUtils::toPutObjectResponse); this.maxMemoryUsageInBytes = maxMemoryUsageInBytes; this.multipartUploadThresholdInBytes = multipartUploadThresholdInBytes; - this.multipartUploadHelper = new MultipartUploadHelper(s3AsyncClient, partSizeInBytes, multipartUploadThresholdInBytes, + this.multipartUploadHelper = new MultipartUploadHelper(s3AsyncClient, multipartUploadThresholdInBytes, maxMemoryUsageInBytes); } @@ -160,10 +162,21 @@ public void onSubscribe(Subscription s) { @Override public void onNext(AsyncRequestBody asyncRequestBody) { + if (isDone) { + return; + } int currentPartNum = partNumber.incrementAndGet(); log.trace(() -> "Received asyncRequestBody " + asyncRequestBody.contentLength()); asyncRequestBodyInFlight.incrementAndGet(); + Optional sdkClientException = validatePart(asyncRequestBody, currentPartNum); + if (sdkClientException.isPresent()) { + multipartUploadHelper.failRequestsElegantly(futures, sdkClientException.get(), uploadId, returnFuture, + putObjectRequest); + subscription.cancel(); + return; + } + if (isFirstAsyncRequestBody.compareAndSet(true, false)) { log.trace(() -> "Received first async request body"); // If this is the first AsyncRequestBody received, request another one because we don't know if there is more @@ -203,15 +216,25 @@ public void onNext(AsyncRequestBody asyncRequestBody) { } } - private void sendUploadPartRequest(String uploadId, - AsyncRequestBody asyncRequestBody, - int currentPartNum) { + private Optional validatePart(AsyncRequestBody asyncRequestBody, int currentPartNum) { Optional contentLength = asyncRequestBody.contentLength(); if (!contentLength.isPresent()) { - SdkClientException e = SdkClientException.create("Content length must be present on the AsyncRequestBody"); - multipartUploadHelper.failRequestsElegantly(futures, e, uploadId, returnFuture, putObjectRequest); + return Optional.of(contentLengthMissingForPart(currentPartNum)); + } + + Long contentLengthCurrentPart = contentLength.get(); + if (contentLengthCurrentPart > partSizeInBytes) { + return Optional.of(contentLengthMismatchForPart(partSizeInBytes, contentLengthCurrentPart)); + } - this.contentLength.getAndAdd(contentLength.get()); + return Optional.empty(); + } + + private void sendUploadPartRequest(String uploadId, + AsyncRequestBody asyncRequestBody, + int currentPartNum) { + Long contentLengthCurrentPart = asyncRequestBody.contentLength().get(); + this.contentLength.getAndAdd(contentLengthCurrentPart); multipartUploadHelper .sendIndividualUploadPartRequest(uploadId, completedParts::add, futures, @@ -235,6 +258,7 @@ private Pair uploadPart(AsyncRequestBody as SdkPojoConversionUtils.toUploadPartRequest(putObjectRequest, partNum, uploadId); + return Pair.of(uploadRequest, asyncRequestBody); } @@ -242,6 +266,7 @@ private Pair uploadPart(AsyncRequestBody as public void onError(Throwable t) { log.debug(() -> "Received onError() ", t); if (failureActionInitiated.compareAndSet(false, true)) { + isDone = true; multipartUploadHelper.failRequestsElegantly(futures, t, uploadId, returnFuture, putObjectRequest); } } @@ -264,8 +289,19 @@ private void completeMultipartUploadIfFinish(int requestsInFlight) { CompletedPart[] parts = completedParts.stream() .sorted(Comparator.comparingInt(CompletedPart::partNumber)) .toArray(CompletedPart[]::new); + + long totalLength = contentLength.get(); + int expectedNumParts = genericMultipartHelper.determinePartCount(totalLength, partSizeInBytes); + if (parts.length != expectedNumParts) { + SdkClientException exception = SdkClientException.create( + String.format("The number of UploadParts requests is not equal to the expected number of parts. " + + "Expected: %d, Actual: %d", expectedNumParts, parts.length)); + multipartUploadHelper.failRequestsElegantly(futures, exception, uploadId, returnFuture, putObjectRequest); + return; + } + multipartUploadHelper.completeMultipartUpload(returnFuture, uploadId, parts, putObjectRequest, - this.contentLength.get()); + totalLength); } } } diff --git a/services/s3/src/test/java/software/amazon/awssdk/services/s3/internal/multipart/KnownContentLengthAsyncRequestBodySubscriberTest.java b/services/s3/src/test/java/software/amazon/awssdk/services/s3/internal/multipart/KnownContentLengthAsyncRequestBodySubscriberTest.java index 68c54bd4b7c1..4faf9d4a04b0 100644 --- a/services/s3/src/test/java/software/amazon/awssdk/services/s3/internal/multipart/KnownContentLengthAsyncRequestBodySubscriberTest.java +++ b/services/s3/src/test/java/software/amazon/awssdk/services/s3/internal/multipart/KnownContentLengthAsyncRequestBodySubscriberTest.java @@ -17,18 +17,32 @@ import static org.assertj.core.api.Assertions.assertThat; import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.anyList; +import static org.mockito.ArgumentMatchers.eq; import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; import java.io.IOException; +import java.util.Collection; +import java.util.HashMap; import java.util.Map; +import java.util.Optional; import java.util.concurrent.CompletableFuture; import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.ConcurrentLinkedQueue; +import java.util.function.Function; +import java.util.stream.Collectors; +import java.util.stream.IntStream; import org.junit.jupiter.api.AfterAll; import org.junit.jupiter.api.BeforeAll; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; +import org.mockito.ArgumentCaptor; +import org.reactivestreams.Subscriber; +import org.reactivestreams.Subscription; import software.amazon.awssdk.core.async.AsyncRequestBody; +import software.amazon.awssdk.core.exception.SdkClientException; import software.amazon.awssdk.services.s3.S3AsyncClient; import software.amazon.awssdk.services.s3.model.CompleteMultipartUploadResponse; import software.amazon.awssdk.services.s3.model.CompletedPart; @@ -46,10 +60,15 @@ public class KnownContentLengthAsyncRequestBodySubscriberTest { private static final int TOTAL_NUM_PARTS = 4; private static final String UPLOAD_ID = "1234"; private static RandomTempFile testFile; + private AsyncRequestBody asyncRequestBody; private PutObjectRequest putObjectRequest; private S3AsyncClient s3AsyncClient; private MultipartUploadHelper multipartUploadHelper; + private CompletableFuture returnFuture; + private KnownContentLengthAsyncRequestBodySubscriber subscriber; + private Collection> futures; + private Subscription subscription; @BeforeAll public static void beforeAll() throws IOException { @@ -67,13 +86,97 @@ public void beforeEach() { multipartUploadHelper = mock(MultipartUploadHelper.class); asyncRequestBody = AsyncRequestBody.fromFile(testFile); putObjectRequest = PutObjectRequest.builder().bucket("bucket").key("key").build(); + + returnFuture = new CompletableFuture<>(); + futures = new ConcurrentLinkedQueue<>(); + subscription = mock(Subscription.class); + + when(multipartUploadHelper.sendIndividualUploadPartRequest(eq(UPLOAD_ID), any(), any(), any(), any())) + .thenReturn(CompletableFuture.completedFuture(CompletedPart.builder().build())); + + subscriber = createSubscriber(createDefaultMpuRequestContext()); + subscriber.onSubscribe(subscription); + } + + @Test + void validatePart_withMissingContentLength_shouldFailRequest() { + subscriber.onNext(createMockAsyncRequestBodyWithEmptyContentLength()); + verifyFailRequestsElegantly("Content length is missing on the AsyncRequestBody"); + } + + @Test + void validatePart_withPartSizeExceedingLimit_shouldFailRequest() { + subscriber.onNext(createMockAsyncRequestBody(PART_SIZE + 1)); + verifyFailRequestsElegantly("Content length must not be greater than part size"); + } + + @Test + void validateLastPartSize_withIncorrectSize_shouldFailRequest() { + long expectedLastPartSize = MPU_CONTENT_SIZE % PART_SIZE; + long incorrectLastPartSize = expectedLastPartSize + 1; + + KnownContentLengthAsyncRequestBodySubscriber lastPartSubscriber = createSubscriber(createDefaultMpuRequestContext()); + lastPartSubscriber.onSubscribe(subscription); + + for (int i = 0; i < TOTAL_NUM_PARTS - 1; i++) { + lastPartSubscriber.onNext(createMockAsyncRequestBody(PART_SIZE)); + } + + lastPartSubscriber.onNext(createMockAsyncRequestBody(incorrectLastPartSize)); + + verifyFailRequestsElegantly("Content length of the last part must be equal to the expected last part size"); + } + + @Test + void validateTotalPartNum_receivedMoreParts_shouldFail() { + long expectedLastPartSize = MPU_CONTENT_SIZE % PART_SIZE; + + KnownContentLengthAsyncRequestBodySubscriber lastPartSubscriber = createSubscriber(createDefaultMpuRequestContext()); + lastPartSubscriber.onSubscribe(subscription); + + for (int i = 0; i < TOTAL_NUM_PARTS - 1; i++) { + AsyncRequestBody regularPart = createMockAsyncRequestBody(PART_SIZE); + when(multipartUploadHelper.sendIndividualUploadPartRequest(eq(UPLOAD_ID), any(), any(), any(), any())) + .thenReturn(CompletableFuture.completedFuture(null)); + lastPartSubscriber.onNext(regularPart); + } + + when(multipartUploadHelper.sendIndividualUploadPartRequest(eq(UPLOAD_ID), any(), any(), any(), any())) + .thenReturn(CompletableFuture.completedFuture(null)); + lastPartSubscriber.onNext(createMockAsyncRequestBody(expectedLastPartSize)); + lastPartSubscriber.onNext(createMockAsyncRequestBody(expectedLastPartSize)); + + verifyFailRequestsElegantly("The number of parts divided is not equal to the expected number of parts"); + } + + @Test + void validateLastPartSize_withCorrectSize_shouldNotFail() { + long expectedLastPartSize = MPU_CONTENT_SIZE % PART_SIZE; + + KnownContentLengthAsyncRequestBodySubscriber subscriber = createSubscriber(createDefaultMpuRequestContext()); + subscriber.onSubscribe(subscription); + + for (int i = 0; i < TOTAL_NUM_PARTS - 1; i++) { + AsyncRequestBody regularPart = createMockAsyncRequestBody(PART_SIZE); + when(multipartUploadHelper.sendIndividualUploadPartRequest(eq(UPLOAD_ID), any(), any(), any(), any())) + .thenReturn(CompletableFuture.completedFuture(null)); + subscriber.onNext(regularPart); + } + + when(multipartUploadHelper.sendIndividualUploadPartRequest(eq(UPLOAD_ID), any(), any(), any(), any())) + .thenReturn(CompletableFuture.completedFuture(null)); + subscriber.onNext(createMockAsyncRequestBody(expectedLastPartSize)); + subscriber.onComplete(); + + assertThat(returnFuture).isNotCompletedExceptionally(); } @Test void pause_withOngoingCompleteMpuFuture_shouldReturnTokenAndCancelFuture() { CompletableFuture completeMpuFuture = new CompletableFuture<>(); int numExistingParts = 2; - S3ResumeToken resumeToken = configureSubscriberAndPause(numExistingParts, completeMpuFuture); + + S3ResumeToken resumeToken = testPauseScenario(numExistingParts, completeMpuFuture); verifyResumeToken(resumeToken, numExistingParts); assertThat(completeMpuFuture).isCancelled(); @@ -84,56 +187,92 @@ void pause_withCompletedCompleteMpuFuture_shouldReturnNullToken() { CompletableFuture completeMpuFuture = CompletableFuture.completedFuture(CompleteMultipartUploadResponse.builder().build()); int numExistingParts = 2; - S3ResumeToken resumeToken = configureSubscriberAndPause(numExistingParts, completeMpuFuture); + + S3ResumeToken resumeToken = testPauseScenario(numExistingParts, completeMpuFuture); assertThat(resumeToken).isNull(); } @Test void pause_withUninitiatedCompleteMpuFuture_shouldReturnToken() { - CompletableFuture completeMpuFuture = null; int numExistingParts = 2; - S3ResumeToken resumeToken = configureSubscriberAndPause(numExistingParts, completeMpuFuture); + + S3ResumeToken resumeToken = testPauseScenario(numExistingParts, null); verifyResumeToken(resumeToken, numExistingParts); } - - private S3ResumeToken configureSubscriberAndPause(int numExistingParts, - CompletableFuture completeMpuFuture) { - Map existingParts = existingParts(numExistingParts); - KnownContentLengthAsyncRequestBodySubscriber subscriber = subscriber(putObjectRequest, asyncRequestBody, existingParts, - new CompletableFuture<>()); + + private S3ResumeToken testPauseScenario(int numExistingParts, + CompletableFuture completeMpuFuture) { + KnownContentLengthAsyncRequestBodySubscriber subscriber = + createSubscriber(createMpuRequestContextWithExistingParts(numExistingParts)); when(multipartUploadHelper.completeMultipartUpload(any(CompletableFuture.class), any(String.class), - any(CompletedPart[].class), any(PutObjectRequest.class), - any(Long.class))) + any(CompletedPart[].class), any(PutObjectRequest.class), + any(Long.class))) .thenReturn(completeMpuFuture); + + simulateOnNextForAllParts(subscriber); subscriber.onComplete(); + assertThat(returnFuture).isNotCompletedExceptionally(); return subscriber.pause(); } - private KnownContentLengthAsyncRequestBodySubscriber subscriber(PutObjectRequest putObjectRequest, - AsyncRequestBody asyncRequestBody, - Map existingParts, - CompletableFuture returnFuture) { + private MpuRequestContext createDefaultMpuRequestContext() { + return MpuRequestContext.builder() + .request(Pair.of(putObjectRequest, AsyncRequestBody.fromFile(testFile))) + .contentLength(MPU_CONTENT_SIZE) + .partSize(PART_SIZE) + .uploadId(UPLOAD_ID) + .numPartsCompleted(0L) + .expectedNumParts(TOTAL_NUM_PARTS) + .build(); + } - MpuRequestContext mpuRequestContext = MpuRequestContext.builder() - .request(Pair.of(putObjectRequest, asyncRequestBody)) - .contentLength(MPU_CONTENT_SIZE) - .partSize(PART_SIZE) - .uploadId(UPLOAD_ID) - .existingParts(existingParts) - .numPartsCompleted((long) existingParts.size()) - .build(); + private MpuRequestContext createMpuRequestContextWithExistingParts(int numExistingParts) { + Map existingParts = createExistingParts(numExistingParts); + return MpuRequestContext.builder() + .request(Pair.of(putObjectRequest, asyncRequestBody)) + .contentLength(MPU_CONTENT_SIZE) + .partSize(PART_SIZE) + .uploadId(UPLOAD_ID) + .existingParts(existingParts) + .expectedNumParts(TOTAL_NUM_PARTS) + .numPartsCompleted((long) existingParts.size()) + .build(); + } + private KnownContentLengthAsyncRequestBodySubscriber createSubscriber(MpuRequestContext mpuRequestContext) { return new KnownContentLengthAsyncRequestBodySubscriber(mpuRequestContext, returnFuture, multipartUploadHelper); } - private Map existingParts(int numExistingParts) { - Map existingParts = new ConcurrentHashMap<>(); - for (int i = 1; i <= numExistingParts; i++) { - existingParts.put(i, CompletedPart.builder().partNumber(i).build()); - } + private AsyncRequestBody createMockAsyncRequestBody(long contentLength) { + AsyncRequestBody mockBody = mock(AsyncRequestBody.class); + when(mockBody.contentLength()).thenReturn(Optional.of(contentLength)); + return mockBody; + } + + private AsyncRequestBody createMockAsyncRequestBodyWithEmptyContentLength() { + AsyncRequestBody mockBody = mock(AsyncRequestBody.class); + when(mockBody.contentLength()).thenReturn(Optional.empty()); + return mockBody; + } + + private void verifyFailRequestsElegantly(String expectedErrorMessage) { + ArgumentCaptor exceptionCaptor = ArgumentCaptor.forClass(Throwable.class); + verify(multipartUploadHelper).failRequestsElegantly(any(), exceptionCaptor.capture(), eq(UPLOAD_ID), eq(returnFuture), eq(putObjectRequest)); + + Throwable exception = exceptionCaptor.getValue(); + assertThat(exception).isInstanceOf(SdkClientException.class); + assertThat(exception.getMessage()).contains(expectedErrorMessage); + verify(subscription).cancel(); + } + + private Map createExistingParts(int numExistingParts) { + Map existingParts = + IntStream.range(0, numExistingParts) + .boxed().collect(Collectors.toMap(Function.identity(), + i -> CompletedPart.builder().partNumber(i).build(), (a, b) -> b)); return existingParts; } @@ -144,4 +283,13 @@ private void verifyResumeToken(S3ResumeToken s3ResumeToken, int numExistingParts assertThat(s3ResumeToken.totalNumParts()).isEqualTo(TOTAL_NUM_PARTS); assertThat(s3ResumeToken.numPartsCompleted()).isEqualTo(numExistingParts); } + + private void simulateOnNextForAllParts(KnownContentLengthAsyncRequestBodySubscriber subscriber) { + subscriber.onSubscribe(subscription); + + for (int i = 0; i < TOTAL_NUM_PARTS; i++) { + subscriber.onNext(createMockAsyncRequestBody(PART_SIZE)); + } + } + } diff --git a/services/s3/src/test/java/software/amazon/awssdk/services/s3/internal/multipart/MpuRequestContextTest.java b/services/s3/src/test/java/software/amazon/awssdk/services/s3/internal/multipart/MpuRequestContextTest.java index c858e7e8e9ec..46b03a90d2cf 100644 --- a/services/s3/src/test/java/software/amazon/awssdk/services/s3/internal/multipart/MpuRequestContextTest.java +++ b/services/s3/src/test/java/software/amazon/awssdk/services/s3/internal/multipart/MpuRequestContextTest.java @@ -34,6 +34,7 @@ public class MpuRequestContextTest { private static final long NUM_PARTS_COMPLETED = 3; private static final String UPLOAD_ID = "55555"; private static final Map EXISTING_PARTS = new ConcurrentHashMap<>(); + public static final int EXPECTED_NUM_PARTS = 10; @Test public void mpuRequestContext_withValues_buildsCorrectly() { @@ -44,6 +45,7 @@ public void mpuRequestContext_withValues_buildsCorrectly() { .uploadId(UPLOAD_ID) .existingParts(EXISTING_PARTS) .numPartsCompleted(NUM_PARTS_COMPLETED) + .expectedNumParts(EXPECTED_NUM_PARTS) .build(); assertThat(mpuRequestContext.request()).isEqualTo(REQUEST); @@ -52,11 +54,14 @@ public void mpuRequestContext_withValues_buildsCorrectly() { assertThat(mpuRequestContext.uploadId()).isEqualTo(UPLOAD_ID); assertThat(mpuRequestContext.existingParts()).isEqualTo(EXISTING_PARTS); assertThat(mpuRequestContext.numPartsCompleted()).isEqualTo(NUM_PARTS_COMPLETED); + assertThat(mpuRequestContext.expectedNumParts()).isEqualTo(EXPECTED_NUM_PARTS); } @Test public void mpuRequestContext_default_buildsCorrectly() { - MpuRequestContext mpuRequestContext = MpuRequestContext.builder().build(); + MpuRequestContext mpuRequestContext = MpuRequestContext.builder() + .expectedNumParts(1) + .build(); assertThat(mpuRequestContext.request()).isNull(); assertThat(mpuRequestContext.contentLength()).isNull(); diff --git a/services/s3/src/test/java/software/amazon/awssdk/services/s3/internal/multipart/S3MultipartClientPutObjectWiremockTest.java b/services/s3/src/test/java/software/amazon/awssdk/services/s3/internal/multipart/S3MultipartClientPutObjectWiremockTest.java index 19bf3988e3ec..90e14dcff2dd 100644 --- a/services/s3/src/test/java/software/amazon/awssdk/services/s3/internal/multipart/S3MultipartClientPutObjectWiremockTest.java +++ b/services/s3/src/test/java/software/amazon/awssdk/services/s3/internal/multipart/S3MultipartClientPutObjectWiremockTest.java @@ -26,25 +26,40 @@ import com.github.tomakehurst.wiremock.junit5.WireMockRuntimeInfo; import com.github.tomakehurst.wiremock.junit5.WireMockTest; +import io.reactivex.rxjava3.core.Flowable; import java.io.InputStream; import java.net.URI; +import java.nio.ByteBuffer; +import java.util.ArrayList; +import java.util.List; +import java.util.Optional; import java.util.concurrent.CancellationException; import java.util.concurrent.CompletableFuture; +import java.util.stream.Stream; +import org.apache.commons.lang3.RandomStringUtils; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Timeout; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.Arguments; +import org.junit.jupiter.params.provider.MethodSource; +import org.reactivestreams.Subscriber; import software.amazon.awssdk.auth.credentials.AwsBasicCredentials; import software.amazon.awssdk.auth.credentials.StaticCredentialsProvider; import software.amazon.awssdk.core.async.AsyncRequestBody; +import software.amazon.awssdk.core.async.AsyncRequestBodySplitConfiguration; import software.amazon.awssdk.core.async.BlockingInputStreamAsyncRequestBody; +import software.amazon.awssdk.core.async.SdkPublisher; +import software.amazon.awssdk.core.exception.SdkClientException; import software.amazon.awssdk.http.crt.AwsCrtAsyncHttpClient; import software.amazon.awssdk.regions.Region; import software.amazon.awssdk.services.s3.S3AsyncClient; import software.amazon.awssdk.services.s3.model.PutObjectResponse; import software.amazon.awssdk.services.s3.model.S3Exception; +import software.amazon.awssdk.utils.async.SimplePublisher; @WireMockTest -@Timeout(10) +@Timeout(100) public class S3MultipartClientPutObjectWiremockTest { private static final String BUCKET = "Example-Bucket"; @@ -56,9 +71,25 @@ public class S3MultipartClientPutObjectWiremockTest { + ""; private S3AsyncClient s3AsyncClient; + public static Stream invalidAsyncRequestBodies() { + return Stream.of( + Arguments.of("knownContentLength_nullPartSize", new TestPublisherWithIncorrectSplitImpl(20L, null), + "Content length is missing on the AsyncRequestBody for part number"), + Arguments.of("unknownContentLength_nullPartSize", new TestPublisherWithIncorrectSplitImpl(null, null), + "Content length is missing on the AsyncRequestBody for part number"), + Arguments.of("knownContentLength_partSizeIncorrect", new TestPublisherWithIncorrectSplitImpl(20L, 11L), + "Content length must not be greater than part size"), + Arguments.of("unknownContentLength_partSizeIncorrect", new TestPublisherWithIncorrectSplitImpl(null, 11L), + "Content length must not be greater than part size"), + Arguments.of("knownContentLength_sendMoreParts", new TestPublisherWithIncorrectSplitImpl(20L, 10L, 3), + "The number of parts divided is not equal to the expected number of parts"), + Arguments.of("knownContentLength_sendFewerParts", new TestPublisherWithIncorrectSplitImpl(20L, 10L, 1), + "The number of parts divided is not equal to the expected number of parts")); + } + @BeforeEach public void setup(WireMockRuntimeInfo wiremock) { - stubPutObjectCalls(); + stubFailedPutObjectCalls(); s3AsyncClient = S3AsyncClient.builder() .region(Region.US_EAST_1) .endpointOverride(URI.create("http://localhost:" + wiremock.getHttpPort())) @@ -66,17 +97,23 @@ public void setup(WireMockRuntimeInfo wiremock) { StaticCredentialsProvider.create(AwsBasicCredentials.create("key", "secret"))) .multipartEnabled(true) .multipartConfiguration(b -> b.minimumPartSizeInBytes(10L).apiCallBufferSizeInBytes(10L)) - .httpClientBuilder(AwsCrtAsyncHttpClient.builder()) + .httpClientBuilder(AwsCrtAsyncHttpClient.builder()) .build(); } - private void stubPutObjectCalls() { + private void stubFailedPutObjectCalls() { stubFor(post(anyUrl()).willReturn(aResponse().withStatus(200).withBody(CREATE_MULTIPART_PAYLOAD))); stubFor(put(anyUrl()).willReturn(aResponse().withStatus(404))); stubFor(put(urlEqualTo("/Example-Bucket/Example-Object?partNumber=1&uploadId=string")).willReturn(aResponse().withStatus(200))); stubFor(delete(anyUrl()).willReturn(aResponse().withStatus(200))); } + private void stubSuccessfulPutObjectCalls() { + stubFor(post(anyUrl()).willReturn(aResponse().withStatus(200).withBody(CREATE_MULTIPART_PAYLOAD))); + stubFor(put(anyUrl()).willReturn(aResponse().withStatus(200))); + } + + // https://github.com/aws/aws-sdk-java-v2/issues/4801 @Test void uploadWithUnknownContentLength_onePartFails_shouldCancelUpstream() { @@ -110,6 +147,19 @@ void uploadWithKnownContentLength_onePartFails_shouldCancelUpstream() { assertThatThrownBy(() -> putObjectResponse.join()).hasRootCauseInstanceOf(S3Exception.class); } + @ParameterizedTest(name = "{index} {0}") + @MethodSource("invalidAsyncRequestBodies") + void uploadWithIncorrectAsyncRequestBodySplit_contentLengthMismatch_shouldThrowException(String description, + TestPublisherWithIncorrectSplitImpl asyncRequestBody, + String errorMsg) { + stubSuccessfulPutObjectCalls(); + CompletableFuture putObjectResponse = s3AsyncClient.putObject( + r -> r.bucket(BUCKET).key(KEY), asyncRequestBody); + + assertThatThrownBy(() -> putObjectResponse.join()).hasMessageContaining(errorMsg) + .hasRootCauseInstanceOf(SdkClientException.class); + } + private InputStream createUnlimitedInputStream() { return new InputStream() { @Override @@ -118,4 +168,65 @@ public int read() { } }; } + + private static class TestPublisherWithIncorrectSplitImpl implements AsyncRequestBody { + private SimplePublisher simplePublisher = new SimplePublisher<>(); + private Long totalSize; + private Long partSize; + private Integer numParts; + + private TestPublisherWithIncorrectSplitImpl(Long totalSize, Long partSize) { + this.totalSize = totalSize; + this.partSize = partSize; + } + + private TestPublisherWithIncorrectSplitImpl(Long totalSize, long partSize, int numParts) { + this.totalSize = totalSize; + this.partSize = partSize; + this.numParts = numParts; + } + + @Override + public Optional contentLength() { + return Optional.ofNullable(totalSize); + } + + @Override + public void subscribe(Subscriber s) { + simplePublisher.subscribe(s); + } + + @Override + public SdkPublisher split(AsyncRequestBodySplitConfiguration splitConfiguration) { + List requestBodies = new ArrayList<>(); + int numAsyncRequestBodies = numParts == null ? 1 : numParts; + for (int i = 0; i < numAsyncRequestBodies; i++) { + requestBodies.add(new TestAsyncRequestBody(partSize)); + } + + return SdkPublisher.adapt(Flowable.fromArray(requestBodies.toArray(new AsyncRequestBody[requestBodies.size()]))); + } + } + + private static class TestAsyncRequestBody implements AsyncRequestBody { + private Long partSize; + private SimplePublisher simplePublisher = new SimplePublisher<>(); + + public TestAsyncRequestBody(Long partSize) { + this.partSize = partSize; + } + + @Override + public Optional contentLength() { + return Optional.ofNullable(partSize); + } + + @Override + public void subscribe(Subscriber s) { + simplePublisher.subscribe(s); + simplePublisher.send(ByteBuffer.wrap( + RandomStringUtils.randomAscii(Math.toIntExact(partSize)).getBytes())); + simplePublisher.complete(); + } + } } diff --git a/services/s3/src/test/java/software/amazon/awssdk/services/s3/internal/multipart/UploadWithUnknownContentLengthHelperTest.java b/services/s3/src/test/java/software/amazon/awssdk/services/s3/internal/multipart/UploadWithUnknownContentLengthHelperTest.java index 7be4ae7c3135..972f0b86241a 100644 --- a/services/s3/src/test/java/software/amazon/awssdk/services/s3/internal/multipart/UploadWithUnknownContentLengthHelperTest.java +++ b/services/s3/src/test/java/software/amazon/awssdk/services/s3/internal/multipart/UploadWithUnknownContentLengthHelperTest.java @@ -16,17 +16,25 @@ package software.amazon.awssdk.services.s3.internal.multipart; import static org.assertj.core.api.Assertions.assertThat; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.mock; import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; import static software.amazon.awssdk.services.s3.internal.multipart.MpuTestUtils.stubSuccessfulCompleteMultipartCall; import static software.amazon.awssdk.services.s3.internal.multipart.MpuTestUtils.stubSuccessfulCreateMultipartCall; import static software.amazon.awssdk.services.s3.internal.multipart.MpuTestUtils.stubSuccessfulUploadPartCalls; +import java.io.ByteArrayInputStream; import java.io.FileInputStream; import java.io.FileNotFoundException; import java.io.IOException; +import java.io.InputStream; import java.util.List; +import java.util.Optional; import java.util.concurrent.CompletableFuture; +import java.util.function.Consumer; import java.util.stream.Collectors; import java.util.stream.IntStream; import org.junit.jupiter.api.AfterAll; @@ -35,24 +43,29 @@ import org.junit.jupiter.api.Test; import org.mockito.ArgumentCaptor; import org.mockito.Mockito; +import org.reactivestreams.Subscriber; +import org.reactivestreams.Subscription; import software.amazon.awssdk.core.async.AsyncRequestBody; import software.amazon.awssdk.core.async.BlockingInputStreamAsyncRequestBody; +import software.amazon.awssdk.core.async.SdkPublisher; +import software.amazon.awssdk.core.exception.SdkClientException; import software.amazon.awssdk.services.s3.S3AsyncClient; import software.amazon.awssdk.services.s3.model.CompleteMultipartUploadRequest; import software.amazon.awssdk.services.s3.model.CompletedPart; +import software.amazon.awssdk.services.s3.model.CreateMultipartUploadRequest; import software.amazon.awssdk.services.s3.model.PutObjectRequest; import software.amazon.awssdk.services.s3.model.PutObjectResponse; import software.amazon.awssdk.services.s3.model.UploadPartRequest; import software.amazon.awssdk.testutils.RandomTempFile; +import software.amazon.awssdk.utils.StringInputStream; public class UploadWithUnknownContentLengthHelperTest { private static final String BUCKET = "bucket"; private static final String KEY = "key"; private static final String UPLOAD_ID = "1234"; - - // Should contain 126 parts private static final long MPU_CONTENT_SIZE = 1005 * 1024; private static final long PART_SIZE = 8 * 1024; + private static final int NUM_TOTAL_PARTS = 126; private UploadWithUnknownContentLengthHelper helper; private S3AsyncClient s3AsyncClient; @@ -81,55 +94,115 @@ void upload_blockingInputStream_shouldInOrder() throws FileNotFoundException { stubSuccessfulCompleteMultipartCall(BUCKET, KEY, s3AsyncClient); BlockingInputStreamAsyncRequestBody body = AsyncRequestBody.forBlockingInputStream(null); - - CompletableFuture future = helper.uploadObject(putObjectRequest(), body); - + CompletableFuture future = helper.uploadObject(createPutObjectRequest(), body); body.writeInputStream(new FileInputStream(testFile)); - future.join(); ArgumentCaptor requestArgumentCaptor = ArgumentCaptor.forClass(UploadPartRequest.class); ArgumentCaptor requestBodyArgumentCaptor = ArgumentCaptor.forClass(AsyncRequestBody.class); - int numTotalParts = 126; - verify(s3AsyncClient, times(numTotalParts)).uploadPart(requestArgumentCaptor.capture(), - requestBodyArgumentCaptor.capture()); + verify(s3AsyncClient, times(NUM_TOTAL_PARTS)).uploadPart(requestArgumentCaptor.capture(), + requestBodyArgumentCaptor.capture()); List actualRequests = requestArgumentCaptor.getAllValues(); List actualRequestBodies = requestBodyArgumentCaptor.getAllValues(); - assertThat(actualRequestBodies).hasSize(numTotalParts); - assertThat(actualRequests).hasSize(numTotalParts); + assertThat(actualRequestBodies).hasSize(NUM_TOTAL_PARTS); + assertThat(actualRequests).hasSize(NUM_TOTAL_PARTS); + + verifyUploadPartRequests(actualRequests, actualRequestBodies); + verifyCompleteMultipartUploadRequest(); + } + + @Test + void uploadObject_withMissingContentLength_shouldFailRequest() { + AsyncRequestBody asyncRequestBody = createMockAsyncRequestBodyWithEmptyContentLength(); + CompletableFuture future = setupAndTriggerUploadFailure(asyncRequestBody); + verifyFailureWithMessage(future, "Content length is missing on the AsyncRequestBody for part number"); + } + + @Test + void uploadObject_withPartSizeExceedingLimit_shouldFailRequest() { + AsyncRequestBody asyncRequestBody = createMockAsyncRequestBody(PART_SIZE + 1); + CompletableFuture future = setupAndTriggerUploadFailure(asyncRequestBody); + verifyFailureWithMessage(future, "Content length must not be greater than part size"); + } + + private PutObjectRequest createPutObjectRequest() { + return PutObjectRequest.builder() + .bucket(BUCKET) + .key(KEY) + .build(); + } + private List createCompletedParts(int totalNumParts) { + return IntStream.range(1, totalNumParts + 1) + .mapToObj(i -> CompletedPart.builder().partNumber(i).build()) + .collect(Collectors.toList()); + } + + private AsyncRequestBody createMockAsyncRequestBody(long contentLength) { + AsyncRequestBody mockBody = mock(AsyncRequestBody.class); + when(mockBody.contentLength()).thenReturn(Optional.of(contentLength)); + return mockBody; + } + + private AsyncRequestBody createMockAsyncRequestBodyWithEmptyContentLength() { + AsyncRequestBody mockBody = mock(AsyncRequestBody.class); + when(mockBody.contentLength()).thenReturn(Optional.empty()); + return mockBody; + } + + private CompletableFuture setupAndTriggerUploadFailure(AsyncRequestBody asyncRequestBody) { + SdkPublisher mockPublisher = mock(SdkPublisher.class); + when(asyncRequestBody.split(any(Consumer.class))).thenReturn(mockPublisher); + + ArgumentCaptor> subscriberCaptor = ArgumentCaptor.forClass(Subscriber.class); + CompletableFuture future = helper.uploadObject(createPutObjectRequest(), asyncRequestBody); + + verify(mockPublisher).subscribe(subscriberCaptor.capture()); + Subscriber subscriber = subscriberCaptor.getValue(); + + Subscription subscription = mock(Subscription.class); + subscriber.onSubscribe(subscription); + subscriber.onNext(asyncRequestBody); + + stubSuccessfulCreateMultipartCall(UPLOAD_ID, s3AsyncClient); + subscriber.onNext(asyncRequestBody); + + return future; + } + + private void verifyFailureWithMessage(CompletableFuture future, String expectedErrorMessage) { + assertThat(future).isCompletedExceptionally(); + future.exceptionally(throwable -> { + assertThat(throwable).isInstanceOf(SdkClientException.class); + assertThat(throwable.getMessage()).contains(expectedErrorMessage); + return null; + }).join(); + } + + private void verifyUploadPartRequests(List actualRequests, + List actualRequestBodies) { for (int i = 0; i < actualRequests.size(); i++) { UploadPartRequest request = actualRequests.get(i); AsyncRequestBody requestBody = actualRequestBodies.get(i); - assertThat(request.partNumber()).isEqualTo( i + 1); + assertThat(request.partNumber()).isEqualTo(i + 1); assertThat(request.bucket()).isEqualTo(BUCKET); assertThat(request.key()).isEqualTo(KEY); if (i == actualRequests.size() - 1) { assertThat(requestBody.contentLength()).hasValue(5120L); - } else{ + } else { assertThat(requestBody.contentLength()).hasValue(PART_SIZE); } } + } - ArgumentCaptor completeMpuArgumentCaptor = ArgumentCaptor.forClass(CompleteMultipartUploadRequest.class); + private void verifyCompleteMultipartUploadRequest() { + ArgumentCaptor completeMpuArgumentCaptor = ArgumentCaptor + .forClass(CompleteMultipartUploadRequest.class); verify(s3AsyncClient).completeMultipartUpload(completeMpuArgumentCaptor.capture()); CompleteMultipartUploadRequest actualRequest = completeMpuArgumentCaptor.getValue(); - assertThat(actualRequest.multipartUpload().parts()).isEqualTo(completedParts(numTotalParts)); - + assertThat(actualRequest.multipartUpload().parts()).isEqualTo(createCompletedParts(NUM_TOTAL_PARTS)); } - - private static PutObjectRequest putObjectRequest() { - return PutObjectRequest.builder() - .bucket(BUCKET) - .key(KEY) - .build(); - } - - private List completedParts(int totalNumParts) { - return IntStream.range(1, totalNumParts + 1).mapToObj(i -> CompletedPart.builder().partNumber(i).build()).collect(Collectors.toList()); - } - } diff --git a/test/ruleset-testing-core/pom.xml b/test/ruleset-testing-core/pom.xml index 2f411c79bfb8..ff67838fc5fc 100644 --- a/test/ruleset-testing-core/pom.xml +++ b/test/ruleset-testing-core/pom.xml @@ -100,7 +100,7 @@ io.reactivex.rxjava3 rxjava - 3.1.5 + compile