-
Notifications
You must be signed in to change notification settings - Fork 932
Add validations for upload in s3 mulitpart client #6282
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 1 commit
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,6 @@ | ||
{ | ||
"type": "bugfix", | ||
"category": "Amazon S3", | ||
"contributor": "", | ||
"description": "Add additional validations for multipart upload operations in the Java multipart S3 client." | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -32,6 +32,7 @@ | |
import software.amazon.awssdk.annotations.SdkInternalApi; | ||
import software.amazon.awssdk.core.async.AsyncRequestBody; | ||
import software.amazon.awssdk.core.async.listener.PublisherListener; | ||
import software.amazon.awssdk.core.exception.SdkClientException; | ||
import software.amazon.awssdk.services.s3.model.CompleteMultipartUploadResponse; | ||
import software.amazon.awssdk.services.s3.model.CompletedPart; | ||
import software.amazon.awssdk.services.s3.model.PutObjectRequest; | ||
|
@@ -54,10 +55,10 @@ public class KnownContentLengthAsyncRequestBodySubscriber implements Subscriber< | |
private final AtomicBoolean failureActionInitiated = new AtomicBoolean(false); | ||
private final AtomicInteger partNumber = new AtomicInteger(1); | ||
private final MultipartUploadHelper multipartUploadHelper; | ||
private final long contentLength; | ||
private final long totalSize; | ||
private final long partSize; | ||
private final int partCount; | ||
private final int numExistingParts; | ||
private final int expectedNumParts; | ||
private final int existingNumParts; | ||
private final String uploadId; | ||
private final Collection<CompletableFuture<CompletedPart>> futures = new ConcurrentLinkedQueue<>(); | ||
private final PutObjectRequest putObjectRequest; | ||
|
@@ -77,25 +78,21 @@ public class KnownContentLengthAsyncRequestBodySubscriber implements Subscriber< | |
KnownContentLengthAsyncRequestBodySubscriber(MpuRequestContext mpuRequestContext, | ||
CompletableFuture<PutObjectResponse> returnFuture, | ||
MultipartUploadHelper multipartUploadHelper) { | ||
this.contentLength = mpuRequestContext.contentLength(); | ||
this.totalSize = mpuRequestContext.contentLength(); | ||
this.partSize = mpuRequestContext.partSize(); | ||
this.partCount = determinePartCount(contentLength, partSize); | ||
this.expectedNumParts = mpuRequestContext.expectedNumParts(); | ||
this.putObjectRequest = mpuRequestContext.request().left(); | ||
this.returnFuture = returnFuture; | ||
this.uploadId = mpuRequestContext.uploadId(); | ||
this.existingParts = mpuRequestContext.existingParts() == null ? new HashMap<>() : mpuRequestContext.existingParts(); | ||
this.numExistingParts = NumericUtils.saturatedCast(mpuRequestContext.numPartsCompleted()); | ||
this.completedParts = new AtomicReferenceArray<>(partCount); | ||
this.existingNumParts = NumericUtils.saturatedCast(mpuRequestContext.numPartsCompleted()); | ||
this.completedParts = new AtomicReferenceArray<>(expectedNumParts); | ||
this.multipartUploadHelper = multipartUploadHelper; | ||
this.progressListener = putObjectRequest.overrideConfiguration().map(c -> c.executionAttributes() | ||
.getAttribute(JAVA_PROGRESS_LISTENER)) | ||
.orElseGet(PublisherListener::noOp); | ||
} | ||
|
||
private int determinePartCount(long contentLength, long partSize) { | ||
return (int) Math.ceil(contentLength / (double) partSize); | ||
} | ||
|
||
public S3ResumeToken pause() { | ||
isPaused = true; | ||
|
||
|
@@ -119,8 +116,8 @@ public S3ResumeToken pause() { | |
return S3ResumeToken.builder() | ||
.uploadId(uploadId) | ||
.partSize(partSize) | ||
.totalNumParts((long) partCount) | ||
.numPartsCompleted(numPartsCompleted + numExistingParts) | ||
.totalNumParts((long) expectedNumParts) | ||
.numPartsCompleted(numPartsCompleted + existingNumParts) | ||
.build(); | ||
} | ||
|
||
|
@@ -145,21 +142,23 @@ public void onSubscribe(Subscription s) { | |
|
||
@Override | ||
public void onNext(AsyncRequestBody asyncRequestBody) { | ||
if (isPaused) { | ||
if (isPaused || isDone) { | ||
return; | ||
} | ||
|
||
if (existingParts.containsKey(partNumber.get())) { | ||
partNumber.getAndIncrement(); | ||
int currentPartNum = partNumber.getAndIncrement(); | ||
if (existingParts.containsKey(currentPartNum)) { | ||
asyncRequestBody.subscribe(new CancelledSubscriber<>()); | ||
subscription.request(1); | ||
asyncRequestBody.contentLength().ifPresent(progressListener::subscriberOnNext); | ||
return; | ||
} | ||
|
||
validatePart(asyncRequestBody, currentPartNum); | ||
|
||
asyncRequestBodyInFlight.incrementAndGet(); | ||
UploadPartRequest uploadRequest = SdkPojoConversionUtils.toUploadPartRequest(putObjectRequest, | ||
partNumber.getAndIncrement(), | ||
currentPartNum, | ||
uploadId); | ||
|
||
Consumer<CompletedPart> completedPartConsumer = completedPart -> completedParts.set(completedPart.partNumber() - 1, | ||
|
@@ -179,6 +178,49 @@ public void onNext(AsyncRequestBody asyncRequestBody) { | |
subscription.request(1); | ||
} | ||
|
||
private void validatePart(AsyncRequestBody asyncRequestBody, int currentPartNum) { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I see we already have JUnit tests for each class. However, I was wondering how we end up with the validation failures below from an external API perspective, such as when users pass invalid AsyncRequestBody or files get corrupted in transit. Is it possible to write end-to-end test cases for these scenarios so that we can test them with UnknownContentLength publisher or with S3CrtClient? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yup, good point, added |
||
if (!asyncRequestBody.contentLength().isPresent()) { | ||
SdkClientException e = SdkClientException.create("Content length must be present on the AsyncRequestBody"); | ||
multipartUploadHelper.failRequestsElegantly(futures, e, uploadId, returnFuture, putObjectRequest); | ||
return; | ||
} | ||
|
||
Long currentPartSize = asyncRequestBody.contentLength().get(); | ||
if (currentPartNum > expectedNumParts) { | ||
SdkClientException exception = SdkClientException.create(String.format("The number of parts divided is " | ||
+ "not equal to the expected number of " | ||
+ "parts. Expected: %d, Actual: %d", | ||
expectedNumParts, currentPartNum)); | ||
multipartUploadHelper.failRequestsElegantly(futures, exception, uploadId, returnFuture, putObjectRequest); | ||
return; | ||
} | ||
|
||
if (currentPartNum == expectedNumParts) { | ||
validateLastPartSize(currentPartSize); | ||
return; | ||
} | ||
|
||
if (currentPartSize != partSize) { | ||
SdkClientException e = SdkClientException.create(String.format("Content length must be equal to the " | ||
+ "part size. Expected: %d, Actual: %d", | ||
partSize, | ||
currentPartSize)); | ||
multipartUploadHelper.failRequestsElegantly(futures, e, uploadId, returnFuture, putObjectRequest); | ||
} | ||
} | ||
|
||
private void validateLastPartSize(Long currentPartSize) { | ||
long remainder = totalSize % partSize; | ||
long expectedLastPartSize = remainder == 0 ? partSize : remainder; | ||
if (currentPartSize != expectedLastPartSize) { | ||
SdkClientException exception = | ||
SdkClientException.create("Content length of the last part must be equal to the " | ||
+ "expected last part size. Expected: " + expectedLastPartSize | ||
+ ", Actual: " + currentPartSize); | ||
multipartUploadHelper.failRequestsElegantly(futures, exception, uploadId, returnFuture, putObjectRequest); | ||
} | ||
} | ||
|
||
private boolean shouldFailRequest() { | ||
return failureActionInitiated.compareAndSet(false, true) && !isPaused; | ||
} | ||
|
@@ -187,6 +229,7 @@ private boolean shouldFailRequest() { | |
public void onError(Throwable t) { | ||
log.debug(() -> "Received onError ", t); | ||
if (failureActionInitiated.compareAndSet(false, true)) { | ||
isDone = true; | ||
multipartUploadHelper.failRequestsElegantly(futures, t, uploadId, returnFuture, putObjectRequest); | ||
} | ||
} | ||
|
@@ -203,6 +246,7 @@ public void onComplete() { | |
private void completeMultipartUploadIfFinished(int requestsInFlight) { | ||
if (isDone && requestsInFlight == 0 && completedMultipartInitiated.compareAndSet(false, true)) { | ||
CompletedPart[] parts; | ||
|
||
if (existingParts.isEmpty()) { | ||
parts = | ||
IntStream.range(0, completedParts.length()) | ||
|
@@ -213,14 +257,14 @@ private void completeMultipartUploadIfFinished(int requestsInFlight) { | |
parts = mergeCompletedParts(); | ||
} | ||
completeMpuFuture = multipartUploadHelper.completeMultipartUpload(returnFuture, uploadId, parts, putObjectRequest, | ||
contentLength); | ||
totalSize); | ||
} | ||
} | ||
|
||
private CompletedPart[] mergeCompletedParts() { | ||
CompletedPart[] merged = new CompletedPart[partCount]; | ||
CompletedPart[] merged = new CompletedPart[expectedNumParts]; | ||
int currPart = 1; | ||
while (currPart < partCount + 1) { | ||
while (currPart < expectedNumParts + 1) { | ||
CompletedPart completedPart = existingParts.containsKey(currPart) ? existingParts.get(currPart) : | ||
completedParts.get(currPart - 1); | ||
merged[currPart - 1] = completedPart; | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you please help me to understand why earlier we used to do contains on get, now we first increment and then do containsKey check.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't know why we did that earlier, but the reason I changed is to avoid another atomic integer get call (micro perf optimization)