Skip to content
6 changes: 6 additions & 0 deletions .changes/next-release/feature-AWSEC2-9b178a4.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
{
"type": "feature",
"category": "AWS SDK for Java v2",
"contributor": "",
"description": "Include the account ID associated with the credentials retrieved from IMDS when available."
}
5 changes: 5 additions & 0 deletions core/auth/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,11 @@
<artifactId>regions</artifactId>
<version>${awsjavasdk.version}</version>
</dependency>
<dependency>
<groupId>software.amazon.awssdk</groupId>
<artifactId>imds</artifactId>
<version>${awsjavasdk.version}</version>
</dependency>
<dependency>
<groupId>software.amazon.awssdk</groupId>
<artifactId>profiles</artifactId>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
import software.amazon.awssdk.core.SdkSystemSetting;
import software.amazon.awssdk.core.exception.SdkClientException;
import software.amazon.awssdk.core.exception.SdkServiceException;
import software.amazon.awssdk.imds.Ec2MetadataClientException;
import software.amazon.awssdk.profiles.ProfileFile;
import software.amazon.awssdk.profiles.ProfileFileSupplier;
import software.amazon.awssdk.profiles.ProfileFileSystemSetting;
Expand Down Expand Up @@ -70,9 +71,24 @@ public final class InstanceProfileCredentialsProvider
private static final String PROVIDER_NAME = "InstanceProfileCredentialsProvider";
private static final String EC2_METADATA_TOKEN_HEADER = "x-aws-ec2-metadata-token";
private static final String SECURITY_CREDENTIALS_RESOURCE = "/latest/meta-data/iam/security-credentials/";
private static final String SECURITY_CREDENTIALS_EXTENDED_RESOURCE = "/latest/meta-data/iam/security-credentials-extended/";
private static final String TOKEN_RESOURCE = "/latest/api/token";
private static final String FAILED_TO_LOAD_CREDENTIALS_ERROR = "Failed to load credentials from IMDS.";

private enum ApiVersion {
UNKNOWN,
LEGACY,
EXTENDED
}

private static final String EC2_METADATA_TOKEN_TTL_HEADER = "x-aws-ec2-metadata-token-ttl-seconds";
private static final String DEFAULT_TOKEN_TTL = "21600";
private static final int MAX_PROFILE_RETRIES = 1;

// These fields are accessed from methods called by CachedSupplier which provides thread safety through its ReentrantLock
private ApiVersion apiVersion = ApiVersion.UNKNOWN;
private String resolvedProfile = null;
private int profileRetryCount = 0;

private final Clock clock;
private final String endpoint;
Expand Down Expand Up @@ -160,12 +176,33 @@ private RefreshResult<AwsCredentials> refreshCredentials() {
Instant expiration = credentials.getExpiration().orElse(null);
log.debug(() -> "Loaded credentials from IMDS with expiration time of " + expiration);

// Reset profile retry count after successful credential fetch
profileRetryCount = 0;

return RefreshResult.builder(credentials.getAwsCredentials())
.staleTime(staleTime(expiration))
.prefetchTime(prefetchTime(expiration))
.build();
} catch (Ec2MetadataClientException e) {
if (e.statusCode() == 404) {
log.debug(() -> "Resolved profile is no longer available. Resetting it and trying again.");
resolvedProfile = null;

if (apiVersion == ApiVersion.UNKNOWN) {
apiVersion = ApiVersion.LEGACY;
return refreshCredentials();
}

profileRetryCount++;
if (profileRetryCount <= MAX_PROFILE_RETRIES) {
log.debug(() -> "Profile name not found, retrying fetching the profile name again.");
return refreshCredentials();
}
throw SdkClientException.create(FAILED_TO_LOAD_CREDENTIALS_ERROR, e);
}
throw SdkClientException.create(FAILED_TO_LOAD_CREDENTIALS_ERROR, e);
} catch (RuntimeException e) {
throw SdkClientException.create("Failed to load credentials from IMDS.", e);
throw SdkClientException.create(FAILED_TO_LOAD_CREDENTIALS_ERROR, e);
}
}

Expand Down Expand Up @@ -207,14 +244,20 @@ public String toString() {
return ToString.create(PROVIDER_NAME);
}

private String getSecurityCredentialsResource() {
return apiVersion == ApiVersion.LEGACY ?
SECURITY_CREDENTIALS_RESOURCE :
SECURITY_CREDENTIALS_EXTENDED_RESOURCE;
}

private ResourcesEndpointProvider createEndpointProvider() {
String imdsHostname = getImdsEndpoint();
String token = getToken(imdsHostname);
String[] securityCredentials = getSecurityCredentials(imdsHostname, token);

String urlBase = getSecurityCredentialsResource();

return StaticResourcesEndpointProvider.builder()
.endpoint(URI.create(imdsHostname + SECURITY_CREDENTIALS_RESOURCE
+ securityCredentials[0]))
.endpoint(URI.create(imdsHostname + urlBase + securityCredentials[0]))
.headers(getTokenHeaders(token))
.connectionTimeout(Duration.ofMillis(
this.configProvider.serviceTimeout()))
Expand Down Expand Up @@ -285,21 +328,41 @@ private boolean isInsecureFallbackDisabled() {
}

private String[] getSecurityCredentials(String imdsHostname, String metadataToken) {
if (resolvedProfile != null) {
return new String[]{resolvedProfile};
}

String urlBase = getSecurityCredentialsResource();
ResourcesEndpointProvider securityCredentialsEndpoint =
StaticResourcesEndpointProvider.builder()
.endpoint(URI.create(imdsHostname + SECURITY_CREDENTIALS_RESOURCE))
.endpoint(URI.create(imdsHostname + urlBase))
.headers(getTokenHeaders(metadataToken))
.connectionTimeout(Duration.ofMillis(this.configProvider.serviceTimeout()))
.connectionTimeout(Duration.ofMillis(this.configProvider.serviceTimeout()))
.build();

String securityCredentialsList =
invokeSafely(() -> HttpResourcesUtils.instance().readResource(securityCredentialsEndpoint));
String[] securityCredentials = securityCredentialsList.trim().split("\n");
try {
String securityCredentialsList =
invokeSafely(() -> HttpResourcesUtils.instance().readResource(securityCredentialsEndpoint));
String[] securityCredentials = securityCredentialsList.trim().split("\n");

if (securityCredentials.length == 0) {
throw SdkClientException.builder().message("Unable to load credentials path").build();
if (securityCredentials.length == 0) {
throw SdkClientException.builder().message("Unable to load credentials path").build();
}

if (apiVersion == ApiVersion.UNKNOWN) {
apiVersion = ApiVersion.EXTENDED;
}
resolvedProfile = securityCredentials[0];
return securityCredentials;

} catch (Ec2MetadataClientException e) {
if (e.statusCode() == 404 && apiVersion == ApiVersion.UNKNOWN) {
apiVersion = ApiVersion.LEGACY;
log.debug(() -> "Instance does not support IMDS extended API. Falling back to legacy API.");
return getSecurityCredentials(imdsHostname, metadataToken);
}
throw SdkClientException.create(FAILED_TO_LOAD_CREDENTIALS_ERROR, e);
}
return securityCredentials;
}

private Map<String, String> getTokenHeaders(String metadataToken) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ public LoadedCredentials loadCredentials(ResourcesEndpointProvider endpoint) {
JsonNode secretKey = node.get("SecretAccessKey");
JsonNode token = node.get("Token");
JsonNode expiration = node.get("Expiration");
JsonNode accountId = node.get("AccountId");

Validate.notNull(accessKey, "Failed to load access key from metadata service.");
Validate.notNull(secretKey, "Failed to load secret key from metadata service.");
Expand All @@ -72,6 +73,7 @@ public LoadedCredentials loadCredentials(ResourcesEndpointProvider endpoint) {
secretKey.text(),
token != null ? token.text() : null,
expiration != null ? expiration.text() : null,
accountId != null ? accountId.text() : null,
providerName);
} catch (SdkClientException e) {
throw e;
Expand All @@ -89,12 +91,15 @@ public static final class LoadedCredentials {
private final String token;
private final Instant expiration;
private final String providerName;
private final String accountId;

private LoadedCredentials(String accessKeyId, String secretKey, String token, String expiration, String providerName) {
private LoadedCredentials(String accessKeyId, String secretKey, String token,
String expiration, String accountId, String providerName) {
this.accessKeyId = Validate.paramNotBlank(accessKeyId, "accessKeyId");
this.secretKey = Validate.paramNotBlank(secretKey, "secretKey");
this.token = token;
this.expiration = expiration == null ? null : parseExpiration(expiration);
this.accountId = accountId;
this.providerName = providerName;
}

Expand All @@ -105,11 +110,13 @@ public AwsCredentials getAwsCredentials() {
.secretAccessKey(secretKey)
.sessionToken(token)
.providerName(providerName)
.accountId(accountId)
.build() :
AwsBasicCredentials.builder()
.accessKeyId(accessKeyId)
.secretAccessKey(secretKey)
.providerName(providerName)
.accountId(accountId)
.build();
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ public class EC2MetadataServiceMock {
"Content-Type: text/html\r\n" +
"Content-Length: ";
private static final String OUTPUT_END_OF_HEADERS = "\r\n\r\n";
private static final String EXTENDED_PATH = "/latest/meta-data/iam/security-credentials-extended/";
private final String securityCredentialsResource;
private EC2MockMetadataServiceListenerThread hosmMockServerThread;

Expand Down Expand Up @@ -140,6 +141,15 @@ public void run() {
String[] strings = requestLine.split(" ");
String resourcePath = strings[1];

// Return 404 for extended path when in legacy mode
if (!credentialsResource.equals(EXTENDED_PATH) &&
(resourcePath.equals(EXTENDED_PATH) || resourcePath.startsWith(EXTENDED_PATH))) {
String notFound = "HTTP/1.1 404 Not Found\r\n" +
"Content-Length: 0\r\n" +
"\r\n";
outputStream.write(notFound.getBytes());
continue;
}

String httpResponse = null;

Expand Down
Loading
Loading