|
23 | 23 | import org.reactivestreams.Subscription;
|
24 | 24 | import software.amazon.awssdk.annotations.SdkInternalApi;
|
25 | 25 | import software.amazon.awssdk.core.async.AsyncResponseTransformer;
|
| 26 | +import software.amazon.awssdk.core.exception.SdkClientException; |
26 | 27 | import software.amazon.awssdk.services.s3.S3AsyncClient;
|
27 | 28 | import software.amazon.awssdk.services.s3.model.GetObjectRequest;
|
28 | 29 | import software.amazon.awssdk.services.s3.model.GetObjectResponse;
|
@@ -76,6 +77,16 @@ public class MultipartDownloaderSubscriber implements Subscriber<AsyncResponseTr
|
76 | 77 | */
|
77 | 78 | private volatile String eTag;
|
78 | 79 |
|
| 80 | + /** |
| 81 | + * The size of each part of the object being downloaded. |
| 82 | + */ |
| 83 | + private volatile Long partSize; |
| 84 | + |
| 85 | + /** |
| 86 | + * The total size of the object being downloaded. |
| 87 | + */ |
| 88 | + private volatile Long totalContentLength; |
| 89 | + |
79 | 90 | /**
|
80 | 91 | * The Subscription lock
|
81 | 92 | */
|
@@ -117,6 +128,7 @@ public void onNext(AsyncResponseTransformer<GetObjectResponse, GetObjectResponse
|
117 | 128 |
|
118 | 129 | synchronized (lock) {
|
119 | 130 | if (totalParts != null && nextPartToGet > totalParts) {
|
| 131 | + validatePartsCount(completedParts.get()); |
120 | 132 | log.debug(() -> String.format("Completing multipart download after a total of %d parts downloaded.", totalParts));
|
121 | 133 | subscription.cancel();
|
122 | 134 | return;
|
@@ -162,10 +174,20 @@ private void requestMoreIfNeeded(GetObjectResponse response) {
|
162 | 174 | totalParts = partCount;
|
163 | 175 | }
|
164 | 176 |
|
| 177 | + String actualContentRange = response.contentRange(); |
| 178 | + if (actualContentRange != null && partSize == null) { |
| 179 | + getRangeInfo(actualContentRange); |
| 180 | + log.debug(() -> String.format("Part size of the object to download: " + partSize)); |
| 181 | + log.debug(() -> String.format("Total Content Length of the object to download: " + totalContentLength)); |
| 182 | + } |
| 183 | + |
| 184 | + validateContentRange(totalComplete, actualContentRange); |
| 185 | + |
165 | 186 | synchronized (lock) {
|
166 | 187 | if (totalParts != null && totalParts > 1 && totalComplete < totalParts) {
|
167 | 188 | subscription.request(1);
|
168 | 189 | } else {
|
| 190 | + validatePartsCount(completedParts.get()); |
169 | 191 | log.debug(() -> String.format("Completing multipart download after a total of %d parts downloaded.", totalParts));
|
170 | 192 | subscription.cancel();
|
171 | 193 | }
|
@@ -198,4 +220,45 @@ private GetObjectRequest nextRequest(int nextPartToGet) {
|
198 | 220 | }
|
199 | 221 | });
|
200 | 222 | }
|
| 223 | + |
| 224 | + private void validatePartsCount(int currentGetCount) { |
| 225 | + if (totalParts != null && currentGetCount != totalParts) { |
| 226 | + String errorMessage = "PartsCount validation failed. Expected " + totalParts + ", downloaded" |
| 227 | + + " " + currentGetCount + " parts."; |
| 228 | + log.error(() -> errorMessage); |
| 229 | + subscription.cancel(); |
| 230 | + SdkClientException exception = SdkClientException.create(errorMessage); |
| 231 | + onError(exception); |
| 232 | + } |
| 233 | + } |
| 234 | + |
| 235 | + private void validateContentRange(int partNumber, String contentRange) { |
| 236 | + if (contentRange == null) { |
| 237 | + return; |
| 238 | + } |
| 239 | + |
| 240 | + long expectedStart = (partNumber - 1) * partSize; |
| 241 | + long expectedEnd = partNumber == totalParts ? totalContentLength - 1 : expectedStart + partSize - 1; |
| 242 | + |
| 243 | + String expectedContentRange = String.format("bytes %d-%d/%d", expectedStart, expectedEnd, totalContentLength); |
| 244 | + |
| 245 | + if (!expectedContentRange.equals(contentRange)) { |
| 246 | + String errorMessage = String.format( |
| 247 | + "Content-Range validation failed for part %d. Expected: %s, Actual: %s", |
| 248 | + partNumber, expectedContentRange, contentRange); |
| 249 | + log.error(() -> errorMessage); |
| 250 | + onError(SdkClientException.create(errorMessage)); |
| 251 | + } |
| 252 | + } |
| 253 | + |
| 254 | + private void getRangeInfo(String contentRange) { |
| 255 | + String rangeInfo = contentRange.substring(6); |
| 256 | + String[] parts = rangeInfo.split("/"); |
| 257 | + |
| 258 | + this.totalContentLength = Long.parseLong(parts[1]); |
| 259 | + String[] rangeParts = parts[0].split("-"); |
| 260 | + long startByte = Long.parseLong(rangeParts[0]); |
| 261 | + long endByte = Long.parseLong(rangeParts[1]); |
| 262 | + this.partSize = endByte - startByte + 1; |
| 263 | + } |
201 | 264 | }
|
0 commit comments