diff --git a/.changes/next-release/bugfix-AmazonS3-6522f77.json b/.changes/next-release/bugfix-AmazonS3-6522f77.json
new file mode 100644
index 000000000000..3f4b2cb0c6ff
--- /dev/null
+++ b/.changes/next-release/bugfix-AmazonS3-6522f77.json
@@ -0,0 +1,6 @@
+{
+ "type": "bugfix",
+ "category": "Amazon S3",
+ "contributor": "",
+ "description": "Add additional validations for multipart upload operations in the Java multipart S3 client."
+}
diff --git a/bom-internal/pom.xml b/bom-internal/pom.xml
index 70c0248c0df8..8bceb73e63f9 100644
--- a/bom-internal/pom.xml
+++ b/bom-internal/pom.xml
@@ -235,6 +235,12 @@
${rxjava.version}
test
+
+ io.reactivex.rxjava3
+ rxjava
+ ${rxjava3.version}
+ test
+
commons-lang3
org.apache.commons
diff --git a/pom.xml b/pom.xml
index 43aedc0ee458..ac9378780d2f 100644
--- a/pom.xml
+++ b/pom.xml
@@ -124,6 +124,7 @@
3.10.0
3.5.101
2.2.21
+ 3.1.5
1.17.1
1.37
0.38.1
diff --git a/services/s3/pom.xml b/services/s3/pom.xml
index f0ca2e658780..b8b76dfba495 100644
--- a/services/s3/pom.xml
+++ b/services/s3/pom.xml
@@ -230,5 +230,10 @@
jimfs
test
+
+ io.reactivex.rxjava3
+ rxjava
+ test
+
diff --git a/services/s3/src/main/java/software/amazon/awssdk/services/s3/internal/multipart/KnownContentLengthAsyncRequestBodySubscriber.java b/services/s3/src/main/java/software/amazon/awssdk/services/s3/internal/multipart/KnownContentLengthAsyncRequestBodySubscriber.java
index 7a27e60e31dc..93bc0dfeb6f8 100644
--- a/services/s3/src/main/java/software/amazon/awssdk/services/s3/internal/multipart/KnownContentLengthAsyncRequestBodySubscriber.java
+++ b/services/s3/src/main/java/software/amazon/awssdk/services/s3/internal/multipart/KnownContentLengthAsyncRequestBodySubscriber.java
@@ -15,11 +15,14 @@
package software.amazon.awssdk.services.s3.internal.multipart;
+import static software.amazon.awssdk.services.s3.internal.multipart.MultipartUploadHelper.contentLengthMismatchForPart;
+import static software.amazon.awssdk.services.s3.internal.multipart.MultipartUploadHelper.partNumMismatch;
import static software.amazon.awssdk.services.s3.multipart.S3MultipartExecutionAttribute.JAVA_PROGRESS_LISTENER;
import java.util.Collection;
import java.util.HashMap;
import java.util.Map;
+import java.util.Optional;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ConcurrentLinkedQueue;
import java.util.concurrent.atomic.AtomicBoolean;
@@ -32,6 +35,7 @@
import software.amazon.awssdk.annotations.SdkInternalApi;
import software.amazon.awssdk.core.async.AsyncRequestBody;
import software.amazon.awssdk.core.async.listener.PublisherListener;
+import software.amazon.awssdk.core.exception.SdkClientException;
import software.amazon.awssdk.services.s3.model.CompleteMultipartUploadResponse;
import software.amazon.awssdk.services.s3.model.CompletedPart;
import software.amazon.awssdk.services.s3.model.PutObjectRequest;
@@ -54,10 +58,10 @@ public class KnownContentLengthAsyncRequestBodySubscriber implements Subscriber<
private final AtomicBoolean failureActionInitiated = new AtomicBoolean(false);
private final AtomicInteger partNumber = new AtomicInteger(1);
private final MultipartUploadHelper multipartUploadHelper;
- private final long contentLength;
+ private final long totalSize;
private final long partSize;
- private final int partCount;
- private final int numExistingParts;
+ private final int expectedNumParts;
+ private final int existingNumParts;
private final String uploadId;
private final Collection> futures = new ConcurrentLinkedQueue<>();
private final PutObjectRequest putObjectRequest;
@@ -77,25 +81,21 @@ public class KnownContentLengthAsyncRequestBodySubscriber implements Subscriber<
KnownContentLengthAsyncRequestBodySubscriber(MpuRequestContext mpuRequestContext,
CompletableFuture returnFuture,
MultipartUploadHelper multipartUploadHelper) {
- this.contentLength = mpuRequestContext.contentLength();
+ this.totalSize = mpuRequestContext.contentLength();
this.partSize = mpuRequestContext.partSize();
- this.partCount = determinePartCount(contentLength, partSize);
+ this.expectedNumParts = mpuRequestContext.expectedNumParts();
this.putObjectRequest = mpuRequestContext.request().left();
this.returnFuture = returnFuture;
this.uploadId = mpuRequestContext.uploadId();
this.existingParts = mpuRequestContext.existingParts() == null ? new HashMap<>() : mpuRequestContext.existingParts();
- this.numExistingParts = NumericUtils.saturatedCast(mpuRequestContext.numPartsCompleted());
- this.completedParts = new AtomicReferenceArray<>(partCount);
+ this.existingNumParts = NumericUtils.saturatedCast(mpuRequestContext.numPartsCompleted());
+ this.completedParts = new AtomicReferenceArray<>(expectedNumParts);
this.multipartUploadHelper = multipartUploadHelper;
this.progressListener = putObjectRequest.overrideConfiguration().map(c -> c.executionAttributes()
.getAttribute(JAVA_PROGRESS_LISTENER))
.orElseGet(PublisherListener::noOp);
}
- private int determinePartCount(long contentLength, long partSize) {
- return (int) Math.ceil(contentLength / (double) partSize);
- }
-
public S3ResumeToken pause() {
isPaused = true;
@@ -119,8 +119,8 @@ public S3ResumeToken pause() {
return S3ResumeToken.builder()
.uploadId(uploadId)
.partSize(partSize)
- .totalNumParts((long) partCount)
- .numPartsCompleted(numPartsCompleted + numExistingParts)
+ .totalNumParts((long) expectedNumParts)
+ .numPartsCompleted(numPartsCompleted + existingNumParts)
.build();
}
@@ -145,21 +145,32 @@ public void onSubscribe(Subscription s) {
@Override
public void onNext(AsyncRequestBody asyncRequestBody) {
- if (isPaused) {
+ if (isPaused || isDone) {
return;
}
- if (existingParts.containsKey(partNumber.get())) {
- partNumber.getAndIncrement();
+ int currentPartNum = partNumber.getAndIncrement();
+ if (existingParts.containsKey(currentPartNum)) {
asyncRequestBody.subscribe(new CancelledSubscriber<>());
subscription.request(1);
asyncRequestBody.contentLength().ifPresent(progressListener::subscriberOnNext);
return;
}
+ Optional sdkClientException = validatePart(asyncRequestBody, currentPartNum);
+ if (sdkClientException.isPresent()) {
+ multipartUploadHelper.failRequestsElegantly(futures,
+ sdkClientException.get(),
+ uploadId,
+ returnFuture,
+ putObjectRequest);
+ subscription.cancel();
+ return;
+ }
+
asyncRequestBodyInFlight.incrementAndGet();
UploadPartRequest uploadRequest = SdkPojoConversionUtils.toUploadPartRequest(putObjectRequest,
- partNumber.getAndIncrement(),
+ currentPartNum,
uploadId);
Consumer completedPartConsumer = completedPart -> completedParts.set(completedPart.partNumber() - 1,
@@ -179,6 +190,39 @@ public void onNext(AsyncRequestBody asyncRequestBody) {
subscription.request(1);
}
+ private Optional validatePart(AsyncRequestBody asyncRequestBody, int currentPartNum) {
+ if (!asyncRequestBody.contentLength().isPresent()) {
+ return Optional.of(MultipartUploadHelper.contentLengthMissingForPart(currentPartNum));
+ }
+
+ Long currentPartSize = asyncRequestBody.contentLength().get();
+
+ if (currentPartNum > expectedNumParts) {
+ return Optional.of(partNumMismatch(expectedNumParts, currentPartNum));
+ }
+
+ if (currentPartNum == expectedNumParts) {
+ return validateLastPartSize(currentPartSize);
+ }
+
+ if (currentPartSize != partSize) {
+ return Optional.of(contentLengthMismatchForPart(partSize, currentPartSize));
+ }
+ return Optional.empty();
+ }
+
+ private Optional validateLastPartSize(Long currentPartSize) {
+ long remainder = totalSize % partSize;
+ long expectedLastPartSize = remainder == 0 ? partSize : remainder;
+ if (currentPartSize != expectedLastPartSize) {
+ return Optional.of(
+ SdkClientException.create("Content length of the last part must be equal to the "
+ + "expected last part size. Expected: " + expectedLastPartSize
+ + ", Actual: " + currentPartSize));
+ }
+ return Optional.empty();
+ }
+
private boolean shouldFailRequest() {
return failureActionInitiated.compareAndSet(false, true) && !isPaused;
}
@@ -187,6 +231,7 @@ private boolean shouldFailRequest() {
public void onError(Throwable t) {
log.debug(() -> "Received onError ", t);
if (failureActionInitiated.compareAndSet(false, true)) {
+ isDone = true;
multipartUploadHelper.failRequestsElegantly(futures, t, uploadId, returnFuture, putObjectRequest);
}
}
@@ -203,6 +248,7 @@ public void onComplete() {
private void completeMultipartUploadIfFinished(int requestsInFlight) {
if (isDone && requestsInFlight == 0 && completedMultipartInitiated.compareAndSet(false, true)) {
CompletedPart[] parts;
+
if (existingParts.isEmpty()) {
parts =
IntStream.range(0, completedParts.length())
@@ -212,15 +258,23 @@ private void completeMultipartUploadIfFinished(int requestsInFlight) {
// List of CompletedParts needs to be in ascending order
parts = mergeCompletedParts();
}
+
+ int actualNumParts = partNumber.get() - 1;
+ if (actualNumParts != expectedNumParts) {
+ SdkClientException exception = partNumMismatch(expectedNumParts, actualNumParts);
+ multipartUploadHelper.failRequestsElegantly(futures, exception, uploadId, returnFuture, putObjectRequest);
+ return;
+ }
+
completeMpuFuture = multipartUploadHelper.completeMultipartUpload(returnFuture, uploadId, parts, putObjectRequest,
- contentLength);
+ totalSize);
}
}
private CompletedPart[] mergeCompletedParts() {
- CompletedPart[] merged = new CompletedPart[partCount];
+ CompletedPart[] merged = new CompletedPart[expectedNumParts];
int currPart = 1;
- while (currPart < partCount + 1) {
+ while (currPart < expectedNumParts + 1) {
CompletedPart completedPart = existingParts.containsKey(currPart) ? existingParts.get(currPart) :
completedParts.get(currPart - 1);
merged[currPart - 1] = completedPart;
diff --git a/services/s3/src/main/java/software/amazon/awssdk/services/s3/internal/multipart/MpuRequestContext.java b/services/s3/src/main/java/software/amazon/awssdk/services/s3/internal/multipart/MpuRequestContext.java
index b9d47f6b6c71..d0965f2ecc0e 100644
--- a/services/s3/src/main/java/software/amazon/awssdk/services/s3/internal/multipart/MpuRequestContext.java
+++ b/services/s3/src/main/java/software/amazon/awssdk/services/s3/internal/multipart/MpuRequestContext.java
@@ -22,6 +22,7 @@
import software.amazon.awssdk.services.s3.model.CompletedPart;
import software.amazon.awssdk.services.s3.model.PutObjectRequest;
import software.amazon.awssdk.utils.Pair;
+import software.amazon.awssdk.utils.Validate;
@SdkInternalApi
public class MpuRequestContext {
@@ -32,6 +33,7 @@ public class MpuRequestContext {
private final Long numPartsCompleted;
private final String uploadId;
private final Map existingParts;
+ private final int expectedNumParts;
protected MpuRequestContext(Builder builder) {
this.request = builder.request;
@@ -40,6 +42,8 @@ protected MpuRequestContext(Builder builder) {
this.uploadId = builder.uploadId;
this.existingParts = builder.existingParts;
this.numPartsCompleted = builder.numPartsCompleted;
+ this.expectedNumParts = Validate.paramNotNull(builder.expectedNumParts,
+ "expectedNumParts");
}
public static Builder builder() {
@@ -56,9 +60,13 @@ public boolean equals(Object o) {
}
MpuRequestContext that = (MpuRequestContext) o;
- return Objects.equals(request, that.request) && Objects.equals(contentLength, that.contentLength)
- && Objects.equals(partSize, that.partSize) && Objects.equals(numPartsCompleted, that.numPartsCompleted)
- && Objects.equals(uploadId, that.uploadId) && Objects.equals(existingParts, that.existingParts);
+ return expectedNumParts == that.expectedNumParts
+ && Objects.equals(request, that.request)
+ && Objects.equals(contentLength, that.contentLength)
+ && Objects.equals(partSize, that.partSize)
+ && Objects.equals(numPartsCompleted, that.numPartsCompleted)
+ && Objects.equals(uploadId, that.uploadId)
+ && Objects.equals(existingParts, that.existingParts);
}
@Override
@@ -69,6 +77,7 @@ public int hashCode() {
result = 31 * result + (contentLength != null ? contentLength.hashCode() : 0);
result = 31 * result + (partSize != null ? partSize.hashCode() : 0);
result = 31 * result + (numPartsCompleted != null ? numPartsCompleted.hashCode() : 0);
+ result = 31 * result + expectedNumParts;
return result;
}
@@ -92,6 +101,10 @@ public String uploadId() {
return uploadId;
}
+ public int expectedNumParts() {
+ return expectedNumParts;
+ }
+
public Map existingParts() {
return existingParts;
}
@@ -103,6 +116,7 @@ public static final class Builder {
private Long numPartsCompleted;
private String uploadId;
private Map existingParts;
+ private Integer expectedNumParts;
private Builder() {
}
@@ -137,6 +151,11 @@ public Builder existingParts(Map existingParts) {
return this;
}
+ public Builder expectedNumParts(Integer expectedNumParts) {
+ this.expectedNumParts = expectedNumParts;
+ return this;
+ }
+
public MpuRequestContext build() {
return new MpuRequestContext(this);
}
diff --git a/services/s3/src/main/java/software/amazon/awssdk/services/s3/internal/multipart/MultipartUploadHelper.java b/services/s3/src/main/java/software/amazon/awssdk/services/s3/internal/multipart/MultipartUploadHelper.java
index e752a40e6262..d25d5b6fa7fa 100644
--- a/services/s3/src/main/java/software/amazon/awssdk/services/s3/internal/multipart/MultipartUploadHelper.java
+++ b/services/s3/src/main/java/software/amazon/awssdk/services/s3/internal/multipart/MultipartUploadHelper.java
@@ -25,6 +25,7 @@
import software.amazon.awssdk.annotations.SdkInternalApi;
import software.amazon.awssdk.core.async.AsyncRequestBody;
import software.amazon.awssdk.core.async.listener.PublisherListener;
+import software.amazon.awssdk.core.exception.SdkClientException;
import software.amazon.awssdk.services.s3.S3AsyncClient;
import software.amazon.awssdk.services.s3.model.CompleteMultipartUploadResponse;
import software.amazon.awssdk.services.s3.model.CompletedPart;
@@ -47,18 +48,15 @@ public final class MultipartUploadHelper {
private static final Logger log = Logger.loggerFor(MultipartUploadHelper.class);
private final S3AsyncClient s3AsyncClient;
- private final long partSizeInBytes;
private final GenericMultipartHelper genericMultipartHelper;
private final long maxMemoryUsageInBytes;
private final long multipartUploadThresholdInBytes;
public MultipartUploadHelper(S3AsyncClient s3AsyncClient,
- long partSizeInBytes,
long multipartUploadThresholdInBytes,
long maxMemoryUsageInBytes) {
this.s3AsyncClient = s3AsyncClient;
- this.partSizeInBytes = partSizeInBytes;
this.genericMultipartHelper = new GenericMultipartHelper<>(s3AsyncClient,
SdkPojoConversionUtils::toAbortMultipartUploadRequest,
SdkPojoConversionUtils::toPutObjectResponse);
@@ -123,11 +121,18 @@ void failRequestsElegantly(Collection> futures,
String uploadId,
CompletableFuture returnFuture,
PutObjectRequest putObjectRequest) {
- genericMultipartHelper.handleException(returnFuture, () -> "Failed to send multipart upload requests", t);
- if (uploadId != null) {
- genericMultipartHelper.cleanUpParts(uploadId, toAbortMultipartUploadRequest(putObjectRequest));
+
+ try {
+ genericMultipartHelper.handleException(returnFuture, () -> "Failed to send multipart upload requests", t);
+ if (uploadId != null) {
+ genericMultipartHelper.cleanUpParts(uploadId, toAbortMultipartUploadRequest(putObjectRequest));
+ }
+ cancelingOtherOngoingRequests(futures, t);
+ } catch (Throwable throwable) {
+ returnFuture.completeExceptionally(SdkClientException.create("Unexpected error occurred while handling the upstream "
+ + "exception.", throwable));
}
- cancelingOtherOngoingRequests(futures, t);
+
}
static void cancelingOtherOngoingRequests(Collection> futures, Throwable t) {
@@ -152,4 +157,22 @@ void uploadInOneChunk(PutObjectRequest putObjectRequest,
CompletableFutureUtils.forwardExceptionTo(returnFuture, putObjectResponseCompletableFuture);
CompletableFutureUtils.forwardResultTo(putObjectResponseCompletableFuture, returnFuture);
}
+
+ static SdkClientException contentLengthMissingForPart(int currentPartNum) {
+ return SdkClientException.create("Content length is missing on the AsyncRequestBody for part number " + currentPartNum);
+ }
+
+ static SdkClientException contentLengthMismatchForPart(long expected, long actual) {
+ return SdkClientException.create(String.format("Content length must not be greater than "
+ + "part size. Expected: %d, Actual: %d",
+ expected,
+ actual));
+ }
+
+ static SdkClientException partNumMismatch(int expectedNumParts, int actualNumParts) {
+ return SdkClientException.create(String.format("The number of parts divided is "
+ + "not equal to the expected number of "
+ + "parts. Expected: %d, Actual: %d",
+ expectedNumParts, actualNumParts));
+ }
}
diff --git a/services/s3/src/main/java/software/amazon/awssdk/services/s3/internal/multipart/UploadWithKnownContentLengthHelper.java b/services/s3/src/main/java/software/amazon/awssdk/services/s3/internal/multipart/UploadWithKnownContentLengthHelper.java
index cd4ebc4a88b9..04690677c92b 100644
--- a/services/s3/src/main/java/software/amazon/awssdk/services/s3/internal/multipart/UploadWithKnownContentLengthHelper.java
+++ b/services/s3/src/main/java/software/amazon/awssdk/services/s3/internal/multipart/UploadWithKnownContentLengthHelper.java
@@ -62,7 +62,7 @@ public UploadWithKnownContentLengthHelper(S3AsyncClient s3AsyncClient,
SdkPojoConversionUtils::toPutObjectResponse);
this.maxMemoryUsageInBytes = maxMemoryUsageInBytes;
this.multipartUploadThresholdInBytes = multipartUploadThresholdInBytes;
- this.multipartUploadHelper = new MultipartUploadHelper(s3AsyncClient, partSizeInBytes, multipartUploadThresholdInBytes,
+ this.multipartUploadHelper = new MultipartUploadHelper(s3AsyncClient, multipartUploadThresholdInBytes,
maxMemoryUsageInBytes);
}
@@ -137,6 +137,7 @@ private void uploadFromBeginning(Pair reques
.partSize(partSize)
.uploadId(uploadId)
.numPartsCompleted(numPartsCompleted)
+ .expectedNumParts(partCount)
.build();
splitAndSubscribe(mpuRequestContext, returnFuture);
@@ -170,6 +171,7 @@ private void resumePausedUpload(ResumeRequestContext resumeContext) {
.partSize(resumeToken.partSize())
.uploadId(uploadId)
.existingParts(existingParts)
+ .expectedNumParts(Math.toIntExact(resumeToken.totalNumParts()))
.numPartsCompleted(resumeToken.numPartsCompleted())
.build();
diff --git a/services/s3/src/main/java/software/amazon/awssdk/services/s3/internal/multipart/UploadWithUnknownContentLengthHelper.java b/services/s3/src/main/java/software/amazon/awssdk/services/s3/internal/multipart/UploadWithUnknownContentLengthHelper.java
index 745e7a9d3981..520625ad90b0 100644
--- a/services/s3/src/main/java/software/amazon/awssdk/services/s3/internal/multipart/UploadWithUnknownContentLengthHelper.java
+++ b/services/s3/src/main/java/software/amazon/awssdk/services/s3/internal/multipart/UploadWithUnknownContentLengthHelper.java
@@ -16,6 +16,8 @@
package software.amazon.awssdk.services.s3.internal.multipart;
+import static software.amazon.awssdk.services.s3.internal.multipart.MultipartUploadHelper.contentLengthMismatchForPart;
+import static software.amazon.awssdk.services.s3.internal.multipart.MultipartUploadHelper.contentLengthMissingForPart;
import static software.amazon.awssdk.services.s3.multipart.S3MultipartExecutionAttribute.JAVA_PROGRESS_LISTENER;
import java.util.Collection;
@@ -71,7 +73,7 @@ public UploadWithUnknownContentLengthHelper(S3AsyncClient s3AsyncClient,
SdkPojoConversionUtils::toPutObjectResponse);
this.maxMemoryUsageInBytes = maxMemoryUsageInBytes;
this.multipartUploadThresholdInBytes = multipartUploadThresholdInBytes;
- this.multipartUploadHelper = new MultipartUploadHelper(s3AsyncClient, partSizeInBytes, multipartUploadThresholdInBytes,
+ this.multipartUploadHelper = new MultipartUploadHelper(s3AsyncClient, multipartUploadThresholdInBytes,
maxMemoryUsageInBytes);
}
@@ -160,10 +162,21 @@ public void onSubscribe(Subscription s) {
@Override
public void onNext(AsyncRequestBody asyncRequestBody) {
+ if (isDone) {
+ return;
+ }
int currentPartNum = partNumber.incrementAndGet();
log.trace(() -> "Received asyncRequestBody " + asyncRequestBody.contentLength());
asyncRequestBodyInFlight.incrementAndGet();
+ Optional sdkClientException = validatePart(asyncRequestBody, currentPartNum);
+ if (sdkClientException.isPresent()) {
+ multipartUploadHelper.failRequestsElegantly(futures, sdkClientException.get(), uploadId, returnFuture,
+ putObjectRequest);
+ subscription.cancel();
+ return;
+ }
+
if (isFirstAsyncRequestBody.compareAndSet(true, false)) {
log.trace(() -> "Received first async request body");
// If this is the first AsyncRequestBody received, request another one because we don't know if there is more
@@ -203,15 +216,25 @@ public void onNext(AsyncRequestBody asyncRequestBody) {
}
}
- private void sendUploadPartRequest(String uploadId,
- AsyncRequestBody asyncRequestBody,
- int currentPartNum) {
+ private Optional validatePart(AsyncRequestBody asyncRequestBody, int currentPartNum) {
Optional contentLength = asyncRequestBody.contentLength();
if (!contentLength.isPresent()) {
- SdkClientException e = SdkClientException.create("Content length must be present on the AsyncRequestBody");
- multipartUploadHelper.failRequestsElegantly(futures, e, uploadId, returnFuture, putObjectRequest);
+ return Optional.of(contentLengthMissingForPart(currentPartNum));
+ }
+
+ Long contentLengthCurrentPart = contentLength.get();
+ if (contentLengthCurrentPart > partSizeInBytes) {
+ return Optional.of(contentLengthMismatchForPart(partSizeInBytes, contentLengthCurrentPart));
+
}
- this.contentLength.getAndAdd(contentLength.get());
+ return Optional.empty();
+ }
+
+ private void sendUploadPartRequest(String uploadId,
+ AsyncRequestBody asyncRequestBody,
+ int currentPartNum) {
+ Long contentLengthCurrentPart = asyncRequestBody.contentLength().get();
+ this.contentLength.getAndAdd(contentLengthCurrentPart);
multipartUploadHelper
.sendIndividualUploadPartRequest(uploadId, completedParts::add, futures,
@@ -235,6 +258,7 @@ private Pair uploadPart(AsyncRequestBody as
SdkPojoConversionUtils.toUploadPartRequest(putObjectRequest,
partNum,
uploadId);
+
return Pair.of(uploadRequest, asyncRequestBody);
}
@@ -242,6 +266,7 @@ private Pair uploadPart(AsyncRequestBody as
public void onError(Throwable t) {
log.debug(() -> "Received onError() ", t);
if (failureActionInitiated.compareAndSet(false, true)) {
+ isDone = true;
multipartUploadHelper.failRequestsElegantly(futures, t, uploadId, returnFuture, putObjectRequest);
}
}
@@ -264,8 +289,19 @@ private void completeMultipartUploadIfFinish(int requestsInFlight) {
CompletedPart[] parts = completedParts.stream()
.sorted(Comparator.comparingInt(CompletedPart::partNumber))
.toArray(CompletedPart[]::new);
+
+ long totalLength = contentLength.get();
+ int expectedNumParts = genericMultipartHelper.determinePartCount(totalLength, partSizeInBytes);
+ if (parts.length != expectedNumParts) {
+ SdkClientException exception = SdkClientException.create(
+ String.format("The number of UploadParts requests is not equal to the expected number of parts. "
+ + "Expected: %d, Actual: %d", expectedNumParts, parts.length));
+ multipartUploadHelper.failRequestsElegantly(futures, exception, uploadId, returnFuture, putObjectRequest);
+ return;
+ }
+
multipartUploadHelper.completeMultipartUpload(returnFuture, uploadId, parts, putObjectRequest,
- this.contentLength.get());
+ totalLength);
}
}
}
diff --git a/services/s3/src/test/java/software/amazon/awssdk/services/s3/internal/multipart/KnownContentLengthAsyncRequestBodySubscriberTest.java b/services/s3/src/test/java/software/amazon/awssdk/services/s3/internal/multipart/KnownContentLengthAsyncRequestBodySubscriberTest.java
index 68c54bd4b7c1..4faf9d4a04b0 100644
--- a/services/s3/src/test/java/software/amazon/awssdk/services/s3/internal/multipart/KnownContentLengthAsyncRequestBodySubscriberTest.java
+++ b/services/s3/src/test/java/software/amazon/awssdk/services/s3/internal/multipart/KnownContentLengthAsyncRequestBodySubscriberTest.java
@@ -17,18 +17,32 @@
import static org.assertj.core.api.Assertions.assertThat;
import static org.mockito.ArgumentMatchers.any;
+import static org.mockito.ArgumentMatchers.anyList;
+import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.Mockito.mock;
+import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;
import java.io.IOException;
+import java.util.Collection;
+import java.util.HashMap;
import java.util.Map;
+import java.util.Optional;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ConcurrentHashMap;
+import java.util.concurrent.ConcurrentLinkedQueue;
+import java.util.function.Function;
+import java.util.stream.Collectors;
+import java.util.stream.IntStream;
import org.junit.jupiter.api.AfterAll;
import org.junit.jupiter.api.BeforeAll;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
+import org.mockito.ArgumentCaptor;
+import org.reactivestreams.Subscriber;
+import org.reactivestreams.Subscription;
import software.amazon.awssdk.core.async.AsyncRequestBody;
+import software.amazon.awssdk.core.exception.SdkClientException;
import software.amazon.awssdk.services.s3.S3AsyncClient;
import software.amazon.awssdk.services.s3.model.CompleteMultipartUploadResponse;
import software.amazon.awssdk.services.s3.model.CompletedPart;
@@ -46,10 +60,15 @@ public class KnownContentLengthAsyncRequestBodySubscriberTest {
private static final int TOTAL_NUM_PARTS = 4;
private static final String UPLOAD_ID = "1234";
private static RandomTempFile testFile;
+
private AsyncRequestBody asyncRequestBody;
private PutObjectRequest putObjectRequest;
private S3AsyncClient s3AsyncClient;
private MultipartUploadHelper multipartUploadHelper;
+ private CompletableFuture returnFuture;
+ private KnownContentLengthAsyncRequestBodySubscriber subscriber;
+ private Collection> futures;
+ private Subscription subscription;
@BeforeAll
public static void beforeAll() throws IOException {
@@ -67,13 +86,97 @@ public void beforeEach() {
multipartUploadHelper = mock(MultipartUploadHelper.class);
asyncRequestBody = AsyncRequestBody.fromFile(testFile);
putObjectRequest = PutObjectRequest.builder().bucket("bucket").key("key").build();
+
+ returnFuture = new CompletableFuture<>();
+ futures = new ConcurrentLinkedQueue<>();
+ subscription = mock(Subscription.class);
+
+ when(multipartUploadHelper.sendIndividualUploadPartRequest(eq(UPLOAD_ID), any(), any(), any(), any()))
+ .thenReturn(CompletableFuture.completedFuture(CompletedPart.builder().build()));
+
+ subscriber = createSubscriber(createDefaultMpuRequestContext());
+ subscriber.onSubscribe(subscription);
+ }
+
+ @Test
+ void validatePart_withMissingContentLength_shouldFailRequest() {
+ subscriber.onNext(createMockAsyncRequestBodyWithEmptyContentLength());
+ verifyFailRequestsElegantly("Content length is missing on the AsyncRequestBody");
+ }
+
+ @Test
+ void validatePart_withPartSizeExceedingLimit_shouldFailRequest() {
+ subscriber.onNext(createMockAsyncRequestBody(PART_SIZE + 1));
+ verifyFailRequestsElegantly("Content length must not be greater than part size");
+ }
+
+ @Test
+ void validateLastPartSize_withIncorrectSize_shouldFailRequest() {
+ long expectedLastPartSize = MPU_CONTENT_SIZE % PART_SIZE;
+ long incorrectLastPartSize = expectedLastPartSize + 1;
+
+ KnownContentLengthAsyncRequestBodySubscriber lastPartSubscriber = createSubscriber(createDefaultMpuRequestContext());
+ lastPartSubscriber.onSubscribe(subscription);
+
+ for (int i = 0; i < TOTAL_NUM_PARTS - 1; i++) {
+ lastPartSubscriber.onNext(createMockAsyncRequestBody(PART_SIZE));
+ }
+
+ lastPartSubscriber.onNext(createMockAsyncRequestBody(incorrectLastPartSize));
+
+ verifyFailRequestsElegantly("Content length of the last part must be equal to the expected last part size");
+ }
+
+ @Test
+ void validateTotalPartNum_receivedMoreParts_shouldFail() {
+ long expectedLastPartSize = MPU_CONTENT_SIZE % PART_SIZE;
+
+ KnownContentLengthAsyncRequestBodySubscriber lastPartSubscriber = createSubscriber(createDefaultMpuRequestContext());
+ lastPartSubscriber.onSubscribe(subscription);
+
+ for (int i = 0; i < TOTAL_NUM_PARTS - 1; i++) {
+ AsyncRequestBody regularPart = createMockAsyncRequestBody(PART_SIZE);
+ when(multipartUploadHelper.sendIndividualUploadPartRequest(eq(UPLOAD_ID), any(), any(), any(), any()))
+ .thenReturn(CompletableFuture.completedFuture(null));
+ lastPartSubscriber.onNext(regularPart);
+ }
+
+ when(multipartUploadHelper.sendIndividualUploadPartRequest(eq(UPLOAD_ID), any(), any(), any(), any()))
+ .thenReturn(CompletableFuture.completedFuture(null));
+ lastPartSubscriber.onNext(createMockAsyncRequestBody(expectedLastPartSize));
+ lastPartSubscriber.onNext(createMockAsyncRequestBody(expectedLastPartSize));
+
+ verifyFailRequestsElegantly("The number of parts divided is not equal to the expected number of parts");
+ }
+
+ @Test
+ void validateLastPartSize_withCorrectSize_shouldNotFail() {
+ long expectedLastPartSize = MPU_CONTENT_SIZE % PART_SIZE;
+
+ KnownContentLengthAsyncRequestBodySubscriber subscriber = createSubscriber(createDefaultMpuRequestContext());
+ subscriber.onSubscribe(subscription);
+
+ for (int i = 0; i < TOTAL_NUM_PARTS - 1; i++) {
+ AsyncRequestBody regularPart = createMockAsyncRequestBody(PART_SIZE);
+ when(multipartUploadHelper.sendIndividualUploadPartRequest(eq(UPLOAD_ID), any(), any(), any(), any()))
+ .thenReturn(CompletableFuture.completedFuture(null));
+ subscriber.onNext(regularPart);
+ }
+
+ when(multipartUploadHelper.sendIndividualUploadPartRequest(eq(UPLOAD_ID), any(), any(), any(), any()))
+ .thenReturn(CompletableFuture.completedFuture(null));
+ subscriber.onNext(createMockAsyncRequestBody(expectedLastPartSize));
+ subscriber.onComplete();
+
+ assertThat(returnFuture).isNotCompletedExceptionally();
}
@Test
void pause_withOngoingCompleteMpuFuture_shouldReturnTokenAndCancelFuture() {
CompletableFuture completeMpuFuture = new CompletableFuture<>();
int numExistingParts = 2;
- S3ResumeToken resumeToken = configureSubscriberAndPause(numExistingParts, completeMpuFuture);
+
+ S3ResumeToken resumeToken = testPauseScenario(numExistingParts, completeMpuFuture);
verifyResumeToken(resumeToken, numExistingParts);
assertThat(completeMpuFuture).isCancelled();
@@ -84,56 +187,92 @@ void pause_withCompletedCompleteMpuFuture_shouldReturnNullToken() {
CompletableFuture completeMpuFuture =
CompletableFuture.completedFuture(CompleteMultipartUploadResponse.builder().build());
int numExistingParts = 2;
- S3ResumeToken resumeToken = configureSubscriberAndPause(numExistingParts, completeMpuFuture);
+
+ S3ResumeToken resumeToken = testPauseScenario(numExistingParts, completeMpuFuture);
assertThat(resumeToken).isNull();
}
@Test
void pause_withUninitiatedCompleteMpuFuture_shouldReturnToken() {
- CompletableFuture completeMpuFuture = null;
int numExistingParts = 2;
- S3ResumeToken resumeToken = configureSubscriberAndPause(numExistingParts, completeMpuFuture);
+
+ S3ResumeToken resumeToken = testPauseScenario(numExistingParts, null);
verifyResumeToken(resumeToken, numExistingParts);
}
-
- private S3ResumeToken configureSubscriberAndPause(int numExistingParts,
- CompletableFuture completeMpuFuture) {
- Map existingParts = existingParts(numExistingParts);
- KnownContentLengthAsyncRequestBodySubscriber subscriber = subscriber(putObjectRequest, asyncRequestBody, existingParts,
- new CompletableFuture<>());
+
+ private S3ResumeToken testPauseScenario(int numExistingParts,
+ CompletableFuture completeMpuFuture) {
+ KnownContentLengthAsyncRequestBodySubscriber subscriber =
+ createSubscriber(createMpuRequestContextWithExistingParts(numExistingParts));
when(multipartUploadHelper.completeMultipartUpload(any(CompletableFuture.class), any(String.class),
- any(CompletedPart[].class), any(PutObjectRequest.class),
- any(Long.class)))
+ any(CompletedPart[].class), any(PutObjectRequest.class),
+ any(Long.class)))
.thenReturn(completeMpuFuture);
+
+ simulateOnNextForAllParts(subscriber);
subscriber.onComplete();
+ assertThat(returnFuture).isNotCompletedExceptionally();
return subscriber.pause();
}
- private KnownContentLengthAsyncRequestBodySubscriber subscriber(PutObjectRequest putObjectRequest,
- AsyncRequestBody asyncRequestBody,
- Map existingParts,
- CompletableFuture returnFuture) {
+ private MpuRequestContext createDefaultMpuRequestContext() {
+ return MpuRequestContext.builder()
+ .request(Pair.of(putObjectRequest, AsyncRequestBody.fromFile(testFile)))
+ .contentLength(MPU_CONTENT_SIZE)
+ .partSize(PART_SIZE)
+ .uploadId(UPLOAD_ID)
+ .numPartsCompleted(0L)
+ .expectedNumParts(TOTAL_NUM_PARTS)
+ .build();
+ }
- MpuRequestContext mpuRequestContext = MpuRequestContext.builder()
- .request(Pair.of(putObjectRequest, asyncRequestBody))
- .contentLength(MPU_CONTENT_SIZE)
- .partSize(PART_SIZE)
- .uploadId(UPLOAD_ID)
- .existingParts(existingParts)
- .numPartsCompleted((long) existingParts.size())
- .build();
+ private MpuRequestContext createMpuRequestContextWithExistingParts(int numExistingParts) {
+ Map existingParts = createExistingParts(numExistingParts);
+ return MpuRequestContext.builder()
+ .request(Pair.of(putObjectRequest, asyncRequestBody))
+ .contentLength(MPU_CONTENT_SIZE)
+ .partSize(PART_SIZE)
+ .uploadId(UPLOAD_ID)
+ .existingParts(existingParts)
+ .expectedNumParts(TOTAL_NUM_PARTS)
+ .numPartsCompleted((long) existingParts.size())
+ .build();
+ }
+ private KnownContentLengthAsyncRequestBodySubscriber createSubscriber(MpuRequestContext mpuRequestContext) {
return new KnownContentLengthAsyncRequestBodySubscriber(mpuRequestContext, returnFuture, multipartUploadHelper);
}
- private Map existingParts(int numExistingParts) {
- Map existingParts = new ConcurrentHashMap<>();
- for (int i = 1; i <= numExistingParts; i++) {
- existingParts.put(i, CompletedPart.builder().partNumber(i).build());
- }
+ private AsyncRequestBody createMockAsyncRequestBody(long contentLength) {
+ AsyncRequestBody mockBody = mock(AsyncRequestBody.class);
+ when(mockBody.contentLength()).thenReturn(Optional.of(contentLength));
+ return mockBody;
+ }
+
+ private AsyncRequestBody createMockAsyncRequestBodyWithEmptyContentLength() {
+ AsyncRequestBody mockBody = mock(AsyncRequestBody.class);
+ when(mockBody.contentLength()).thenReturn(Optional.empty());
+ return mockBody;
+ }
+
+ private void verifyFailRequestsElegantly(String expectedErrorMessage) {
+ ArgumentCaptor exceptionCaptor = ArgumentCaptor.forClass(Throwable.class);
+ verify(multipartUploadHelper).failRequestsElegantly(any(), exceptionCaptor.capture(), eq(UPLOAD_ID), eq(returnFuture), eq(putObjectRequest));
+
+ Throwable exception = exceptionCaptor.getValue();
+ assertThat(exception).isInstanceOf(SdkClientException.class);
+ assertThat(exception.getMessage()).contains(expectedErrorMessage);
+ verify(subscription).cancel();
+ }
+
+ private Map createExistingParts(int numExistingParts) {
+ Map existingParts =
+ IntStream.range(0, numExistingParts)
+ .boxed().collect(Collectors.toMap(Function.identity(),
+ i -> CompletedPart.builder().partNumber(i).build(), (a, b) -> b));
return existingParts;
}
@@ -144,4 +283,13 @@ private void verifyResumeToken(S3ResumeToken s3ResumeToken, int numExistingParts
assertThat(s3ResumeToken.totalNumParts()).isEqualTo(TOTAL_NUM_PARTS);
assertThat(s3ResumeToken.numPartsCompleted()).isEqualTo(numExistingParts);
}
+
+ private void simulateOnNextForAllParts(KnownContentLengthAsyncRequestBodySubscriber subscriber) {
+ subscriber.onSubscribe(subscription);
+
+ for (int i = 0; i < TOTAL_NUM_PARTS; i++) {
+ subscriber.onNext(createMockAsyncRequestBody(PART_SIZE));
+ }
+ }
+
}
diff --git a/services/s3/src/test/java/software/amazon/awssdk/services/s3/internal/multipart/MpuRequestContextTest.java b/services/s3/src/test/java/software/amazon/awssdk/services/s3/internal/multipart/MpuRequestContextTest.java
index c858e7e8e9ec..46b03a90d2cf 100644
--- a/services/s3/src/test/java/software/amazon/awssdk/services/s3/internal/multipart/MpuRequestContextTest.java
+++ b/services/s3/src/test/java/software/amazon/awssdk/services/s3/internal/multipart/MpuRequestContextTest.java
@@ -34,6 +34,7 @@ public class MpuRequestContextTest {
private static final long NUM_PARTS_COMPLETED = 3;
private static final String UPLOAD_ID = "55555";
private static final Map EXISTING_PARTS = new ConcurrentHashMap<>();
+ public static final int EXPECTED_NUM_PARTS = 10;
@Test
public void mpuRequestContext_withValues_buildsCorrectly() {
@@ -44,6 +45,7 @@ public void mpuRequestContext_withValues_buildsCorrectly() {
.uploadId(UPLOAD_ID)
.existingParts(EXISTING_PARTS)
.numPartsCompleted(NUM_PARTS_COMPLETED)
+ .expectedNumParts(EXPECTED_NUM_PARTS)
.build();
assertThat(mpuRequestContext.request()).isEqualTo(REQUEST);
@@ -52,11 +54,14 @@ public void mpuRequestContext_withValues_buildsCorrectly() {
assertThat(mpuRequestContext.uploadId()).isEqualTo(UPLOAD_ID);
assertThat(mpuRequestContext.existingParts()).isEqualTo(EXISTING_PARTS);
assertThat(mpuRequestContext.numPartsCompleted()).isEqualTo(NUM_PARTS_COMPLETED);
+ assertThat(mpuRequestContext.expectedNumParts()).isEqualTo(EXPECTED_NUM_PARTS);
}
@Test
public void mpuRequestContext_default_buildsCorrectly() {
- MpuRequestContext mpuRequestContext = MpuRequestContext.builder().build();
+ MpuRequestContext mpuRequestContext = MpuRequestContext.builder()
+ .expectedNumParts(1)
+ .build();
assertThat(mpuRequestContext.request()).isNull();
assertThat(mpuRequestContext.contentLength()).isNull();
diff --git a/services/s3/src/test/java/software/amazon/awssdk/services/s3/internal/multipart/S3MultipartClientPutObjectWiremockTest.java b/services/s3/src/test/java/software/amazon/awssdk/services/s3/internal/multipart/S3MultipartClientPutObjectWiremockTest.java
index 19bf3988e3ec..90e14dcff2dd 100644
--- a/services/s3/src/test/java/software/amazon/awssdk/services/s3/internal/multipart/S3MultipartClientPutObjectWiremockTest.java
+++ b/services/s3/src/test/java/software/amazon/awssdk/services/s3/internal/multipart/S3MultipartClientPutObjectWiremockTest.java
@@ -26,25 +26,40 @@
import com.github.tomakehurst.wiremock.junit5.WireMockRuntimeInfo;
import com.github.tomakehurst.wiremock.junit5.WireMockTest;
+import io.reactivex.rxjava3.core.Flowable;
import java.io.InputStream;
import java.net.URI;
+import java.nio.ByteBuffer;
+import java.util.ArrayList;
+import java.util.List;
+import java.util.Optional;
import java.util.concurrent.CancellationException;
import java.util.concurrent.CompletableFuture;
+import java.util.stream.Stream;
+import org.apache.commons.lang3.RandomStringUtils;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.Timeout;
+import org.junit.jupiter.params.ParameterizedTest;
+import org.junit.jupiter.params.provider.Arguments;
+import org.junit.jupiter.params.provider.MethodSource;
+import org.reactivestreams.Subscriber;
import software.amazon.awssdk.auth.credentials.AwsBasicCredentials;
import software.amazon.awssdk.auth.credentials.StaticCredentialsProvider;
import software.amazon.awssdk.core.async.AsyncRequestBody;
+import software.amazon.awssdk.core.async.AsyncRequestBodySplitConfiguration;
import software.amazon.awssdk.core.async.BlockingInputStreamAsyncRequestBody;
+import software.amazon.awssdk.core.async.SdkPublisher;
+import software.amazon.awssdk.core.exception.SdkClientException;
import software.amazon.awssdk.http.crt.AwsCrtAsyncHttpClient;
import software.amazon.awssdk.regions.Region;
import software.amazon.awssdk.services.s3.S3AsyncClient;
import software.amazon.awssdk.services.s3.model.PutObjectResponse;
import software.amazon.awssdk.services.s3.model.S3Exception;
+import software.amazon.awssdk.utils.async.SimplePublisher;
@WireMockTest
-@Timeout(10)
+@Timeout(100)
public class S3MultipartClientPutObjectWiremockTest {
private static final String BUCKET = "Example-Bucket";
@@ -56,9 +71,25 @@ public class S3MultipartClientPutObjectWiremockTest {
+ "";
private S3AsyncClient s3AsyncClient;
+ public static Stream invalidAsyncRequestBodies() {
+ return Stream.of(
+ Arguments.of("knownContentLength_nullPartSize", new TestPublisherWithIncorrectSplitImpl(20L, null),
+ "Content length is missing on the AsyncRequestBody for part number"),
+ Arguments.of("unknownContentLength_nullPartSize", new TestPublisherWithIncorrectSplitImpl(null, null),
+ "Content length is missing on the AsyncRequestBody for part number"),
+ Arguments.of("knownContentLength_partSizeIncorrect", new TestPublisherWithIncorrectSplitImpl(20L, 11L),
+ "Content length must not be greater than part size"),
+ Arguments.of("unknownContentLength_partSizeIncorrect", new TestPublisherWithIncorrectSplitImpl(null, 11L),
+ "Content length must not be greater than part size"),
+ Arguments.of("knownContentLength_sendMoreParts", new TestPublisherWithIncorrectSplitImpl(20L, 10L, 3),
+ "The number of parts divided is not equal to the expected number of parts"),
+ Arguments.of("knownContentLength_sendFewerParts", new TestPublisherWithIncorrectSplitImpl(20L, 10L, 1),
+ "The number of parts divided is not equal to the expected number of parts"));
+ }
+
@BeforeEach
public void setup(WireMockRuntimeInfo wiremock) {
- stubPutObjectCalls();
+ stubFailedPutObjectCalls();
s3AsyncClient = S3AsyncClient.builder()
.region(Region.US_EAST_1)
.endpointOverride(URI.create("http://localhost:" + wiremock.getHttpPort()))
@@ -66,17 +97,23 @@ public void setup(WireMockRuntimeInfo wiremock) {
StaticCredentialsProvider.create(AwsBasicCredentials.create("key", "secret")))
.multipartEnabled(true)
.multipartConfiguration(b -> b.minimumPartSizeInBytes(10L).apiCallBufferSizeInBytes(10L))
- .httpClientBuilder(AwsCrtAsyncHttpClient.builder())
+ .httpClientBuilder(AwsCrtAsyncHttpClient.builder())
.build();
}
- private void stubPutObjectCalls() {
+ private void stubFailedPutObjectCalls() {
stubFor(post(anyUrl()).willReturn(aResponse().withStatus(200).withBody(CREATE_MULTIPART_PAYLOAD)));
stubFor(put(anyUrl()).willReturn(aResponse().withStatus(404)));
stubFor(put(urlEqualTo("/Example-Bucket/Example-Object?partNumber=1&uploadId=string")).willReturn(aResponse().withStatus(200)));
stubFor(delete(anyUrl()).willReturn(aResponse().withStatus(200)));
}
+ private void stubSuccessfulPutObjectCalls() {
+ stubFor(post(anyUrl()).willReturn(aResponse().withStatus(200).withBody(CREATE_MULTIPART_PAYLOAD)));
+ stubFor(put(anyUrl()).willReturn(aResponse().withStatus(200)));
+ }
+
+
// https://github.com/aws/aws-sdk-java-v2/issues/4801
@Test
void uploadWithUnknownContentLength_onePartFails_shouldCancelUpstream() {
@@ -110,6 +147,19 @@ void uploadWithKnownContentLength_onePartFails_shouldCancelUpstream() {
assertThatThrownBy(() -> putObjectResponse.join()).hasRootCauseInstanceOf(S3Exception.class);
}
+ @ParameterizedTest(name = "{index} {0}")
+ @MethodSource("invalidAsyncRequestBodies")
+ void uploadWithIncorrectAsyncRequestBodySplit_contentLengthMismatch_shouldThrowException(String description,
+ TestPublisherWithIncorrectSplitImpl asyncRequestBody,
+ String errorMsg) {
+ stubSuccessfulPutObjectCalls();
+ CompletableFuture putObjectResponse = s3AsyncClient.putObject(
+ r -> r.bucket(BUCKET).key(KEY), asyncRequestBody);
+
+ assertThatThrownBy(() -> putObjectResponse.join()).hasMessageContaining(errorMsg)
+ .hasRootCauseInstanceOf(SdkClientException.class);
+ }
+
private InputStream createUnlimitedInputStream() {
return new InputStream() {
@Override
@@ -118,4 +168,65 @@ public int read() {
}
};
}
+
+ private static class TestPublisherWithIncorrectSplitImpl implements AsyncRequestBody {
+ private SimplePublisher simplePublisher = new SimplePublisher<>();
+ private Long totalSize;
+ private Long partSize;
+ private Integer numParts;
+
+ private TestPublisherWithIncorrectSplitImpl(Long totalSize, Long partSize) {
+ this.totalSize = totalSize;
+ this.partSize = partSize;
+ }
+
+ private TestPublisherWithIncorrectSplitImpl(Long totalSize, long partSize, int numParts) {
+ this.totalSize = totalSize;
+ this.partSize = partSize;
+ this.numParts = numParts;
+ }
+
+ @Override
+ public Optional contentLength() {
+ return Optional.ofNullable(totalSize);
+ }
+
+ @Override
+ public void subscribe(Subscriber super ByteBuffer> s) {
+ simplePublisher.subscribe(s);
+ }
+
+ @Override
+ public SdkPublisher split(AsyncRequestBodySplitConfiguration splitConfiguration) {
+ List requestBodies = new ArrayList<>();
+ int numAsyncRequestBodies = numParts == null ? 1 : numParts;
+ for (int i = 0; i < numAsyncRequestBodies; i++) {
+ requestBodies.add(new TestAsyncRequestBody(partSize));
+ }
+
+ return SdkPublisher.adapt(Flowable.fromArray(requestBodies.toArray(new AsyncRequestBody[requestBodies.size()])));
+ }
+ }
+
+ private static class TestAsyncRequestBody implements AsyncRequestBody {
+ private Long partSize;
+ private SimplePublisher simplePublisher = new SimplePublisher<>();
+
+ public TestAsyncRequestBody(Long partSize) {
+ this.partSize = partSize;
+ }
+
+ @Override
+ public Optional contentLength() {
+ return Optional.ofNullable(partSize);
+ }
+
+ @Override
+ public void subscribe(Subscriber super ByteBuffer> s) {
+ simplePublisher.subscribe(s);
+ simplePublisher.send(ByteBuffer.wrap(
+ RandomStringUtils.randomAscii(Math.toIntExact(partSize)).getBytes()));
+ simplePublisher.complete();
+ }
+ }
}
diff --git a/services/s3/src/test/java/software/amazon/awssdk/services/s3/internal/multipart/UploadWithUnknownContentLengthHelperTest.java b/services/s3/src/test/java/software/amazon/awssdk/services/s3/internal/multipart/UploadWithUnknownContentLengthHelperTest.java
index 7be4ae7c3135..972f0b86241a 100644
--- a/services/s3/src/test/java/software/amazon/awssdk/services/s3/internal/multipart/UploadWithUnknownContentLengthHelperTest.java
+++ b/services/s3/src/test/java/software/amazon/awssdk/services/s3/internal/multipart/UploadWithUnknownContentLengthHelperTest.java
@@ -16,17 +16,25 @@
package software.amazon.awssdk.services.s3.internal.multipart;
import static org.assertj.core.api.Assertions.assertThat;
+import static org.mockito.ArgumentMatchers.any;
+import static org.mockito.ArgumentMatchers.eq;
+import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;
+import static org.mockito.Mockito.when;
import static software.amazon.awssdk.services.s3.internal.multipart.MpuTestUtils.stubSuccessfulCompleteMultipartCall;
import static software.amazon.awssdk.services.s3.internal.multipart.MpuTestUtils.stubSuccessfulCreateMultipartCall;
import static software.amazon.awssdk.services.s3.internal.multipart.MpuTestUtils.stubSuccessfulUploadPartCalls;
+import java.io.ByteArrayInputStream;
import java.io.FileInputStream;
import java.io.FileNotFoundException;
import java.io.IOException;
+import java.io.InputStream;
import java.util.List;
+import java.util.Optional;
import java.util.concurrent.CompletableFuture;
+import java.util.function.Consumer;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
import org.junit.jupiter.api.AfterAll;
@@ -35,24 +43,29 @@
import org.junit.jupiter.api.Test;
import org.mockito.ArgumentCaptor;
import org.mockito.Mockito;
+import org.reactivestreams.Subscriber;
+import org.reactivestreams.Subscription;
import software.amazon.awssdk.core.async.AsyncRequestBody;
import software.amazon.awssdk.core.async.BlockingInputStreamAsyncRequestBody;
+import software.amazon.awssdk.core.async.SdkPublisher;
+import software.amazon.awssdk.core.exception.SdkClientException;
import software.amazon.awssdk.services.s3.S3AsyncClient;
import software.amazon.awssdk.services.s3.model.CompleteMultipartUploadRequest;
import software.amazon.awssdk.services.s3.model.CompletedPart;
+import software.amazon.awssdk.services.s3.model.CreateMultipartUploadRequest;
import software.amazon.awssdk.services.s3.model.PutObjectRequest;
import software.amazon.awssdk.services.s3.model.PutObjectResponse;
import software.amazon.awssdk.services.s3.model.UploadPartRequest;
import software.amazon.awssdk.testutils.RandomTempFile;
+import software.amazon.awssdk.utils.StringInputStream;
public class UploadWithUnknownContentLengthHelperTest {
private static final String BUCKET = "bucket";
private static final String KEY = "key";
private static final String UPLOAD_ID = "1234";
-
- // Should contain 126 parts
private static final long MPU_CONTENT_SIZE = 1005 * 1024;
private static final long PART_SIZE = 8 * 1024;
+ private static final int NUM_TOTAL_PARTS = 126;
private UploadWithUnknownContentLengthHelper helper;
private S3AsyncClient s3AsyncClient;
@@ -81,55 +94,115 @@ void upload_blockingInputStream_shouldInOrder() throws FileNotFoundException {
stubSuccessfulCompleteMultipartCall(BUCKET, KEY, s3AsyncClient);
BlockingInputStreamAsyncRequestBody body = AsyncRequestBody.forBlockingInputStream(null);
-
- CompletableFuture future = helper.uploadObject(putObjectRequest(), body);
-
+ CompletableFuture future = helper.uploadObject(createPutObjectRequest(), body);
body.writeInputStream(new FileInputStream(testFile));
-
future.join();
ArgumentCaptor requestArgumentCaptor = ArgumentCaptor.forClass(UploadPartRequest.class);
ArgumentCaptor requestBodyArgumentCaptor = ArgumentCaptor.forClass(AsyncRequestBody.class);
- int numTotalParts = 126;
- verify(s3AsyncClient, times(numTotalParts)).uploadPart(requestArgumentCaptor.capture(),
- requestBodyArgumentCaptor.capture());
+ verify(s3AsyncClient, times(NUM_TOTAL_PARTS)).uploadPart(requestArgumentCaptor.capture(),
+ requestBodyArgumentCaptor.capture());
List actualRequests = requestArgumentCaptor.getAllValues();
List actualRequestBodies = requestBodyArgumentCaptor.getAllValues();
- assertThat(actualRequestBodies).hasSize(numTotalParts);
- assertThat(actualRequests).hasSize(numTotalParts);
+ assertThat(actualRequestBodies).hasSize(NUM_TOTAL_PARTS);
+ assertThat(actualRequests).hasSize(NUM_TOTAL_PARTS);
+
+ verifyUploadPartRequests(actualRequests, actualRequestBodies);
+ verifyCompleteMultipartUploadRequest();
+ }
+
+ @Test
+ void uploadObject_withMissingContentLength_shouldFailRequest() {
+ AsyncRequestBody asyncRequestBody = createMockAsyncRequestBodyWithEmptyContentLength();
+ CompletableFuture future = setupAndTriggerUploadFailure(asyncRequestBody);
+ verifyFailureWithMessage(future, "Content length is missing on the AsyncRequestBody for part number");
+ }
+
+ @Test
+ void uploadObject_withPartSizeExceedingLimit_shouldFailRequest() {
+ AsyncRequestBody asyncRequestBody = createMockAsyncRequestBody(PART_SIZE + 1);
+ CompletableFuture future = setupAndTriggerUploadFailure(asyncRequestBody);
+ verifyFailureWithMessage(future, "Content length must not be greater than part size");
+ }
+
+ private PutObjectRequest createPutObjectRequest() {
+ return PutObjectRequest.builder()
+ .bucket(BUCKET)
+ .key(KEY)
+ .build();
+ }
+ private List createCompletedParts(int totalNumParts) {
+ return IntStream.range(1, totalNumParts + 1)
+ .mapToObj(i -> CompletedPart.builder().partNumber(i).build())
+ .collect(Collectors.toList());
+ }
+
+ private AsyncRequestBody createMockAsyncRequestBody(long contentLength) {
+ AsyncRequestBody mockBody = mock(AsyncRequestBody.class);
+ when(mockBody.contentLength()).thenReturn(Optional.of(contentLength));
+ return mockBody;
+ }
+
+ private AsyncRequestBody createMockAsyncRequestBodyWithEmptyContentLength() {
+ AsyncRequestBody mockBody = mock(AsyncRequestBody.class);
+ when(mockBody.contentLength()).thenReturn(Optional.empty());
+ return mockBody;
+ }
+
+ private CompletableFuture setupAndTriggerUploadFailure(AsyncRequestBody asyncRequestBody) {
+ SdkPublisher mockPublisher = mock(SdkPublisher.class);
+ when(asyncRequestBody.split(any(Consumer.class))).thenReturn(mockPublisher);
+
+ ArgumentCaptor> subscriberCaptor = ArgumentCaptor.forClass(Subscriber.class);
+ CompletableFuture future = helper.uploadObject(createPutObjectRequest(), asyncRequestBody);
+
+ verify(mockPublisher).subscribe(subscriberCaptor.capture());
+ Subscriber subscriber = subscriberCaptor.getValue();
+
+ Subscription subscription = mock(Subscription.class);
+ subscriber.onSubscribe(subscription);
+ subscriber.onNext(asyncRequestBody);
+
+ stubSuccessfulCreateMultipartCall(UPLOAD_ID, s3AsyncClient);
+ subscriber.onNext(asyncRequestBody);
+
+ return future;
+ }
+
+ private void verifyFailureWithMessage(CompletableFuture future, String expectedErrorMessage) {
+ assertThat(future).isCompletedExceptionally();
+ future.exceptionally(throwable -> {
+ assertThat(throwable).isInstanceOf(SdkClientException.class);
+ assertThat(throwable.getMessage()).contains(expectedErrorMessage);
+ return null;
+ }).join();
+ }
+
+ private void verifyUploadPartRequests(List actualRequests,
+ List actualRequestBodies) {
for (int i = 0; i < actualRequests.size(); i++) {
UploadPartRequest request = actualRequests.get(i);
AsyncRequestBody requestBody = actualRequestBodies.get(i);
- assertThat(request.partNumber()).isEqualTo( i + 1);
+ assertThat(request.partNumber()).isEqualTo(i + 1);
assertThat(request.bucket()).isEqualTo(BUCKET);
assertThat(request.key()).isEqualTo(KEY);
if (i == actualRequests.size() - 1) {
assertThat(requestBody.contentLength()).hasValue(5120L);
- } else{
+ } else {
assertThat(requestBody.contentLength()).hasValue(PART_SIZE);
}
}
+ }
- ArgumentCaptor completeMpuArgumentCaptor = ArgumentCaptor.forClass(CompleteMultipartUploadRequest.class);
+ private void verifyCompleteMultipartUploadRequest() {
+ ArgumentCaptor completeMpuArgumentCaptor = ArgumentCaptor
+ .forClass(CompleteMultipartUploadRequest.class);
verify(s3AsyncClient).completeMultipartUpload(completeMpuArgumentCaptor.capture());
CompleteMultipartUploadRequest actualRequest = completeMpuArgumentCaptor.getValue();
- assertThat(actualRequest.multipartUpload().parts()).isEqualTo(completedParts(numTotalParts));
-
+ assertThat(actualRequest.multipartUpload().parts()).isEqualTo(createCompletedParts(NUM_TOTAL_PARTS));
}
-
- private static PutObjectRequest putObjectRequest() {
- return PutObjectRequest.builder()
- .bucket(BUCKET)
- .key(KEY)
- .build();
- }
-
- private List completedParts(int totalNumParts) {
- return IntStream.range(1, totalNumParts + 1).mapToObj(i -> CompletedPart.builder().partNumber(i).build()).collect(Collectors.toList());
- }
-
}
diff --git a/test/ruleset-testing-core/pom.xml b/test/ruleset-testing-core/pom.xml
index 2f411c79bfb8..ff67838fc5fc 100644
--- a/test/ruleset-testing-core/pom.xml
+++ b/test/ruleset-testing-core/pom.xml
@@ -100,7 +100,7 @@
io.reactivex.rxjava3
rxjava
- 3.1.5
+ compile