Skip to content

Commit 19b969e

Browse files
authored
Fix bug in MultipartS3AsyncClient GetObject (#6320)
* S3 multipart client cancel GetObject futures if onError invoked * Add parameterized test for AsyncResponseTransformer types * Update test * Add changelog
1 parent 9c720b0 commit 19b969e

File tree

4 files changed

+200
-1
lines changed

4 files changed

+200
-1
lines changed
Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
{
2+
"type": "bugfix",
3+
"category": "Amazon S3",
4+
"contributor": "",
5+
"description": "Fix a bug in the Java based multipart client with GetObject incorrect retry behavior"
6+
}

core/sdk-core/src/main/java/software/amazon/awssdk/core/internal/async/SplittingTransformer.java

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
import software.amazon.awssdk.core.SplittingTransformerConfiguration;
2626
import software.amazon.awssdk.core.async.AsyncResponseTransformer;
2727
import software.amazon.awssdk.core.async.SdkPublisher;
28+
import software.amazon.awssdk.core.exception.NonRetryableException;
2829
import software.amazon.awssdk.utils.CompletableFutureUtils;
2930
import software.amazon.awssdk.utils.Logger;
3031
import software.amazon.awssdk.utils.Validate;
@@ -279,7 +280,9 @@ public CompletableFuture<ResponseT> prepare() {
279280
if (e == null) {
280281
return;
281282
}
282-
individualFuture.completeExceptionally(e);
283+
284+
individualFuture.completeExceptionally(NonRetryableException.create(
285+
"Error occurred during multipart download. Request will not be retried.", e));
283286
});
284287
individualFuture.whenComplete((r, e) -> {
285288
if (isCancelled.get()) {

services/s3/src/main/java/software/amazon/awssdk/services/s3/internal/multipart/MultipartDownloaderSubscriber.java

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,9 @@
1515

1616
package software.amazon.awssdk.services.s3.internal.multipart;
1717

18+
import java.util.Queue;
1819
import java.util.concurrent.CompletableFuture;
20+
import java.util.concurrent.ConcurrentLinkedQueue;
1921
import java.util.concurrent.atomic.AtomicInteger;
2022
import org.reactivestreams.Subscriber;
2123
import org.reactivestreams.Subscription;
@@ -79,6 +81,11 @@ public class MultipartDownloaderSubscriber implements Subscriber<AsyncResponseTr
7981
*/
8082
private final Object lock = new Object();
8183

84+
/**
85+
* Store the GetObject futures so we can cancel them if onError() is invoked.
86+
*/
87+
private final Queue<CompletableFuture<GetObjectResponse>> getObjectFutures = new ConcurrentLinkedQueue<>();
88+
8289
public MultipartDownloaderSubscriber(S3AsyncClient s3, GetObjectRequest getObjectRequest) {
8390
this(s3, getObjectRequest, 0);
8491
}
@@ -119,6 +126,7 @@ public void onNext(AsyncResponseTransformer<GetObjectResponse, GetObjectResponse
119126
GetObjectRequest actualRequest = nextRequest(nextPartToGet);
120127
log.debug(() -> "Sending GetObjectRequest for next part with partNumber=" + nextPartToGet);
121128
CompletableFuture<GetObjectResponse> getObjectFuture = s3.getObject(actualRequest, asyncResponseTransformer);
129+
getObjectFutures.add(getObjectFuture);
122130
getObjectFuture.whenComplete((response, error) -> {
123131
if (error != null) {
124132
log.debug(() -> "Error encountered during GetObjectRequest with partNumber=" + nextPartToGet);
@@ -166,6 +174,10 @@ private void requestMoreIfNeeded(GetObjectResponse response) {
166174

167175
@Override
168176
public void onError(Throwable t) {
177+
CompletableFuture<GetObjectResponse> partFuture;
178+
while ((partFuture = getObjectFutures.poll()) != null) {
179+
partFuture.cancel(true);
180+
}
169181
future.completeExceptionally(t);
170182
}
171183

Original file line numberDiff line numberDiff line change
@@ -0,0 +1,178 @@
1+
/*
2+
* Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License").
5+
* You may not use this file except in compliance with the License.
6+
* A copy of the License is located at
7+
*
8+
* http://aws.amazon.com/apache2.0
9+
*
10+
* or in the "license" file accompanying this file. This file is distributed
11+
* on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
12+
* express or implied. See the License for the specific language governing
13+
* permissions and limitations under the License.
14+
*/
15+
16+
package software.amazon.awssdk.services.s3.internal.multipart;
17+
18+
import static com.github.tomakehurst.wiremock.client.WireMock.aResponse;
19+
import static com.github.tomakehurst.wiremock.client.WireMock.any;
20+
import static com.github.tomakehurst.wiremock.client.WireMock.anyUrl;
21+
import static com.github.tomakehurst.wiremock.client.WireMock.get;
22+
import static com.github.tomakehurst.wiremock.client.WireMock.getRequestedFor;
23+
import static com.github.tomakehurst.wiremock.client.WireMock.matching;
24+
import static com.github.tomakehurst.wiremock.client.WireMock.stubFor;
25+
import static com.github.tomakehurst.wiremock.client.WireMock.urlEqualTo;
26+
import static com.github.tomakehurst.wiremock.client.WireMock.verify;
27+
import static org.assertj.core.api.Assertions.assertThat;
28+
import static org.assertj.core.api.Assertions.fail;
29+
30+
import com.github.tomakehurst.wiremock.junit5.WireMockRuntimeInfo;
31+
import com.github.tomakehurst.wiremock.junit5.WireMockTest;
32+
import com.github.tomakehurst.wiremock.stubbing.Scenario;
33+
import java.io.IOException;
34+
import java.io.UncheckedIOException;
35+
import java.net.URI;
36+
import java.nio.file.Files;
37+
import java.nio.file.Path;
38+
import java.time.Duration;
39+
import java.util.ArrayList;
40+
import java.util.List;
41+
import java.util.UUID;
42+
import java.util.concurrent.CompletableFuture;
43+
import java.util.concurrent.CompletionException;
44+
import java.util.concurrent.TimeUnit;
45+
import java.util.stream.Stream;
46+
import org.junit.jupiter.api.BeforeEach;
47+
import org.junit.jupiter.api.Timeout;
48+
import org.junit.jupiter.params.ParameterizedTest;
49+
import org.junit.jupiter.params.provider.MethodSource;
50+
import software.amazon.awssdk.auth.credentials.AwsBasicCredentials;
51+
import software.amazon.awssdk.auth.credentials.StaticCredentialsProvider;
52+
import software.amazon.awssdk.core.async.AsyncResponseTransformer;
53+
import software.amazon.awssdk.http.nio.netty.NettyNioAsyncHttpClient;
54+
import software.amazon.awssdk.regions.Region;
55+
import software.amazon.awssdk.services.s3.S3AsyncClient;
56+
import software.amazon.awssdk.services.s3.model.GetObjectResponse;
57+
import software.amazon.awssdk.services.s3.model.S3Exception;
58+
59+
@WireMockTest
60+
@Timeout(value = 45, unit = TimeUnit.SECONDS)
61+
public class S3MultipartClientGetObjectWiremockTest {
62+
private static final String BUCKET = "Example-Bucket";
63+
private static final String KEY = "Key";
64+
private static int fileCounter = 0;
65+
private S3AsyncClient multipartClient;
66+
67+
@BeforeEach
68+
public void setup(WireMockRuntimeInfo wm) {
69+
multipartClient = S3AsyncClient.builder()
70+
.region(Region.US_EAST_1)
71+
.endpointOverride(URI.create(wm.getHttpBaseUrl()))
72+
.multipartEnabled(true)
73+
.httpClientBuilder(NettyNioAsyncHttpClient.builder()
74+
.maxConcurrency(100)
75+
.connectionAcquisitionTimeout(Duration.ofSeconds(60)))
76+
.credentialsProvider(StaticCredentialsProvider.create(AwsBasicCredentials.create("key", "secret")))
77+
.build();
78+
}
79+
80+
private static Stream<TransformerFactory> responseTransformerFactories() {
81+
return Stream.of(
82+
AsyncResponseTransformer::toBytes,
83+
AsyncResponseTransformer::toBlockingInputStream,
84+
AsyncResponseTransformer::toPublisher,
85+
() -> {
86+
try {
87+
Path tempDir = Files.createTempDirectory("s3-test");
88+
Path tempFile = tempDir.resolve("testFile" + fileCounter + ".txt");
89+
fileCounter++;
90+
tempFile.toFile().deleteOnExit();
91+
return AsyncResponseTransformer.toFile(tempFile);
92+
} catch (IOException e) {
93+
throw new UncheckedIOException(e);
94+
}
95+
}
96+
);
97+
}
98+
99+
interface TransformerFactory {
100+
AsyncResponseTransformer<GetObjectResponse, ?> create();
101+
}
102+
103+
@ParameterizedTest
104+
@MethodSource("responseTransformerFactories")
105+
public void getObject_single500WithinMany200s_shouldNotRetryError(TransformerFactory transformerFactory) {
106+
List<CompletableFuture<?>> futures = new ArrayList<>();
107+
108+
int numRuns = 100;
109+
for (int i = 0; i < numRuns; i++) {
110+
CompletableFuture<?> resp = mock200Response(multipartClient, i, transformerFactory);
111+
futures.add(resp);
112+
}
113+
114+
String errorKey = "ErrorKey";
115+
stubFor(get(urlEqualTo(String.format("/%s/%s?partNumber=1", BUCKET, errorKey)))
116+
.inScenario("RetryableError")
117+
.whenScenarioStateIs(Scenario.STARTED)
118+
.willReturn(aResponse()
119+
.withHeader("x-amz-request-id", String.valueOf(UUID.randomUUID()))
120+
.withStatus(500)
121+
.withBody(internalErrorBody())
122+
)
123+
.willSetStateTo("RetryAttempt"));
124+
125+
stubFor(get(urlEqualTo(String.format("/%s/%s?partNumber=1", BUCKET, errorKey)))
126+
.inScenario("RetryableError")
127+
.whenScenarioStateIs("RetryAttempt")
128+
.willReturn(aResponse().withStatus(200)
129+
.withHeader("x-amz-request-id", String.valueOf(UUID.randomUUID()))
130+
.withBody("Hello World")));
131+
132+
CompletableFuture<?> requestWithRetryableError =
133+
multipartClient.getObject(r -> r.bucket(BUCKET).key(errorKey), transformerFactory.create());
134+
futures.add(requestWithRetryableError);
135+
136+
for (int i = 0; i < numRuns; i++) {
137+
CompletableFuture<?> resp = mock200Response(multipartClient, i + 1000, transformerFactory);
138+
futures.add(resp);
139+
}
140+
141+
try {
142+
CompletableFuture.allOf(futures.toArray(new CompletableFuture[0])).join();
143+
fail("Expecting 500 error to fail request.");
144+
} catch (CompletionException e) {
145+
assertThat(e.getCause()).isInstanceOf(S3Exception.class);
146+
}
147+
148+
verify(1, getRequestedFor(urlEqualTo(String.format("/%s/%s?partNumber=1", BUCKET, errorKey))));
149+
}
150+
151+
private CompletableFuture<?> mock200Response(S3AsyncClient s3Client, int runNumber, TransformerFactory transformerFactory) {
152+
String runId = runNumber + " success";
153+
154+
stubFor(any(anyUrl())
155+
.withHeader("RunNum", matching(runId))
156+
.inScenario(runId)
157+
.whenScenarioStateIs(Scenario.STARTED)
158+
.willReturn(aResponse().withStatus(200)
159+
.withHeader("x-amz-request-id", String.valueOf(UUID.randomUUID()))
160+
.withBody("Hello World")));
161+
162+
return s3Client.getObject(r -> r.bucket(BUCKET).key(KEY)
163+
.overrideConfiguration(c -> c.putHeader("RunNum", runId)),
164+
transformerFactory.create());
165+
}
166+
167+
private String errorBody(String errorCode, String errorMessage) {
168+
return "<?xml version=\"1.0\" encoding=\"UTF-8\"?>\n"
169+
+ "<Error>\n"
170+
+ " <Code>" + errorCode + "</Code>\n"
171+
+ " <Message>" + errorMessage + "</Message>\n"
172+
+ "</Error>";
173+
}
174+
175+
private String internalErrorBody() {
176+
return errorBody("InternalError", "We encountered an internal error. Please try again.");
177+
}
178+
}

0 commit comments

Comments
 (0)