Skip to content

Commit 84468f6

Browse files
authored
feat(rt): support static stability for IMDS credentials (#719)
1 parent fccaffa commit 84468f6

File tree

6 files changed

+488
-21
lines changed

6 files changed

+488
-21
lines changed
Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
{
2+
"id": "8c2a1262-90c7-49a1-8c0a-29b269ef0f19",
3+
"type": "feature",
4+
"description": "Support static stability for IMDS credentials",
5+
"issues": ["#707"]
6+
}

aws-runtime/aws-config/common/src/aws/sdk/kotlin/runtime/auth/credentials/CachedCredentialsProvider.kt

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,8 @@ public class CachedCredentialsProvider(
5858
logger.trace { "refreshing credentials cache" }
5959
val providerCreds = source.getCredentials()
6060
if (providerCreds.expiration != null) {
61-
ExpiringValue(providerCreds, providerCreds.expiration!!)
61+
val expiration = minOf(providerCreds.expiration!!, (clock.now() + expireCredentialsAfter))
62+
ExpiringValue(providerCreds, expiration)
6263
} else {
6364
val expiration = clock.now() + expireCredentialsAfter
6465
val creds = providerCreds.copy(expiration = expiration)

aws-runtime/aws-config/common/src/aws/sdk/kotlin/runtime/auth/credentials/ImdsCredentialsProvider.kt

Lines changed: 62 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -16,14 +16,21 @@ import aws.smithy.kotlin.runtime.http.HttpStatusCode
1616
import aws.smithy.kotlin.runtime.io.Closeable
1717
import aws.smithy.kotlin.runtime.logging.Logger
1818
import aws.smithy.kotlin.runtime.serde.json.JsonDeserializer
19+
import aws.smithy.kotlin.runtime.time.Clock
20+
import aws.smithy.kotlin.runtime.time.Instant
1921
import aws.smithy.kotlin.runtime.util.Platform
2022
import aws.smithy.kotlin.runtime.util.PlatformEnvironProvider
2123
import aws.smithy.kotlin.runtime.util.asyncLazy
24+
import kotlinx.coroutines.sync.Mutex
25+
import kotlinx.coroutines.sync.withLock
26+
import kotlin.time.Duration.Companion.seconds
2227

2328
private const val CREDENTIALS_BASE_PATH: String = "/latest/meta-data/iam/security-credentials"
2429
private const val CODE_ASSUME_ROLE_UNAUTHORIZED_ACCESS: String = "AssumeRoleUnauthorizedAccess"
2530
private const val PROVIDER_NAME = "IMDSv2"
2631

32+
internal expect class SdkIOException : Exception // FIXME move this to the proper place when we do the larger KMP Exception refactor
33+
2734
/**
2835
* [CredentialsProvider] that uses EC2 instance metadata service (IMDS) to provide credentials information.
2936
* This provider requires that the EC2 instance has an [instance profile](https://docs.aws.amazon.com/AWSEC2/latest/UserGuide/iam-roles-for-amazon-ec2.html#ec2-instance-profile)
@@ -43,8 +50,16 @@ public class ImdsCredentialsProvider(
4350
private val profileOverride: String? = null,
4451
private val client: Lazy<InstanceMetadataProvider> = lazy { ImdsClient() },
4552
private val platformProvider: PlatformEnvironProvider = Platform,
53+
private val clock: Clock = Clock.System,
4654
) : CredentialsProvider, Closeable {
4755
private val logger = Logger.getLogger<ImdsCredentialsProvider>()
56+
private var previousCredentials: Credentials? = null
57+
58+
// the time to refresh the Credentials. If set, it will take precedence over the Credentials' expiration time
59+
private var nextRefresh: Instant? = null
60+
61+
// protects previousCredentials and nextRefresh
62+
private val mu = Mutex()
4863

4964
private val profile = asyncLazy {
5065
if (profileOverride != null) return@asyncLazy profileOverride
@@ -56,23 +71,52 @@ public class ImdsCredentialsProvider(
5671
throw CredentialsNotLoadedException("AWS EC2 metadata is explicitly disabled; credentials not loaded")
5772
}
5873

74+
// if we have previously served IMDS credentials and it's not time for a refresh, just return the previous credentials
75+
mu.withLock {
76+
previousCredentials?.run {
77+
nextRefresh?.takeIf { clock.now() < it }?.run {
78+
return previousCredentials!!
79+
}
80+
}
81+
}
82+
5983
val profileName = try {
6084
profile.get()
6185
} catch (ex: Exception) {
62-
throw CredentialsProviderException("failed to load instance profile", ex)
86+
return useCachedCredentials(ex) ?: throw CredentialsProviderException("failed to load instance profile", ex)
87+
}
88+
89+
val payload = try {
90+
client.value.get("$CREDENTIALS_BASE_PATH/$profileName")
91+
} catch (ex: Exception) {
92+
return useCachedCredentials(ex) ?: throw CredentialsProviderException("failed to load credentials", ex)
6393
}
6494

65-
val payload = client.value.get("$CREDENTIALS_BASE_PATH/$profileName")
6695
val deserializer = JsonDeserializer(payload.encodeToByteArray())
6796

6897
return when (val resp = deserializeJsonCredentials(deserializer)) {
69-
is JsonCredentialsResponse.SessionCredentials -> Credentials(
70-
resp.accessKeyId,
71-
resp.secretAccessKey,
72-
resp.sessionToken,
73-
resp.expiration,
74-
PROVIDER_NAME,
75-
)
98+
is JsonCredentialsResponse.SessionCredentials -> {
99+
nextRefresh = if (resp.expiration < clock.now()) {
100+
logger.warn {
101+
"Attempting credential expiration extension due to a credential service availability issue. " +
102+
"A refresh of these credentials will be attempted again in " +
103+
"${ DEFAULT_CREDENTIALS_REFRESH_SECONDS / 60 } minutes."
104+
}
105+
clock.now() + DEFAULT_CREDENTIALS_REFRESH_SECONDS.seconds
106+
} else null
107+
108+
val creds = Credentials(
109+
resp.accessKeyId,
110+
resp.secretAccessKey,
111+
resp.sessionToken,
112+
resp.expiration,
113+
PROVIDER_NAME,
114+
)
115+
116+
creds.also {
117+
mu.withLock { previousCredentials = it }
118+
}
119+
}
76120
is JsonCredentialsResponse.Error -> {
77121
when (resp.code) {
78122
CODE_ASSUME_ROLE_UNAUTHORIZED_ACCESS -> throw ProviderConfigurationException("Incorrect IMDS/IAM configuration: [${resp.code}] ${resp.message}. Hint: Does this role have a trust relationship with EC2?")
@@ -98,4 +142,13 @@ public class ImdsCredentialsProvider(
98142
throw ex
99143
}
100144
}
145+
146+
private suspend fun useCachedCredentials(ex: Exception): Credentials? = when {
147+
ex is SdkIOException || ex is EC2MetadataError && ex.statusCode == HttpStatusCode.InternalServerError.value -> {
148+
mu.withLock {
149+
previousCredentials?.apply { nextRefresh = clock.now() + DEFAULT_CREDENTIALS_REFRESH_SECONDS.seconds }
150+
}
151+
}
152+
else -> null
153+
}
101154
}

aws-runtime/aws-config/common/test/aws/sdk/kotlin/runtime/auth/credentials/CachedCredentialsProviderTest.kt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,7 @@ class CachedCredentialsProviderTest {
8383
@Test
8484
fun testRefreshBufferWindow() = runTest {
8585
val source = TestCredentialsProvider(expiration = testExpiration)
86-
val provider = CachedCredentialsProvider(source, clock = testClock)
86+
val provider = CachedCredentialsProvider(source, clock = testClock, expireCredentialsAfter = 60.minutes)
8787
val creds = provider.getCredentials()
8888
val expected = Credentials("AKID", "secret", expiration = testExpiration)
8989
assertEquals(expected, creds)

0 commit comments

Comments
 (0)