Skip to content

added part count and content range validation for download request #6353

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions .changes/next-release/bugfix-AWSSDKforJavav2-17f90b1.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
{
"type": "bugfix",
"category": "AWS SDK for Java v2",
"contributor": "",
"description": "Added partCount and ContentRange validation for s3 transfer manager download request "
}
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import org.reactivestreams.Subscription;
import software.amazon.awssdk.annotations.SdkInternalApi;
import software.amazon.awssdk.core.async.AsyncResponseTransformer;
import software.amazon.awssdk.core.exception.SdkClientException;
import software.amazon.awssdk.services.s3.S3AsyncClient;
import software.amazon.awssdk.services.s3.model.GetObjectRequest;
import software.amazon.awssdk.services.s3.model.GetObjectResponse;
Expand Down Expand Up @@ -76,6 +77,16 @@ public class MultipartDownloaderSubscriber implements Subscriber<AsyncResponseTr
*/
private volatile String eTag;

/**
* The size of each part of the object being downloaded.
*/
private volatile Long partSize;

/**
* The total size of the object being downloaded.
*/
private volatile Long totalContentLength;

/**
* The Subscription lock
*/
Expand Down Expand Up @@ -117,6 +128,7 @@ public void onNext(AsyncResponseTransformer<GetObjectResponse, GetObjectResponse

synchronized (lock) {
if (totalParts != null && nextPartToGet > totalParts) {
validatePartsCount(completedParts.get());
log.debug(() -> String.format("Completing multipart download after a total of %d parts downloaded.", totalParts));
subscription.cancel();
return;
Expand Down Expand Up @@ -162,10 +174,20 @@ private void requestMoreIfNeeded(GetObjectResponse response) {
totalParts = partCount;
}

String actualContentRange = response.contentRange();
if (actualContentRange != null && partSize == null) {
getRangeInfo(actualContentRange);
log.debug(() -> String.format("Part size of the object to download: " + partSize));
log.debug(() -> String.format("Total Content Length of the object to download: " + totalContentLength));
}

validateContentRange(totalComplete, actualContentRange);
Comment on lines +177 to +184
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It looks like there a race on the partSize and totalContentLength, since reading them is done outside of the lock. Or do we know that they won't change per part so it's safe?

Can you explain why it's safe to have multiple get object results modifying the part size and total content length in parallel before we do (again, potentially parallel) validation on them?


synchronized (lock) {
if (totalParts != null && totalParts > 1 && totalComplete < totalParts) {
subscription.request(1);
} else {
validatePartsCount(completedParts.get());
log.debug(() -> String.format("Completing multipart download after a total of %d parts downloaded.", totalParts));
subscription.cancel();
}
Expand Down Expand Up @@ -198,4 +220,45 @@ private GetObjectRequest nextRequest(int nextPartToGet) {
}
});
}

private void validatePartsCount(int currentGetCount) {
if (totalParts != null && currentGetCount != totalParts) {
String errorMessage = "PartsCount validation failed. Expected " + totalParts + ", downloaded"
+ " " + currentGetCount + " parts.";
log.error(() -> errorMessage);
subscription.cancel();
SdkClientException exception = SdkClientException.create(errorMessage);
onError(exception);
}
}

private void validateContentRange(int partNumber, String contentRange) {
if (contentRange == null) {
return;
}

long expectedStart = (partNumber - 1) * partSize;
long expectedEnd = partNumber == totalParts ? totalContentLength - 1 : expectedStart + partSize - 1;

String expectedContentRange = String.format("bytes %d-%d/%d", expectedStart, expectedEnd, totalContentLength);

if (!expectedContentRange.equals(contentRange)) {
String errorMessage = String.format(
"Content-Range validation failed for part %d. Expected: %s, Actual: %s",
partNumber, expectedContentRange, contentRange);
log.error(() -> errorMessage);
onError(SdkClientException.create(errorMessage));
}
}

private void getRangeInfo(String contentRange) {
String rangeInfo = contentRange.substring(6);
String[] parts = rangeInfo.split("/");

this.totalContentLength = Long.parseLong(parts[1]);
String[] rangeParts = parts[0].split("-");
long startByte = Long.parseLong(rangeParts[0]);
long endByte = Long.parseLong(rangeParts[1]);
this.partSize = endByte - startByte + 1;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,21 @@ public byte[] stubForPart(String testBucket, String testKey,int part, int totalP
aResponse()
.withHeader("x-amz-mp-parts-count", totalPart + "")
.withHeader("ETag", eTag)
.withHeader("Content-Length", String.valueOf(body.length))
.withHeader("Content-Range", contentRange(part, totalPart, partSize))
.withBody(body)));
return body;
}

public byte[] stubForPartwithWrongContentRange(String testBucket, String testKey,int part, int totalPart, int partSize) {
byte[] body = new byte[partSize];
random.nextBytes(body);
stubFor(get(urlEqualTo(String.format("/%s/%s?partNumber=%d", testBucket, testKey, part))).willReturn(
aResponse()
.withHeader("x-amz-mp-parts-count", totalPart + "")
.withHeader("ETag", eTag)
.withHeader("Content-Length", String.valueOf(body.length))
.withHeader("Content-Range", contentRange(part, totalPart, partSize + 1))
.withBody(body)));
return body;
}
Expand All @@ -95,4 +110,16 @@ public byte[] stubForPartSuccess(int part, int totalPart, int partSize) {
.withBody(body)));
return body;
}

private String contentRange(int part, int totalPart, int partSize) {
long totalObjectSize = (long) totalPart * partSize;
long startByte = (long) (part - 1) * partSize;
long endByte = startByte + partSize - 1;

if (part == totalPart) {
endByte = totalObjectSize - 1;
}

return String.format("bytes %d-%d/%d", startByte, endByte, totalObjectSize);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,32 @@ <T> void errorOnThirdRequest_shouldCompleteExceptionallyOnlyPartsGreaterThanTwo(
}
}

@ParameterizedTest
@MethodSource("argumentsProvider")
<T> void wrongContentRangeOnSecondRequest_should(AsyncResponseTransformerTestSupplier<T> supplier,
int amountOfPartToTest,
int partSize) {
util.stubForPart(testBucket, testKey, 1, 3, partSize);
util.stubForPartwithWrongContentRange(testBucket, testKey, 2, 3, partSize);
util.stubForPart(testBucket, testKey, 3, 3, partSize);
//byte[] expectedBody = util.stubAllParts(testBucket, testKey, amountOfPartToTest, partSize);
AsyncResponseTransformer<GetObjectResponse, T> transformer = supplier.transformer();
AsyncResponseTransformer.SplitResult<GetObjectResponse, T> split = transformer.split(
SplittingTransformerConfiguration.builder()
.bufferSizeInBytes(1024 * 32L)
.build());
Subscriber<AsyncResponseTransformer<GetObjectResponse, GetObjectResponse>> subscriber = new MultipartDownloaderSubscriber(
s3AsyncClient,
GetObjectRequest.builder()
.bucket(testBucket)
.key(testKey)
.build());

split.publisher().subscribe(subscriber);
T response = split.resultFuture().join();

}

private static Stream<Arguments> argumentsProvider() {
// amount of part, individual part size
List<Pair<Integer, Integer>> partSizes = Arrays.asList(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
import software.amazon.awssdk.core.sync.ResponseTransformer;
import software.amazon.awssdk.metrics.publishers.emf.EmfMetricLoggingPublisher;
import software.amazon.awssdk.metrics.publishers.emf.internal.MetricEmfConverter;
import software.amazon.awssdk.services.s3.internal.multipart.MultipartDownloaderSubscriber;
import software.amazon.awssdk.utils.Logger;

/**
Expand All @@ -57,6 +58,7 @@ public class CodingConventionWithSuppressionTest {
private static final Set<Pattern> ALLOWED_ERROR_LOG_SUPPRESSION = new HashSet<>(
Arrays.asList(
ArchUtils.classNameToPattern(EmfMetricLoggingPublisher.class),
ArchUtils.classNameToPattern(MultipartDownloaderSubscriber.class),
ArchUtils.classWithInnerClassesToPattern(ResponseTransformer.class)));

@Test
Expand Down
Loading