Skip to content

Commit 6afcafa

Browse files
Refactor code and add tests
1 parent 9ec0672 commit 6afcafa

File tree

3 files changed

+203
-168
lines changed

3 files changed

+203
-168
lines changed

services/s3/src/main/java/software/amazon/awssdk/services/s3/internal/multipart/PresignedUrlMultipartDownloaderSubscriber.java

Lines changed: 92 additions & 149 deletions
Original file line numberDiff line numberDiff line change
@@ -54,27 +54,14 @@ public class PresignedUrlMultipartDownloaderSubscriber
5454
private final S3AsyncClient s3AsyncClient;
5555
private final PresignedUrlDownloadRequest presignedUrlDownloadRequest;
5656
private final long configuredPartSizeInBytes;
57-
private final int completedParts;
5857
private final CompletableFuture<Void> future;
5958
private final Object lock = new Object();
60-
private volatile MultipartDownloadState state;
61-
private Subscription subscription;
59+
private final AtomicInteger completedParts;
6260

63-
private static class MultipartDownloadState {
64-
final long totalContentLength;
65-
final long actualPartSizeInBytes;
66-
final int totalParts;
67-
final AtomicInteger completedParts;
68-
final String etag;
69-
70-
MultipartDownloadState(long totalLength, long partSize, int totalParts, String etag, int completedParts) {
71-
this.totalContentLength = totalLength;
72-
this.actualPartSizeInBytes = partSize;
73-
this.totalParts = totalParts;
74-
this.completedParts = new AtomicInteger(completedParts);
75-
this.etag = etag;
76-
}
77-
}
61+
private volatile Long totalContentLength;
62+
private volatile Integer totalParts;
63+
private volatile String eTag;
64+
private volatile Subscription subscription;
7865

7966
public PresignedUrlMultipartDownloaderSubscriber(
8067
S3AsyncClient s3AsyncClient,
@@ -83,7 +70,7 @@ public PresignedUrlMultipartDownloaderSubscriber(
8370
this.s3AsyncClient = s3AsyncClient;
8471
this.presignedUrlDownloadRequest = presignedUrlDownloadRequest;
8572
this.configuredPartSizeInBytes = configuredPartSizeInBytes;
86-
this.completedParts = 0;
73+
this.completedParts = new AtomicInteger(0);
8774
this.future = new CompletableFuture<>();
8875
}
8976

@@ -102,135 +89,87 @@ public void onSubscribe(Subscription s) {
10289
@Override
10390
public void onNext(AsyncResponseTransformer<GetObjectResponse, GetObjectResponse> asyncResponseTransformer) {
10491
if (asyncResponseTransformer == null) {
105-
subscription.cancel();
10692
throw new NullPointerException("onNext must not be called with null asyncResponseTransformer");
10793
}
94+
95+
int nextPartIndex;
10896
synchronized (lock) {
109-
if (state == null) {
110-
performSizeDiscoveryAndFirstPart(asyncResponseTransformer);
111-
} else {
112-
downloadNextPart(asyncResponseTransformer);
97+
nextPartIndex = completedParts.get();
98+
if (totalParts != null && nextPartIndex >= totalParts) {
99+
log.debug(() -> String.format("Completing multipart download after a total of %d parts downloaded.", totalParts));
100+
subscription.cancel();
101+
return;
113102
}
103+
completedParts.incrementAndGet();
114104
}
105+
106+
makeRangeRequest(nextPartIndex, asyncResponseTransformer);
115107
}
116108

117-
private void performSizeDiscoveryAndFirstPart(AsyncResponseTransformer<GetObjectResponse,
118-
GetObjectResponse> asyncResponseTransformer) {
119-
if (completedParts > 0) {
120-
performSizeDiscoveryOnly(asyncResponseTransformer);
121-
return;
122-
}
123-
long endByte = configuredPartSizeInBytes - 1;
124-
String firstPartRange = String.format("%s0-%d", BYTES_RANGE_PREFIX, endByte);
125-
PresignedUrlDownloadRequest firstPartRequest = presignedUrlDownloadRequest.toBuilder()
126-
.range(firstPartRange)
127-
.build();
128-
s3AsyncClient.presignedUrlExtension().getObject(firstPartRequest, asyncResponseTransformer)
129-
.whenComplete((response, error) -> {
130-
if (error != null) {
131-
log.debug(() -> "Error encountered during first part request");
132-
onError(error);
133-
return;
134-
}
135-
try {
136-
String contentRange = response.contentRange();
137-
if (contentRange == null) {
138-
onError(new IllegalStateException("No Content-Range header in response"));
139-
return;
140-
}
141-
long totalSize = parseContentRangeForTotalSize(contentRange);
142-
if (totalSize <= configuredPartSizeInBytes) {
143-
subscription.cancel();
144-
return;
145-
}
146-
String etag = response.eTag();
147-
initializeStateAfterFirstPart(totalSize, etag);
148-
if (state.totalParts > 1) {
149-
subscription.request(1);
150-
} else {
151-
subscription.cancel();
152-
}
153-
} catch (Exception e) {
154-
log.debug(() -> "Error during first part processing", e);
155-
onError(e);
156-
}
157-
});
158-
}
159-
160-
private void performSizeDiscoveryOnly(
161-
AsyncResponseTransformer<GetObjectResponse, GetObjectResponse> asyncResponseTransformer) {
162-
String sizeDiscoveryRange = String.format("%s0-0", BYTES_RANGE_PREFIX);
163-
PresignedUrlDownloadRequest sizeDiscoveryRequest = presignedUrlDownloadRequest.toBuilder()
164-
.range(sizeDiscoveryRange)
165-
.build();
166-
167-
s3AsyncClient.presignedUrlExtension().getObject(sizeDiscoveryRequest, asyncResponseTransformer)
168-
.whenComplete((response, error) -> {
169-
if (error != null) {
170-
log.debug(() -> "Error encountered during size discovery request");
171-
onError(error);
172-
return;
173-
}
174-
try {
175-
String contentRange = response.contentRange();
176-
if (contentRange == null) {
177-
onError(new IllegalStateException("No Content-Range header in response"));
178-
return;
179-
}
180-
long totalSize = parseContentRangeForTotalSize(contentRange);
181-
String etag = response.eTag();
182-
if (etag == null) {
183-
onError(new IllegalStateException("No ETag in response, cannot ensure consistency"));
184-
return;
185-
}
186-
int totalParts = calculateTotalParts(totalSize, configuredPartSizeInBytes);
187-
this.state = new MultipartDownloadState(totalSize, configuredPartSizeInBytes,
188-
totalParts, etag, completedParts);
189-
if (completedParts < state.totalParts) {
190-
subscription.request(1);
191-
} else {
192-
subscription.cancel();
193-
}
194-
} catch (Exception e) {
195-
log.debug(() -> "Error during size discovery processing", e);
196-
onError(e);
197-
}
198-
});
109+
private void makeRangeRequest(int partIndex,
110+
AsyncResponseTransformer<GetObjectResponse,
111+
GetObjectResponse> asyncResponseTransformer) {
112+
PresignedUrlDownloadRequest partRequest = createPartRequest(partIndex);
113+
log.debug(() -> "Sending range request for part " + partIndex + " with range=" + partRequest.range());
114+
115+
s3AsyncClient.presignedUrlExtension()
116+
.getObject(partRequest, asyncResponseTransformer)
117+
.whenComplete((response, error) -> {
118+
if (error != null) {
119+
log.debug(() -> "Error encountered during part request for part " + partIndex);
120+
handleError(error);
121+
return;
122+
}
123+
requestMoreIfNeeded(response);
124+
});
199125
}
200126

201-
private void downloadNextPart(AsyncResponseTransformer<GetObjectResponse, GetObjectResponse> transformer) {
202-
int nextPartIndex = state.completedParts.getAndIncrement();
203-
if (nextPartIndex >= state.totalParts) {
204-
subscription.cancel();
205-
return;
127+
private void requestMoreIfNeeded(GetObjectResponse response) {
128+
int totalComplete = completedParts.get();
129+
log.debug(() -> String.format("Completed part %d", totalComplete));
130+
131+
synchronized (lock) {
132+
if (eTag == null) {
133+
this.eTag = response.eTag();
134+
log.debug(() -> String.format("Multipart object ETag: %s", this.eTag));
135+
} else if (response.eTag() != null && !eTag.equals(response.eTag())) {
136+
handleError(new IllegalStateException("ETag mismatch - object may have changed during download"));
137+
return;
138+
}
139+
if (totalContentLength == null && response.contentRange() != null) {
140+
try {
141+
validateResponse(response);
142+
long totalSize = parseContentRangeForTotalSize(response.contentRange());
143+
int calculatedTotalParts = calculateTotalParts(totalSize, configuredPartSizeInBytes);
144+
this.totalContentLength = totalSize;
145+
this.totalParts = calculatedTotalParts;
146+
log.debug(() -> String.format("Total content length: %d, Total parts: %d", totalSize, calculatedTotalParts));
147+
} catch (Exception e) {
148+
log.debug(() -> "Failed to parse content range", e);
149+
handleError(e);
150+
return;
151+
}
152+
}
153+
if (totalParts != null && totalParts > 1 && totalComplete < totalParts) {
154+
subscription.request(1);
155+
} else {
156+
log.debug(() -> String.format("Completing multipart download after a total of %d parts downloaded.", totalParts));
157+
subscription.cancel();
158+
}
206159
}
207-
PresignedUrlDownloadRequest partRequest = createPartRequest(nextPartIndex);
208-
String expectedRange = partRequest.range();
209-
s3AsyncClient.presignedUrlExtension().getObject(partRequest, transformer)
210-
.whenComplete((response, error) -> {
211-
if (error != null) {
212-
log.debug(() -> "Error encountered during part request with range=" + expectedRange);
213-
onError(error);
214-
} else {
215-
try {
216-
validatePartResponse(response, nextPartIndex, expectedRange);
217-
int completedCount = nextPartIndex + 1;
218-
if (completedCount < state.totalParts) {
219-
subscription.request(1);
220-
} else {
221-
subscription.cancel();
222-
}
223-
} catch (Exception validationError) {
224-
log.debug(() -> "Validation failed for part " + (nextPartIndex + 1));
225-
onError(validationError);
226-
}
227-
}
228-
});
229160
}
230161

231-
private void initializeStateAfterFirstPart(long totalSize, String etag) {
232-
int totalParts = calculateTotalParts(totalSize, configuredPartSizeInBytes);
233-
this.state = new MultipartDownloadState(totalSize, configuredPartSizeInBytes, totalParts, etag, completedParts + 1);
162+
private void validateResponse(GetObjectResponse response) {
163+
if (response == null) {
164+
throw new IllegalStateException("Response cannot be null");
165+
}
166+
if (response.contentRange() == null) {
167+
throw new IllegalStateException("No Content-Range header in response");
168+
}
169+
Long contentLength = response.contentLength();
170+
if (contentLength == null || contentLength <= 0) {
171+
throw new IllegalStateException("Invalid or missing Content-Length in response");
172+
}
234173
}
235174

236175
private long parseContentRangeForTotalSize(String contentRange) {
@@ -246,15 +185,29 @@ private int calculateTotalParts(long contentLength, long partSize) {
246185
}
247186

248187
private PresignedUrlDownloadRequest createPartRequest(int partIndex) {
249-
long startByte = partIndex * state.actualPartSizeInBytes;
250-
long endByte = Math.min(startByte + state.actualPartSizeInBytes - 1, state.totalContentLength - 1);
188+
long startByte = partIndex * configuredPartSizeInBytes;
189+
long endByte;
190+
191+
if (totalContentLength != null) {
192+
endByte = Math.min(startByte + configuredPartSizeInBytes - 1, totalContentLength - 1);
193+
} else {
194+
endByte = startByte + configuredPartSizeInBytes - 1;
195+
}
251196
String rangeHeader = BYTES_RANGE_PREFIX + startByte + "-" + endByte;
252-
253197
return presignedUrlDownloadRequest.toBuilder()
254198
.range(rangeHeader)
255199
.build();
256200
}
257201

202+
private void handleError(Throwable t) {
203+
synchronized (lock) {
204+
if (subscription != null) {
205+
subscription.cancel();
206+
}
207+
}
208+
onError(t);
209+
}
210+
258211
@Override
259212
public void onError(Throwable t) {
260213
log.debug(() -> "Error in multipart download", t);
@@ -269,14 +222,4 @@ public void onComplete() {
269222
public CompletableFuture<Void> future() {
270223
return this.future;
271224
}
272-
273-
private void validatePartResponse(GetObjectResponse response, int partIndex, String expectedRange) {
274-
if (response == null) {
275-
throw new IllegalArgumentException("Response cannot be null");
276-
}
277-
String responseETag = response.eTag();
278-
if (responseETag != null && state.etag != null && !state.etag.equals(responseETag)) {
279-
throw new IllegalStateException("ETag mismatch - object may have changed during download");
280-
}
281-
}
282225
}

services/s3/src/test/java/software/amazon/awssdk/services/s3/internal/multipart/PresignedUrlMultipartDownloadTestUtil.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,6 @@
3434
public class PresignedUrlMultipartDownloadTestUtil {
3535

3636
private static final String PRESIGNED_URL_PATH = "/presigned-url";
37-
private static final String DIFFERENT_ETAG = "different-etag-12345";
3837

3938
private final String presignedUrl;
4039
private final String eTag;
@@ -148,4 +147,6 @@ public void verifyNoRequestMadeForRange(long startByte, long endByte) {
148147
verify(0, getRequestedFor(urlEqualTo(PRESIGNED_URL_PATH))
149148
.withHeader("Range", new EqualToPattern(rangeHeader)));
150149
}
150+
151+
// Additional utility methods for error condition testing can be added here as needed
151152
}

0 commit comments

Comments
 (0)