diff --git a/services/s3/src/main/java/software/amazon/awssdk/services/s3/internal/multipart/MultipartDownloadUtils.java b/services/s3/src/main/java/software/amazon/awssdk/services/s3/internal/multipart/MultipartDownloadUtils.java index 807b6a8bbbc..8be5bb8ee5b 100644 --- a/services/s3/src/main/java/software/amazon/awssdk/services/s3/internal/multipart/MultipartDownloadUtils.java +++ b/services/s3/src/main/java/software/amazon/awssdk/services/s3/internal/multipart/MultipartDownloadUtils.java @@ -20,12 +20,17 @@ import java.util.Collections; import java.util.List; import java.util.Optional; +import java.util.regex.Matcher; +import java.util.regex.Pattern; import software.amazon.awssdk.annotations.SdkInternalApi; import software.amazon.awssdk.services.s3.model.GetObjectRequest; +import software.amazon.awssdk.services.s3.model.S3Request; @SdkInternalApi public final class MultipartDownloadUtils { + private static final Pattern CONTENT_RANGE_PATTERN = Pattern.compile("bytes\\s+(\\d+)-(\\d+)/(\\d+)"); + private MultipartDownloadUtils() { } @@ -58,4 +63,52 @@ public static Optional multipartDownloadResumeCo .flatMap(conf -> Optional.ofNullable(conf.executionAttributes().getAttribute(MULTIPART_DOWNLOAD_RESUME_CONTEXT))); } + /** + * This method checks the + * {@link software.amazon.awssdk.services.s3.multipart.S3MultipartExecutionAttribute#MULTIPART_DOWNLOAD_RESUME_CONTEXT} + * execution attributes for a context object and returns it if it finds one. Otherwise, returns an empty Optional. + * + * @param request the request to look for execution attributes + * @return the MultipartDownloadResumeContext if one is found, otherwise an empty Optional. + */ + public static Optional multipartDownloadResumeContext(S3Request request) { + return request + .overrideConfiguration() + .flatMap(conf -> Optional.ofNullable(conf.executionAttributes().getAttribute(MULTIPART_DOWNLOAD_RESUME_CONTEXT))); + } + + /** + * Parses the start byte from a Content-Range header. + * + * @param contentRange the Content-Range header value (e.g., "bytes 0-1023/2048") + * @return the start byte position, or -1 if parsing fails + */ + public static long parseStartByteFromContentRange(String contentRange) { + if (contentRange == null) { + return -1; + } + Matcher matcher = CONTENT_RANGE_PATTERN.matcher(contentRange); + if (!matcher.matches()) { + return -1; + } + return Long.parseLong(matcher.group(1)); + } + + /** + * Parses the total size from a Content-Range header. + * + * @param contentRange the Content-Range header value (e.g., "bytes 0-1023/2048") + * @return the total size, or empty if parsing fails + */ + public static Optional parseContentRangeForTotalSize(String contentRange) { + if (contentRange == null) { + return Optional.empty(); + } + Matcher matcher = CONTENT_RANGE_PATTERN.matcher(contentRange); + if (!matcher.matches()) { + return Optional.empty(); + } + return Optional.of(Long.parseLong(matcher.group(3))); + } + } diff --git a/services/s3/src/main/java/software/amazon/awssdk/services/s3/internal/multipart/PresignedUrlDownloadHelper.java b/services/s3/src/main/java/software/amazon/awssdk/services/s3/internal/multipart/PresignedUrlDownloadHelper.java index cfb94e8a024..7ee520b76c2 100644 --- a/services/s3/src/main/java/software/amazon/awssdk/services/s3/internal/multipart/PresignedUrlDownloadHelper.java +++ b/services/s3/src/main/java/software/amazon/awssdk/services/s3/internal/multipart/PresignedUrlDownloadHelper.java @@ -19,6 +19,7 @@ import software.amazon.awssdk.annotations.SdkInternalApi; import software.amazon.awssdk.core.SplittingTransformerConfiguration; import software.amazon.awssdk.core.async.AsyncResponseTransformer; +import software.amazon.awssdk.core.exception.SdkClientException; import software.amazon.awssdk.services.s3.S3AsyncClient; import software.amazon.awssdk.services.s3.model.GetObjectResponse; import software.amazon.awssdk.services.s3.presignedurl.AsyncPresignedUrlExtension; @@ -63,15 +64,25 @@ public CompletableFuture downloadObject( .build(); AsyncResponseTransformer.SplitResult split = asyncResponseTransformer.split(splittingConfig); - // TODO: PresignedUrlMultipartDownloaderSubscriber needs to be implemented in next PR - // PresignedUrlMultipartDownloaderSubscriber subscriber = - // new PresignedUrlMultipartDownloaderSubscriber( - // s3AsyncClient, - // presignedRequest, - // configuredPartSizeInBytes); - // - // split.publisher().subscribe(subscriber); - // return split.resultFuture(); - throw new UnsupportedOperationException("Multipart presigned URL download not yet implemented - TODO in next PR"); + PresignedUrlMultipartDownloaderSubscriber subscriber = + new PresignedUrlMultipartDownloaderSubscriber( + s3AsyncClient, + presignedRequest, + configuredPartSizeInBytes); + + split.publisher().subscribe(subscriber); + return split.resultFuture(); + } + + static SdkClientException invalidContentRangeHeader(String contentRange) { + return SdkClientException.create("Invalid Content-Range header: " + contentRange); + } + + static SdkClientException missingContentRangeHeader() { + return SdkClientException.create("No Content-Range header in response"); + } + + static SdkClientException invalidContentLength() { + return SdkClientException.create("Invalid or missing Content-Length in response"); } } diff --git a/services/s3/src/main/java/software/amazon/awssdk/services/s3/internal/multipart/PresignedUrlMultipartDownloaderSubscriber.java b/services/s3/src/main/java/software/amazon/awssdk/services/s3/internal/multipart/PresignedUrlMultipartDownloaderSubscriber.java new file mode 100644 index 00000000000..69b210cb287 --- /dev/null +++ b/services/s3/src/main/java/software/amazon/awssdk/services/s3/internal/multipart/PresignedUrlMultipartDownloaderSubscriber.java @@ -0,0 +1,275 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * A copy of the License is located at + * + * http://aws.amazon.com/apache2.0 + * + * or in the "license" file accompanying this file. This file is distributed + * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either + * express or implied. See the License for the specific language governing + * permissions and limitations under the License. + */ + +package software.amazon.awssdk.services.s3.internal.multipart; + +import java.util.Optional; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.atomic.AtomicInteger; +import org.reactivestreams.Subscriber; +import org.reactivestreams.Subscription; +import software.amazon.awssdk.annotations.Immutable; +import software.amazon.awssdk.annotations.SdkInternalApi; +import software.amazon.awssdk.annotations.ThreadSafe; +import software.amazon.awssdk.core.async.AsyncResponseTransformer; +import software.amazon.awssdk.core.exception.SdkClientException; +import software.amazon.awssdk.services.s3.S3AsyncClient; +import software.amazon.awssdk.services.s3.model.GetObjectResponse; +import software.amazon.awssdk.services.s3.presignedurl.model.PresignedUrlDownloadRequest; +import software.amazon.awssdk.utils.Logger; + +/** + * A subscriber implementation that will download all individual parts for a multipart presigned URL download request. + * It receives individual {@link AsyncResponseTransformer} instances which will be used to perform the individual + * range-based part requests using presigned URLs. This is a 'one-shot' class, it should NOT be reused + * for more than one multipart download. + * + *

Unlike the standard {@link MultipartDownloaderSubscriber} which uses S3's native multipart API with part numbers, + * this subscriber uses HTTP range requests against presigned URLs to achieve multipart download functionality. + *

This implementation is thread-safe and handles concurrent part downloads while maintaining proper + * ordering and validation of responses.

+ */ +@ThreadSafe +@Immutable +@SdkInternalApi +public class PresignedUrlMultipartDownloaderSubscriber + implements Subscriber> { + + private static final Logger log = Logger.loggerFor(PresignedUrlMultipartDownloaderSubscriber.class); + private static final String BYTES_RANGE_PREFIX = "bytes="; + + private final S3AsyncClient s3AsyncClient; + private final PresignedUrlDownloadRequest presignedUrlDownloadRequest; + private final Long configuredPartSizeInBytes; + private final CompletableFuture future; + private final Object lock = new Object(); + private final AtomicInteger completedParts; + private final AtomicInteger requestsSent; + + private volatile Long totalContentLength; + private volatile Integer totalParts; + private volatile String eTag; + private Subscription subscription; + + public PresignedUrlMultipartDownloaderSubscriber( + S3AsyncClient s3AsyncClient, + PresignedUrlDownloadRequest presignedUrlDownloadRequest, + long configuredPartSizeInBytes) { + this.s3AsyncClient = s3AsyncClient; + this.presignedUrlDownloadRequest = presignedUrlDownloadRequest; + this.configuredPartSizeInBytes = configuredPartSizeInBytes; + this.completedParts = new AtomicInteger(0); + this.requestsSent = new AtomicInteger(0); + this.future = new CompletableFuture<>(); + } + + @Override + public void onSubscribe(Subscription s) { + if (subscription != null) { + s.cancel(); + return; + } + this.subscription = s; + s.request(1); + } + + @Override + public void onNext(AsyncResponseTransformer asyncResponseTransformer) { + if (asyncResponseTransformer == null) { + throw new NullPointerException("onNext must not be called with null asyncResponseTransformer"); + } + + int nextPartIndex; + synchronized (lock) { + nextPartIndex = completedParts.get(); + if (totalParts != null && nextPartIndex >= totalParts) { + log.debug(() -> String.format("Completing multipart download after a total of %d parts downloaded.", totalParts)); + subscription.cancel(); + return; + } + completedParts.incrementAndGet(); + } + makeRangeRequest(nextPartIndex, asyncResponseTransformer); + } + + private void makeRangeRequest(int partIndex, + AsyncResponseTransformer asyncResponseTransformer) { + PresignedUrlDownloadRequest partRequest = createRangedGetRequest(partIndex); + log.debug(() -> "Sending range request for part " + partIndex + " with range=" + partRequest.range()); + + requestsSent.incrementAndGet(); + s3AsyncClient.presignedUrlExtension() + .getObject(partRequest, asyncResponseTransformer) + .whenComplete((response, error) -> { + if (error != null) { + log.debug(() -> "Error encountered during part request for part " + partIndex); + handleError(error); + return; + } + requestMoreIfNeeded(response, partIndex); + }); + } + + private void requestMoreIfNeeded(GetObjectResponse response, int partIndex) { + int totalComplete = completedParts.get(); + log.debug(() -> String.format("Completed part %d", totalComplete)); + + String responseETag = response.eTag(); + String responseContentRange = response.contentRange(); + if (eTag == null) { + this.eTag = responseETag; + log.debug(() -> String.format("Multipart object ETag: %s", this.eTag)); + } + + Optional validationError = validateResponse(response, partIndex); + if (validationError.isPresent()) { + log.debug(() -> "Response validation failed", validationError.get()); + handleError(validationError.get()); + return; + } + + if (totalContentLength == null && responseContentRange != null) { + Optional parsedContentLength = MultipartDownloadUtils.parseContentRangeForTotalSize(responseContentRange); + if (!parsedContentLength.isPresent()) { + SdkClientException error = PresignedUrlDownloadHelper.invalidContentRangeHeader(responseContentRange); + log.debug(() -> "Failed to parse content range", error); + handleError(error); + return; + } + + this.totalContentLength = parsedContentLength.get(); + this.totalParts = calculateTotalParts(totalContentLength, configuredPartSizeInBytes); + log.debug(() -> String.format("Total content length: %d, Total parts: %d", totalContentLength, totalParts)); + } + + synchronized (lock) { + if (hasMoreParts(totalComplete)) { + subscription.request(1); + } else { + if (totalParts != null && requestsSent.get() != totalParts) { + handleError(new IllegalStateException( + "Request count mismatch. Expected: " + totalParts + ", sent: " + requestsSent.get())); + return; + } + log.debug(() -> String.format("Completing multipart download after a total of %d parts downloaded.", totalParts)); + subscription.cancel(); + } + } + } + + private Optional validateResponse(GetObjectResponse response, int partIndex) { + if (response == null) { + return Optional.of(SdkClientException.create("Response cannot be null")); + } + + String contentRange = response.contentRange(); + if (contentRange == null) { + return Optional.of(PresignedUrlDownloadHelper.missingContentRangeHeader()); + } + + Long contentLength = response.contentLength(); + if (contentLength == null || contentLength < 0) { + return Optional.of(PresignedUrlDownloadHelper.invalidContentLength()); + } + + long expectedStartByte = partIndex * configuredPartSizeInBytes; + long expectedEndByte; + if (totalContentLength != null) { + expectedEndByte = Math.min(expectedStartByte + configuredPartSizeInBytes - 1, totalContentLength - 1); + } else { + expectedEndByte = expectedStartByte + configuredPartSizeInBytes - 1; + } + + String expectedRange = "bytes " + expectedStartByte + "-" + expectedEndByte + "/"; + if (!contentRange.startsWith(expectedRange)) { + return Optional.of(SdkClientException.create( + "Content-Range mismatch. Expected range starting with: " + expectedRange + + ", but got: " + contentRange)); + } + + long expectedPartSize; + if (totalContentLength != null && partIndex == totalParts - 1) { + expectedPartSize = totalContentLength - (partIndex * configuredPartSizeInBytes); + } else { + expectedPartSize = configuredPartSizeInBytes; + } + + if (!contentLength.equals(expectedPartSize)) { + return Optional.of(SdkClientException.create( + "Part content length validation failed for part " + partIndex + + ". Expected: " + expectedPartSize + ", but got: " + contentLength)); + } + + long actualStartByte = MultipartDownloadUtils.parseStartByteFromContentRange(contentRange); + if (actualStartByte != expectedStartByte) { + return Optional.of(SdkClientException.create( + "Content range offset mismatch for part " + partIndex + + ". Expected start: " + expectedStartByte + ", but got: " + actualStartByte)); + } + + return Optional.empty(); + } + + private int calculateTotalParts(long contentLength, long partSize) { + return (int) Math.ceil((double) contentLength / partSize); + } + + private boolean hasMoreParts(int completedPartsCount) { + return totalParts != null && totalParts > 1 && completedPartsCount < totalParts; + } + + private PresignedUrlDownloadRequest createRangedGetRequest(int partIndex) { + long startByte = partIndex * configuredPartSizeInBytes; + long endByte; + if (totalContentLength != null) { + endByte = Math.min(startByte + configuredPartSizeInBytes - 1, totalContentLength - 1); + } else { + endByte = startByte + configuredPartSizeInBytes - 1; + } + String rangeHeader = BYTES_RANGE_PREFIX + startByte + "-" + endByte; + PresignedUrlDownloadRequest.Builder builder = presignedUrlDownloadRequest.toBuilder() + .range(rangeHeader); + if (partIndex > 0 && eTag != null) { + builder.ifMatch(eTag); + log.debug(() -> "Setting IfMatch header to: " + eTag + " for part " + partIndex); + } + return builder.build(); + } + + private void handleError(Throwable t) { + synchronized (lock) { + if (subscription != null) { + subscription.cancel(); + } + } + onError(t); + } + + @Override + public void onError(Throwable t) { + log.debug(() -> "Error in multipart download", t); + future.completeExceptionally(t); + } + + @Override + public void onComplete() { + future.complete(null); + } + + public CompletableFuture future() { + return future; + } +} \ No newline at end of file diff --git a/services/s3/src/main/java/software/amazon/awssdk/services/s3/internal/presignedurl/DefaultAsyncPresignedUrlExtension.java b/services/s3/src/main/java/software/amazon/awssdk/services/s3/internal/presignedurl/DefaultAsyncPresignedUrlExtension.java index db7515b24b5..d640482c5c3 100644 --- a/services/s3/src/main/java/software/amazon/awssdk/services/s3/internal/presignedurl/DefaultAsyncPresignedUrlExtension.java +++ b/services/s3/src/main/java/software/amazon/awssdk/services/s3/internal/presignedurl/DefaultAsyncPresignedUrlExtension.java @@ -90,6 +90,7 @@ public CompletableFuture getObject( PresignedUrlDownloadRequestWrapper internalRequest = PresignedUrlDownloadRequestWrapper.builder() .url(presignedUrlDownloadRequest.presignedUrl()) .range(presignedUrlDownloadRequest.range()) + .ifMatch(presignedUrlDownloadRequest.ifMatch()) .build(); MetricCollector apiCallMetricCollector = metricPublishers.isEmpty() ? diff --git a/services/s3/src/main/java/software/amazon/awssdk/services/s3/internal/presignedurl/model/PresignedUrlDownloadRequestWrapper.java b/services/s3/src/main/java/software/amazon/awssdk/services/s3/internal/presignedurl/model/PresignedUrlDownloadRequestWrapper.java index 983c6b9c13e..5d556952396 100644 --- a/services/s3/src/main/java/software/amazon/awssdk/services/s3/internal/presignedurl/model/PresignedUrlDownloadRequestWrapper.java +++ b/services/s3/src/main/java/software/amazon/awssdk/services/s3/internal/presignedurl/model/PresignedUrlDownloadRequestWrapper.java @@ -48,18 +48,27 @@ public final class PresignedUrlDownloadRequestWrapper extends S3Request { .traits(LocationTrait.builder().location(MarshallLocation.HEADER).locationName("Range") .unmarshallLocationName("Range").build()).build(); + private static final SdkField IF_MATCH_FIELD = SdkField + .builder(MarshallingType.STRING) + .memberName("IfMatch") + .getter(getter(PresignedUrlDownloadRequestWrapper::ifMatch)) + .traits(LocationTrait.builder().location(MarshallLocation.HEADER).locationName("If-Match") + .unmarshallLocationName("If-Match").build()).build(); + private static final List> SDK_FIELDS = Collections.unmodifiableList( - Arrays.asList(RANGE_FIELD)); + Arrays.asList(RANGE_FIELD, IF_MATCH_FIELD)); private static final Map> SDK_NAME_TO_FIELD = memberNameToFieldInitializer(); private final URL url; private final String range; + private final String ifMatch; private PresignedUrlDownloadRequestWrapper(Builder builder) { super(builder); this.url = builder.url; this.range = builder.range; + this.ifMatch = builder.ifMatch; } public URL url() { @@ -70,6 +79,10 @@ public String range() { return range; } + public String ifMatch() { + return ifMatch; + } + @Override public List> sdkFields() { return SDK_FIELDS; @@ -87,6 +100,7 @@ private static Function getter(Function> memberNameToFieldInitializer() { Map> map = new HashMap<>(); map.put("Range", RANGE_FIELD); + map.put("IfMatch", IF_MATCH_FIELD); return Collections.unmodifiableMap(map); } @@ -111,7 +125,7 @@ public boolean equals(Object obj) { return false; } PresignedUrlDownloadRequestWrapper that = (PresignedUrlDownloadRequestWrapper) obj; - return Objects.equals(url, that.url) && Objects.equals(range, that.range); + return Objects.equals(url, that.url) && Objects.equals(range, that.range) && Objects.equals(ifMatch, that.ifMatch); } @Override @@ -119,12 +133,14 @@ public int hashCode() { int result = Objects.hashCode(super.hashCode()); result = 31 * result + Objects.hashCode(url); result = 31 * result + Objects.hashCode(range); + result = 31 * result + Objects.hashCode(ifMatch); return result; } public static final class Builder extends S3Request.BuilderImpl { private URL url; private String range; + private String ifMatch; public Builder() { } @@ -133,6 +149,7 @@ public Builder() { super(request); this.url = request.url(); this.range = request.range(); + this.ifMatch = request.ifMatch(); } public Builder url(URL url) { @@ -145,6 +162,11 @@ public Builder range(String range) { return this; } + public Builder ifMatch(String ifMatch) { + this.ifMatch = ifMatch; + return this; + } + @Override public PresignedUrlDownloadRequestWrapper build() { return new PresignedUrlDownloadRequestWrapper(this); diff --git a/services/s3/src/main/java/software/amazon/awssdk/services/s3/presignedurl/model/PresignedUrlDownloadRequest.java b/services/s3/src/main/java/software/amazon/awssdk/services/s3/presignedurl/model/PresignedUrlDownloadRequest.java index 380c234b82c..6c19159251f 100644 --- a/services/s3/src/main/java/software/amazon/awssdk/services/s3/presignedurl/model/PresignedUrlDownloadRequest.java +++ b/services/s3/src/main/java/software/amazon/awssdk/services/s3/presignedurl/model/PresignedUrlDownloadRequest.java @@ -31,10 +31,12 @@ public final class PresignedUrlDownloadRequest implements ToCopyableBuilder { private final URL presignedUrl; private final String range; + private final String ifMatch; private PresignedUrlDownloadRequest(BuilderImpl builder) { this.presignedUrl = builder.presignedUrl; this.range = builder.range; + this.ifMatch = builder.ifMatch; } /** @@ -65,6 +67,18 @@ public String range() { return range; } + /** + *

+ * Return the object only if its entity tag (ETag) is the same as the one specified in this header, + * otherwise return a 412 (precondition failed) error. + *

+ * + * @return The If-Match header value, or null if not specified. + */ + public String ifMatch() { + return ifMatch; + } + @Override public Builder toBuilder() { return new BuilderImpl(this); @@ -83,6 +97,7 @@ public int hashCode() { int hashCode = 1; hashCode = 31 * hashCode + Objects.hashCode(presignedUrl()); hashCode = 31 * hashCode + Objects.hashCode(range()); + hashCode = 31 * hashCode + Objects.hashCode(ifMatch()); return hashCode; } @@ -96,7 +111,8 @@ public boolean equals(Object obj) { } PresignedUrlDownloadRequest other = (PresignedUrlDownloadRequest) obj; return Objects.equals(presignedUrl(), other.presignedUrl()) && - Objects.equals(range(), other.range()); + Objects.equals(range(), other.range()) && + Objects.equals(ifMatch(), other.ifMatch()); } @Override @@ -104,6 +120,7 @@ public String toString() { return ToString.builder("PresignedUrlDownloadRequest") .add("PresignedUrl", presignedUrl()) .add("Range", range()) + .add("IfMatch", ifMatch()) .build(); } @@ -121,11 +138,19 @@ public interface Builder extends CopyableBuilder> { + + private S3AsyncClient s3mock; + + public PresignedUrlMultipartDownloaderSubscriberTckTest() { + super(new TestEnvironment()); + this.s3mock = Mockito.mock(S3AsyncClient.class); + } + + @Override + public Subscriber> + createSubscriber(WhiteboxSubscriberProbe> probe) { + AsyncPresignedUrlExtension presignedUrlExtension = Mockito.mock(AsyncPresignedUrlExtension.class); + when(s3mock.presignedUrlExtension()).thenReturn(presignedUrlExtension); + + CompletableFuture firstPartResponse = CompletableFuture.completedFuture( + GetObjectResponse.builder() + .contentRange("bytes 0-8388607/33554432") + .contentLength(8388608L) // 8MB + .eTag("\"test-etag-12345\"") + .build() + ); + + CompletableFuture subsequentPartResponse = CompletableFuture.completedFuture( + GetObjectResponse.builder() + .contentRange("bytes 8388608-16777215/33554432") + .contentLength(8388608L) // 8MB + .eTag("\"test-etag-12345\"") + .build() + ); + + when(presignedUrlExtension.getObject(any(PresignedUrlDownloadRequest.class), any(AsyncResponseTransformer.class))) + .thenReturn(firstPartResponse) + .thenReturn(subsequentPartResponse) + .thenReturn(subsequentPartResponse) + .thenReturn(subsequentPartResponse); + + return new PresignedUrlMultipartDownloaderSubscriber( + s3mock, + createTestPresignedUrlRequest(), + 8 * 1024 * 1024L + ) { + @Override + public void onError(Throwable throwable) { + super.onError(throwable); + probe.registerOnError(throwable); + } + + @Override + public void onSubscribe(Subscription subscription) { + super.onSubscribe(subscription); + probe.registerOnSubscribe(new SubscriberPuppet() { + @Override + public void triggerRequest(long elements) { + subscription.request(elements); + } + + @Override + public void signalCancel() { + subscription.cancel(); + } + }); + } + + @Override + public void onNext(AsyncResponseTransformer item) { + super.onNext(item); + probe.registerOnNext(item); + } + + @Override + public void onComplete() { + super.onComplete(); + probe.registerOnComplete(); + } + }; + } + + @Override + public AsyncResponseTransformer createElement(int element) { + return new TestAsyncResponseTransformer(); + } + + private PresignedUrlDownloadRequest createTestPresignedUrlRequest() { + try { + return PresignedUrlDownloadRequest.builder() + .presignedUrl(java.net.URI.create("https://test-bucket.s3.amazonaws.com/test-key").toURL()) + .build(); + } catch (MalformedURLException e) { + throw new RuntimeException("Failed to create test URL", e); + } + } + + private static class TestAsyncResponseTransformer implements AsyncResponseTransformer { + private CompletableFuture future; + + @Override + public CompletableFuture prepare() { + this.future = new CompletableFuture<>(); + return this.future; + } + + @Override + public void onResponse(GetObjectResponse response) { + this.future.complete(response); + } + + @Override + public void onStream(SdkPublisher publisher) { + } + + @Override + public void exceptionOccurred(Throwable error) { + future.completeExceptionally(error); + } + } +} diff --git a/services/s3/src/test/java/software/amazon/awssdk/services/s3/internal/multipart/PresignedUrlMultipartDownloaderSubscriberWiremockTest.java b/services/s3/src/test/java/software/amazon/awssdk/services/s3/internal/multipart/PresignedUrlMultipartDownloaderSubscriberWiremockTest.java new file mode 100644 index 00000000000..87e3dc87256 --- /dev/null +++ b/services/s3/src/test/java/software/amazon/awssdk/services/s3/internal/multipart/PresignedUrlMultipartDownloaderSubscriberWiremockTest.java @@ -0,0 +1,252 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * A copy of the License is located at + * + * http://aws.amazon.com/apache2.0 + * + * or in the "license" file accompanying this file. This file is distributed + * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either + * express or implied. See the License for the specific language governing + * permissions and limitations under the License. + */ + +package software.amazon.awssdk.services.s3.internal.multipart; + +import static com.github.tomakehurst.wiremock.client.WireMock.aResponse; +import static com.github.tomakehurst.wiremock.client.WireMock.get; +import static com.github.tomakehurst.wiremock.client.WireMock.stubFor; +import static com.github.tomakehurst.wiremock.client.WireMock.urlEqualTo; +import static org.apache.commons.lang3.RandomStringUtils.randomAscii; +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; +import static org.junit.jupiter.api.Assertions.assertArrayEquals; + +import com.github.tomakehurst.wiremock.junit5.WireMockRuntimeInfo; +import com.github.tomakehurst.wiremock.junit5.WireMockTest; +import java.io.IOException; +import java.net.MalformedURLException; +import java.net.URI; +import java.net.URL; +import java.nio.charset.StandardCharsets; +import java.nio.file.Files; +import java.nio.file.Path; +import java.util.Arrays; +import java.util.UUID; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import software.amazon.awssdk.core.async.AsyncResponseTransformer; +import software.amazon.awssdk.services.s3.S3AsyncClient; +import software.amazon.awssdk.services.s3.model.S3Exception; +import software.amazon.awssdk.services.s3.multipart.MultipartConfiguration; +import software.amazon.awssdk.services.s3.presignedurl.model.PresignedUrlDownloadRequest; + +@WireMockTest +class PresignedUrlMultipartDownloaderSubscriberWiremockTest { + + private static final String PRESIGNED_URL_PATH = "/presigned-url"; + private static final byte[] TEST_DATA = randomAscii(5 * 1024 * 1024).getBytes(StandardCharsets.UTF_8); + + private S3AsyncClient s3AsyncClient; + private String presignedUrlBase; + private URL presignedUrl; + private Path tempFile; + + @BeforeEach + public void setup(WireMockRuntimeInfo wiremock) throws MalformedURLException { + MultipartConfiguration multipartConfig = MultipartConfiguration.builder() + .minimumPartSizeInBytes(16L) + .build(); + s3AsyncClient = S3AsyncClient.builder() + .endpointOverride(URI.create("http://localhost:" + wiremock.getHttpPort())) + .multipartEnabled(true) + .multipartConfiguration(multipartConfig) + .build(); + presignedUrlBase = "http://localhost:" + wiremock.getHttpPort(); + presignedUrl = createPresignedUrl(); + } + + @Test + void presignedUrlDownload_withMultipartData_shouldReceiveCompleteBody() { + stubSuccessfulPresignedUrlResponse(); + byte[] result = s3AsyncClient.presignedUrlExtension() + .getObject(PresignedUrlDownloadRequest.builder() + .presignedUrl(presignedUrl) + .build(), + AsyncResponseTransformer.toBytes()) + .join() + .asByteArray(); + assertArrayEquals(TEST_DATA, result); + } + + @Test + void presignedUrlDownload_withRangeHeader_shouldReceivePartialContent() { + stubSuccessfulRangeResponse(); + byte[] result = s3AsyncClient.presignedUrlExtension() + .getObject(PresignedUrlDownloadRequest.builder() + .presignedUrl(presignedUrl) + .range("bytes=0-10") + .build(), + AsyncResponseTransformer.toBytes()) + .join() + .asByteArray(); + byte[] expectedPartial = Arrays.copyOfRange(TEST_DATA, 0, 11); + assertArrayEquals(expectedPartial, result); + } + + @Test + void presignedUrlDownload_whenRequestFails_shouldThrowException() { + stubFailedPresignedUrlResponse(); + assertThatThrownBy(() -> s3AsyncClient.presignedUrlExtension() + .getObject(PresignedUrlDownloadRequest.builder() + .presignedUrl(presignedUrl) + .build(), + AsyncResponseTransformer.toBytes()) + .join()) + .hasRootCauseInstanceOf(S3Exception.class); + } + + @Test + void presignedUrlDownload_withFileTransformer_shouldWork() throws IOException { + stubSuccessfulPresignedUrlResponse(); + tempFile = createUniqueTempFile(); + s3AsyncClient.presignedUrlExtension() + .getObject(PresignedUrlDownloadRequest.builder() + .presignedUrl(presignedUrl) + .build(), + AsyncResponseTransformer.toFile(tempFile)) + .join(); + assertThat(tempFile.toFile()).exists(); + assertThat(tempFile.toFile().length()).isGreaterThan(0); + } + + @Test + void presignedUrlDownload_whenFirstRequestFails_shouldThrowException() { + stubInternalServerError(); + assertThatThrownBy(() -> s3AsyncClient.presignedUrlExtension() + .getObject(PresignedUrlDownloadRequest.builder() + .presignedUrl(presignedUrl) + .build(), + AsyncResponseTransformer.toBytes()) + .join()) + .hasRootCauseInstanceOf(S3Exception.class); + } + + @Test + void presignedUrlDownload_whenSecondRequestFails_shouldThrowException() { + stubPartialFailureScenario(); + assertThatThrownBy(() -> s3AsyncClient.presignedUrlExtension() + .getObject(PresignedUrlDownloadRequest.builder() + .presignedUrl(presignedUrl) + .build(), + AsyncResponseTransformer.toBytes()) + .join()) + .hasRootCauseInstanceOf(S3Exception.class); + } + + + @Test + void presignedUrlDownload_withNullTransformer_shouldThrowException() { + PresignedUrlMultipartDownloaderSubscriber subscriber = + new PresignedUrlMultipartDownloaderSubscriber( + s3AsyncClient, + PresignedUrlDownloadRequest.builder().presignedUrl(presignedUrl).build(), + 1024); + + assertThatThrownBy(() -> subscriber.onNext(null)) + .isInstanceOf(NullPointerException.class) + .hasMessageContaining("onNext must not be called with null asyncResponseTransformer"); + } + + @AfterEach + void cleanup() { + if (tempFile != null && Files.exists(tempFile)) { + try { + Files.delete(tempFile); + } catch (IOException e) { + } + } + } + + private static Path createUniqueTempFile() throws IOException { + String uniqueName = "test-" + UUID.randomUUID().toString(); + Path tempFile = Files.createTempFile(uniqueName, ".tmp"); + Files.deleteIfExists(tempFile); + return tempFile; + } + + private void stubSuccessfulPresignedUrlResponse() { + stubFor(get(urlEqualTo(PRESIGNED_URL_PATH)) + .willReturn(aResponse() + .withStatus(200) + .withHeader("Content-Type", "application/octet-stream") + .withHeader("Content-Length", String.valueOf(TEST_DATA.length)) + .withHeader("ETag", "\"test-etag\"") + .withBody(TEST_DATA))); + } + + private void stubSuccessfulRangeResponse() { + byte[] partialData = Arrays.copyOfRange(TEST_DATA, 0, 11); + stubFor(get(urlEqualTo(PRESIGNED_URL_PATH)) + .willReturn(aResponse() + .withStatus(206) + .withHeader("Content-Type", "application/octet-stream") + .withHeader("Content-Length", String.valueOf(partialData.length)) + .withHeader("Content-Range", "bytes 0-10/" + TEST_DATA.length) + .withHeader("ETag", "\"test-etag\"") + .withBody(partialData))); + } + + private void stubFailedPresignedUrlResponse() { + stubFor(get(urlEqualTo(PRESIGNED_URL_PATH)) + .willReturn(aResponse() + .withStatus(404) + .withBody("NoSuchKeyThe specified key does not exist."))); + } + + private void stubInternalServerError() { + stubFor(get(urlEqualTo(PRESIGNED_URL_PATH)) + .willReturn(aResponse() + .withStatus(500) + .withBody("InternalErrorInternal Server Error"))); + } + + private void stubPartialFailureScenario() { + stubFor(get(urlEqualTo(PRESIGNED_URL_PATH)) + .inScenario("partial-failure") + .whenScenarioStateIs("Started") + .willReturn(aResponse() + .withStatus(206) + .withHeader("Content-Type", "application/octet-stream") + .withHeader("Content-Length", "16") + .withHeader("Content-Range", "bytes 0-15/" + TEST_DATA.length) + .withHeader("ETag", "\"test-etag\"") + .withBody(Arrays.copyOfRange(TEST_DATA, 0, 16))) + .willSetStateTo("first-success")); + + stubFor(get(urlEqualTo(PRESIGNED_URL_PATH)) + .inScenario("partial-failure") + .whenScenarioStateIs("first-success") + .willReturn(aResponse() + .withStatus(500) + .withBody("InternalErrorSecond request failed"))); + } + + private void stubSinglePartResponse() { + stubFor(get(urlEqualTo(PRESIGNED_URL_PATH)) + .willReturn(aResponse() + .withStatus(206) + .withHeader("Content-Type", "application/octet-stream") + .withHeader("Content-Length", String.valueOf(TEST_DATA.length)) + .withHeader("Content-Range", "bytes 0-" + (TEST_DATA.length - 1) + "/" + TEST_DATA.length) + .withHeader("ETag", "\"test-etag\"") + .withBody(TEST_DATA))); + } + + private URL createPresignedUrl() throws MalformedURLException { + return new URL(presignedUrlBase + PRESIGNED_URL_PATH); + } +} \ No newline at end of file