Skip to content

Commit 856260c

Browse files
authored
Fix AWS SSO connection with RDS using IAM (#4278)
* Fix AWS SSO connection with RDS using IAM * Removed unUseed message in messageBundles file * addressing comments * Addressed test case * Exception Updated * detekt failed * fix failed test
1 parent 02fb061 commit 856260c

File tree

8 files changed

+123
-6
lines changed

8 files changed

+123
-6
lines changed
Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
{
2+
"type" : "bugfix",
3+
"description" : "Fix AWS SSO connection when authenticating with RDS using an IAM Identity Center profile (#4145)"
4+
}

plugins/core/jetbrains-community/src/software/aws/toolkits/jetbrains/core/credentials/profiles/ProfileCredentialProviderFactory.kt

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@ import software.aws.toolkits.jetbrains.core.credentials.profiles.SsoSessionConst
5151
import software.aws.toolkits.jetbrains.core.credentials.profiles.SsoSessionConstants.SSO_SESSION_SECTION_NAME
5252
import software.aws.toolkits.jetbrains.core.credentials.reauthConnectionIfNeeded
5353
import software.aws.toolkits.jetbrains.core.credentials.sso.SsoCache
54+
import software.aws.toolkits.jetbrains.core.credentials.sso.bearer.NoTokenInitializedException
5455
import software.aws.toolkits.jetbrains.settings.AwsSettings
5556
import software.aws.toolkits.jetbrains.settings.ProfilesNotification
5657
import software.aws.toolkits.jetbrains.utils.createNotificationExpiringAction
@@ -93,7 +94,7 @@ class ProfileCredentialsIdentifierSso @TestOnly constructor(
9394

9495
override fun handleValidationException(e: Exception): ConnectionState.RequiresUserAction? {
9596
// in the new SSO flow, we must attempt validation before knowing if user action is truly required
96-
if (findUpException<SsoOidcException>(e) || findUpException<IllegalStateException>(e)) {
97+
if (findUpException<SsoOidcException>(e) || findUpException<IllegalStateException>(e) || findUpException<NoTokenInitializedException>(e)) {
9798
return ConnectionState.RequiresUserAction(
9899
object : InteractiveCredential, CredentialIdentifier by this {
99100
override val userActionDisplayMessage = message("credentials.sso.display", displayName)

plugins/core/jetbrains-community/src/software/aws/toolkits/jetbrains/core/credentials/sso/bearer/BearerTokenProvider.kt

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -137,7 +137,7 @@ class InteractiveBearerTokenProvider(
137137
}
138138

139139
private fun refreshToken(): RefreshResult<out SdkToken> {
140-
val lastToken = lastToken.get() ?: error("Token refresh started before session initialized")
140+
val lastToken = lastToken.get() ?: throw NoTokenInitializedException("Token refresh started before session initialized")
141141
val token = if (Duration.between(Instant.now(), lastToken.expiresAt) > Duration.ofMinutes(30)) {
142142
lastToken
143143
} else {
@@ -167,7 +167,7 @@ class InteractiveBearerTokenProvider(
167167
* Only use if you know what you're doing.
168168
*/
169169
override fun refresh(): AccessToken {
170-
val lastToken = lastToken.get() ?: error("Token refresh started before session initialized")
170+
val lastToken = lastToken.get() ?: throw NoTokenInitializedException("Token refresh started before session initialized")
171171
return accessTokenProvider.refreshToken(lastToken).also {
172172
this.lastToken.set(it)
173173
}
@@ -189,6 +189,8 @@ class InteractiveBearerTokenProvider(
189189
}
190190
}
191191

192+
class NoTokenInitializedException(message: String) : Exception(message)
193+
192194
public enum class BearerTokenAuthState {
193195
AUTHORIZED,
194196
NEEDS_REFRESH,

plugins/core/jetbrains-community/tst/software/aws/toolkits/jetbrains/core/credentials/profiles/ProfileCredentialsIdentifierSsoTest.kt

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ import software.amazon.awssdk.services.ssooidc.model.SsoOidcException
1515
import software.aws.toolkits.jetbrains.core.MockClientManagerExtension
1616
import software.aws.toolkits.jetbrains.core.credentials.sso.DiskCache
1717
import software.aws.toolkits.jetbrains.core.credentials.sso.bearer.InteractiveBearerTokenProvider
18+
import software.aws.toolkits.jetbrains.core.credentials.sso.bearer.NoTokenInitializedException
1819

1920
@ExtendWith(ApplicationExtension::class)
2021
class ProfileCredentialsIdentifierSsoTest {
@@ -46,7 +47,7 @@ class ProfileCredentialsIdentifierSsoTest {
4647
mockClientManager.create<SsoOidcClient>()
4748

4849
// IllegalStateException instead of more general base Exception so we know if the type changes
49-
val exception = assertThrows<IllegalStateException> {
50+
val exception = assertThrows<NoTokenInitializedException> {
5051
InteractiveBearerTokenProvider("", "us-east-1", emptyList(), cache = cache, id = "test").resolveToken()
5152
}
5253
assertThat(sut.handleValidationException(exception)).isNotNull()

plugins/core/jetbrains-community/tstFixtures/software/aws/toolkits/jetbrains/core/credentials/MockCredentialsManager.kt

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ import software.aws.toolkits.core.credentials.CredentialIdentifierBase
1616
import software.aws.toolkits.core.credentials.CredentialProviderFactory
1717
import software.aws.toolkits.core.credentials.CredentialSourceId
1818
import software.aws.toolkits.core.credentials.CredentialsChangeListener
19+
import software.aws.toolkits.core.credentials.SsoSessionIdentifier
1920
import software.aws.toolkits.core.credentials.ToolkitCredentialsProvider
2021
import software.aws.toolkits.core.region.AwsRegion
2122
import software.aws.toolkits.core.utils.test.aString
@@ -49,6 +50,20 @@ class MockCredentialsManager : CredentialManager() {
4950
addProvider(it)
5051
}
5152

53+
fun addCredentials(
54+
credentialIdentifier: CredentialIdentifier
55+
): CredentialIdentifier {
56+
addProvider(credentialIdentifier)
57+
return credentialIdentifier
58+
}
59+
60+
fun addSsoProvider(
61+
ssoSessionIdentifier: SsoSessionIdentifier
62+
): SsoSessionIdentifier {
63+
super.addSsoSession(ssoSessionIdentifier)
64+
return ssoSessionIdentifier
65+
}
66+
5267
fun createCredentialProvider(
5368
id: String = aString(),
5469
credentials: AwsCredentials,
@@ -113,6 +128,14 @@ open class MockCredentialManagerRule : ApplicationRule() {
113128
regionId: String? = null
114129
): MockCredentialsManager.MockCredentialIdentifier = credentialManager.addCredentials(id, credentials, regionId)
115130

131+
fun addCredentials(
132+
credentialIdentifier: CredentialIdentifier
133+
): CredentialIdentifier = credentialManager.addCredentials(credentialIdentifier)
134+
135+
fun addSsoProvider(
136+
ssoSessionIdentifier: SsoSessionIdentifier
137+
): SsoSessionIdentifier = credentialManager.addSsoProvider(ssoSessionIdentifier)
138+
116139
fun createCredentialProvider(
117140
id: String = aString(),
118141
credentials: AwsCredentials = AwsBasicCredentials.create("Access", "Secret"),

plugins/toolkit/jetbrains-ultimate/src/software/aws/toolkits/jetbrains/services/rds/auth/IamAuth.kt

Lines changed: 52 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,14 @@ package software.aws.toolkits.jetbrains.services.rds.auth
55

66
import com.intellij.credentialStore.Credentials
77
import com.intellij.database.access.DatabaseCredentials
8+
import com.intellij.database.connection.throwable.KnownDatabaseException
9+
import com.intellij.database.connection.throwable.info.ErrorInfo
10+
import com.intellij.database.connection.throwable.info.SimpleErrorInfo
811
import com.intellij.database.dataSource.DatabaseAuthProvider.AuthWidget
912
import com.intellij.database.dataSource.DatabaseConnectionInterceptor.ProtoConnection
1013
import com.intellij.database.dataSource.DatabaseCredentialsAuthProvider
1114
import com.intellij.database.dataSource.LocalDataSource
15+
import com.intellij.openapi.actionSystem.DataContext
1216
import com.intellij.openapi.project.Project
1317
import kotlinx.coroutines.future.future
1418
import software.amazon.awssdk.regions.Region
@@ -17,6 +21,12 @@ import software.aws.toolkits.core.ConnectionSettings
1721
import software.aws.toolkits.core.utils.getLogger
1822
import software.aws.toolkits.core.utils.info
1923
import software.aws.toolkits.jetbrains.core.coroutines.projectCoroutineScope
24+
import software.aws.toolkits.jetbrains.core.credentials.CredentialManager
25+
import software.aws.toolkits.jetbrains.core.credentials.ToolkitAuthManager
26+
import software.aws.toolkits.jetbrains.core.credentials.UserConfigSsoSessionProfile
27+
import software.aws.toolkits.jetbrains.core.credentials.profiles.ProfileCredentialsIdentifierSso
28+
import software.aws.toolkits.jetbrains.core.credentials.reauthConnectionIfNeeded
29+
import software.aws.toolkits.jetbrains.core.credentials.sso.bearer.NoTokenInitializedException
2030
import software.aws.toolkits.jetbrains.datagrip.auth.compatability.DatabaseAuthProviderCompatabilityAdapter
2131
import software.aws.toolkits.jetbrains.datagrip.auth.compatability.project
2232
import software.aws.toolkits.jetbrains.datagrip.getAwsConnectionSettings
@@ -49,6 +59,17 @@ class IamAuth : DatabaseAuthProviderCompatabilityAdapter {
4959

5060
override fun createWidget(project: Project?, credentials: DatabaseCredentials, dataSource: LocalDataSource): AuthWidget? = IamAuthWidget()
5161

62+
inner class SsoNoTokenFix(val project: Project, val connection: ProtoConnection) : ErrorInfo.Fix {
63+
64+
override fun getName(): String = message("credentials.sso.login.session", getAuthInformation(connection).connectionSettings.credentials.id)
65+
66+
override fun isSilent(): Boolean = false
67+
68+
override fun apply(dataContext: DataContext) {
69+
handleSsoAuthentication(project, connection)
70+
}
71+
}
72+
5273
override fun intercept(
5374
connection: ProtoConnection,
5475
silent: Boolean
@@ -63,13 +84,43 @@ class IamAuth : DatabaseAuthProviderCompatabilityAdapter {
6384
DatabaseCredentialsAuthProvider.applyCredentials(connection, credentials, true)
6485
} catch (e: Throwable) {
6586
result = Result.Failed
66-
throw e
87+
if (e is NoTokenInitializedException) {
88+
val simpleErrorInfo = SimpleErrorInfo(
89+
message("rds.validation.iam_sso_connection.error_info"),
90+
e,
91+
listOf(SsoNoTokenFix(project, connection))
92+
)
93+
throw KnownDatabaseException(simpleErrorInfo)
94+
} else {
95+
throw e
96+
}
6797
} finally {
6898
RdsTelemetry.getCredentials(project = project, result = result, databaseCredentials = IAM, databaseEngine = connection.getDatabaseEngine())
6999
}
70100
}
71101
}
72102

103+
fun handleSsoAuthentication(project: Project, connection: ProtoConnection): ProtoConnection {
104+
val authInformation = getAuthInformation(connection)
105+
val profileCredentials =
106+
CredentialManager.getInstance().getCredentialIdentifierById(authInformation.connectionSettings.credentials.id) as ProfileCredentialsIdentifierSso
107+
108+
val session = CredentialManager.getInstance()
109+
.getSsoSessionIdentifiers()
110+
.first { it.id == profileCredentials.sessionIdentifier }
111+
val ssoConnection = ToolkitAuthManager.getInstance().getOrCreateSsoConnection(
112+
UserConfigSsoSessionProfile(
113+
configSessionName = profileCredentials.ssoSessionName,
114+
ssoRegion = session.ssoRegion,
115+
startUrl = session.startUrl,
116+
scopes = session.scopes.toList()
117+
)
118+
)
119+
120+
reauthConnectionIfNeeded(project, ssoConnection)
121+
return connection
122+
}
123+
73124
internal fun getAuthInformation(connection: ProtoConnection): RdsAuth {
74125
validateIamConfiguration(connection)
75126
val signingUrl = connection.connectionPoint.additionalProperties[RDS_SIGNING_HOST_PROPERTY]

plugins/toolkit/jetbrains-ultimate/tst/software/aws/toolkits/jetbrains/services/rds/auth/IamAuthTest.kt

Lines changed: 35 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,16 +19,25 @@ import org.mockito.kotlin.doAnswer
1919
import org.mockito.kotlin.doReturn
2020
import org.mockito.kotlin.mock
2121
import software.amazon.awssdk.auth.credentials.AwsBasicCredentials
22+
import software.amazon.awssdk.services.ssooidc.SsoOidcClient
23+
import software.aws.toolkits.core.credentials.CredentialType
2224
import software.aws.toolkits.core.region.AwsRegion
2325
import software.aws.toolkits.core.utils.RuleUtils
2426
import software.aws.toolkits.core.utils.unwrap
27+
import software.aws.toolkits.jetbrains.core.MockClientManagerRule
2528
import software.aws.toolkits.jetbrains.core.credentials.MockCredentialManagerRule
29+
import software.aws.toolkits.jetbrains.core.credentials.MockToolkitAuthManagerRule
30+
import software.aws.toolkits.jetbrains.core.credentials.diskCache
31+
import software.aws.toolkits.jetbrains.core.credentials.profiles.ProfileCredentialsIdentifierSso
32+
import software.aws.toolkits.jetbrains.core.credentials.profiles.ProfileSsoSessionIdentifier
33+
import software.aws.toolkits.jetbrains.core.credentials.sso.DeviceAuthorizationGrantToken
2634
import software.aws.toolkits.jetbrains.core.region.MockRegionProviderRule
2735
import software.aws.toolkits.jetbrains.datagrip.CREDENTIAL_ID_PROPERTY
2836
import software.aws.toolkits.jetbrains.datagrip.REGION_ID_PROPERTY
2937
import software.aws.toolkits.jetbrains.datagrip.RequireSsl
3038
import software.aws.toolkits.jetbrains.datagrip.auth.compatability.project
3139
import software.aws.toolkits.resources.message
40+
import java.time.Instant
3241

3342
class IamAuthTest {
3443
@Rule
@@ -53,10 +62,34 @@ class IamAuthTest {
5362
@JvmField
5463
val credentialManager = MockCredentialManagerRule()
5564

65+
@Rule
66+
@JvmField
67+
val authManager = MockToolkitAuthManagerRule()
68+
69+
@JvmField
70+
@Rule
71+
val mockClientManager = MockClientManagerRule()
72+
73+
private lateinit var ssoClient: SsoOidcClient
74+
5675
@Before
5776
fun setUp() {
5877
credentialManager.addCredentials(credentialId, mockCreds)
5978
regionProvider.addRegion(AwsRegion(defaultRegion, RuleUtils.randomName(), RuleUtils.randomName()))
79+
ssoClient = mockClientManager.create()
80+
}
81+
82+
@Test
83+
fun `Handle Sso authentication no token present`() {
84+
val noTokenCredentialId = RuleUtils.randomName()
85+
val ssoUrl = RuleUtils.randomName()
86+
diskCache.saveAccessToken(ssoUrl, DeviceAuthorizationGrantToken(ssoUrl, "us-east-1", "access1", "refresh1", Instant.MAX))
87+
credentialManager.addCredentials(ProfileCredentialsIdentifierSso(noTokenCredentialId, noTokenCredentialId, "us-east-1", CredentialType.SsoProfile))
88+
credentialManager.addSsoProvider(ProfileSsoSessionIdentifier(noTokenCredentialId, ssoUrl, "us-east-1", setOf()))
89+
val conneciton = buildConnection(hasCredentials = true, credentialId = "profile:" + noTokenCredentialId)
90+
91+
val connection = iamAuth.handleSsoAuthentication(projectRule.project, conneciton)
92+
assertThat(connection).isNotNull
6093
}
6194

6295
@Test
@@ -168,7 +201,8 @@ class IamAuthTest {
168201
hasCredentials: Boolean = true,
169202
hasBadHost: Boolean = false,
170203
hasSslConfig: Boolean = true,
171-
dbmsType: Dbms = Dbms.POSTGRES
204+
dbmsType: Dbms = Dbms.POSTGRES,
205+
credentialId: String = this.credentialId
172206
): ProtoConnection {
173207
val mockConnection = mock<LocalDataSource> {
174208
on { url } doReturn "jdbc:postgresql://$dbHost:$connectionPort/dev"

plugins/toolkit/resources/resources/software/aws/toolkits/resources/MessagesBundle.properties

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1455,6 +1455,7 @@ rds.port=RDS Port:
14551455
rds.postgres=PostgreSQL
14561456
rds.url=RDS Host:
14571457
rds.validation.aurora_mysql_ssl_required=Aurora MySQL requires SSL to be enabled to connect
1458+
rds.validation.iam_sso_connection.error_info=No token found, please login and try again
14581459
rds.validation.no_iam_auth=Database {0} does not have IAM authentication enabled
14591460
rds.validation.no_instance_host=No RDS database host specified
14601461
rds.validation.no_instance_port=No RDS database port specified

0 commit comments

Comments
 (0)