Skip to content

Commit 0505e35

Browse files
authored
feat(rt): ec2 imds region provider (#341)
1 parent cf81418 commit 0505e35

File tree

8 files changed

+167
-39
lines changed

8 files changed

+167
-39
lines changed

aws-runtime/aws-config/common/src/aws/sdk/kotlin/runtime/config/imds/ImdsClient.kt

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ public class ImdsClient private constructor(builder: Builder) : Closeable {
5353

5454
private val maxRetries: UInt = builder.maxRetries
5555
private val endpointConfiguration: EndpointConfiguration = builder.endpointConfiguration
56-
private val tokenTtl: Duration = builder.tokenTTL
56+
private val tokenTtl: Duration = builder.tokenTtl
5757
private val clock: Clock = builder.clock
5858
private val platformProvider: PlatformProvider = builder.platformProvider
5959
private val httpClient: SdkHttpClient
@@ -160,7 +160,7 @@ public class ImdsClient private constructor(builder: Builder) : Closeable {
160160
/**
161161
* Override the time-to-live for the session token
162162
*/
163-
public var tokenTTL: Duration = Duration.seconds(DEFAULT_TOKEN_TTL_SECONDS)
163+
public var tokenTtl: Duration = Duration.seconds(DEFAULT_TOKEN_TTL_SECONDS)
164164

165165
/**
166166
* The HTTP engine to use to make requests with. This is here to facilitate testing and can otherwise be ignored

aws-runtime/aws-config/common/src/aws/sdk/kotlin/runtime/region/DefaultRegionProviderChain.kt

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,13 @@
55

66
package aws.sdk.kotlin.runtime.region
77

8+
import aws.smithy.kotlin.runtime.io.Closeable
9+
810
/**
911
* [RegionProvider] that looks for region in this order:
1012
* 1. Check `aws.region` system property (JVM only)
1113
* 2. Check the `AWS_REGION` environment variable (JVM, Node, Native)
1214
* 3. Check the AWS config files/profile for region information
1315
* 4. If running on EC2, check the EC2 metadata service for region
1416
*/
15-
public expect class DefaultRegionProviderChain public constructor() : RegionProvider
17+
public expect class DefaultRegionProviderChain public constructor() : RegionProvider, Closeable
Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
/*
2+
* Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
3+
* SPDX-License-Identifier: Apache-2.0.
4+
*/
5+
6+
package aws.sdk.kotlin.runtime.region
7+
8+
import aws.sdk.kotlin.runtime.config.AwsSdkSetting
9+
import aws.sdk.kotlin.runtime.config.imds.ImdsClient
10+
import aws.sdk.kotlin.runtime.config.resolve
11+
import aws.smithy.kotlin.runtime.io.Closeable
12+
import aws.smithy.kotlin.runtime.util.Platform
13+
import aws.smithy.kotlin.runtime.util.PlatformEnvironProvider
14+
import aws.smithy.kotlin.runtime.util.asyncLazy
15+
16+
private const val REGION_PATH: String = "/latest/meta-data/placement/region"
17+
18+
/**
19+
* [RegionProvider] that uses EC2 instance metadata service (IMDS) to provider region information
20+
*
21+
* @param client the IMDS client to use to resolve region information with
22+
* @param platformProvider the [PlatformEnvironProvider] instance
23+
*/
24+
public class ImdsRegionProvider(
25+
private val client: Lazy<ImdsClient> = lazy { ImdsClient() },
26+
private val platformProvider: PlatformEnvironProvider = Platform,
27+
) : RegionProvider, Closeable {
28+
private val resolvedRegion = asyncLazy(::loadRegion)
29+
30+
override suspend fun getRegion(): String? {
31+
if (AwsSdkSetting.AwsEc2MetadataDisabled.resolve(platformProvider) == true) {
32+
return null
33+
}
34+
35+
return resolvedRegion.get()
36+
}
37+
38+
private suspend fun loadRegion(): String = client.value.get(REGION_PATH)
39+
40+
override fun close() {
41+
if (client.isInitialized()) {
42+
client.value.close()
43+
}
44+
}
45+
}

aws-runtime/aws-config/common/test/aws/sdk/kotlin/runtime/config/imds/ImdsClientTest.kt

Lines changed: 2 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -10,10 +10,6 @@ import aws.sdk.kotlin.runtime.endpoint.Endpoint
1010
import aws.sdk.kotlin.runtime.testing.TestPlatformProvider
1111
import aws.sdk.kotlin.runtime.testing.runSuspendTest
1212
import aws.smithy.kotlin.runtime.http.*
13-
import aws.smithy.kotlin.runtime.http.content.ByteArrayContent
14-
import aws.smithy.kotlin.runtime.http.request.HttpRequest
15-
import aws.smithy.kotlin.runtime.http.request.url
16-
import aws.smithy.kotlin.runtime.http.response.HttpResponse
1713
import aws.smithy.kotlin.runtime.httptest.TestConnection
1814
import aws.smithy.kotlin.runtime.httptest.buildTestConnection
1915
import aws.smithy.kotlin.runtime.time.Instant
@@ -29,33 +25,6 @@ import kotlin.time.ExperimentalTime
2925
@OptIn(ExperimentalTime::class)
3026
class ImdsClientTest {
3127

32-
private fun tokenRequest(host: String, ttl: Int): HttpRequest = HttpRequest {
33-
val parsed = Url.parse(host)
34-
url(parsed)
35-
url.path = "/latest/api/token"
36-
headers.append(X_AWS_EC2_METADATA_TOKEN_TTL_SECONDS, ttl.toString())
37-
}
38-
39-
private fun tokenResponse(ttl: Int, token: String): HttpResponse = HttpResponse(
40-
HttpStatusCode.OK,
41-
Headers {
42-
append(X_AWS_EC2_METADATA_TOKEN_TTL_SECONDS, ttl.toString())
43-
},
44-
ByteArrayContent(token.encodeToByteArray())
45-
)
46-
47-
private fun imdsRequest(url: String, token: String): HttpRequest = HttpRequest {
48-
val parsed = Url.parse(url)
49-
url(parsed)
50-
headers.append(X_AWS_EC2_METADATA_TOKEN, token)
51-
}
52-
53-
private fun imdsResponse(body: String): HttpResponse = HttpResponse(
54-
HttpStatusCode.OK,
55-
Headers.Empty,
56-
ByteArrayContent(body.encodeToByteArray())
57-
)
58-
5928
@Test
6029
fun testInvalidEndpointOverrideFailsCreation() {
6130
val connection = TestConnection()
@@ -122,7 +91,7 @@ class ImdsClientTest {
12291
engine = connection
12392
endpointConfiguration = EndpointConfiguration.ModeOverride(EndpointMode.IPv6)
12493
clock = testClock
125-
tokenTTL = Duration.seconds(600)
94+
tokenTtl = Duration.seconds(600)
12695
}
12796

12897
val r1 = client.get("/latest/metadata")
@@ -169,7 +138,7 @@ class ImdsClientTest {
169138
engine = connection
170139
endpointConfiguration = EndpointConfiguration.ModeOverride(EndpointMode.IPv6)
171140
clock = testClock
172-
tokenTTL = Duration.seconds(600)
141+
tokenTtl = Duration.seconds(600)
173142
}
174143

175144
val r1 = client.get("/latest/metadata")
Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
/*
2+
* Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
3+
* SPDX-License-Identifier: Apache-2.0.
4+
*/
5+
6+
package aws.sdk.kotlin.runtime.config.imds
7+
8+
import aws.smithy.kotlin.runtime.http.Headers
9+
import aws.smithy.kotlin.runtime.http.HttpStatusCode
10+
import aws.smithy.kotlin.runtime.http.Url
11+
import aws.smithy.kotlin.runtime.http.content.ByteArrayContent
12+
import aws.smithy.kotlin.runtime.http.request.HttpRequest
13+
import aws.smithy.kotlin.runtime.http.request.url
14+
import aws.smithy.kotlin.runtime.http.response.HttpResponse
15+
16+
fun tokenRequest(host: String, ttl: Int): HttpRequest = HttpRequest {
17+
val parsed = Url.parse(host)
18+
url(parsed)
19+
url.path = "/latest/api/token"
20+
headers.append(X_AWS_EC2_METADATA_TOKEN_TTL_SECONDS, ttl.toString())
21+
}
22+
23+
fun tokenResponse(ttl: Int, token: String): HttpResponse = HttpResponse(
24+
HttpStatusCode.OK,
25+
Headers {
26+
append(X_AWS_EC2_METADATA_TOKEN_TTL_SECONDS, ttl.toString())
27+
},
28+
ByteArrayContent(token.encodeToByteArray())
29+
)
30+
31+
fun imdsRequest(url: String, token: String): HttpRequest = HttpRequest {
32+
val parsed = Url.parse(url)
33+
url(parsed)
34+
headers.append(X_AWS_EC2_METADATA_TOKEN, token)
35+
}
36+
37+
fun imdsResponse(body: String): HttpResponse = HttpResponse(
38+
HttpStatusCode.OK,
39+
Headers.Empty,
40+
ByteArrayContent(body.encodeToByteArray())
41+
)
Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
/*
2+
* Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
3+
* SPDX-License-Identifier: Apache-2.0.
4+
*/
5+
6+
package aws.sdk.kotlin.runtime.region
7+
8+
import aws.sdk.kotlin.runtime.config.AwsSdkSetting
9+
import aws.sdk.kotlin.runtime.config.imds.*
10+
import aws.sdk.kotlin.runtime.testing.TestPlatformProvider
11+
import aws.sdk.kotlin.runtime.testing.runSuspendTest
12+
import aws.smithy.kotlin.runtime.httptest.buildTestConnection
13+
import aws.smithy.kotlin.runtime.time.ManualClock
14+
import kotlin.test.Test
15+
import kotlin.test.assertEquals
16+
import kotlin.test.assertNull
17+
import kotlin.time.ExperimentalTime
18+
19+
class ImdsRegionProviderTest {
20+
21+
@Test
22+
fun testImdsDisabled() = runSuspendTest {
23+
val platform = TestPlatformProvider(
24+
env = mapOf(AwsSdkSetting.AwsEc2MetadataDisabled.environmentVariable to "true")
25+
)
26+
27+
val provider = ImdsRegionProvider(platformProvider = platform)
28+
assertNull(provider.getRegion())
29+
}
30+
31+
@OptIn(ExperimentalTime::class)
32+
@Test
33+
fun testResolveRegion() = runSuspendTest {
34+
35+
val connection = buildTestConnection {
36+
expect(
37+
tokenRequest("http://169.254.169.254", DEFAULT_TOKEN_TTL_SECONDS),
38+
tokenResponse(DEFAULT_TOKEN_TTL_SECONDS, "TOKEN_A")
39+
)
40+
expect(
41+
imdsRequest("http://169.254.169.254/latest/meta-data/placement/region", "TOKEN_A"),
42+
imdsResponse("us-east-2")
43+
)
44+
}
45+
46+
val testClock = ManualClock()
47+
48+
val client = ImdsClient {
49+
engine = connection
50+
clock = testClock
51+
}
52+
53+
val provider = ImdsRegionProvider(client = lazyOf(client))
54+
assertEquals("us-east-2", provider.getRegion())
55+
connection.assertRequests()
56+
57+
// test that it's cached, test connection would fail if it tries again
58+
assertEquals("us-east-2", provider.getRegion())
59+
}
60+
}

aws-runtime/aws-config/jvm/src/aws/sdk/kotlin/runtime/region/DefaultRegionProviderChainJVM.kt

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,20 @@
55

66
package aws.sdk.kotlin.runtime.region
77

8+
import aws.smithy.kotlin.runtime.io.Closeable
9+
810
public actual class DefaultRegionProviderChain public actual constructor() :
911
RegionProvider,
12+
Closeable,
1013
RegionProviderChain(
1114
JvmSystemPropRegionProvider(),
12-
EnvironmentRegionProvider()
13-
)
15+
EnvironmentRegionProvider(),
16+
// TODO - profile
17+
ImdsRegionProvider()
18+
) {
19+
override fun close() {
20+
providers.forEach { provider ->
21+
if (provider is Closeable) provider.close()
22+
}
23+
}
24+
}

aws-runtime/aws-types/common/src/aws/sdk/kotlin/runtime/region/RegionProviderChain.kt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ import aws.smithy.kotlin.runtime.logging.Logger
1414
* @param providers the list of providers to delegate to
1515
*/
1616
public open class RegionProviderChain(
17-
private vararg val providers: RegionProvider
17+
protected vararg val providers: RegionProvider
1818
) : RegionProvider {
1919
private val logger = Logger.getLogger<RegionProviderChain>()
2020

0 commit comments

Comments
 (0)