Skip to content

Commit 97f9314

Browse files
committed
EC2 IMDS Changes to Support Account ID
1 parent a7801ce commit 97f9314

File tree

7 files changed

+271
-20
lines changed

7 files changed

+271
-20
lines changed
Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
{
2+
"type": "feature",
3+
"category": "AWS EC2",
4+
"contributor": "",
5+
"description": "EC2 IMDS Changes to Support Account ID"
6+
}

core/auth/src/it/java/software/amazon/awssdk/auth/credentials/InstanceProfileCredentialsProviderIntegrationTest.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ public class InstanceProfileCredentialsProviderIntegrationTest {
3535
/** Starts up the mock EC2 Instance Metadata Service. */
3636
@Before
3737
public void setUp() throws Exception {
38-
mockServer = new EC2MetadataServiceMock("/latest/meta-data/iam/security-credentials/");
38+
mockServer = new EC2MetadataServiceMock("/latest/meta-data/iam/security-credentials-extended/");
3939
mockServer.start();
4040
}
4141

core/auth/src/main/java/software/amazon/awssdk/auth/credentials/InstanceProfileCredentialsProvider.java

Lines changed: 62 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -37,12 +37,14 @@
3737
import software.amazon.awssdk.core.SdkSystemSetting;
3838
import software.amazon.awssdk.core.exception.SdkClientException;
3939
import software.amazon.awssdk.core.exception.SdkServiceException;
40+
import software.amazon.awssdk.imds.Ec2MetadataClientException;
4041
import software.amazon.awssdk.profiles.ProfileFile;
4142
import software.amazon.awssdk.profiles.ProfileFileSupplier;
4243
import software.amazon.awssdk.profiles.ProfileFileSystemSetting;
4344
import software.amazon.awssdk.profiles.ProfileProperty;
4445
import software.amazon.awssdk.regions.util.HttpResourcesUtils;
4546
import software.amazon.awssdk.regions.util.ResourcesEndpointProvider;
47+
import software.amazon.awssdk.utils.Lazy;
4648
import software.amazon.awssdk.utils.Logger;
4749
import software.amazon.awssdk.utils.ToString;
4850
import software.amazon.awssdk.utils.Validate;
@@ -70,9 +72,19 @@ public final class InstanceProfileCredentialsProvider
7072
private static final String PROVIDER_NAME = "InstanceProfileCredentialsProvider";
7173
private static final String EC2_METADATA_TOKEN_HEADER = "x-aws-ec2-metadata-token";
7274
private static final String SECURITY_CREDENTIALS_RESOURCE = "/latest/meta-data/iam/security-credentials/";
75+
private static final String SECURITY_CREDENTIALS_EXTENDED_RESOURCE = "/latest/meta-data/iam/security-credentials-extended/";
7376
private static final String TOKEN_RESOURCE = "/latest/api/token";
77+
78+
private enum ApiVersion {
79+
UNKNOWN,
80+
LEGACY,
81+
EXTENDED
82+
}
83+
7484
private static final String EC2_METADATA_TOKEN_TTL_HEADER = "x-aws-ec2-metadata-token-ttl-seconds";
7585
private static final String DEFAULT_TOKEN_TTL = "21600";
86+
private Lazy<ApiVersion> apiVersion = new Lazy<>(() -> ApiVersion.UNKNOWN);
87+
private Lazy<String> resolvedProfile = new Lazy<>(() -> null);
7688

7789
private final Clock clock;
7890
private final String endpoint;
@@ -157,14 +169,27 @@ private RefreshResult<AwsCredentials> refreshCredentials() {
157169

158170
try {
159171
LoadedCredentials credentials = httpCredentialsLoader.loadCredentials(createEndpointProvider());
172+
ApiVersion currentVersion = apiVersion.getValue();
173+
if (currentVersion == ApiVersion.UNKNOWN) {
174+
apiVersion = Lazy.withValue(ApiVersion.EXTENDED);
175+
}
176+
160177
Instant expiration = credentials.getExpiration().orElse(null);
161178
log.debug(() -> "Loaded credentials from IMDS with expiration time of " + expiration);
162179

163180
return RefreshResult.builder(credentials.getAwsCredentials())
164181
.staleTime(staleTime(expiration))
165182
.prefetchTime(prefetchTime(expiration))
166183
.build();
184+
} catch (Ec2MetadataClientException e) {
185+
if (apiVersion.getValue() == ApiVersion.UNKNOWN) {
186+
apiVersion = Lazy.withValue(ApiVersion.LEGACY);
187+
resolvedProfile = new Lazy<>(() -> null);
188+
return refreshCredentials();
189+
}
190+
throw SdkClientException.create("Failed to load credentials from IMDS.", e);
167191
} catch (RuntimeException e) {
192+
resolvedProfile = new Lazy<>(() -> null);
168193
throw SdkClientException.create("Failed to load credentials from IMDS.", e);
169194
}
170195
}
@@ -207,14 +232,20 @@ public String toString() {
207232
return ToString.create(PROVIDER_NAME);
208233
}
209234

235+
private String getSecurityCredentialsResource() {
236+
return apiVersion.getValue() == ApiVersion.LEGACY ?
237+
SECURITY_CREDENTIALS_RESOURCE :
238+
SECURITY_CREDENTIALS_EXTENDED_RESOURCE;
239+
}
240+
210241
private ResourcesEndpointProvider createEndpointProvider() {
211242
String imdsHostname = getImdsEndpoint();
212243
String token = getToken(imdsHostname);
213244
String[] securityCredentials = getSecurityCredentials(imdsHostname, token);
214-
245+
String urlBase = getSecurityCredentialsResource();
246+
215247
return StaticResourcesEndpointProvider.builder()
216-
.endpoint(URI.create(imdsHostname + SECURITY_CREDENTIALS_RESOURCE
217-
+ securityCredentials[0]))
248+
.endpoint(URI.create(imdsHostname + urlBase + securityCredentials[0]))
218249
.headers(getTokenHeaders(token))
219250
.connectionTimeout(Duration.ofMillis(
220251
this.configProvider.serviceTimeout()))
@@ -285,21 +316,41 @@ private boolean isInsecureFallbackDisabled() {
285316
}
286317

287318
private String[] getSecurityCredentials(String imdsHostname, String metadataToken) {
319+
if (resolvedProfile.hasValue()) {
320+
return new String[]{resolvedProfile.getValue()};
321+
}
322+
323+
String urlBase = getSecurityCredentialsResource();
288324
ResourcesEndpointProvider securityCredentialsEndpoint =
289325
StaticResourcesEndpointProvider.builder()
290-
.endpoint(URI.create(imdsHostname + SECURITY_CREDENTIALS_RESOURCE))
326+
.endpoint(URI.create(imdsHostname + urlBase))
291327
.headers(getTokenHeaders(metadataToken))
292-
.connectionTimeout(Duration.ofMillis(this.configProvider.serviceTimeout()))
328+
.connectionTimeout(Duration.ofMillis(this.configProvider.serviceTimeout()))
293329
.build();
294330

295-
String securityCredentialsList =
296-
invokeSafely(() -> HttpResourcesUtils.instance().readResource(securityCredentialsEndpoint));
297-
String[] securityCredentials = securityCredentialsList.trim().split("\n");
331+
try {
332+
String securityCredentialsList =
333+
invokeSafely(() -> HttpResourcesUtils.instance().readResource(securityCredentialsEndpoint));
334+
String[] securityCredentials = securityCredentialsList.trim().split("\n");
335+
336+
if (securityCredentials.length == 0) {
337+
throw SdkClientException.builder().message("Unable to load credentials path").build();
338+
}
298339

299-
if (securityCredentials.length == 0) {
300-
throw SdkClientException.builder().message("Unable to load credentials path").build();
340+
ApiVersion currentVersion = apiVersion.getValue();
341+
if (currentVersion == ApiVersion.UNKNOWN) {
342+
apiVersion = Lazy.withValue(ApiVersion.EXTENDED);
343+
}
344+
resolvedProfile = new Lazy<>(() -> securityCredentials[0]);
345+
return securityCredentials;
346+
347+
} catch (Ec2MetadataClientException e) {
348+
if (apiVersion.getValue() == ApiVersion.UNKNOWN) {
349+
apiVersion = Lazy.withValue(ApiVersion.LEGACY);
350+
return getSecurityCredentials(imdsHostname, metadataToken);
351+
}
352+
throw SdkClientException.create("Failed to load credentials from IMDS.", e);
301353
}
302-
return securityCredentials;
303354
}
304355

305356
private Map<String, String> getTokenHeaders(String metadataToken) {

core/auth/src/main/java/software/amazon/awssdk/auth/credentials/internal/HttpCredentialsLoader.java

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -64,15 +64,17 @@ public LoadedCredentials loadCredentials(ResourcesEndpointProvider endpoint) {
6464
JsonNode secretKey = node.get("SecretAccessKey");
6565
JsonNode token = node.get("Token");
6666
JsonNode expiration = node.get("Expiration");
67+
JsonNode accountId = node.get("AccountId");
6768

6869
Validate.notNull(accessKey, "Failed to load access key from metadata service.");
6970
Validate.notNull(secretKey, "Failed to load secret key from metadata service.");
7071

7172
return new LoadedCredentials(accessKey.text(),
72-
secretKey.text(),
73-
token != null ? token.text() : null,
74-
expiration != null ? expiration.text() : null,
75-
providerName);
73+
secretKey.text(),
74+
token != null ? token.text() : null,
75+
expiration != null ? expiration.text() : null,
76+
accountId != null ? accountId.text() : null,
77+
providerName);
7678
} catch (SdkClientException e) {
7779
throw e;
7880
} catch (RuntimeException | IOException e) {
@@ -89,12 +91,15 @@ public static final class LoadedCredentials {
8991
private final String token;
9092
private final Instant expiration;
9193
private final String providerName;
94+
private final String accountId;
9295

93-
private LoadedCredentials(String accessKeyId, String secretKey, String token, String expiration, String providerName) {
96+
private LoadedCredentials(String accessKeyId, String secretKey, String token,
97+
String expiration, String accountId, String providerName) {
9498
this.accessKeyId = Validate.paramNotBlank(accessKeyId, "accessKeyId");
9599
this.secretKey = Validate.paramNotBlank(secretKey, "secretKey");
96100
this.token = token;
97101
this.expiration = expiration == null ? null : parseExpiration(expiration);
102+
this.accountId = accountId;
98103
this.providerName = providerName;
99104
}
100105

@@ -105,11 +110,13 @@ public AwsCredentials getAwsCredentials() {
105110
.secretAccessKey(secretKey)
106111
.sessionToken(token)
107112
.providerName(providerName)
113+
.accountId(accountId)
108114
.build() :
109115
AwsBasicCredentials.builder()
110116
.accessKeyId(accessKeyId)
111117
.secretAccessKey(secretKey)
112118
.providerName(providerName)
119+
.accountId(accountId)
113120
.build();
114121
}
115122

Lines changed: 181 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,181 @@
1+
/*
2+
* Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License").
5+
* You may not use this file except in compliance with the License.
6+
* A copy of the License is located at
7+
*
8+
* http://aws.amazon.com/apache2.0
9+
*
10+
* or in the "license" file accompanying this file. This file is distributed
11+
* on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
12+
* express or implied. See the License for the specific language governing
13+
* permissions and limitations under the License.
14+
*/
15+
16+
package software.amazon.awssdk.auth.credentials;
17+
18+
import static com.github.tomakehurst.wiremock.client.WireMock.aResponse;
19+
import static com.github.tomakehurst.wiremock.client.WireMock.equalTo;
20+
import static com.github.tomakehurst.wiremock.client.WireMock.get;
21+
import static com.github.tomakehurst.wiremock.client.WireMock.getRequestedFor;
22+
import static com.github.tomakehurst.wiremock.client.WireMock.put;
23+
import static com.github.tomakehurst.wiremock.client.WireMock.putRequestedFor;
24+
import static com.github.tomakehurst.wiremock.client.WireMock.urlPathEqualTo;
25+
import static com.github.tomakehurst.wiremock.client.WireMock.verify;
26+
import static com.github.tomakehurst.wiremock.core.WireMockConfiguration.wireMockConfig;
27+
import static org.assertj.core.api.Assertions.assertThat;
28+
29+
import com.github.tomakehurst.wiremock.junit5.WireMockExtension;
30+
import com.github.tomakehurst.wiremock.junit5.WireMockTest;
31+
import org.junit.jupiter.api.AfterAll;
32+
import org.junit.jupiter.api.BeforeEach;
33+
import org.junit.jupiter.api.Test;
34+
import org.junit.jupiter.api.extension.RegisterExtension;
35+
import software.amazon.awssdk.core.SdkSystemSetting;
36+
import software.amazon.awssdk.testutils.EnvironmentVariableHelper;
37+
import software.amazon.awssdk.utils.DateUtils;
38+
39+
import java.time.Duration;
40+
import java.time.Instant;
41+
42+
/**
43+
* Tests verifying IMDS credential resolution with account ID support.
44+
*/
45+
@WireMockTest
46+
public class InstanceProfileCredentialsProviderAccountIDTest {
47+
private static final String TOKEN_RESOURCE_PATH = "/latest/api/token";
48+
private static final String CREDENTIALS_RESOURCE_PATH = "/latest/meta-data/iam/security-credentials/";
49+
private static final String CREDENTIALS_EXTENDED_RESOURCE_PATH = "/latest/meta-data/iam/security-credentials-extended/";
50+
private static final String TOKEN_HEADER = "x-aws-ec2-metadata-token";
51+
private static final String TOKEN_STUB = "some-token";
52+
private static final String PROFILE_NAME = "some-profile";
53+
private static final String EC2_METADATA_TOKEN_TTL_HEADER = "x-aws-ec2-metadata-token-ttl-seconds";
54+
private static final String ACCOUNT_ID = "123456789012";
55+
private static final EnvironmentVariableHelper environmentVariableHelper = new EnvironmentVariableHelper();
56+
57+
@RegisterExtension
58+
static WireMockExtension wireMockServer = WireMockExtension.newInstance()
59+
.options(wireMockConfig().dynamicPort())
60+
.configureStaticDsl(true)
61+
.build();
62+
63+
@BeforeEach
64+
public void methodSetup() {
65+
environmentVariableHelper.reset();
66+
System.setProperty(SdkSystemSetting.AWS_EC2_METADATA_SERVICE_ENDPOINT.property(),
67+
"http://localhost:" + wireMockServer.getPort());
68+
}
69+
70+
@AfterAll
71+
public static void teardown() {
72+
System.clearProperty(SdkSystemSetting.AWS_EC2_METADATA_SERVICE_ENDPOINT.property());
73+
environmentVariableHelper.reset();
74+
}
75+
76+
@Test
77+
void resolveCredentials_usesExtendedEndpoint_withAccountId() {
78+
String credentialsWithAccountId = String.format(
79+
"{\"AccessKeyId\":\"ACCESS_KEY_ID\"," +
80+
"\"SecretAccessKey\":\"SECRET_ACCESS_KEY\"," +
81+
"\"Token\":\"SESSION_TOKEN\"," +
82+
"\"Expiration\":\"%s\"," +
83+
"\"AccountId\":\"%s\"}",
84+
DateUtils.formatIso8601Date(Instant.now().plus(Duration.ofDays(1))),
85+
ACCOUNT_ID
86+
);
87+
88+
stubSecureCredentialsResponse(aResponse().withBody(credentialsWithAccountId), true);
89+
InstanceProfileCredentialsProvider provider = InstanceProfileCredentialsProvider.builder().build();
90+
AwsCredentials credentials = provider.resolveCredentials();
91+
92+
assertThat(credentials.accessKeyId()).isEqualTo("ACCESS_KEY_ID");
93+
assertThat(credentials.secretAccessKey()).isEqualTo("SECRET_ACCESS_KEY");
94+
assertThat(((AwsSessionCredentials)credentials).sessionToken()).isEqualTo("SESSION_TOKEN");
95+
assertThat(credentials.accountId()).hasValue(ACCOUNT_ID);
96+
verifyImdsCallWithToken(true);
97+
}
98+
99+
@Test
100+
void resolveCredentials_fallsBackToLegacy_noAccountId() {
101+
String credentialsWithoutAccountId = String.format(
102+
"{\"AccessKeyId\":\"ACCESS_KEY_ID\"," +
103+
"\"SecretAccessKey\":\"SECRET_ACCESS_KEY\"," +
104+
"\"Token\":\"SESSION_TOKEN\"," +
105+
"\"Expiration\":\"%s\"," +
106+
"\"Code\":\"Success\"}", // No AccountId field at all
107+
DateUtils.formatIso8601Date(Instant.now().plus(Duration.ofDays(1)))
108+
);
109+
110+
stubSecureCredentialsResponse(aResponse().withBody(credentialsWithoutAccountId), false);
111+
InstanceProfileCredentialsProvider provider = InstanceProfileCredentialsProvider.builder().build();
112+
AwsCredentials credentials = provider.resolveCredentials();
113+
114+
assertThat(credentials.accessKeyId()).isEqualTo("ACCESS_KEY_ID");
115+
assertThat(credentials.secretAccessKey()).isEqualTo("SECRET_ACCESS_KEY");
116+
assertThat(((AwsSessionCredentials)credentials).sessionToken()).isEqualTo("SESSION_TOKEN");
117+
verifyImdsCallWithToken(false);
118+
}
119+
120+
@Test
121+
void resolveCredentials_cachesProfile_maintainsAccountId() {
122+
String credentialsWithAccountId = String.format(
123+
"{\"AccessKeyId\":\"ACCESS_KEY_ID\"," +
124+
"\"SecretAccessKey\":\"SECRET_ACCESS_KEY\"," +
125+
"\"Token\":\"SESSION_TOKEN\"," +
126+
"\"Expiration\":\"%s\"," +
127+
"\"AccountId\":\"%s\"}",
128+
DateUtils.formatIso8601Date(Instant.now().plus(Duration.ofDays(1))),
129+
ACCOUNT_ID
130+
);
131+
132+
stubSecureCredentialsResponse(aResponse().withBody(credentialsWithAccountId), true);
133+
InstanceProfileCredentialsProvider provider = InstanceProfileCredentialsProvider.builder().build();
134+
135+
// First call
136+
AwsCredentials creds1 = provider.resolveCredentials();
137+
assertThat(creds1.accountId()).hasValue(ACCOUNT_ID);
138+
139+
// Second call - should use cached profile
140+
AwsCredentials creds2 = provider.resolveCredentials();
141+
assertThat(creds2.accountId()).hasValue(ACCOUNT_ID);
142+
143+
// Verify profile discovery only called once
144+
verify(1, getRequestedFor(urlPathEqualTo(CREDENTIALS_EXTENDED_RESOURCE_PATH)));
145+
}
146+
147+
private void stubSecureCredentialsResponse(com.github.tomakehurst.wiremock.client.ResponseDefinitionBuilder responseDefinitionBuilder, boolean useExtended) {
148+
wireMockServer.stubFor(put(urlPathEqualTo(TOKEN_RESOURCE_PATH)).willReturn(aResponse().withBody(TOKEN_STUB)));
149+
String path = useExtended ? CREDENTIALS_EXTENDED_RESOURCE_PATH : CREDENTIALS_RESOURCE_PATH;
150+
151+
if (useExtended) {
152+
wireMockServer.stubFor(get(urlPathEqualTo(path)).willReturn(aResponse().withBody(PROFILE_NAME)));
153+
wireMockServer.stubFor(get(urlPathEqualTo(path + PROFILE_NAME)).willReturn(responseDefinitionBuilder));
154+
} else {
155+
// Extended endpoint fails, fallback to legacy
156+
wireMockServer.stubFor(get(urlPathEqualTo(CREDENTIALS_EXTENDED_RESOURCE_PATH))
157+
.willReturn(aResponse().withStatus(404)));
158+
wireMockServer.stubFor(get(urlPathEqualTo(CREDENTIALS_EXTENDED_RESOURCE_PATH + PROFILE_NAME))
159+
.willReturn(aResponse().withStatus(404)));
160+
wireMockServer.stubFor(get(urlPathEqualTo(CREDENTIALS_RESOURCE_PATH)).willReturn(aResponse().withBody(PROFILE_NAME)));
161+
wireMockServer.stubFor(get(urlPathEqualTo(CREDENTIALS_RESOURCE_PATH + PROFILE_NAME)).willReturn(responseDefinitionBuilder));
162+
}
163+
}
164+
165+
private void verifyImdsCallWithToken(boolean useExtended) {
166+
verify(putRequestedFor(urlPathEqualTo(TOKEN_RESOURCE_PATH))
167+
.withHeader(EC2_METADATA_TOKEN_TTL_HEADER, equalTo("21600")));
168+
169+
String path = useExtended ? CREDENTIALS_EXTENDED_RESOURCE_PATH : CREDENTIALS_RESOURCE_PATH;
170+
verify(getRequestedFor(urlPathEqualTo(path))
171+
.withHeader(TOKEN_HEADER, equalTo(TOKEN_STUB)));
172+
verify(getRequestedFor(urlPathEqualTo(path + PROFILE_NAME))
173+
.withHeader(TOKEN_HEADER, equalTo(TOKEN_STUB)));
174+
175+
if (useExtended) {
176+
// Verify extended endpoint was tried first
177+
verify(getRequestedFor(urlPathEqualTo(CREDENTIALS_EXTENDED_RESOURCE_PATH)));
178+
verify(getRequestedFor(urlPathEqualTo(CREDENTIALS_EXTENDED_RESOURCE_PATH + PROFILE_NAME)));
179+
}
180+
}
181+
}

0 commit comments

Comments
 (0)