Skip to content

Add validations for upload in s3 mulitpart client #6282

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Jul 31, 2025
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions .changes/next-release/bugfix-AmazonS3-6522f77.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
{
"type": "bugfix",
"category": "Amazon S3",
"contributor": "",
"description": "Add additional validations for multipart upload operations in the Java multipart S3 client."
}
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
import software.amazon.awssdk.annotations.SdkInternalApi;
import software.amazon.awssdk.core.async.AsyncRequestBody;
import software.amazon.awssdk.core.async.listener.PublisherListener;
import software.amazon.awssdk.core.exception.SdkClientException;
import software.amazon.awssdk.services.s3.model.CompleteMultipartUploadResponse;
import software.amazon.awssdk.services.s3.model.CompletedPart;
import software.amazon.awssdk.services.s3.model.PutObjectRequest;
Expand All @@ -54,10 +55,10 @@ public class KnownContentLengthAsyncRequestBodySubscriber implements Subscriber<
private final AtomicBoolean failureActionInitiated = new AtomicBoolean(false);
private final AtomicInteger partNumber = new AtomicInteger(1);
private final MultipartUploadHelper multipartUploadHelper;
private final long contentLength;
private final long totalSize;
private final long partSize;
private final int partCount;
private final int numExistingParts;
private final int expectedNumParts;
private final int existingNumParts;
private final String uploadId;
private final Collection<CompletableFuture<CompletedPart>> futures = new ConcurrentLinkedQueue<>();
private final PutObjectRequest putObjectRequest;
Expand All @@ -77,25 +78,21 @@ public class KnownContentLengthAsyncRequestBodySubscriber implements Subscriber<
KnownContentLengthAsyncRequestBodySubscriber(MpuRequestContext mpuRequestContext,
CompletableFuture<PutObjectResponse> returnFuture,
MultipartUploadHelper multipartUploadHelper) {
this.contentLength = mpuRequestContext.contentLength();
this.totalSize = mpuRequestContext.contentLength();
this.partSize = mpuRequestContext.partSize();
this.partCount = determinePartCount(contentLength, partSize);
this.expectedNumParts = mpuRequestContext.expectedNumParts();
this.putObjectRequest = mpuRequestContext.request().left();
this.returnFuture = returnFuture;
this.uploadId = mpuRequestContext.uploadId();
this.existingParts = mpuRequestContext.existingParts() == null ? new HashMap<>() : mpuRequestContext.existingParts();
this.numExistingParts = NumericUtils.saturatedCast(mpuRequestContext.numPartsCompleted());
this.completedParts = new AtomicReferenceArray<>(partCount);
this.existingNumParts = NumericUtils.saturatedCast(mpuRequestContext.numPartsCompleted());
this.completedParts = new AtomicReferenceArray<>(expectedNumParts);
this.multipartUploadHelper = multipartUploadHelper;
this.progressListener = putObjectRequest.overrideConfiguration().map(c -> c.executionAttributes()
.getAttribute(JAVA_PROGRESS_LISTENER))
.orElseGet(PublisherListener::noOp);
}

private int determinePartCount(long contentLength, long partSize) {
return (int) Math.ceil(contentLength / (double) partSize);
}

public S3ResumeToken pause() {
isPaused = true;

Expand All @@ -119,8 +116,8 @@ public S3ResumeToken pause() {
return S3ResumeToken.builder()
.uploadId(uploadId)
.partSize(partSize)
.totalNumParts((long) partCount)
.numPartsCompleted(numPartsCompleted + numExistingParts)
.totalNumParts((long) expectedNumParts)
.numPartsCompleted(numPartsCompleted + existingNumParts)
.build();
}

Expand All @@ -145,21 +142,23 @@ public void onSubscribe(Subscription s) {

@Override
public void onNext(AsyncRequestBody asyncRequestBody) {
if (isPaused) {
if (isPaused || isDone) {
return;
}

if (existingParts.containsKey(partNumber.get())) {
partNumber.getAndIncrement();
int currentPartNum = partNumber.getAndIncrement();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you please help me to understand why earlier we used to do contains on get, now we first increment and then do containsKey check.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't know why we did that earlier, but the reason I changed is to avoid another atomic integer get call (micro perf optimization)

if (existingParts.containsKey(currentPartNum)) {
asyncRequestBody.subscribe(new CancelledSubscriber<>());
subscription.request(1);
asyncRequestBody.contentLength().ifPresent(progressListener::subscriberOnNext);
return;
}

validatePart(asyncRequestBody, currentPartNum);

asyncRequestBodyInFlight.incrementAndGet();
UploadPartRequest uploadRequest = SdkPojoConversionUtils.toUploadPartRequest(putObjectRequest,
partNumber.getAndIncrement(),
currentPartNum,
uploadId);

Consumer<CompletedPart> completedPartConsumer = completedPart -> completedParts.set(completedPart.partNumber() - 1,
Expand All @@ -179,6 +178,49 @@ public void onNext(AsyncRequestBody asyncRequestBody) {
subscription.request(1);
}

private void validatePart(AsyncRequestBody asyncRequestBody, int currentPartNum) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see we already have JUnit tests for each class. However, I was wondering how we end up with the validation failures below from an external API perspective, such as when users pass invalid AsyncRequestBody or files get corrupted in transit. Is it possible to write end-to-end test cases for these scenarios so that we can test them with UnknownContentLength publisher or with S3CrtClient?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yup, good point, added

if (!asyncRequestBody.contentLength().isPresent()) {
SdkClientException e = SdkClientException.create("Content length must be present on the AsyncRequestBody");
multipartUploadHelper.failRequestsElegantly(futures, e, uploadId, returnFuture, putObjectRequest);
return;
}

Long currentPartSize = asyncRequestBody.contentLength().get();
if (currentPartNum > expectedNumParts) {
SdkClientException exception = SdkClientException.create(String.format("The number of parts divided is "
+ "not equal to the expected number of "
+ "parts. Expected: %d, Actual: %d",
expectedNumParts, currentPartNum));
multipartUploadHelper.failRequestsElegantly(futures, exception, uploadId, returnFuture, putObjectRequest);
return;
}

if (currentPartNum == expectedNumParts) {
validateLastPartSize(currentPartSize);
return;
}

if (currentPartSize != partSize) {
SdkClientException e = SdkClientException.create(String.format("Content length must be equal to the "
+ "part size. Expected: %d, Actual: %d",
partSize,
currentPartSize));
multipartUploadHelper.failRequestsElegantly(futures, e, uploadId, returnFuture, putObjectRequest);
}
}

private void validateLastPartSize(Long currentPartSize) {
long remainder = totalSize % partSize;
long expectedLastPartSize = remainder == 0 ? partSize : remainder;
if (currentPartSize != expectedLastPartSize) {
SdkClientException exception =
SdkClientException.create("Content length of the last part must be equal to the "
+ "expected last part size. Expected: " + expectedLastPartSize
+ ", Actual: " + currentPartSize);
multipartUploadHelper.failRequestsElegantly(futures, exception, uploadId, returnFuture, putObjectRequest);
}
}

private boolean shouldFailRequest() {
return failureActionInitiated.compareAndSet(false, true) && !isPaused;
}
Expand All @@ -187,6 +229,7 @@ private boolean shouldFailRequest() {
public void onError(Throwable t) {
log.debug(() -> "Received onError ", t);
if (failureActionInitiated.compareAndSet(false, true)) {
isDone = true;
multipartUploadHelper.failRequestsElegantly(futures, t, uploadId, returnFuture, putObjectRequest);
}
}
Expand All @@ -203,6 +246,7 @@ public void onComplete() {
private void completeMultipartUploadIfFinished(int requestsInFlight) {
if (isDone && requestsInFlight == 0 && completedMultipartInitiated.compareAndSet(false, true)) {
CompletedPart[] parts;

if (existingParts.isEmpty()) {
parts =
IntStream.range(0, completedParts.length())
Expand All @@ -213,14 +257,14 @@ private void completeMultipartUploadIfFinished(int requestsInFlight) {
parts = mergeCompletedParts();
}
completeMpuFuture = multipartUploadHelper.completeMultipartUpload(returnFuture, uploadId, parts, putObjectRequest,
contentLength);
totalSize);
}
}

private CompletedPart[] mergeCompletedParts() {
CompletedPart[] merged = new CompletedPart[partCount];
CompletedPart[] merged = new CompletedPart[expectedNumParts];
int currPart = 1;
while (currPart < partCount + 1) {
while (currPart < expectedNumParts + 1) {
CompletedPart completedPart = existingParts.containsKey(currPart) ? existingParts.get(currPart) :
completedParts.get(currPart - 1);
merged[currPart - 1] = completedPart;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import software.amazon.awssdk.services.s3.model.CompletedPart;
import software.amazon.awssdk.services.s3.model.PutObjectRequest;
import software.amazon.awssdk.utils.Pair;
import software.amazon.awssdk.utils.Validate;

@SdkInternalApi
public class MpuRequestContext {
Expand All @@ -32,6 +33,7 @@ public class MpuRequestContext {
private final Long numPartsCompleted;
private final String uploadId;
private final Map<Integer, CompletedPart> existingParts;
private final int expectedNumParts;

protected MpuRequestContext(Builder builder) {
this.request = builder.request;
Expand All @@ -40,6 +42,8 @@ protected MpuRequestContext(Builder builder) {
this.uploadId = builder.uploadId;
this.existingParts = builder.existingParts;
this.numPartsCompleted = builder.numPartsCompleted;
this.expectedNumParts = Validate.paramNotNull(builder.expectedNumParts,
"expectedNumParts");
}

public static Builder builder() {
Expand All @@ -56,9 +60,13 @@ public boolean equals(Object o) {
}
MpuRequestContext that = (MpuRequestContext) o;

return Objects.equals(request, that.request) && Objects.equals(contentLength, that.contentLength)
&& Objects.equals(partSize, that.partSize) && Objects.equals(numPartsCompleted, that.numPartsCompleted)
&& Objects.equals(uploadId, that.uploadId) && Objects.equals(existingParts, that.existingParts);
return expectedNumParts == that.expectedNumParts
&& Objects.equals(request, that.request)
&& Objects.equals(contentLength, that.contentLength)
&& Objects.equals(partSize, that.partSize)
&& Objects.equals(numPartsCompleted, that.numPartsCompleted)
&& Objects.equals(uploadId, that.uploadId)
&& Objects.equals(existingParts, that.existingParts);
}

@Override
Expand All @@ -69,6 +77,7 @@ public int hashCode() {
result = 31 * result + (contentLength != null ? contentLength.hashCode() : 0);
result = 31 * result + (partSize != null ? partSize.hashCode() : 0);
result = 31 * result + (numPartsCompleted != null ? numPartsCompleted.hashCode() : 0);
result = 31 * result + expectedNumParts;
return result;
}

Expand All @@ -92,6 +101,10 @@ public String uploadId() {
return uploadId;
}

public int expectedNumParts() {
return expectedNumParts;
}

public Map<Integer, CompletedPart> existingParts() {
return existingParts;
}
Expand All @@ -103,6 +116,7 @@ public static final class Builder {
private Long numPartsCompleted;
private String uploadId;
private Map<Integer, CompletedPart> existingParts;
private Integer expectedNumParts;

private Builder() {
}
Expand Down Expand Up @@ -137,6 +151,11 @@ public Builder existingParts(Map<Integer, CompletedPart> existingParts) {
return this;
}

public Builder expectedNumParts(Integer expectedNumParts) {
this.expectedNumParts = expectedNumParts;
return this;
}

public MpuRequestContext build() {
return new MpuRequestContext(this);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
import software.amazon.awssdk.annotations.SdkInternalApi;
import software.amazon.awssdk.core.async.AsyncRequestBody;
import software.amazon.awssdk.core.async.listener.PublisherListener;
import software.amazon.awssdk.core.exception.SdkClientException;
import software.amazon.awssdk.services.s3.S3AsyncClient;
import software.amazon.awssdk.services.s3.model.CompleteMultipartUploadResponse;
import software.amazon.awssdk.services.s3.model.CompletedPart;
Expand Down Expand Up @@ -123,11 +124,18 @@ void failRequestsElegantly(Collection<CompletableFuture<CompletedPart>> futures,
String uploadId,
CompletableFuture<PutObjectResponse> returnFuture,
PutObjectRequest putObjectRequest) {
genericMultipartHelper.handleException(returnFuture, () -> "Failed to send multipart upload requests", t);
if (uploadId != null) {
genericMultipartHelper.cleanUpParts(uploadId, toAbortMultipartUploadRequest(putObjectRequest));

try {
genericMultipartHelper.handleException(returnFuture, () -> "Failed to send multipart upload requests", t);
if (uploadId != null) {
genericMultipartHelper.cleanUpParts(uploadId, toAbortMultipartUploadRequest(putObjectRequest));
}
cancelingOtherOngoingRequests(futures, t);
} catch (Throwable throwable) {
returnFuture.completeExceptionally(SdkClientException.create("Unexpected error occurred while handling the upstream "
+ "exception.", throwable));
}
cancelingOtherOngoingRequests(futures, t);

}

static void cancelingOtherOngoingRequests(Collection<CompletableFuture<CompletedPart>> futures, Throwable t) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,7 @@ private void uploadFromBeginning(Pair<PutObjectRequest, AsyncRequestBody> reques
.partSize(partSize)
.uploadId(uploadId)
.numPartsCompleted(numPartsCompleted)
.expectedNumParts(partCount)
.build();

splitAndSubscribe(mpuRequestContext, returnFuture);
Expand Down Expand Up @@ -170,6 +171,7 @@ private void resumePausedUpload(ResumeRequestContext resumeContext) {
.partSize(resumeToken.partSize())
.uploadId(uploadId)
.existingParts(existingParts)
.expectedNumParts(Math.toIntExact(resumeToken.totalNumParts()))
.numPartsCompleted(resumeToken.numPartsCompleted())
.build();

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,9 @@ public void onSubscribe(Subscription s) {

@Override
public void onNext(AsyncRequestBody asyncRequestBody) {
if (isDone) {
return;
}
int currentPartNum = partNumber.incrementAndGet();
log.trace(() -> "Received asyncRequestBody " + asyncRequestBody.contentLength());
asyncRequestBodyInFlight.incrementAndGet();
Expand Down Expand Up @@ -211,7 +214,17 @@ private void sendUploadPartRequest(String uploadId,
SdkClientException e = SdkClientException.create("Content length must be present on the AsyncRequestBody");
multipartUploadHelper.failRequestsElegantly(futures, e, uploadId, returnFuture, putObjectRequest);
}
this.contentLength.getAndAdd(contentLength.get());

Long contentLengthCurrentPart = contentLength.get();
if (contentLengthCurrentPart > partSizeInBytes) {
SdkClientException e = SdkClientException.create(String.format("Content length must not be greater than the "
+ "part size. Expected: %d, Actual: %d",
partSizeInBytes,
contentLengthCurrentPart));
multipartUploadHelper.failRequestsElegantly(futures, e, uploadId, returnFuture, putObjectRequest);
}

this.contentLength.getAndAdd(contentLengthCurrentPart);

multipartUploadHelper
.sendIndividualUploadPartRequest(uploadId, completedParts::add, futures,
Expand All @@ -235,13 +248,15 @@ private Pair<UploadPartRequest, AsyncRequestBody> uploadPart(AsyncRequestBody as
SdkPojoConversionUtils.toUploadPartRequest(putObjectRequest,
partNum,
uploadId);

return Pair.of(uploadRequest, asyncRequestBody);
}

@Override
public void onError(Throwable t) {
log.debug(() -> "Received onError() ", t);
if (failureActionInitiated.compareAndSet(false, true)) {
isDone = true;
multipartUploadHelper.failRequestsElegantly(futures, t, uploadId, returnFuture, putObjectRequest);
}
}
Expand All @@ -264,8 +279,19 @@ private void completeMultipartUploadIfFinish(int requestsInFlight) {
CompletedPart[] parts = completedParts.stream()
.sorted(Comparator.comparingInt(CompletedPart::partNumber))
.toArray(CompletedPart[]::new);

long totalLength = contentLength.get();
int expectedNumPart = genericMultipartHelper.determinePartCount(totalLength, partSizeInBytes);
if (parts.length != expectedNumPart) {
SdkClientException exception = SdkClientException.create(
String.format("The number of UploadParts requests is not equal to the expected number of parts. "
+ "Expected: %d, Actual: %d", expectedNumPart, parts.length));
multipartUploadHelper.failRequestsElegantly(futures, exception, uploadId, returnFuture, putObjectRequest);
return;
}

multipartUploadHelper.completeMultipartUpload(returnFuture, uploadId, parts, putObjectRequest,
this.contentLength.get());
totalLength);
}
}
}
Expand Down
Loading
Loading