Skip to content

Add Range-based Multipart download Subscriber for Pre-signed URLs #6331

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 6 commits into
base: feature/master/pre-signed-url-getobject
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,17 @@
import java.util.Collections;
import java.util.List;
import java.util.Optional;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
import software.amazon.awssdk.annotations.SdkInternalApi;
import software.amazon.awssdk.services.s3.model.GetObjectRequest;
import software.amazon.awssdk.services.s3.model.S3Request;

@SdkInternalApi
public final class MultipartDownloadUtils {

private static final Pattern CONTENT_RANGE_PATTERN = Pattern.compile("bytes\\s+(\\d+)-(\\d+)/(\\d+)");

private MultipartDownloadUtils() {
}

Expand Down Expand Up @@ -58,4 +63,52 @@ public static Optional<MultipartDownloadResumeContext> multipartDownloadResumeCo
.flatMap(conf -> Optional.ofNullable(conf.executionAttributes().getAttribute(MULTIPART_DOWNLOAD_RESUME_CONTEXT)));
}

/**
* This method checks the
* {@link software.amazon.awssdk.services.s3.multipart.S3MultipartExecutionAttribute#MULTIPART_DOWNLOAD_RESUME_CONTEXT}
* execution attributes for a context object and returns it if it finds one. Otherwise, returns an empty Optional.
*
* @param request the request to look for execution attributes
* @return the MultipartDownloadResumeContext if one is found, otherwise an empty Optional.
*/
public static Optional<MultipartDownloadResumeContext> multipartDownloadResumeContext(S3Request request) {
return request
.overrideConfiguration()
.flatMap(conf -> Optional.ofNullable(conf.executionAttributes().getAttribute(MULTIPART_DOWNLOAD_RESUME_CONTEXT)));
}

/**
* Parses the start byte from a Content-Range header.
*
* @param contentRange the Content-Range header value (e.g., "bytes 0-1023/2048")
* @return the start byte position, or -1 if parsing fails
*/
public static long parseStartByteFromContentRange(String contentRange) {
if (contentRange == null) {
return -1;
}
Matcher matcher = CONTENT_RANGE_PATTERN.matcher(contentRange);
if (!matcher.matches()) {
return -1;
}
return Long.parseLong(matcher.group(1));
}

/**
* Parses the total size from a Content-Range header.
*
* @param contentRange the Content-Range header value (e.g., "bytes 0-1023/2048")
* @return the total size, or empty if parsing fails
*/
public static Optional<Long> parseContentRangeForTotalSize(String contentRange) {
if (contentRange == null) {
return Optional.empty();
}
Matcher matcher = CONTENT_RANGE_PATTERN.matcher(contentRange);
if (!matcher.matches()) {
return Optional.empty();
}
return Optional.of(Long.parseLong(matcher.group(3)));
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import software.amazon.awssdk.annotations.SdkInternalApi;
import software.amazon.awssdk.core.SplittingTransformerConfiguration;
import software.amazon.awssdk.core.async.AsyncResponseTransformer;
import software.amazon.awssdk.core.exception.SdkClientException;
import software.amazon.awssdk.services.s3.S3AsyncClient;
import software.amazon.awssdk.services.s3.model.GetObjectResponse;
import software.amazon.awssdk.services.s3.presignedurl.AsyncPresignedUrlExtension;
Expand Down Expand Up @@ -63,15 +64,25 @@ public <T> CompletableFuture<T> downloadObject(
.build();
AsyncResponseTransformer.SplitResult<GetObjectResponse, T> split =
asyncResponseTransformer.split(splittingConfig);
// TODO: PresignedUrlMultipartDownloaderSubscriber needs to be implemented in next PR
// PresignedUrlMultipartDownloaderSubscriber subscriber =
// new PresignedUrlMultipartDownloaderSubscriber(
// s3AsyncClient,
// presignedRequest,
// configuredPartSizeInBytes);
//
// split.publisher().subscribe(subscriber);
// return split.resultFuture();
throw new UnsupportedOperationException("Multipart presigned URL download not yet implemented - TODO in next PR");
PresignedUrlMultipartDownloaderSubscriber subscriber =
new PresignedUrlMultipartDownloaderSubscriber(
s3AsyncClient,
presignedRequest,
configuredPartSizeInBytes);

split.publisher().subscribe(subscriber);
return split.resultFuture();
}

static SdkClientException invalidContentRangeHeader(String contentRange) {
return SdkClientException.create("Invalid Content-Range header: " + contentRange);
}

static SdkClientException missingContentRangeHeader() {
return SdkClientException.create("No Content-Range header in response");
}

static SdkClientException invalidContentLength() {
return SdkClientException.create("Invalid or missing Content-Length in response");
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,275 @@
/*
* Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License").
* You may not use this file except in compliance with the License.
* A copy of the License is located at
*
* http://aws.amazon.com/apache2.0
*
* or in the "license" file accompanying this file. This file is distributed
* on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
* express or implied. See the License for the specific language governing
* permissions and limitations under the License.
*/

package software.amazon.awssdk.services.s3.internal.multipart;

import java.util.Optional;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.atomic.AtomicInteger;
import org.reactivestreams.Subscriber;
import org.reactivestreams.Subscription;
import software.amazon.awssdk.annotations.Immutable;
import software.amazon.awssdk.annotations.SdkInternalApi;
import software.amazon.awssdk.annotations.ThreadSafe;
import software.amazon.awssdk.core.async.AsyncResponseTransformer;
import software.amazon.awssdk.core.exception.SdkClientException;
import software.amazon.awssdk.services.s3.S3AsyncClient;
import software.amazon.awssdk.services.s3.model.GetObjectResponse;
import software.amazon.awssdk.services.s3.presignedurl.model.PresignedUrlDownloadRequest;
import software.amazon.awssdk.utils.Logger;

/**
* A subscriber implementation that will download all individual parts for a multipart presigned URL download request.
* It receives individual {@link AsyncResponseTransformer} instances which will be used to perform the individual
* range-based part requests using presigned URLs. This is a 'one-shot' class, it should <em>NOT</em> be reused
* for more than one multipart download.
*
* <p>Unlike the standard {@link MultipartDownloaderSubscriber} which uses S3's native multipart API with part numbers,
* this subscriber uses HTTP range requests against presigned URLs to achieve multipart download functionality.
* <p>This implementation is thread-safe and handles concurrent part downloads while maintaining proper
* ordering and validation of responses.</p>
*/
@ThreadSafe
@Immutable
@SdkInternalApi
public class PresignedUrlMultipartDownloaderSubscriber
implements Subscriber<AsyncResponseTransformer<GetObjectResponse, GetObjectResponse>> {

private static final Logger log = Logger.loggerFor(PresignedUrlMultipartDownloaderSubscriber.class);
private static final String BYTES_RANGE_PREFIX = "bytes=";

private final S3AsyncClient s3AsyncClient;
private final PresignedUrlDownloadRequest presignedUrlDownloadRequest;
private final Long configuredPartSizeInBytes;
private final CompletableFuture<Void> future;
private final Object lock = new Object();
private final AtomicInteger completedParts;
private final AtomicInteger requestsSent;

private volatile Long totalContentLength;
private volatile Integer totalParts;
private volatile String eTag;
private Subscription subscription;

public PresignedUrlMultipartDownloaderSubscriber(
S3AsyncClient s3AsyncClient,
PresignedUrlDownloadRequest presignedUrlDownloadRequest,
long configuredPartSizeInBytes) {
this.s3AsyncClient = s3AsyncClient;
this.presignedUrlDownloadRequest = presignedUrlDownloadRequest;
this.configuredPartSizeInBytes = configuredPartSizeInBytes;
this.completedParts = new AtomicInteger(0);
this.requestsSent = new AtomicInteger(0);
this.future = new CompletableFuture<>();
}

@Override
public void onSubscribe(Subscription s) {
if (subscription != null) {
s.cancel();
return;
}
this.subscription = s;
s.request(1);
}

@Override
public void onNext(AsyncResponseTransformer<GetObjectResponse, GetObjectResponse> asyncResponseTransformer) {
if (asyncResponseTransformer == null) {
throw new NullPointerException("onNext must not be called with null asyncResponseTransformer");
}

int nextPartIndex;
synchronized (lock) {
nextPartIndex = completedParts.get();
if (totalParts != null && nextPartIndex >= totalParts) {
log.debug(() -> String.format("Completing multipart download after a total of %d parts downloaded.", totalParts));
subscription.cancel();
return;
}
completedParts.incrementAndGet();
}
makeRangeRequest(nextPartIndex, asyncResponseTransformer);
}

private void makeRangeRequest(int partIndex,
AsyncResponseTransformer<GetObjectResponse,
GetObjectResponse> asyncResponseTransformer) {
PresignedUrlDownloadRequest partRequest = createRangedGetRequest(partIndex);
log.debug(() -> "Sending range request for part " + partIndex + " with range=" + partRequest.range());

requestsSent.incrementAndGet();
s3AsyncClient.presignedUrlExtension()
.getObject(partRequest, asyncResponseTransformer)
.whenComplete((response, error) -> {
if (error != null) {
log.debug(() -> "Error encountered during part request for part " + partIndex);
handleError(error);
return;
}
requestMoreIfNeeded(response, partIndex);
});
}

private void requestMoreIfNeeded(GetObjectResponse response, int partIndex) {
int totalComplete = completedParts.get();
log.debug(() -> String.format("Completed part %d", totalComplete));

String responseETag = response.eTag();
String responseContentRange = response.contentRange();
if (eTag == null) {
this.eTag = responseETag;
log.debug(() -> String.format("Multipart object ETag: %s", this.eTag));
}

Optional<SdkClientException> validationError = validateResponse(response, partIndex);
if (validationError.isPresent()) {
log.debug(() -> "Response validation failed", validationError.get());
handleError(validationError.get());
return;
}

if (totalContentLength == null && responseContentRange != null) {
Optional<Long> parsedContentLength = MultipartDownloadUtils.parseContentRangeForTotalSize(responseContentRange);
if (!parsedContentLength.isPresent()) {
SdkClientException error = PresignedUrlDownloadHelper.invalidContentRangeHeader(responseContentRange);
log.debug(() -> "Failed to parse content range", error);
handleError(error);
return;
}

this.totalContentLength = parsedContentLength.get();
this.totalParts = calculateTotalParts(totalContentLength, configuredPartSizeInBytes);
log.debug(() -> String.format("Total content length: %d, Total parts: %d", totalContentLength, totalParts));
}

synchronized (lock) {
if (hasMoreParts(totalComplete)) {
subscription.request(1);
} else {
if (totalParts != null && requestsSent.get() != totalParts) {
handleError(new IllegalStateException(
"Request count mismatch. Expected: " + totalParts + ", sent: " + requestsSent.get()));
return;
}
log.debug(() -> String.format("Completing multipart download after a total of %d parts downloaded.", totalParts));
subscription.cancel();
}
}
}

private Optional<SdkClientException> validateResponse(GetObjectResponse response, int partIndex) {
if (response == null) {
return Optional.of(SdkClientException.create("Response cannot be null"));
}

String contentRange = response.contentRange();
if (contentRange == null) {
return Optional.of(PresignedUrlDownloadHelper.missingContentRangeHeader());
}

Long contentLength = response.contentLength();
if (contentLength == null || contentLength < 0) {
return Optional.of(PresignedUrlDownloadHelper.invalidContentLength());
}

long expectedStartByte = partIndex * configuredPartSizeInBytes;
long expectedEndByte;
if (totalContentLength != null) {
expectedEndByte = Math.min(expectedStartByte + configuredPartSizeInBytes - 1, totalContentLength - 1);
} else {
expectedEndByte = expectedStartByte + configuredPartSizeInBytes - 1;
}

String expectedRange = "bytes " + expectedStartByte + "-" + expectedEndByte + "/";
if (!contentRange.startsWith(expectedRange)) {
return Optional.of(SdkClientException.create(
"Content-Range mismatch. Expected range starting with: " + expectedRange +
", but got: " + contentRange));
}

long expectedPartSize;
if (totalContentLength != null && partIndex == totalParts - 1) {
expectedPartSize = totalContentLength - (partIndex * configuredPartSizeInBytes);
} else {
expectedPartSize = configuredPartSizeInBytes;
}

if (!contentLength.equals(expectedPartSize)) {
return Optional.of(SdkClientException.create(
"Part content length validation failed for part " + partIndex +
". Expected: " + expectedPartSize + ", but got: " + contentLength));
}

long actualStartByte = MultipartDownloadUtils.parseStartByteFromContentRange(contentRange);
if (actualStartByte != expectedStartByte) {
return Optional.of(SdkClientException.create(
"Content range offset mismatch for part " + partIndex +
". Expected start: " + expectedStartByte + ", but got: " + actualStartByte));
}

return Optional.empty();
}

private int calculateTotalParts(long contentLength, long partSize) {
return (int) Math.ceil((double) contentLength / partSize);
}

private boolean hasMoreParts(int completedPartsCount) {
return totalParts != null && totalParts > 1 && completedPartsCount < totalParts;
}

private PresignedUrlDownloadRequest createRangedGetRequest(int partIndex) {
long startByte = partIndex * configuredPartSizeInBytes;
long endByte;
if (totalContentLength != null) {
endByte = Math.min(startByte + configuredPartSizeInBytes - 1, totalContentLength - 1);
} else {
endByte = startByte + configuredPartSizeInBytes - 1;
}
String rangeHeader = BYTES_RANGE_PREFIX + startByte + "-" + endByte;
PresignedUrlDownloadRequest.Builder builder = presignedUrlDownloadRequest.toBuilder()
.range(rangeHeader);
if (partIndex > 0 && eTag != null) {
builder.ifMatch(eTag);
log.debug(() -> "Setting IfMatch header to: " + eTag + " for part " + partIndex);
}
return builder.build();
}

private void handleError(Throwable t) {
synchronized (lock) {
if (subscription != null) {
subscription.cancel();
}
}
onError(t);
}

@Override
public void onError(Throwable t) {
log.debug(() -> "Error in multipart download", t);
future.completeExceptionally(t);
}

@Override
public void onComplete() {
future.complete(null);
}

public CompletableFuture<Void> future() {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this used anywhere?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

future() method is used for testing to verify subscriber completion/error states

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see, can we add SdkTestInternalApi annotation?

return future;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,7 @@ public <ReturnT> CompletableFuture<ReturnT> getObject(
PresignedUrlDownloadRequestWrapper internalRequest = PresignedUrlDownloadRequestWrapper.builder()
.url(presignedUrlDownloadRequest.presignedUrl())
.range(presignedUrlDownloadRequest.range())
.ifMatch(presignedUrlDownloadRequest.ifMatch())
.build();

MetricCollector apiCallMetricCollector = metricPublishers.isEmpty() ?
Expand Down
Loading