diff --git a/.changes/next-release/bugfix-566abe4c-547f-413b-b1c8-838282782550.json b/.changes/next-release/bugfix-566abe4c-547f-413b-b1c8-838282782550.json new file mode 100644 index 00000000000..3a5324ebc40 --- /dev/null +++ b/.changes/next-release/bugfix-566abe4c-547f-413b-b1c8-838282782550.json @@ -0,0 +1,4 @@ +{ + "type" : "bugfix", + "description" : "Fix issue where a user may get stuck while attempting to login to Builder ID" +} diff --git a/plugins/core/jetbrains-community/src/software/aws/toolkits/jetbrains/core/credentials/DefaultToolkitConnectionManager.kt b/plugins/core/jetbrains-community/src/software/aws/toolkits/jetbrains/core/credentials/DefaultToolkitConnectionManager.kt index 51b0d0f3126..ba68478dfcc 100644 --- a/plugins/core/jetbrains-community/src/software/aws/toolkits/jetbrains/core/credentials/DefaultToolkitConnectionManager.kt +++ b/plugins/core/jetbrains-community/src/software/aws/toolkits/jetbrains/core/credentials/DefaultToolkitConnectionManager.kt @@ -19,6 +19,16 @@ import software.aws.toolkits.jetbrains.core.credentials.sso.bearer.BearerTokenPr // TODO: unify with AwsConnectionManager @State(name = "connectionManager", storages = [Storage("aws.xml")]) class DefaultToolkitConnectionManager : ToolkitConnectionManager, PersistentStateComponent { + private val project: Project? + + constructor(project: Project) { + this.project = project + } + + constructor() { + this.project = null + } + init { ApplicationManager.getApplication().messageBus.connect(this).subscribe( BearerTokenProviderListener.TOPIC, @@ -33,19 +43,10 @@ class DefaultToolkitConnectionManager : ToolkitConnectionManager, PersistentStat ) } - private val project: Project? - - constructor(project: Project) { - this.project = project - } - - constructor() { - this.project = null - } - private var connection: ToolkitConnection? = null - private val pinningManager: ConnectionPinningManager = ConnectionPinningManager.getInstance() + private val pinningManager + get() = ConnectionPinningManager.getInstance() private val defaultConnection: ToolkitConnection? get() { @@ -60,6 +61,10 @@ class DefaultToolkitConnectionManager : ToolkitConnectionManager, PersistentStat return null } + @Deprecated( + "Fragile API. Probably leads to unexpected behavior. Use only for toolkit explorer dropdown state.", + replaceWith = ReplaceWith("activeConnectionForFeature(feature)") + ) @Synchronized override fun activeConnection() = connection ?: defaultConnection @@ -71,11 +76,11 @@ class DefaultToolkitConnectionManager : ToolkitConnectionManager, PersistentStat } return connection?.let { - if (feature.supportsConnectionType(it)) { - return it + return@let if (feature.supportsConnectionType(it)) { + it + } else { + null } - - null } ?: defaultConnection?.let { if (ApplicationInfo.getInstance().build.productCode == "GW") return null if (feature.supportsConnectionType(it)) { diff --git a/plugins/core/jetbrains-community/src/software/aws/toolkits/jetbrains/core/credentials/ToolkitAuthManager.kt b/plugins/core/jetbrains-community/src/software/aws/toolkits/jetbrains/core/credentials/ToolkitAuthManager.kt index 86d94d9a5fb..ad7d1be3e67 100644 --- a/plugins/core/jetbrains-community/src/software/aws/toolkits/jetbrains/core/credentials/ToolkitAuthManager.kt +++ b/plugins/core/jetbrains-community/src/software/aws/toolkits/jetbrains/core/credentials/ToolkitAuthManager.kt @@ -89,6 +89,10 @@ interface ToolkitStartupAuthFactory { } interface ToolkitConnectionManager : Disposable { + @Deprecated( + "Fragile API. Probably leads to unexpected behavior. Use only for toolkit explorer dropdown state.", + ReplaceWith("activeConnectionForFeature(feature)") + ) fun activeConnection(): ToolkitConnection? fun activeConnectionForFeature(feature: FeatureWithPinnedConnection): ToolkitConnection? @@ -129,6 +133,7 @@ fun loginSso( project = project, connection = transientConnection, onPendingToken = onPendingToken, + isReAuth = false, source = metadata?.sourceId, ) } @@ -149,7 +154,7 @@ fun loginSso( val manager = ToolkitAuthManager.getInstance() val allScopes = requestedScopes.toMutableSet() - return manager.getConnection(connectionId)?.let { connection -> + val connection = manager.getConnection(connectionId)?.let { connection -> val logger = getLogger() if (connection !is AwsBearerTokenConnection) { @@ -167,32 +172,32 @@ fun loginSso( """.trimIndent() } // can't reuse since requested scopes are not in current connection. forcing reauth - return createAndAuthNewConnection( - ManagedSsoProfile( - region, - startUrl, - allScopes.toList() - ) - ) + return@let null } // For the case when the existing connection is in invalid state, we need to re-auth reauthConnectionIfNeeded( project = project, connection = connection, - isReAuth = true + onPendingToken = onPendingToken, + isReAuth = true, + source = metadata?.sourceId, ) + return@let connection + } + + if (connection != null) { return connection - } ?: run { - // No existing connection, start from scratch - createAndAuthNewConnection( - ManagedSsoProfile( - region, - startUrl, - allScopes.toList() - ) - ) } + + // No existing connection, start from scratch + return createAndAuthNewConnection( + ManagedSsoProfile( + region, + startUrl, + allScopes.toList() + ) + ) } @Suppress("UnusedParameter") @@ -242,7 +247,9 @@ fun reauthConnectionIfNeeded( } val startUrl = (connection as AwsBearerTokenConnection).startUrl + var didReauth = false maybeReauthProviderIfNeeded(project, tokenProvider) { + didReauth = true runUnderProgressIfNeeded(project, AwsCoreBundle.message("credentials.pending.title"), true) { try { tokenProvider.reauthenticate() @@ -283,6 +290,11 @@ fun reauthConnectionIfNeeded( } } } + + if (!didReauth) { + // webview is stuck if reauth was not needed (i.e. token on disk is valid) + project?.let { ToolkitConnectionManager.getInstance(it).switchConnection(connection) } + } return tokenProvider } diff --git a/plugins/core/jetbrains-community/tst/software/aws/toolkits/jetbrains/core/credentials/DefaultToolkitAuthManagerTest.kt b/plugins/core/jetbrains-community/tst/software/aws/toolkits/jetbrains/core/credentials/DefaultToolkitAuthManagerTest.kt index 4692a111d67..b34ae000b48 100644 --- a/plugins/core/jetbrains-community/tst/software/aws/toolkits/jetbrains/core/credentials/DefaultToolkitAuthManagerTest.kt +++ b/plugins/core/jetbrains-community/tst/software/aws/toolkits/jetbrains/core/credentials/DefaultToolkitAuthManagerTest.kt @@ -18,6 +18,7 @@ import org.junit.jupiter.api.extension.RegisterExtension import org.mockito.Mockito.mockConstruction import org.mockito.kotlin.any import org.mockito.kotlin.argumentCaptor +import org.mockito.kotlin.atLeastOnce import org.mockito.kotlin.doNothing import org.mockito.kotlin.mock import org.mockito.kotlin.spy @@ -25,7 +26,6 @@ import org.mockito.kotlin.timeout import org.mockito.kotlin.verify import org.mockito.kotlin.verifyNoMoreInteractions import org.mockito.kotlin.whenever -import software.amazon.awssdk.regions.Region import software.amazon.awssdk.services.ssooidc.SsoOidcClient import software.aws.toolkits.core.telemetry.MetricEvent import software.aws.toolkits.core.telemetry.TelemetryBatcher @@ -36,8 +36,6 @@ import software.aws.toolkits.jetbrains.core.credentials.profiles.ProfileSsoSessi import software.aws.toolkits.jetbrains.core.credentials.sso.bearer.BearerTokenAuthState import software.aws.toolkits.jetbrains.core.credentials.sso.bearer.BearerTokenProviderListener import software.aws.toolkits.jetbrains.core.credentials.sso.bearer.InteractiveBearerTokenProvider -import software.aws.toolkits.jetbrains.core.region.MockRegionProviderExtension -import software.aws.toolkits.jetbrains.core.region.MockRegionProviderRule import software.aws.toolkits.jetbrains.services.telemetry.NoOpPublisher import software.aws.toolkits.jetbrains.services.telemetry.TelemetryService import software.aws.toolkits.jetbrains.settings.AwsSettings @@ -51,9 +49,6 @@ class DefaultToolkitAuthManagerTest { batcher: TelemetryBatcher ) : TelemetryService(publisher, batcher) - @ExtendWith(MockRegionProviderExtension::class) - val regionProvider = MockRegionProviderRule() - @JvmField @RegisterExtension val mockClientManager = MockClientManagerExtension() @@ -67,9 +62,13 @@ class DefaultToolkitAuthManagerTest { @BeforeEach fun setUp(@TestDisposable disposable: Disposable) { mockClientManager.create() + sut = DefaultToolkitAuthManager() ApplicationManager.getApplication().replaceService(ToolkitAuthManager::class.java, sut, disposable) - connectionManager = DefaultToolkitConnectionManager() + + connectionManager = DefaultToolkitConnectionManager(projectRule.project) + projectRule.project.replaceService(ToolkitConnectionManager::class.java, connectionManager, disposable) + batcher = mock() telemetryService = spy(TestTelemetryService(batcher = batcher)) ApplicationManager.getApplication().replaceService(TelemetryService::class.java, telemetryService, disposable) @@ -222,12 +221,7 @@ class DefaultToolkitAuthManagerTest { } @Test - fun `loginSso with an working existing connection`(@TestDisposable disposable: Disposable) { - val connectionManager: ToolkitConnectionManager = mock() - regionProvider.addRegion(Region.US_EAST_1) - projectRule.project.replaceService(ToolkitConnectionManager::class.java, connectionManager, disposable) - ApplicationManager.getApplication().replaceService(ToolkitAuthManager::class.java, sut, disposable) - + fun `loginSso with an working existing connection`() { mockConstruction(InteractiveBearerTokenProvider::class.java) { context, _ -> whenever(context.state()).thenReturn(BearerTokenAuthState.AUTHORIZED) }.use { @@ -248,12 +242,7 @@ class DefaultToolkitAuthManagerTest { } @Test - fun `loginSso with an existing connection but expired and refresh token is valid, should refreshToken`(@TestDisposable disposable: Disposable) { - val connectionManager = ToolkitConnectionManager.getInstance(projectRule.project) - regionProvider.addRegion(Region.US_EAST_1) - projectRule.project.replaceService(ToolkitConnectionManager::class.java, connectionManager, disposable) - ApplicationManager.getApplication().replaceService(ToolkitAuthManager::class.java, sut, disposable) - + fun `loginSso with an existing connection but expired and refresh token is valid, should refreshToken`() { mockConstruction(InteractiveBearerTokenProvider::class.java) { context, _ -> whenever(context.id).thenReturn("id") whenever(context.state()).thenReturn(BearerTokenAuthState.NEEDS_REFRESH) @@ -276,13 +265,7 @@ class DefaultToolkitAuthManagerTest { } @Test - fun `loginSso with an existing connection that token is invalid and there's no refresh token, should re-authenticate`( - @TestDisposable disposable: Disposable - ) { - val connectionManager = ToolkitConnectionManager.getInstance(projectRule.project) - regionProvider.addRegion(Region.US_EAST_1) - ApplicationManager.getApplication().replaceService(ToolkitAuthManager::class.java, sut, disposable) - + fun `loginSso with an existing connection that token is invalid and there's no refresh token, should re-authenticate`() { mockConstruction(InteractiveBearerTokenProvider::class.java) { context, _ -> whenever(context.state()).thenReturn(BearerTokenAuthState.NOT_AUTHENTICATED) }.use { @@ -305,13 +288,12 @@ class DefaultToolkitAuthManagerTest { @Test fun `loginSso reuses connection if requested scopes are subset of existing`(@TestDisposable disposable: Disposable) { - val connectionManager = ToolkitConnectionManager.getInstance(projectRule.project) - regionProvider.addRegion(Region.US_EAST_1) - ApplicationManager.getApplication().replaceService(ToolkitAuthManager::class.java, sut, disposable) - mockConstruction(InteractiveBearerTokenProvider::class.java) { context, _ -> whenever(context.state()).thenReturn(BearerTokenAuthState.AUTHORIZED) }.use { + val connectionManager = spy(connectionManager) + projectRule.project.replaceService(ToolkitConnectionManager::class.java, connectionManager, disposable) + val existingConnection = sut.createConnection( ManagedSsoProfile( "us-east-1", @@ -328,16 +310,12 @@ class DefaultToolkitAuthManagerTest { verify(tokenProvider).state() verifyNoMoreInteractions(tokenProvider) assertThat(connectionManager.activeConnection()).isEqualTo(existingConnection) + verify(connectionManager, atLeastOnce()).switchConnection(existingConnection) } } @Test - fun `loginSso forces reauth if requested scopes are not complete subset`(@TestDisposable disposable: Disposable) { - regionProvider.addRegion(Region.US_EAST_1) - - projectRule.project.replaceService(ToolkitConnectionManager::class.java, connectionManager, disposable) - ApplicationManager.getApplication().replaceService(ToolkitAuthManager::class.java, sut, disposable) - + fun `loginSso forces reauth if requested scopes are not complete subset`() { mockConstruction(InteractiveBearerTokenProvider::class.java) { context, _ -> whenever(context.state()).thenReturn(BearerTokenAuthState.AUTHORIZED) }.use { @@ -364,14 +342,12 @@ class DefaultToolkitAuthManagerTest { @Test fun `loginSso with a new connection`(@TestDisposable disposable: Disposable) { - val connectionManager: ToolkitConnectionManager = mock() - ApplicationManager.getApplication().replaceService(ToolkitAuthManager::class.java, sut, disposable) - projectRule.project.replaceService(ToolkitConnectionManager::class.java, connectionManager, disposable) - mockConstruction(InteractiveBearerTokenProvider::class.java) { context, _ -> doNothing().whenever(context).reauthenticate() whenever(context.state()).thenReturn(BearerTokenAuthState.NOT_AUTHENTICATED) }.use { + val connectionManager = spy(connectionManager) + projectRule.project.replaceService(ToolkitConnectionManager::class.java, connectionManager, disposable) // before assertThat(sut.listConnections()).hasSize(0) @@ -399,17 +375,14 @@ class DefaultToolkitAuthManagerTest { @Test fun `logoutFromConnection should invalidate the token provider and the connection and invoke callback`(@TestDisposable disposable: Disposable) { - regionProvider.addRegion(Region.US_EAST_1) - projectRule.project.replaceService(ToolkitConnectionManager::class.java, connectionManager, disposable) - val profile = ManagedSsoProfile("us-east-1", "startUrl000", listOf("scopes")) - val connection = ToolkitAuthManager.getInstance().createConnection(profile) as ManagedBearerSsoConnection + val connection = sut.createConnection(profile) as ManagedBearerSsoConnection connectionManager.switchConnection(connection) var providerInvalidatedMessageReceived = 0 var connectionSwitchedMessageReceived = 0 var callbackInvoked = 0 - ApplicationManager.getApplication().messageBus.connect().subscribe( + ApplicationManager.getApplication().messageBus.connect(disposable).subscribe( BearerTokenProviderListener.TOPIC, object : BearerTokenProviderListener { override fun invalidate(providerId: String) { @@ -419,7 +392,7 @@ class DefaultToolkitAuthManagerTest { } } ) - ApplicationManager.getApplication().messageBus.connect().subscribe( + ApplicationManager.getApplication().messageBus.connect(disposable).subscribe( ToolkitConnectionManagerListener.TOPIC, object : ToolkitConnectionManagerListener { override fun activeConnectionChanged(newConnection: ToolkitConnection?) {