Skip to content

Commit adb9d72

Browse files
committed
Abstract AuthTokenGenerator and implement RDS AuthTokenGenerator
1 parent b2624ea commit adb9d72

File tree

3 files changed

+94
-26
lines changed

3 files changed

+94
-26
lines changed

services/dsql/common/src/aws/sdk/kotlin/services/dsql/AuthTokenGenerator.kt

Lines changed: 4 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -6,16 +6,11 @@ package aws.sdk.kotlin.services.dsql
66

77
import aws.sdk.kotlin.runtime.auth.credentials.DefaultChainCredentialsProvider
88
import aws.smithy.kotlin.runtime.auth.awscredentials.Credentials
9-
import aws.smithy.kotlin.runtime.auth.awssigning.AwsSignatureType
10-
import aws.smithy.kotlin.runtime.auth.awssigning.AwsSigningConfig
11-
import aws.smithy.kotlin.runtime.auth.awssigning.DefaultAwsSigner
12-
import aws.smithy.kotlin.runtime.http.HttpMethod
13-
import aws.smithy.kotlin.runtime.http.request.HttpRequest
149
import aws.smithy.kotlin.runtime.net.url.Url
15-
import aws.smithy.kotlin.runtime.time.Clock
1610
import kotlinx.coroutines.runBlocking
1711
import kotlin.time.Duration
1812
import kotlin.time.Duration.Companion.seconds
13+
import aws.sdk.kotlin.runtime.auth.AuthTokenGenerator
1914

2015
/**
2116
* Generates an IAM authentication token for use with DSQL databases
@@ -24,7 +19,7 @@ import kotlin.time.Duration.Companion.seconds
2419
public class AuthTokenGenerator(
2520
public val credentials: Credentials? = runBlocking { DefaultChainCredentialsProvider().resolve() }
2621
) {
27-
private fun String.trimScheme() = removePrefix("http://").removePrefix("https://")
22+
private val generator = AuthTokenGenerator("dsql", credentials)
2823

2924
/**
3025
* Generates an auth token for the DbConnect action.
@@ -41,7 +36,7 @@ public class AuthTokenGenerator(
4136
}
4237
}.build()
4338

44-
return generateAuthToken(dbConnectEndpoint, region, expiration)
39+
return generator.generateAuthToken(dbConnectEndpoint, region, expiration)
4540
}
4641

4742
/**
@@ -59,23 +54,6 @@ public class AuthTokenGenerator(
5954
}
6055
}.build()
6156

62-
return generateAuthToken(dbConnectAdminEndpoint, region, expiration)
63-
}
64-
65-
private suspend fun generateAuthToken(endpoint: Url, region: String, expiration: Duration): String {
66-
val req = HttpRequest(HttpMethod.GET, endpoint)
67-
68-
val creds = credentials
69-
70-
val config = AwsSigningConfig {
71-
credentials = creds ?: DefaultChainCredentialsProvider().resolve()
72-
this.region = region
73-
service = "dsql"
74-
signingDate = Clock.System.now()
75-
expiresAfter = expiration
76-
signatureType = AwsSignatureType.HTTP_REQUEST_VIA_QUERY_PARAMS
77-
}
78-
79-
return DefaultAwsSigner.sign(req, config).output.url.toString().trimScheme()
57+
return generator.generateAuthToken(dbConnectAdminEndpoint, region, expiration)
8058
}
8159
}
Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
/*
2+
* Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
3+
* SPDX-License-Identifier: Apache-2.0
4+
*/
5+
package aws.sdk.kotlin.services.rds
6+
7+
import aws.sdk.kotlin.runtime.auth.AuthTokenGenerator
8+
import aws.sdk.kotlin.runtime.auth.credentials.DefaultChainCredentialsProvider
9+
import aws.smithy.kotlin.runtime.auth.awscredentials.Credentials
10+
import aws.smithy.kotlin.runtime.net.url.Url
11+
import kotlinx.coroutines.runBlocking
12+
import kotlin.apply
13+
import kotlin.time.Duration
14+
import kotlin.time.Duration.Companion.seconds
15+
16+
/**
17+
* Generates an IAM authentication token for use with RDS databases
18+
* @param credentials The credentials to use when generating the auth token, defaults to resolving credentials from the [DefaultChainCredentialsProvider]
19+
*/
20+
public class AuthTokenGenerator(
21+
public val credentials: Credentials? = runBlocking { DefaultChainCredentialsProvider().resolve() }
22+
) {
23+
private val generator = AuthTokenGenerator("rds-db", credentials)
24+
25+
/**
26+
* Generates an auth token for the DbConnect action.
27+
* @param endpoint the endpoint of the database
28+
* @param region the region of the database
29+
* @param expiration how long the auth token should be valid for. Defaults to 900.seconds
30+
*/
31+
public suspend fun generateAuthToken(endpoint: Url, region: String, username: String, expiration: Duration = 900.seconds): String {
32+
val dbConnectEndpoint = endpoint.toBuilder().apply {
33+
parameters.apply {
34+
decodedParameters {
35+
add("Action", "connect")
36+
add("DBUser", username)
37+
}
38+
}
39+
}.build()
40+
41+
return generator.generateAuthToken(dbConnectEndpoint, region, expiration)
42+
}
43+
}
Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
/*
2+
* Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
3+
* SPDX-License-Identifier: Apache-2.0
4+
*/
5+
package aws.sdk.kotlin.services.rds
6+
7+
import aws.smithy.kotlin.runtime.auth.awscredentials.Credentials
8+
import aws.smithy.kotlin.runtime.net.Host
9+
import aws.smithy.kotlin.runtime.net.url.Url
10+
import kotlinx.coroutines.test.runTest
11+
import kotlin.test.Test
12+
import kotlin.test.assertContains
13+
import kotlin.test.assertFalse
14+
import kotlin.test.assertTrue
15+
import kotlin.time.Duration.Companion.seconds
16+
17+
class AuthTokenGeneratorTest {
18+
@Test
19+
fun testGenerateDbConnectAuthToken() = runTest {
20+
val credentials = Credentials("akid", "secret")
21+
22+
val token = AuthTokenGenerator(credentials)
23+
.generateAuthToken(
24+
endpoint = Url {
25+
host = Host.parse("prod-instance.us-east-1.rds.amazonaws.com")
26+
port = 3306
27+
},
28+
region = "us-east-1",
29+
username = "peccy",
30+
expiration = 450.seconds
31+
)
32+
33+
// Token should have a parameter Action=DbConnect
34+
assertContains(token, "prod-instance.us-east-1.rds.amazonaws.com:3306?Action=connect&DBUser=peccy")
35+
36+
// Match the X-Amz-Credential parameter for any signing date
37+
val credentialRegex = Regex("X-Amz-Credential=akid%2F(\\d{8})%2Fus-east-1%2Frds-db%2Faws4_request")
38+
assertTrue(token.contains(credentialRegex))
39+
40+
assertContains(token, "X-Amz-Expires=450")
41+
42+
// Token should not contain a scheme
43+
listOf("http://", "https://").forEach {
44+
assertFalse(token.contains(it))
45+
}
46+
}
47+
}

0 commit comments

Comments
 (0)