Skip to content

Commit 357e486

Browse files
committed
S3 multipart client cancel GetObject futures if onError invoked
1 parent b532117 commit 357e486

File tree

3 files changed

+167
-0
lines changed

3 files changed

+167
-0
lines changed

core/sdk-core/src/main/java/software/amazon/awssdk/core/internal/async/SplittingTransformer.java

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
import software.amazon.awssdk.core.SplittingTransformerConfiguration;
2626
import software.amazon.awssdk.core.async.AsyncResponseTransformer;
2727
import software.amazon.awssdk.core.async.SdkPublisher;
28+
import software.amazon.awssdk.core.exception.NonRetryableException;
2829
import software.amazon.awssdk.utils.CompletableFutureUtils;
2930
import software.amazon.awssdk.utils.Logger;
3031
import software.amazon.awssdk.utils.Validate;
@@ -279,6 +280,11 @@ public CompletableFuture<ResponseT> prepare() {
279280
if (e == null) {
280281
return;
281282
}
283+
284+
// This isn't necessary, might be good for debugging? Or can just log error
285+
/*e.addSuppressed(NonRetryableException.create(
286+
"Error occurred during multipart download. Request will not be retried."));*/
287+
282288
individualFuture.completeExceptionally(e);
283289
});
284290
individualFuture.whenComplete((r, e) -> {

services/s3/src/main/java/software/amazon/awssdk/services/s3/internal/multipart/MultipartDownloaderSubscriber.java

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,9 @@
1515

1616
package software.amazon.awssdk.services.s3.internal.multipart;
1717

18+
import java.util.Queue;
1819
import java.util.concurrent.CompletableFuture;
20+
import java.util.concurrent.ConcurrentLinkedQueue;
1921
import java.util.concurrent.atomic.AtomicInteger;
2022
import org.reactivestreams.Subscriber;
2123
import org.reactivestreams.Subscription;
@@ -79,6 +81,11 @@ public class MultipartDownloaderSubscriber implements Subscriber<AsyncResponseTr
7981
*/
8082
private final Object lock = new Object();
8183

84+
/**
85+
* Store the GetObject futures so we can cancel them if onError() is invoked.
86+
*/
87+
private final Queue<CompletableFuture<GetObjectResponse>> getObjectFutures = new ConcurrentLinkedQueue<>();
88+
8289
public MultipartDownloaderSubscriber(S3AsyncClient s3, GetObjectRequest getObjectRequest) {
8390
this(s3, getObjectRequest, 0);
8491
}
@@ -119,6 +126,7 @@ public void onNext(AsyncResponseTransformer<GetObjectResponse, GetObjectResponse
119126
GetObjectRequest actualRequest = nextRequest(nextPartToGet);
120127
log.debug(() -> "Sending GetObjectRequest for next part with partNumber=" + nextPartToGet);
121128
CompletableFuture<GetObjectResponse> getObjectFuture = s3.getObject(actualRequest, asyncResponseTransformer);
129+
getObjectFutures.add(getObjectFuture);
122130
getObjectFuture.whenComplete((response, error) -> {
123131
if (error != null) {
124132
log.debug(() -> "Error encountered during GetObjectRequest with partNumber=" + nextPartToGet);
@@ -166,6 +174,10 @@ private void requestMoreIfNeeded(GetObjectResponse response) {
166174

167175
@Override
168176
public void onError(Throwable t) {
177+
CompletableFuture<GetObjectResponse> partFuture;
178+
while ((partFuture = getObjectFutures.poll()) != null) {
179+
partFuture.cancel(true);
180+
}
169181
future.completeExceptionally(t);
170182
}
171183

Original file line numberDiff line numberDiff line change
@@ -0,0 +1,149 @@
1+
/*
2+
* Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License").
5+
* You may not use this file except in compliance with the License.
6+
* A copy of the License is located at
7+
*
8+
* http://aws.amazon.com/apache2.0
9+
*
10+
* or in the "license" file accompanying this file. This file is distributed
11+
* on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
12+
* express or implied. See the License for the specific language governing
13+
* permissions and limitations under the License.
14+
*/
15+
16+
package software.amazon.awssdk.services.s3.internal.multipart;
17+
18+
import static com.github.tomakehurst.wiremock.client.WireMock.aResponse;
19+
import static com.github.tomakehurst.wiremock.client.WireMock.any;
20+
import static com.github.tomakehurst.wiremock.client.WireMock.anyUrl;
21+
import static com.github.tomakehurst.wiremock.client.WireMock.get;
22+
import static com.github.tomakehurst.wiremock.client.WireMock.getRequestedFor;
23+
import static com.github.tomakehurst.wiremock.client.WireMock.matching;
24+
import static com.github.tomakehurst.wiremock.client.WireMock.stubFor;
25+
import static com.github.tomakehurst.wiremock.client.WireMock.urlEqualTo;
26+
import static com.github.tomakehurst.wiremock.client.WireMock.verify;
27+
import static org.assertj.core.api.Assertions.assertThat;
28+
import static org.assertj.core.api.Assertions.fail;
29+
30+
import com.github.tomakehurst.wiremock.junit5.WireMockRuntimeInfo;
31+
import com.github.tomakehurst.wiremock.junit5.WireMockTest;
32+
import com.github.tomakehurst.wiremock.stubbing.Scenario;
33+
import java.net.URI;
34+
import java.time.Duration;
35+
import java.util.ArrayList;
36+
import java.util.List;
37+
import java.util.UUID;
38+
import java.util.concurrent.CompletableFuture;
39+
import java.util.concurrent.CompletionException;
40+
import java.util.concurrent.TimeUnit;
41+
import org.junit.jupiter.api.BeforeEach;
42+
import org.junit.jupiter.api.Test;
43+
import org.junit.jupiter.api.Timeout;
44+
import software.amazon.awssdk.auth.credentials.AwsBasicCredentials;
45+
import software.amazon.awssdk.auth.credentials.StaticCredentialsProvider;
46+
import software.amazon.awssdk.core.ResponseBytes;
47+
import software.amazon.awssdk.core.async.AsyncResponseTransformer;
48+
import software.amazon.awssdk.http.nio.netty.NettyNioAsyncHttpClient;
49+
import software.amazon.awssdk.regions.Region;
50+
import software.amazon.awssdk.services.s3.S3AsyncClient;
51+
import software.amazon.awssdk.services.s3.model.GetObjectResponse;
52+
import software.amazon.awssdk.services.s3.model.S3Exception;
53+
54+
@WireMockTest
55+
@Timeout(value = 30, unit = TimeUnit.SECONDS)
56+
public class S3MultipartClientGetObjectWiremockTest {
57+
public static final String BUCKET = "Example-Bucket";
58+
public static final String KEY = "Key";
59+
private static final int MAX_ATTEMPTS = 7;
60+
private S3AsyncClient multipartClient;
61+
62+
@BeforeEach
63+
public void setup(WireMockRuntimeInfo wm) {
64+
multipartClient = S3AsyncClient.builder()
65+
.region(Region.US_EAST_1)
66+
.endpointOverride(URI.create(wm.getHttpBaseUrl()))
67+
.multipartEnabled(true)
68+
.httpClientBuilder(NettyNioAsyncHttpClient.builder()
69+
.maxConcurrency(100)
70+
.connectionAcquisitionTimeout(Duration.ofSeconds(100)))
71+
.credentialsProvider(StaticCredentialsProvider.create(AwsBasicCredentials.create("key", "secret")))
72+
.build();
73+
}
74+
75+
@Test
76+
public void getObject_single500WithinMany200s_shouldNotRetryError() {
77+
List<CompletableFuture<ResponseBytes<GetObjectResponse>>> futures = new ArrayList<>();
78+
79+
int numRuns = 1000;
80+
for (int i = 0; i < numRuns; i++) {
81+
CompletableFuture<ResponseBytes<GetObjectResponse>> resp = mock200Response(multipartClient, i);
82+
futures.add(resp);
83+
}
84+
85+
String errorKey = "ErrorKey";
86+
stubFor(get(urlEqualTo(String.format("/%s/%s?partNumber=1", BUCKET, errorKey)))
87+
.inScenario("RetryableError")
88+
.whenScenarioStateIs(Scenario.STARTED)
89+
.willReturn(aResponse()
90+
.withHeader("x-amz-request-id", String.valueOf(UUID.randomUUID()))
91+
.withStatus(500)
92+
.withBody(internalErrorBody())
93+
)
94+
.willSetStateTo("RetryAttempt"));
95+
96+
stubFor(get(urlEqualTo(String.format("/%s/%s?partNumber=1", BUCKET, errorKey)))
97+
.inScenario("RetryableError")
98+
.whenScenarioStateIs("RetryAttempt")
99+
.willReturn(aResponse().withStatus(200)
100+
.withHeader("x-amz-request-id", String.valueOf(UUID.randomUUID()))
101+
.withBody("Hello World")));
102+
103+
CompletableFuture<ResponseBytes<GetObjectResponse>> requestWithRetryableError =
104+
multipartClient.getObject(r -> r.bucket(BUCKET).key(errorKey), AsyncResponseTransformer.toBytes());
105+
futures.add(requestWithRetryableError);
106+
107+
for (int i = 0; i < numRuns; i++) {
108+
CompletableFuture<ResponseBytes<GetObjectResponse>> resp = mock200Response(multipartClient, i + 1000);
109+
futures.add(resp);
110+
}
111+
112+
try {
113+
CompletableFuture.allOf(futures.toArray(new CompletableFuture[0])).join();
114+
fail("Expecting 500 error to fail request.");
115+
} catch (CompletionException e) {
116+
assertThat(e.getCause()).isInstanceOf(S3Exception.class);
117+
}
118+
119+
verify(1, getRequestedFor(urlEqualTo(String.format("/%s/%s?partNumber=1", BUCKET, errorKey))));
120+
}
121+
122+
private String errorBody(String errorCode, String errorMessage) {
123+
return "<?xml version=\"1.0\" encoding=\"UTF-8\"?>\n"
124+
+ "<Error>\n"
125+
+ " <Code>" + errorCode + "</Code>\n"
126+
+ " <Message>" + errorMessage + "</Message>\n"
127+
+ "</Error>";
128+
}
129+
130+
private String internalErrorBody() {
131+
return errorBody("InternalError", "We encountered an internal error. Please try again.");
132+
}
133+
134+
private CompletableFuture<ResponseBytes<GetObjectResponse>> mock200Response(S3AsyncClient s3Client, int runNumber) {
135+
String runId = runNumber + " success";
136+
137+
stubFor(any(anyUrl())
138+
.withHeader("RunNum", matching(runId))
139+
.inScenario(runId)
140+
.whenScenarioStateIs(Scenario.STARTED)
141+
.willReturn(aResponse().withStatus(200)
142+
.withHeader("x-amz-request-id", String.valueOf(UUID.randomUUID()))
143+
.withBody("Hello World")));
144+
145+
return s3Client.getObject(r -> r.bucket(BUCKET).key("key")
146+
.overrideConfiguration(c -> c.putHeader("RunNum", runId)),
147+
AsyncResponseTransformer.toBytes());
148+
}
149+
}

0 commit comments

Comments
 (0)