|
29 | 29 | import static org.junit.jupiter.api.Assertions.assertNotEquals;
|
30 | 30 | import static org.junit.jupiter.api.Assertions.assertNotNull;
|
31 | 31 | import static org.junit.jupiter.api.Assertions.assertThrows;
|
| 32 | +import static org.junit.jupiter.api.Assertions.assertTrue; |
32 | 33 |
|
33 | 34 | import com.github.tomakehurst.wiremock.http.Fault;
|
34 | 35 | import com.github.tomakehurst.wiremock.junit5.WireMockRuntimeInfo;
|
35 | 36 | import com.github.tomakehurst.wiremock.junit5.WireMockTest;
|
36 | 37 | import com.github.tomakehurst.wiremock.stubbing.Scenario;
|
37 | 38 | import java.net.URI;
|
| 39 | +import java.nio.ByteBuffer; |
38 | 40 | import java.nio.charset.StandardCharsets;
|
39 | 41 | import java.time.Duration;
|
40 | 42 | import java.util.ArrayList;
|
|
43 | 45 | import java.util.UUID;
|
44 | 46 | import java.util.concurrent.CompletableFuture;
|
45 | 47 | import java.util.concurrent.CompletionException;
|
| 48 | +import java.util.concurrent.atomic.AtomicBoolean; |
46 | 49 | import org.junit.jupiter.api.BeforeEach;
|
47 | 50 | import org.junit.jupiter.api.Test;
|
| 51 | +import org.reactivestreams.Subscriber; |
| 52 | +import org.reactivestreams.Subscription; |
48 | 53 | import software.amazon.awssdk.auth.credentials.AwsBasicCredentials;
|
49 | 54 | import software.amazon.awssdk.auth.credentials.StaticCredentialsProvider;
|
50 | 55 | import software.amazon.awssdk.awscore.retry.AwsRetryStrategy;
|
51 | 56 | import software.amazon.awssdk.core.ResponseBytes;
|
52 | 57 | import software.amazon.awssdk.core.async.AsyncResponseTransformer;
|
| 58 | +import software.amazon.awssdk.core.async.SdkPublisher; |
53 | 59 | import software.amazon.awssdk.core.interceptor.Context;
|
54 | 60 | import software.amazon.awssdk.core.interceptor.ExecutionAttributes;
|
55 | 61 | import software.amazon.awssdk.core.interceptor.ExecutionInterceptor;
|
@@ -163,9 +169,8 @@ public void getObject_503Response_shouldNotReuseInitialRequestId() {
|
163 | 169 | .withHeader("x-amz-request-id", secondRequestId)
|
164 | 170 | .withStatus(503)));
|
165 | 171 |
|
166 |
| - assertThrows(CompletionException.class, () -> { |
167 |
| - multipartClient.getObject(b -> b.bucket(BUCKET).key(KEY), AsyncResponseTransformer.toBytes()).join(); |
168 |
| - }); |
| 172 | + assertThrows(CompletionException.class, () -> |
| 173 | + multipartClient.getObject(b -> b.bucket(BUCKET).key(KEY), AsyncResponseTransformer.toBytes()).join()); |
169 | 174 |
|
170 | 175 | List<SdkHttpResponse> responses = capturingInterceptor.getResponses();
|
171 | 176 | assertEquals(MAX_ATTEMPTS, responses.size(), () -> String.format("Expected exactly %s responses", MAX_ATTEMPTS));
|
@@ -343,6 +348,87 @@ public void getObject_iOError_shouldRetrySuccessfully() {
|
343 | 348 | assertEquals(requestId, finalRequestId);
|
344 | 349 | }
|
345 | 350 |
|
| 351 | + @Test |
| 352 | + public void multipartDownload_errorDuringFirstPartAfterOnStream_shouldFailAndNotRetry() { |
| 353 | + stubFor(get(urlEqualTo(String.format("/%s/%s?partNumber=1", BUCKET, KEY))) |
| 354 | + .willReturn(aResponse() |
| 355 | + .withHeader("x-amz-mp-parts-count", String.valueOf(2)) |
| 356 | + .withStatus(200) |
| 357 | + .withBody("Hello "))); |
| 358 | + |
| 359 | + stubFor(get(urlEqualTo(String.format("/%s/%s?partNumber=2", BUCKET, KEY))) |
| 360 | + .willReturn(aResponse() |
| 361 | + .withStatus(200) |
| 362 | + .withHeader("x-amz-mp-parts-count", "2") |
| 363 | + .withBody("World"))); |
| 364 | + |
| 365 | + StreamingErrorTransformer failingTransformer = new StreamingErrorTransformer(); |
| 366 | + assertThrows(CompletionException.class, () -> |
| 367 | + multipartClient.getObject(b -> b.bucket(BUCKET).key(KEY), failingTransformer).join()); |
| 368 | + |
| 369 | + assertTrue(failingTransformer.onStreamCalled.get()); |
| 370 | + // Verify that the first part was requested only once and not retried |
| 371 | + verify(1, getRequestedFor(urlEqualTo(String.format("/%s/%s?partNumber=1", BUCKET, KEY)))); |
| 372 | + } |
| 373 | + |
| 374 | + /** |
| 375 | + * Custom AsyncResponseTransformer that simulates an error occurring after onStream() has been called |
| 376 | + */ |
| 377 | + private static final class StreamingErrorTransformer |
| 378 | + implements AsyncResponseTransformer<GetObjectResponse, ResponseBytes<GetObjectResponse>> { |
| 379 | + |
| 380 | + private final CompletableFuture<ResponseBytes<GetObjectResponse>> future = new CompletableFuture<>(); |
| 381 | + private final AtomicBoolean errorThrown = new AtomicBoolean(); |
| 382 | + private final AtomicBoolean onStreamCalled = new AtomicBoolean(); |
| 383 | + |
| 384 | + @Override |
| 385 | + public CompletableFuture<ResponseBytes<GetObjectResponse>> prepare() { |
| 386 | + return future; |
| 387 | + } |
| 388 | + |
| 389 | + @Override |
| 390 | + public void onResponse(GetObjectResponse response) { |
| 391 | + // |
| 392 | + } |
| 393 | + |
| 394 | + @Override |
| 395 | + public void onStream(SdkPublisher<ByteBuffer> publisher) { |
| 396 | + onStreamCalled.set(true); |
| 397 | + publisher.subscribe(new Subscriber<ByteBuffer>() { |
| 398 | + private Subscription subscription; |
| 399 | + |
| 400 | + @Override |
| 401 | + public void onSubscribe(Subscription s) { |
| 402 | + this.subscription = s; |
| 403 | + s.request(1); |
| 404 | + } |
| 405 | + |
| 406 | + @Override |
| 407 | + public void onNext(ByteBuffer byteBuffer) { |
| 408 | + if (errorThrown.compareAndSet(false, true)) { |
| 409 | + future.completeExceptionally(new RuntimeException()); |
| 410 | + subscription.cancel(); |
| 411 | + } |
| 412 | + } |
| 413 | + |
| 414 | + @Override |
| 415 | + public void onError(Throwable t) { |
| 416 | + future.completeExceptionally(t); |
| 417 | + } |
| 418 | + |
| 419 | + @Override |
| 420 | + public void onComplete() { |
| 421 | + // |
| 422 | + } |
| 423 | + }); |
| 424 | + } |
| 425 | + |
| 426 | + @Override |
| 427 | + public void exceptionOccurred(Throwable throwable) { |
| 428 | + future.completeExceptionally(throwable); |
| 429 | + } |
| 430 | + } |
| 431 | + |
346 | 432 | private CompletableFuture<ResponseBytes<GetObjectResponse>> mock200Response(S3AsyncClient s3Client, int runNumber) {
|
347 | 433 | String runId = runNumber + " success";
|
348 | 434 |
|
|
0 commit comments