Skip to content

Commit 8eef6e4

Browse files
committed
Additional Changes:
Adding the logic to close the client Addressing PR feedback
1 parent ad0fd27 commit 8eef6e4

File tree

7 files changed

+348
-61
lines changed

7 files changed

+348
-61
lines changed

core/aws-core/src/main/java/software/amazon/awssdk/awscore/internal/defaultsmode/AutoDefaultsModeDiscovery.java

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -88,15 +88,20 @@ private static Optional<String> queryImdsV2() {
8888
return Optional.empty();
8989
}
9090

91+
Ec2MetadataClient client = null;
9192
try {
92-
Ec2MetadataClient client = Ec2MetadataSharedClient.builder()
93-
.retryPolicy(Ec2MetadataRetryPolicy.none())
94-
.build();
93+
client = Ec2MetadataSharedClient.builder()
94+
.retryPolicy(Ec2MetadataRetryPolicy.none())
95+
.build();
9596

9697
String ec2InstanceRegion = client.get(EC2_METADATA_REGION_PATH).asString();
9798
return Optional.ofNullable(ec2InstanceRegion);
9899
} catch (Exception exception) {
99100
return Optional.empty();
101+
} finally {
102+
if (client != null) {
103+
Ec2MetadataSharedClient.decrementAndClose();
104+
}
100105
}
101106
}
102107

core/aws-core/src/test/java/software/amazon/awssdk/awscore/internal/defaultsmode/AutoDefaultsModeDiscoveryEc2MetadataClientTest.java

Lines changed: 66 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -23,22 +23,21 @@
2323
import static com.github.tomakehurst.wiremock.client.WireMock.putRequestedFor;
2424
import static com.github.tomakehurst.wiremock.client.WireMock.urlEqualTo;
2525
import static com.github.tomakehurst.wiremock.client.WireMock.verify;
26+
import static com.github.tomakehurst.wiremock.client.WireMock.matching;
2627
import static com.github.tomakehurst.wiremock.core.WireMockConfiguration.wireMockConfig;
2728
import static org.assertj.core.api.Assertions.assertThat;
2829

29-
import com.github.tomakehurst.wiremock.junit.WireMockRule;
30+
import com.github.tomakehurst.wiremock.junit5.WireMockExtension;
3031
import java.lang.reflect.Field;
31-
import org.junit.After;
32-
import org.junit.Before;
33-
import org.junit.Rule;
34-
import org.junit.Test;
32+
import org.junit.jupiter.api.AfterEach;
33+
import org.junit.jupiter.api.BeforeAll;
34+
import org.junit.jupiter.api.BeforeEach;
35+
import org.junit.jupiter.api.Test;
36+
import org.junit.jupiter.api.extension.RegisterExtension;
3537
import software.amazon.awssdk.awscore.defaultsmode.DefaultsMode;
3638
import software.amazon.awssdk.core.SdkSystemSetting;
37-
import software.amazon.awssdk.http.SdkHttpClient;
38-
import software.amazon.awssdk.imds.internal.Ec2MetadataSharedClient;
3939
import software.amazon.awssdk.regions.Region;
4040
import software.amazon.awssdk.testutils.EnvironmentVariableHelper;
41-
import software.amazon.awssdk.utils.Lazy;
4241

4342
/**
4443
* Tests specifically for AutoDefaultsModeDiscovery's migration to use Ec2MetadataClient.
@@ -47,26 +46,29 @@
4746
public class AutoDefaultsModeDiscoveryEc2MetadataClientTest {
4847
private static final EnvironmentVariableHelper ENVIRONMENT_VARIABLE_HELPER = new EnvironmentVariableHelper();
4948

50-
@Rule
51-
public WireMockRule wireMock = new WireMockRule(wireMockConfig()
52-
.port(0)
53-
.httpsPort(-1));
49+
@RegisterExtension
50+
static WireMockExtension wireMock = WireMockExtension.newInstance()
51+
.options(wireMockConfig().dynamicPort().dynamicPort())
52+
.configureStaticDsl(true)
53+
.build();
5454

55-
@Before
56-
public void setup() {
55+
@BeforeAll
56+
static void setupClass() {
5757
System.setProperty(SdkSystemSetting.AWS_EC2_METADATA_SERVICE_ENDPOINT.property(),
58-
"http://localhost:" + wireMock.port());
58+
"http://localhost:" + wireMock.getPort());
59+
}
5960

61+
@BeforeEach
62+
public void setup() {
6063
clearEnvironmentVariable("AWS_EXECUTION_ENV");
6164
clearEnvironmentVariable("AWS_REGION");
6265
clearEnvironmentVariable("AWS_DEFAULT_REGION");
6366
}
6467

65-
@After
68+
@AfterEach
6669
public void cleanup() {
6770
wireMock.resetAll();
6871
ENVIRONMENT_VARIABLE_HELPER.reset();
69-
System.clearProperty(SdkSystemSetting.AWS_EC2_METADATA_SERVICE_ENDPOINT.property());
7072
}
7173

7274
// Clear an environment variable by setting it to null.
@@ -82,7 +84,10 @@ private void clearEnvironmentVariable(String name) {
8284
public void autoDefaultsModeDiscovery_shouldUseSharedHttpClient() throws Exception {
8385
// Stub successful IMDS responses
8486
stubFor(put("/latest/api/token")
85-
.willReturn(aResponse().withStatus(200).withBody("test-token")));
87+
.willReturn(aResponse()
88+
.withStatus(200)
89+
.withHeader("x-aws-ec2-metadata-token-ttl-seconds", "21600")
90+
.withBody("test-token")));
8691
stubFor(get("/latest/meta-data/placement/region")
8792
.willReturn(aResponse().withStatus(200).withBody("us-east-1")));
8893

@@ -92,23 +97,25 @@ public void autoDefaultsModeDiscovery_shouldUseSharedHttpClient() throws Excepti
9297
// Should return IN_REGION since client region matches IMDS region
9398
assertThat(result).isEqualTo(DefaultsMode.IN_REGION);
9499

95-
// Verify that the shared HTTP client was used
96-
Field sharedClientField = Ec2MetadataSharedClient.class.getDeclaredField("SHARED_HTTP_CLIENT");
97-
sharedClientField.setAccessible(true);
98-
Lazy<SdkHttpClient> sharedHttpClient = (Lazy<SdkHttpClient>) sharedClientField.get(null);
99-
100-
// Verify the shared HTTP client was initialized
101-
assertThat(sharedHttpClient.hasValue()).isTrue();
102-
103-
// Verify IMDS requests were made
100+
// Verify token request was made
104101
verify(putRequestedFor(urlEqualTo("/latest/api/token")));
105-
verify(getRequestedFor(urlEqualTo("/latest/meta-data/placement/region")));
102+
103+
// Verify region request was made with token header - IMDSv2
104+
verify(getRequestedFor(urlEqualTo("/latest/meta-data/placement/region"))
105+
.withHeader("x-aws-ec2-metadata-token", matching("test-token")));
106+
107+
// Verify no IMDSv1 requests were made
108+
verify(0, getRequestedFor(urlEqualTo("/latest/meta-data/placement/region"))
109+
.withoutHeader("x-aws-ec2-metadata-token"));
106110
}
107111

108112
@Test
109113
public void multipleDiscoveryInstances_shouldShareSameHttpClient() throws Exception {
110114
stubFor(put("/latest/api/token")
111-
.willReturn(aResponse().withStatus(200).withBody("test-token")));
115+
.willReturn(aResponse()
116+
.withStatus(200)
117+
.withHeader("x-aws-ec2-metadata-token-ttl-seconds", "21600")
118+
.withBody("test-token")));
112119
stubFor(get("/latest/meta-data/placement/region")
113120
.willReturn(aResponse().withStatus(200).withBody("us-west-2")));
114121

@@ -124,16 +131,16 @@ public void multipleDiscoveryInstances_shouldShareSameHttpClient() throws Except
124131
assertThat(result1).isEqualTo(DefaultsMode.CROSS_REGION);
125132
assertThat(result2).isEqualTo(DefaultsMode.CROSS_REGION);
126133

127-
// Verify shared HTTP client was used
128-
Field sharedClientField = Ec2MetadataSharedClient.class.getDeclaredField("SHARED_HTTP_CLIENT");
129-
sharedClientField.setAccessible(true);
130-
Lazy<SdkHttpClient> sharedHttpClient = (Lazy<SdkHttpClient>) sharedClientField.get(null);
131-
132-
assertThat(sharedHttpClient.hasValue()).isTrue();
133-
134-
// Verify IMDS requests were made
134+
// Verify token request was made
135135
verify(putRequestedFor(urlEqualTo("/latest/api/token")));
136-
verify(getRequestedFor(urlEqualTo("/latest/meta-data/placement/region")));
136+
137+
// Verify region request was made with token header - IMDSv2
138+
verify(getRequestedFor(urlEqualTo("/latest/meta-data/placement/region"))
139+
.withHeader("x-aws-ec2-metadata-token", matching("test-token")));
140+
141+
// Verify no IMDSv1 requests were made
142+
verify(0, getRequestedFor(urlEqualTo("/latest/meta-data/placement/region"))
143+
.withoutHeader("x-aws-ec2-metadata-token"));
137144
}
138145

139146
@Test
@@ -174,7 +181,10 @@ public void imdsFailure_shouldFallbackToStandardMode() {
174181
public void noRetryPolicy_shouldBeUsedByDefault() {
175182
// Stub token to succeed but region to fail with retryable error
176183
stubFor(put("/latest/api/token")
177-
.willReturn(aResponse().withStatus(200).withBody("test-token")));
184+
.willReturn(aResponse()
185+
.withStatus(200)
186+
.withHeader("x-aws-ec2-metadata-token-ttl-seconds", "21600")
187+
.withBody("test-token")));
178188
stubFor(get("/latest/meta-data/placement/region")
179189
.willReturn(aResponse().withStatus(500).withBody("Internal Server Error")));
180190

@@ -186,7 +196,14 @@ public void noRetryPolicy_shouldBeUsedByDefault() {
186196

187197
// Verify requests were made once (no retries)
188198
verify(1, putRequestedFor(urlEqualTo("/latest/api/token")));
189-
verify(1, getRequestedFor(urlEqualTo("/latest/meta-data/placement/region")));
199+
200+
// Verify region request was made with token header - IMDSv2
201+
verify(1, getRequestedFor(urlEqualTo("/latest/meta-data/placement/region"))
202+
.withHeader("x-aws-ec2-metadata-token", matching("test-token")));
203+
204+
// Verify no IMDSv1 requests were made
205+
verify(0, getRequestedFor(urlEqualTo("/latest/meta-data/placement/region"))
206+
.withoutHeader("x-aws-ec2-metadata-token"));
190207
}
191208

192209
@Test
@@ -205,9 +222,16 @@ public void imdsV1Fallback_shouldWorkWhenTokenFails() {
205222
// Should fall back to IMDSv1 and return IN_REGION
206223
assertThat(result).isEqualTo(DefaultsMode.IN_REGION);
207224

208-
// Verify both token request and region request were made
225+
// Verify token request was attempted
209226
verify(putRequestedFor(urlEqualTo("/latest/api/token")));
210-
verify(getRequestedFor(urlEqualTo("/latest/meta-data/placement/region")));
227+
228+
// Verify region request was made without token header - IMDSv1 fallback
229+
verify(getRequestedFor(urlEqualTo("/latest/meta-data/placement/region"))
230+
.withoutHeader("x-aws-ec2-metadata-token"));
231+
232+
// Verify no IMDSv2 requests were made
233+
verify(0, getRequestedFor(urlEqualTo("/latest/meta-data/placement/region"))
234+
.withHeader("x-aws-ec2-metadata-token", matching(".*")));
211235
}
212236

213237
@Test

core/imds/src/main/java/software/amazon/awssdk/imds/internal/DefaultEc2MetadataClientWithFallback.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -124,7 +124,7 @@ public Ec2MetadataResponse get(String path) {
124124
if (token == null || token.isExpired()) {
125125
token = tokenCache.get();
126126
}
127-
return sendRequest(path, token != null ? token.value() : null);
127+
return sendRequest(path, token == null ? null : token.value());
128128
} catch (UncheckedIOException | RetryableException e) {
129129
lastCause = e;
130130
int currentTry = attempt;

core/imds/src/main/java/software/amazon/awssdk/imds/internal/Ec2MetadataSharedClient.java

Lines changed: 27 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -16,24 +16,24 @@
1616
package software.amazon.awssdk.imds.internal;
1717

1818
import java.time.Duration;
19-
import software.amazon.awssdk.annotations.SdkInternalApi;
19+
import software.amazon.awssdk.annotations.SdkProtectedApi;
2020
import software.amazon.awssdk.core.internal.http.loader.DefaultSdkHttpClientBuilder;
2121
import software.amazon.awssdk.http.SdkHttpClient;
2222
import software.amazon.awssdk.http.SdkHttpConfigurationOption;
2323
import software.amazon.awssdk.imds.Ec2MetadataClient;
2424
import software.amazon.awssdk.imds.Ec2MetadataRetryPolicy;
2525
import software.amazon.awssdk.utils.AttributeMap;
26-
import software.amazon.awssdk.utils.Lazy;
2726

2827
/**
2928
* Creates Ec2MetadataClient instances using a shared HTTP client internally.
3029
* This provides resource efficiency by sharing a single HTTP client across all IMDS-backed providers
3130
*/
32-
@SdkInternalApi
31+
@SdkProtectedApi
3332
public final class Ec2MetadataSharedClient {
34-
// Singleton HTTP client shared across all Ec2MetadataClient instances
35-
private static final Lazy<SdkHttpClient> SHARED_HTTP_CLIENT = new Lazy<>(() -> createImdsHttpClient());
36-
33+
34+
private static volatile SdkHttpClient sharedHttpClient;
35+
private static int referenceCount = 0;
36+
3737
private Ec2MetadataSharedClient() {
3838
// Prevent instantiation
3939
}
@@ -52,6 +52,17 @@ public static Builder builder() {
5252
public static Ec2MetadataClient create() {
5353
return builder().build();
5454
}
55+
56+
/**
57+
* Decrements the reference count and closes the shared HTTP client if no more references exist.
58+
*/
59+
public static synchronized void decrementAndClose() {
60+
referenceCount--;
61+
if (referenceCount == 0 && sharedHttpClient != null) {
62+
sharedHttpClient.close();
63+
sharedHttpClient = null;
64+
}
65+
}
5566

5667
private static SdkHttpClient createImdsHttpClient() {
5768
Duration metadataServiceTimeout = Ec2MetadataConfigProvider.instance().resolveServiceTimeout();
@@ -69,14 +80,21 @@ public static final class Builder {
6980
private Builder() {
7081
}
7182

72-
public Builder retryPolicy(Ec2MetadataRetryPolicy retryPolicy) {
83+
public synchronized Builder retryPolicy(Ec2MetadataRetryPolicy retryPolicy) {
7384
this.retryPolicy = retryPolicy;
7485
return this;
7586
}
7687

77-
public Ec2MetadataClient build() {
88+
public synchronized Ec2MetadataClient build() {
89+
90+
if (sharedHttpClient == null) {
91+
sharedHttpClient = createImdsHttpClient();
92+
}
93+
94+
referenceCount++;
95+
7896
return DefaultEc2MetadataClientWithFallback.builder()
79-
.httpClient(SHARED_HTTP_CLIENT.getValue())
97+
.httpClient(sharedHttpClient)
8098
.retryPolicy(retryPolicy)
8199
.build();
82100
}

core/imds/src/main/java/software/amazon/awssdk/imds/internal/RequestMarshaller.java

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -59,12 +59,14 @@ public SdkHttpFullRequest createTokenRequest(Duration tokenTtl) {
5959

6060
public SdkHttpFullRequest createDataRequest(String path, String token, Duration tokenTtl) {
6161
URI resourcePath = URI.create(basePath + path);
62-
return defaulttHttpBuilder()
62+
SdkHttpFullRequest.Builder builder = defaulttHttpBuilder()
6363
.method(SdkHttpMethod.GET)
6464
.uri(resourcePath)
65-
.putHeader(EC2_METADATA_TOKEN_TTL_HEADER, String.valueOf(tokenTtl.getSeconds()))
66-
.putHeader(TOKEN_HEADER, token)
67-
.build();
65+
.putHeader(EC2_METADATA_TOKEN_TTL_HEADER, String.valueOf(tokenTtl.getSeconds()));
66+
if (token != null) {
67+
builder.putHeader(TOKEN_HEADER, token);
68+
}
69+
return builder.build();
6870
}
6971

7072
private SdkHttpFullRequest.Builder defaulttHttpBuilder() {

0 commit comments

Comments
 (0)