Skip to content

Commit 653095f

Browse files
authored
Move legacy SSO profile users to scoped SSO (#4484)
1 parent 8761d5d commit 653095f

File tree

20 files changed

+317
-383
lines changed

20 files changed

+317
-383
lines changed

plugins/amazonq/chat/jetbrains-community/tst/software/aws/toolkits/jetbrains/services/amazonq/clients/AmazonQStreamingClientTest.kt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@ class AmazonQStreamingClientTest : AmazonQTestBase() {
6565
connectionManager = mock {
6666
on {
6767
activeConnectionForFeature(any())
68-
} doReturn authManagerRule.createConnection(ManagedSsoProfile("us-east-1", aString(), emptyList())) as AwsBearerTokenConnection
68+
} doReturn authManagerRule.createConnection(ManagedSsoProfile("us-east-1", aString(), listOf("scopes"))) as AwsBearerTokenConnection
6969
}
7070

7171
projectRule.project.replaceService(ToolkitConnectionManager::class.java, connectionManager, disposableRule.disposable)

plugins/amazonq/chat/jetbrains-community/tst/software/aws/toolkits/jetbrains/services/amazonqFeatureDev/clients/FeatureDevClientTest.kt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,7 @@ class FeatureDevClientTest : FeatureDevTestBase() {
100100
connectionManager = mock {
101101
on {
102102
activeConnectionForFeature(any())
103-
} doReturn authManagerRule.createConnection(ManagedSsoProfile("us-east-1", aString(), emptyList())) as AwsBearerTokenConnection
103+
} doReturn authManagerRule.createConnection(ManagedSsoProfile("us-east-1", aString(), listOf("scopes"))) as AwsBearerTokenConnection
104104
}
105105
projectRule.project.replaceService(ToolkitConnectionManager::class.java, connectionManager, disposableRule.disposable)
106106

plugins/amazonq/codetransform/jetbrains-community/tst/software/aws/toolkits/jetbrains/services/codemodernizer/CodeWhispererCodeModernizerGumbyClientTest.kt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,7 @@ class CodeWhispererCodeModernizerGumbyClientTest : CodeWhispererCodeModernizerTe
9999
connectionManager = mock {
100100
on {
101101
activeConnectionForFeature(any())
102-
} doReturn authManagerRule.createConnection(ManagedSsoProfile("us-east-1", aString(), emptyList())) as AwsBearerTokenConnection
102+
} doReturn authManagerRule.createConnection(ManagedSsoProfile("us-east-1", aString(), listOf("scopes"))) as AwsBearerTokenConnection
103103
}
104104
projectRule.project.replaceService(ToolkitConnectionManager::class.java, connectionManager, disposableRule.disposable)
105105

plugins/core/jetbrains-community/src/software/aws/toolkits/jetbrains/core/credentials/SsoRequiredInteractiveCredentials.kt

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,17 +4,31 @@
44
package software.aws.toolkits.jetbrains.core.credentials
55

66
import com.intellij.openapi.actionSystem.AnAction
7+
import software.aws.toolkits.jetbrains.core.credentials.sono.IDENTITY_CENTER_ROLE_ACCESS_SCOPE
8+
import software.aws.toolkits.jetbrains.core.credentials.sso.LazyAccessTokenProvider
79
import software.aws.toolkits.jetbrains.core.credentials.sso.SsoCache
810
import software.aws.toolkits.resources.message
911

1012
interface SsoRequiredInteractiveCredentials : InteractiveCredential {
1113
val ssoCache: SsoCache
1214
val ssoUrl: String
15+
val ssoRegion: String
1316

1417
override val userActionDisplayMessage: String get() = message("credentials.sso.display", displayName)
1518
override val userActionShortDisplayMessage: String get() = message("credentials.sso.display.short")
1619

1720
override val userAction: AnAction get() = RefreshConnectionAction(message("credentials.sso.action"))
1821

19-
override fun userActionRequired(): Boolean = ssoCache.loadAccessToken(ssoUrl) == null
22+
private val lazyTokenProvider: LazyAccessTokenProvider
23+
get() = LazyAccessTokenProvider(
24+
ssoCache,
25+
ssoUrl,
26+
ssoRegion,
27+
listOf(IDENTITY_CENTER_ROLE_ACCESS_SCOPE)
28+
)
29+
30+
// assumes single scope if we're going through this interface
31+
override fun userActionRequired(): Boolean = lazyTokenProvider.resolveToken() == null
32+
33+
fun invalidateCurrentToken() = lazyTokenProvider.invalidate()
2034
}

plugins/core/jetbrains-community/src/software/aws/toolkits/jetbrains/core/credentials/profiles/ProfileCredentialProviderFactory.kt

Lines changed: 49 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ import software.amazon.awssdk.auth.credentials.StaticCredentialsProvider
2222
import software.amazon.awssdk.auth.credentials.SystemPropertyCredentialsProvider
2323
import software.amazon.awssdk.profiles.Profile
2424
import software.amazon.awssdk.profiles.ProfileProperty
25+
import software.amazon.awssdk.services.sso.model.UnauthorizedException
2526
import software.amazon.awssdk.services.ssooidc.model.SsoOidcException
2627
import software.aws.toolkits.core.credentials.CredentialIdentifier
2728
import software.aws.toolkits.core.credentials.CredentialIdentifierBase
@@ -80,9 +81,30 @@ private class ProfileCredentialsIdentifierLegacySso(
8081
defaultRegionId: String?,
8182
override val ssoCache: SsoCache,
8283
override val ssoUrl: String,
84+
override val ssoRegion: String,
8385
credentialType: CredentialType?
8486
) : ProfileCredentialsIdentifier(profileName, defaultRegionId, credentialType),
85-
SsoRequiredInteractiveCredentials
87+
SsoRequiredInteractiveCredentials,
88+
PostValidateInteractiveCredential {
89+
// react to failure to use the local credential set
90+
override fun handleValidationException(e: Exception) = ifReAuthNeeded(e) {
91+
ConnectionState.RequiresUserAction(
92+
object : InteractiveCredential, CredentialIdentifier by this {
93+
override val userActionDisplayMessage = message("credentials.sso.display", displayName)
94+
override val userActionShortDisplayMessage = message("credentials.sso.display.short")
95+
override val userAction = object : AnAction(message("credentials.sso.action")), DumbAware {
96+
override fun actionPerformed(e: AnActionEvent) {
97+
invalidateCurrentToken()
98+
99+
RefreshConnectionAction().actionPerformed(e)
100+
}
101+
}
102+
103+
override fun userActionRequired() = true
104+
}
105+
)
106+
}
107+
}
86108

87109
class ProfileCredentialsIdentifierSso @TestOnly constructor(
88110
profileName: String,
@@ -92,10 +114,9 @@ class ProfileCredentialsIdentifierSso @TestOnly constructor(
92114
) : ProfileCredentialsIdentifier(profileName, defaultRegionId, credentialType), PostValidateInteractiveCredential, SsoSessionBackedCredentialIdentifier {
93115
override val sessionIdentifier = "$SSO_SESSION_SECTION_NAME:$ssoSessionName"
94116

95-
override fun handleValidationException(e: Exception): ConnectionState.RequiresUserAction? {
96-
// in the new SSO flow, we must attempt validation before knowing if user action is truly required
97-
if (findUpException<SsoOidcException>(e) || findUpException<IllegalStateException>(e) || findUpException<NoTokenInitializedException>(e)) {
98-
return ConnectionState.RequiresUserAction(
117+
override fun handleValidationException(e: Exception): ConnectionState.RequiresUserAction? =
118+
ifReAuthNeeded(e) {
119+
ConnectionState.RequiresUserAction(
99120
object : InteractiveCredential, CredentialIdentifier by this {
100121
override val userActionDisplayMessage = message("credentials.sso.display", displayName)
101122
override val userActionShortDisplayMessage = message("credentials.sso.display.short")
@@ -121,23 +142,33 @@ class ProfileCredentialsIdentifierSso @TestOnly constructor(
121142
}
122143
)
123144
}
145+
}
124146

125-
return null
147+
// in the new SSO flow, we must attempt validation before knowing if user action is truly required
148+
private fun ifReAuthNeeded(e: Exception, action: () -> ConnectionState.RequiresUserAction?): ConnectionState.RequiresUserAction? {
149+
if (findUpException<SsoOidcException>(e) ||
150+
findUpException<IllegalStateException>(e) ||
151+
findUpException<NoTokenInitializedException>(e) ||
152+
findUpException<UnauthorizedException>(e)
153+
) {
154+
return action()
126155
}
127156

128-
// true exception could be further up the chain
129-
private inline fun<reified T : Throwable> findUpException(e: Throwable?): Boolean {
130-
// inline fun can't use recursion
131-
var throwable = e
132-
while (throwable != null) {
133-
if (throwable is T) {
134-
return true
135-
}
136-
throwable = throwable.cause
137-
}
157+
return null
158+
}
138159

139-
return false
160+
// true exception could be further up the chain
161+
private inline fun<reified T : Throwable> findUpException(e: Throwable?): Boolean {
162+
// inline fun can't use recursion
163+
var throwable = e
164+
while (throwable != null) {
165+
if (throwable is T) {
166+
return true
167+
}
168+
throwable = throwable.cause
140169
}
170+
171+
return false
141172
}
142173

143174
private class NeverShowAgain : DumbAwareAction(message("settings.never_show_again")) {
@@ -410,6 +441,7 @@ class ProfileCredentialProviderFactory(private val ssoCache: SsoCache = diskCach
410441
defaultRegion,
411442
ssoCache,
412443
this.traverseCredentialChain(profiles).map { it.property(ProfileProperty.SSO_START_URL) }.first { it.isPresent }.get(),
444+
this.traverseCredentialChain(profiles).map { it.property(ProfileProperty.SSO_REGION) }.first { it.isPresent }.get(),
413445
requestedProfileType
414446
)
415447
this.requiresSso() -> ProfileCredentialsIdentifierSso(

plugins/core/jetbrains-community/src/software/aws/toolkits/jetbrains/core/credentials/profiles/ProfileLegacySsoProvider.kt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ import software.amazon.awssdk.services.sso.SsoClient
1313
import software.amazon.awssdk.services.ssooidc.SsoOidcClient
1414
import software.amazon.awssdk.utils.SdkAutoCloseable
1515
import software.aws.toolkits.jetbrains.core.AwsClientManager
16+
import software.aws.toolkits.jetbrains.core.credentials.sono.IDENTITY_CENTER_ROLE_ACCESS_SCOPE
1617
import software.aws.toolkits.jetbrains.core.credentials.sso.SsoAccessTokenProvider
1718
import software.aws.toolkits.jetbrains.core.credentials.sso.SsoCache
1819
import software.aws.toolkits.jetbrains.core.credentials.sso.SsoCredentialProvider
@@ -35,6 +36,7 @@ class ProfileLegacySsoProvider(ssoCache: SsoCache, profile: Profile) : AwsCreden
3536
ssoCache,
3637
ssoOidcClient,
3738
isAlwaysShowDeviceCode = true,
39+
scopes = listOf(IDENTITY_CENTER_ROLE_ACCESS_SCOPE),
3840
)
3941

4042
credentialsProvider = SsoCredentialProvider(

plugins/core/jetbrains-community/src/software/aws/toolkits/jetbrains/core/credentials/sso/ClientRegistration.kt

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -54,19 +54,21 @@ data class PKCEClientRegistration(
5454
override fun toString(): String = redactedString(this)
5555
}
5656

57-
sealed interface ClientRegistrationCacheKey
57+
sealed interface ClientRegistrationCacheKey {
58+
val region: String
59+
}
5860

5961
// only applicable in scoped registration path
6062
// based on internal development branch @da780a4,L2574-2586
6163
data class DeviceAuthorizationClientRegistrationCacheKey(
6264
val startUrl: String,
6365
val scopes: List<String>,
64-
val region: String,
66+
override val region: String,
6567
) : ClientRegistrationCacheKey
6668

6769
data class PKCEClientRegistrationCacheKey(
6870
val issuerUrl: String,
69-
val region: String,
71+
override val region: String,
7072
val scopes: List<String>,
7173
// assume clientType, grantTypes, redirectUris are static, but throw them in just in case
7274
val clientType: String,

plugins/core/jetbrains-community/src/software/aws/toolkits/jetbrains/core/credentials/sso/DiskCache.kt

Lines changed: 0 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -93,20 +93,6 @@ class DiskCache(
9393
)
9494
}
9595

96-
override fun loadClientRegistration(ssoRegion: String): ClientRegistration? {
97-
LOG.debug { "loadClientRegistration for $ssoRegion" }
98-
val inputStream = clientRegistrationCache(ssoRegion).tryInputStreamIfExists() ?: return null
99-
return loadClientRegistration(inputStream)
100-
}
101-
102-
override fun saveClientRegistration(ssoRegion: String, registration: ClientRegistration) {
103-
LOG.debug { "saveClientRegistration for $ssoRegion" }
104-
val registrationCache = clientRegistrationCache(ssoRegion)
105-
writeKey(registrationCache) {
106-
objectMapper.writeValue(it, registration)
107-
}
108-
}
109-
11096
override fun invalidateClientRegistration(ssoRegion: String) {
11197
LOG.debug { "invalidateClientRegistration for $ssoRegion" }
11298
clientRegistrationCache(ssoRegion).tryDeleteIfExists()
@@ -133,22 +119,6 @@ class DiskCache(
133119
clientRegistrationCache(cacheKey).tryDeleteIfExists()
134120
}
135121

136-
override fun loadAccessToken(ssoUrl: String): AccessToken? {
137-
LOG.debug { "loadAccessToken for $ssoUrl" }
138-
val cacheFile = accessTokenCache(ssoUrl)
139-
val inputStream = cacheFile.tryInputStreamIfExists() ?: return null
140-
141-
return loadAccessToken(inputStream)
142-
}
143-
144-
override fun saveAccessToken(ssoUrl: String, accessToken: AccessToken) {
145-
LOG.debug { "saveAccessToken for $ssoUrl" }
146-
val accessTokenCache = accessTokenCache(ssoUrl)
147-
writeKey(accessTokenCache) {
148-
objectMapper.writeValue(it, accessToken)
149-
}
150-
}
151-
152122
override fun invalidateAccessToken(ssoUrl: String) {
153123
LOG.debug { "invalidateAccessToken for $ssoUrl" }
154124
accessTokenCache(ssoUrl).tryDeleteIfExists()

0 commit comments

Comments
 (0)