diff --git a/plugins/amazonq/chat/jetbrains-community/src/software/aws/toolkits/jetbrains/services/amazonq/toolwindow/AmazonQToolWindowFactory.kt b/plugins/amazonq/chat/jetbrains-community/src/software/aws/toolkits/jetbrains/services/amazonq/toolwindow/AmazonQToolWindowFactory.kt index 3ae16dd96c6..6e1841499aa 100644 --- a/plugins/amazonq/chat/jetbrains-community/src/software/aws/toolkits/jetbrains/services/amazonq/toolwindow/AmazonQToolWindowFactory.kt +++ b/plugins/amazonq/chat/jetbrains-community/src/software/aws/toolkits/jetbrains/services/amazonq/toolwindow/AmazonQToolWindowFactory.kt @@ -83,7 +83,7 @@ class AmazonQToolWindowFactory : ToolWindowFactory, DumbAware { project.messageBus.connect(toolWindow.disposable).subscribe( BearerTokenProviderListener.TOPIC, object : BearerTokenProviderListener { - override fun onChange(providerId: String, newScopes: List?) { + override fun onProviderChange(providerId: String, newScopes: List?) { if (ToolkitConnectionManager.getInstance(project).connectionStateForFeature(QConnection.getInstance()) == BearerTokenAuthState.AUTHORIZED) { preparePanelContent(project, qPanel) } diff --git a/plugins/amazonq/chat/jetbrains-community/src/software/aws/toolkits/jetbrains/services/amazonqCodeScan/CodeScanChatApp.kt b/plugins/amazonq/chat/jetbrains-community/src/software/aws/toolkits/jetbrains/services/amazonqCodeScan/CodeScanChatApp.kt index 0aa8dc42b04..ce77b84f89a 100644 --- a/plugins/amazonq/chat/jetbrains-community/src/software/aws/toolkits/jetbrains/services/amazonqCodeScan/CodeScanChatApp.kt +++ b/plugins/amazonq/chat/jetbrains-community/src/software/aws/toolkits/jetbrains/services/amazonqCodeScan/CodeScanChatApp.kt @@ -125,7 +125,7 @@ class CodeScanChatApp(private val scope: CoroutineScope) : AmazonQApp { ApplicationManager.getApplication().messageBus.connect(this).subscribe( BearerTokenProviderListener.TOPIC, object : BearerTokenProviderListener { - override fun onChange(providerId: String, newScopes: List?) { + override fun onProviderChange(providerId: String, newScopes: List?) { val qProvider = getQTokenProvider(context.project) val isQ = qProvider?.id == providerId val isAuthorized = qProvider?.state() == BearerTokenAuthState.AUTHORIZED diff --git a/plugins/amazonq/chat/jetbrains-community/tst/software/aws/toolkits/jetbrains/services/amazonq/AmazonQTestBase.kt b/plugins/amazonq/chat/jetbrains-community/tst/software/aws/toolkits/jetbrains/services/amazonq/AmazonQTestBase.kt index 0216a39ef67..490fd794c64 100644 --- a/plugins/amazonq/chat/jetbrains-community/tst/software/aws/toolkits/jetbrains/services/amazonq/AmazonQTestBase.kt +++ b/plugins/amazonq/chat/jetbrains-community/tst/software/aws/toolkits/jetbrains/services/amazonq/AmazonQTestBase.kt @@ -17,17 +17,14 @@ import org.mockito.kotlin.spy import org.mockito.kotlin.whenever import software.aws.toolkits.core.TokenConnectionSettings import software.aws.toolkits.core.credentials.ToolkitBearerTokenProvider -import software.aws.toolkits.core.utils.test.aString import software.aws.toolkits.jetbrains.core.credentials.AwsBearerTokenConnection import software.aws.toolkits.jetbrains.core.credentials.ToolkitConnectionManager -import software.aws.toolkits.jetbrains.core.credentials.sso.DeviceAuthorizationGrantToken import software.aws.toolkits.jetbrains.core.credentials.sso.bearer.BearerTokenProvider import software.aws.toolkits.jetbrains.services.amazonq.clients.AmazonQStreamingClient import software.aws.toolkits.jetbrains.utils.rules.CodeInsightTestFixtureRule import software.aws.toolkits.jetbrains.utils.rules.HeavyJavaCodeInsightTestFixtureRule import software.aws.toolkits.jetbrains.utils.rules.JavaCodeInsightTestFixtureRule import software.aws.toolkits.jetbrains.utils.rules.addModule -import java.time.Instant open class AmazonQTestBase( @Rule @JvmField @@ -47,11 +44,7 @@ open class AmazonQTestBase( project = projectRule.project toolkitConnectionManager = spy(ToolkitConnectionManager.getInstance(project)) - val accessToken = DeviceAuthorizationGrantToken(aString(), aString(), aString(), aString(), Instant.MAX, Instant.now()) - - val provider = mock { - doReturn(accessToken).whenever(it).refresh() - } + val provider = mock() val mockBearerProvider = mock { doReturn(provider).whenever(it).delegate diff --git a/plugins/amazonq/chat/jetbrains-community/tst/software/aws/toolkits/jetbrains/services/amazonqFeatureDev/FeatureDevTestBase.kt b/plugins/amazonq/chat/jetbrains-community/tst/software/aws/toolkits/jetbrains/services/amazonqFeatureDev/FeatureDevTestBase.kt index da392cd08cd..ee655405844 100644 --- a/plugins/amazonq/chat/jetbrains-community/tst/software/aws/toolkits/jetbrains/services/amazonqFeatureDev/FeatureDevTestBase.kt +++ b/plugins/amazonq/chat/jetbrains-community/tst/software/aws/toolkits/jetbrains/services/amazonqFeatureDev/FeatureDevTestBase.kt @@ -28,10 +28,8 @@ import software.amazon.awssdk.services.codewhispererruntime.model.SendTelemetryE import software.amazon.awssdk.services.codewhispererruntime.model.StartTaskAssistCodeGenerationResponse import software.aws.toolkits.core.TokenConnectionSettings import software.aws.toolkits.core.credentials.ToolkitBearerTokenProvider -import software.aws.toolkits.core.utils.test.aString import software.aws.toolkits.jetbrains.core.credentials.AwsBearerTokenConnection import software.aws.toolkits.jetbrains.core.credentials.ToolkitConnectionManager -import software.aws.toolkits.jetbrains.core.credentials.sso.DeviceAuthorizationGrantToken import software.aws.toolkits.jetbrains.core.credentials.sso.bearer.BearerTokenProvider import software.aws.toolkits.jetbrains.services.amazonqFeatureDev.clients.FeatureDevClient import software.aws.toolkits.jetbrains.services.amazonqFeatureDev.session.CodeGenerationStreamResult @@ -41,7 +39,6 @@ import software.aws.toolkits.jetbrains.utils.rules.HeavyJavaCodeInsightTestFixtu import software.aws.toolkits.jetbrains.utils.rules.JavaCodeInsightTestFixtureRule import software.aws.toolkits.jetbrains.utils.rules.addModule import java.io.File -import java.time.Instant open class FeatureDevTestBase( @Rule @JvmField @@ -164,11 +161,7 @@ open class FeatureDevTestBase( open fun setup() { project = projectRule.project toolkitConnectionManager = spy(ToolkitConnectionManager.getInstance(project)) - val accessToken = DeviceAuthorizationGrantToken(aString(), aString(), aString(), aString(), Instant.MAX, Instant.now()) - val provider = - mock { - doReturn(accessToken).whenever(it).refresh() - } + val provider = mock() val mockBearerProvider = mock { doReturn(provider).whenever(it).delegate diff --git a/plugins/amazonq/codetransform/jetbrains-community/src/software/aws/toolkits/jetbrains/services/codemodernizer/CodeTransformChatApp.kt b/plugins/amazonq/codetransform/jetbrains-community/src/software/aws/toolkits/jetbrains/services/codemodernizer/CodeTransformChatApp.kt index aba84e640fb..0d647e0b7b7 100644 --- a/plugins/amazonq/codetransform/jetbrains-community/src/software/aws/toolkits/jetbrains/services/codemodernizer/CodeTransformChatApp.kt +++ b/plugins/amazonq/codetransform/jetbrains-community/src/software/aws/toolkits/jetbrains/services/codemodernizer/CodeTransformChatApp.kt @@ -149,7 +149,7 @@ class CodeTransformChatApp : AmazonQApp { ApplicationManager.getApplication().messageBus.connect(this).subscribe( BearerTokenProviderListener.TOPIC, object : BearerTokenProviderListener { - override fun onChange(providerId: String, newScopes: List?) { + override fun onProviderChange(providerId: String, newScopes: List?) { val qProvider = getQTokenProvider(context.project) val isQ = qProvider?.id == providerId val isAuthorized = qProvider?.state() == BearerTokenAuthState.AUTHORIZED diff --git a/plugins/amazonq/codetransform/jetbrains-community/src/software/aws/toolkits/jetbrains/services/codemodernizer/utils/CodeTransformApiUtils.kt b/plugins/amazonq/codetransform/jetbrains-community/src/software/aws/toolkits/jetbrains/services/codemodernizer/utils/CodeTransformApiUtils.kt index de8785b5c94..db25052ec40 100644 --- a/plugins/amazonq/codetransform/jetbrains-community/src/software/aws/toolkits/jetbrains/services/codemodernizer/utils/CodeTransformApiUtils.kt +++ b/plugins/amazonq/codetransform/jetbrains-community/src/software/aws/toolkits/jetbrains/services/codemodernizer/utils/CodeTransformApiUtils.kt @@ -87,11 +87,6 @@ suspend fun JobId.pollTransformationStatusAndPlan( var transformationPlan: TransformationPlan? = null var didSleepOnce = false var hasSeenTransforming = false - val maxRefreshes = 10 - var numRefreshes = 0 - - // refresh token at start of polling since local build just prior can take a long time - refreshToken(project) try { waitUntil( @@ -138,13 +133,10 @@ suspend fun JobId.pollTransformationStatusAndPlan( onStateChange(state, newStatus, transformationPlan) } state = newStatus - numRefreshes = 0 - return@waitUntil state - } catch (e: AccessDeniedException) { - if (numRefreshes++ > maxRefreshes) throw e - refreshToken(project) return@waitUntil state - } catch (e: InvalidGrantException) { + } catch (e: Exception) { + if (e !is AccessDeniedException && e !is InvalidGrantException) throw e + CodeTransformMessageListener.instance.onReauthStarted() notifyStickyWarn( message("codemodernizer.notification.warn.expired_credentials.title"), diff --git a/plugins/amazonq/codetransform/jetbrains-community/src/software/aws/toolkits/jetbrains/services/codemodernizer/utils/CodeTransformUtils.kt b/plugins/amazonq/codetransform/jetbrains-community/src/software/aws/toolkits/jetbrains/services/codemodernizer/utils/CodeTransformUtils.kt index 5c8a01f20a9..6729ce167f1 100644 --- a/plugins/amazonq/codetransform/jetbrains-community/src/software/aws/toolkits/jetbrains/services/codemodernizer/utils/CodeTransformUtils.kt +++ b/plugins/amazonq/codetransform/jetbrains-community/src/software/aws/toolkits/jetbrains/services/codemodernizer/utils/CodeTransformUtils.kt @@ -8,7 +8,6 @@ import com.intellij.openapi.vfs.VfsUtilCore import com.intellij.openapi.vfs.VirtualFileManager import software.amazon.awssdk.services.codewhispererruntime.model.TransformationLanguage import software.amazon.awssdk.services.codewhispererruntime.model.TransformationStatus -import software.aws.toolkits.core.TokenConnectionSettings import software.aws.toolkits.jetbrains.core.credentials.AwsBearerTokenConnection import software.aws.toolkits.jetbrains.core.credentials.ToolkitConnectionManager import software.aws.toolkits.jetbrains.core.credentials.pinning.QConnection @@ -43,12 +42,6 @@ val STATES_AFTER_STARTED = setOf( *STATES_AFTER_INITIAL_BUILD.toTypedArray(), ) -fun refreshToken(project: Project) { - val connection = ToolkitConnectionManager.getInstance(project).activeConnectionForFeature(QConnection.getInstance()) - val provider = (connection?.getConnectionSettings() as TokenConnectionSettings).tokenProvider.delegate as BearerTokenProvider - provider.refresh() -} - fun getAuthType(project: Project): CredentialSourceId? { val connection = checkBearerConnectionValidity(project, BearerTokenFeatureSet.Q) var authType: CredentialSourceId? = null diff --git a/plugins/amazonq/codetransform/jetbrains-community/tst/software/aws/toolkits/jetbrains/services/codemodernizer/CodeWhispererCodeModernizerTestBase.kt b/plugins/amazonq/codetransform/jetbrains-community/tst/software/aws/toolkits/jetbrains/services/codemodernizer/CodeWhispererCodeModernizerTestBase.kt index e640eebbdd2..6ba37aa6b9f 100644 --- a/plugins/amazonq/codetransform/jetbrains-community/tst/software/aws/toolkits/jetbrains/services/codemodernizer/CodeWhispererCodeModernizerTestBase.kt +++ b/plugins/amazonq/codetransform/jetbrains-community/tst/software/aws/toolkits/jetbrains/services/codemodernizer/CodeWhispererCodeModernizerTestBase.kt @@ -45,7 +45,6 @@ import software.aws.toolkits.core.credentials.ToolkitBearerTokenProvider import software.aws.toolkits.core.utils.test.aString import software.aws.toolkits.jetbrains.core.credentials.AwsBearerTokenConnection import software.aws.toolkits.jetbrains.core.credentials.ToolkitConnectionManager -import software.aws.toolkits.jetbrains.core.credentials.sso.DeviceAuthorizationGrantToken import software.aws.toolkits.jetbrains.core.credentials.sso.PKCEAuthorizationGrantToken import software.aws.toolkits.jetbrains.core.credentials.sso.bearer.BearerTokenAuthState import software.aws.toolkits.jetbrains.core.credentials.sso.bearer.BearerTokenProvider @@ -250,11 +249,7 @@ open class CodeWhispererCodeModernizerTestBase( project = projectRule.project toolkitConnectionManager = spy(ToolkitConnectionManager.getInstance(project)) - val accessToken = DeviceAuthorizationGrantToken(aString(), aString(), aString(), aString(), Instant.MAX, Instant.now()) - val provider = - mock { - doReturn(accessToken).whenever(it).refresh() - } + val provider = mock { } val mockBearerProvider = mock { doReturn(provider).whenever(it).delegate @@ -340,7 +335,6 @@ open class CodeWhispererCodeModernizerTestBase( val accessToken = PKCEAuthorizationGrantToken(aString(), aString(), aString(), aString(), Instant.MAX, Instant.now()) val provider = mock { - doReturn(accessToken).whenever(it).refresh() doReturn(accessToken).whenever(it).currentToken() doReturn(authState).whenever(it).state() } diff --git a/plugins/amazonq/codetransform/jetbrains-community/tst/software/aws/toolkits/jetbrains/services/codemodernizer/CodeWhispererCodeModernizerUtilsTest.kt b/plugins/amazonq/codetransform/jetbrains-community/tst/software/aws/toolkits/jetbrains/services/codemodernizer/CodeWhispererCodeModernizerUtilsTest.kt index 38335097869..2699ddc6178 100644 --- a/plugins/amazonq/codetransform/jetbrains-community/tst/software/aws/toolkits/jetbrains/services/codemodernizer/CodeWhispererCodeModernizerUtilsTest.kt +++ b/plugins/amazonq/codetransform/jetbrains-community/tst/software/aws/toolkits/jetbrains/services/codemodernizer/CodeWhispererCodeModernizerUtilsTest.kt @@ -38,7 +38,6 @@ import software.aws.toolkits.jetbrains.services.codemodernizer.utils.getTableMap import software.aws.toolkits.jetbrains.services.codemodernizer.utils.isPlanComplete import software.aws.toolkits.jetbrains.services.codemodernizer.utils.parseBuildFile import software.aws.toolkits.jetbrains.services.codemodernizer.utils.pollTransformationStatusAndPlan -import software.aws.toolkits.jetbrains.services.codemodernizer.utils.refreshToken import software.aws.toolkits.jetbrains.services.codemodernizer.utils.validateCustomVersionsFile import software.aws.toolkits.jetbrains.services.codemodernizer.utils.validateSctMetadata import software.aws.toolkits.jetbrains.utils.notifyStickyWarn @@ -90,18 +89,18 @@ class CodeWhispererCodeModernizerUtilsTest : CodeWhispererCodeModernizerTestBase } @Test - fun `refresh on access denied`() { + fun `show re-auth notification on access denied`() { val mockAccessDeniedException = Mockito.mock(AccessDeniedException::class.java) - mockkStatic(::refreshToken) - every { refreshToken(any()) } just runs + mockkStatic(::notifyStickyWarn) + every { notifyStickyWarn(any(), any(), any(), any(), any()) } just runs Mockito.doThrow( mockAccessDeniedException ).doReturn( exampleGetCodeMigrationResponse, exampleGetCodeMigrationResponse.replace(TransformationStatus.STARTED), - exampleGetCodeMigrationResponse.replace(TransformationStatus.COMPLETED), // Should stop before this point + exampleGetCodeMigrationResponse.replace(TransformationStatus.COMPLETED), ).whenever(clientAdaptorSpy).getCodeModernizationJob(any()) Mockito.doReturn(exampleGetCodeMigrationPlanResponse) @@ -128,7 +127,7 @@ class CodeWhispererCodeModernizerUtilsTest : CodeWhispererCodeModernizerTestBase TransformationStatus.STARTED, ) assertThat(expected).isEqualTo(mutableList) - io.mockk.verify { refreshToken(any()) } + verify { notifyStickyWarn(message("codemodernizer.notification.warn.expired_credentials.title"), any(), any(), any(), any()) } } @Test diff --git a/plugins/amazonq/codewhisperer/jetbrains-community/src/software/aws/toolkits/jetbrains/services/codewhisperer/status/CodeWhispererStatusBarWidget.kt b/plugins/amazonq/codewhisperer/jetbrains-community/src/software/aws/toolkits/jetbrains/services/codewhisperer/status/CodeWhispererStatusBarWidget.kt index b34d1265393..a5c3df0caf8 100644 --- a/plugins/amazonq/codewhisperer/jetbrains-community/src/software/aws/toolkits/jetbrains/services/codewhisperer/status/CodeWhispererStatusBarWidget.kt +++ b/plugins/amazonq/codewhisperer/jetbrains-community/src/software/aws/toolkits/jetbrains/services/codewhisperer/status/CodeWhispererStatusBarWidget.kt @@ -56,7 +56,7 @@ class CodeWhispererStatusBarWidget(project: Project) : ApplicationManager.getApplication().messageBus.connect(this).subscribe( BearerTokenProviderListener.TOPIC, object : BearerTokenProviderListener { - override fun onChange(providerId: String, newScopes: List?) { + override fun onProviderChange(providerId: String, newScopes: List?) { statusBar.updateWidget(ID) } } diff --git a/plugins/amazonq/shared/jetbrains-community/src/software/aws/toolkits/jetbrains/services/amazonq/lsp/auth/AuthCredentialsService.kt b/plugins/amazonq/shared/jetbrains-community/src/software/aws/toolkits/jetbrains/services/amazonq/lsp/auth/AuthCredentialsService.kt deleted file mode 100644 index e2966004421..00000000000 --- a/plugins/amazonq/shared/jetbrains-community/src/software/aws/toolkits/jetbrains/services/amazonq/lsp/auth/AuthCredentialsService.kt +++ /dev/null @@ -1,13 +0,0 @@ -// Copyright 2025 Amazon.com, Inc. or its affiliates. All Rights Reserved. -// SPDX-License-Identifier: Apache-2.0 - -package software.aws.toolkits.jetbrains.services.amazonq.lsp.auth - -import org.eclipse.lsp4j.jsonrpc.messages.ResponseMessage -import software.aws.toolkits.jetbrains.core.credentials.ToolkitConnection -import java.util.concurrent.CompletableFuture - -interface AuthCredentialsService { - fun updateTokenCredentials(connection: ToolkitConnection, encrypted: Boolean): CompletableFuture - fun deleteTokenCredentials() -} diff --git a/plugins/amazonq/shared/jetbrains-community/src/software/aws/toolkits/jetbrains/services/amazonq/lsp/auth/DefaultAuthCredentialsService.kt b/plugins/amazonq/shared/jetbrains-community/src/software/aws/toolkits/jetbrains/services/amazonq/lsp/auth/DefaultAuthCredentialsService.kt index de03c0e7fa8..bb420a4e034 100644 --- a/plugins/amazonq/shared/jetbrains-community/src/software/aws/toolkits/jetbrains/services/amazonq/lsp/auth/DefaultAuthCredentialsService.kt +++ b/plugins/amazonq/shared/jetbrains-community/src/software/aws/toolkits/jetbrains/services/amazonq/lsp/auth/DefaultAuthCredentialsService.kt @@ -19,7 +19,6 @@ import software.aws.toolkits.jetbrains.core.credentials.ToolkitConnection import software.aws.toolkits.jetbrains.core.credentials.ToolkitConnectionManager import software.aws.toolkits.jetbrains.core.credentials.ToolkitConnectionManagerListener import software.aws.toolkits.jetbrains.core.credentials.pinning.QConnection -import software.aws.toolkits.jetbrains.core.credentials.sso.bearer.BearerTokenAuthState import software.aws.toolkits.jetbrains.core.credentials.sso.bearer.BearerTokenProvider import software.aws.toolkits.jetbrains.core.credentials.sso.bearer.BearerTokenProviderListener import software.aws.toolkits.jetbrains.services.amazonq.lsp.AmazonQLspService @@ -43,15 +42,14 @@ class DefaultAuthCredentialsService( private val project: Project, private val encryptionManager: JwtEncryptionManager, private val cs: CoroutineScope, -) : AuthCredentialsService, - BearerTokenProviderListener, +) : BearerTokenProviderListener, ToolkitConnectionManagerListener, QRegionProfileSelectedListener, Disposable { private val scheduler: ScheduledExecutorService = AppExecutorUtil.getAppScheduledExecutorService() - private var tokenSyncTask: ScheduledFuture<*>? = null - private val tokenSyncIntervalMinutes = 5L + private var tokenRefreshTask: ScheduledFuture<*>? = null + private val tokenRefreshInterval = 5L init { project.messageBus.connect(this).apply { @@ -67,49 +65,37 @@ class DefaultAuthCredentialsService( } } - // Start periodic token sync - startPeriodicTokenSync() + // Start periodic token refresh + startPeriodicTokenRefresh() } - private fun startPeriodicTokenSync() { - tokenSyncTask = scheduler.scheduleWithFixedDelay( + // TODO: we really only need a single application-wide instance of this + private fun startPeriodicTokenRefresh() { + tokenRefreshTask = scheduler.scheduleWithFixedDelay( { try { if (isQConnected(project)) { - if (isQExpired(project)) { - val manager = ToolkitConnectionManager.getInstance(project) - val connection = manager.activeConnectionForFeature(QConnection.getInstance()) ?: return@scheduleWithFixedDelay - - // Try to refresh the token if it's in NEEDS_REFRESH state - val tokenProvider = (connection.getConnectionSettings() as? TokenConnectionSettings) - ?.tokenProvider - ?.delegate - ?.let { it as? BearerTokenProvider } ?: return@scheduleWithFixedDelay - - if (tokenProvider.state() == BearerTokenAuthState.NEEDS_REFRESH) { - try { - tokenProvider.resolveToken() - // Now that the token is refreshed, update it in Flare - updateTokenFromActiveConnection() - } catch (e: Exception) { - LOG.warn(e) { "Failed to refresh bearer token" } - } - } - } else { - updateTokenFromActiveConnection() - } + val manager = ToolkitConnectionManager.getInstance(project) + val connection = manager.activeConnectionForFeature(QConnection.getInstance()) ?: return@scheduleWithFixedDelay + + // periodically poll token to trigger a background refresh if needed + val tokenProvider = (connection.getConnectionSettings() as? TokenConnectionSettings) + ?.tokenProvider + ?.delegate + ?.let { it as? BearerTokenProvider } ?: return@scheduleWithFixedDelay + tokenProvider.resolveToken() } } catch (e: Exception) { - LOG.warn(e) { "Failed to sync bearer token to Flare" } + LOG.warn(e) { "Failed to refresh bearer token" } } }, - tokenSyncIntervalMinutes, - tokenSyncIntervalMinutes, + tokenRefreshInterval, + tokenRefreshInterval, TimeUnit.MINUTES ) } - override fun updateTokenCredentials(connection: ToolkitConnection, encrypted: Boolean): CompletableFuture { + fun updateTokenCredentials(connection: ToolkitConnection, encrypted: Boolean): CompletableFuture { val payload = try { createUpdateCredentialsPayload(connection, encrypted) } catch (e: Exception) { @@ -129,7 +115,7 @@ class DefaultAuthCredentialsService( }.asCompletableFuture() } - override fun deleteTokenCredentials() { + fun deleteTokenCredentials() { cs.launch { AmazonQLspService.executeAsyncIfRunning(project) { server -> server.deleteTokenCredentials() @@ -137,10 +123,18 @@ class DefaultAuthCredentialsService( } } - override fun onChange(providerId: String, newScopes: List?) { + override fun onProviderChange(providerId: String, newScopes: List?) { updateTokenFromActiveConnection() } + override fun onTokenModified(providerId: String) { + updateTokenFromActiveConnection() + } + + override fun invalidate(providerId: String) { + deleteTokenCredentials() + } + override fun activeConnectionChanged(newConnection: ToolkitConnection?) { val qConnection = ToolkitConnectionManager.getInstance(project) .activeConnectionForFeature(QConnection.getInstance()) @@ -161,10 +155,6 @@ class DefaultAuthCredentialsService( private fun updateTokenFromConnection(connection: ToolkitConnection): CompletableFuture = updateTokenCredentials(connection, true) - override fun invalidate(providerId: String) { - deleteTokenCredentials() - } - private fun createUpdateCredentialsPayload(connection: ToolkitConnection, encrypted: Boolean): UpdateCredentialsPayload { val token = (connection.getConnectionSettings() as? TokenConnectionSettings) ?.tokenProvider @@ -212,8 +202,8 @@ class DefaultAuthCredentialsService( } override fun dispose() { - tokenSyncTask?.cancel(false) - tokenSyncTask = null + tokenRefreshTask?.cancel(false) + tokenRefreshTask = null } companion object { diff --git a/plugins/amazonq/shared/jetbrains-community/src/software/aws/toolkits/jetbrains/services/amazonq/profile/QRegionProfileManager.kt b/plugins/amazonq/shared/jetbrains-community/src/software/aws/toolkits/jetbrains/services/amazonq/profile/QRegionProfileManager.kt index 995bea0efe7..3729a4e0fc3 100644 --- a/plugins/amazonq/shared/jetbrains-community/src/software/aws/toolkits/jetbrains/services/amazonq/profile/QRegionProfileManager.kt +++ b/plugins/amazonq/shared/jetbrains-community/src/software/aws/toolkits/jetbrains/services/amazonq/profile/QRegionProfileManager.kt @@ -114,7 +114,12 @@ class QRegionProfileManager : PersistentStateComponent, Disposabl connectionIdToProfileCount[connection.id] = it.size } ?: error("You don't have access to the resource") } catch (e: Exception) { - LOG.warn(e) { "Failed to list region profiles: ${e.message}" } + if (e is AccessDeniedException) { + LOG.warn { "Failed to list region profiles: ${e.message}" } + } else { + LOG.warn(e) { "Failed to list region profiles" } + } + throw e } } diff --git a/plugins/amazonq/shared/jetbrains-community/tst/software/aws/toolkits/jetbrains/services/amazonq/lsp/auth/DefaultAuthCredentialsServiceTest.kt b/plugins/amazonq/shared/jetbrains-community/tst/software/aws/toolkits/jetbrains/services/amazonq/lsp/auth/DefaultAuthCredentialsServiceTest.kt index aef0428732a..a5055b4077f 100644 --- a/plugins/amazonq/shared/jetbrains-community/tst/software/aws/toolkits/jetbrains/services/amazonq/lsp/auth/DefaultAuthCredentialsServiceTest.kt +++ b/plugins/amazonq/shared/jetbrains-community/tst/software/aws/toolkits/jetbrains/services/amazonq/lsp/auth/DefaultAuthCredentialsServiceTest.kt @@ -180,7 +180,7 @@ class DefaultAuthCredentialsServiceTest { sut = DefaultAuthCredentialsService(project, mockEncryptionManager, this) setupMockConnectionManager("updated-token") - sut.onChange("providerId", listOf("new-scope")) + sut.onProviderChange("providerId", listOf("new-scope")) advanceUntilIdle() verify(exactly = 1) { mockLanguageServer.updateTokenCredentials(any()) } diff --git a/plugins/core/jetbrains-community/src/software/aws/toolkits/jetbrains/core/AwsClientManager.kt b/plugins/core/jetbrains-community/src/software/aws/toolkits/jetbrains/core/AwsClientManager.kt index 23a32aa0baa..8c59456f7d4 100644 --- a/plugins/core/jetbrains-community/src/software/aws/toolkits/jetbrains/core/AwsClientManager.kt +++ b/plugins/core/jetbrains-community/src/software/aws/toolkits/jetbrains/core/AwsClientManager.kt @@ -51,7 +51,7 @@ open class AwsClientManager : ToolkitClientManager(), Disposable { busConnection.subscribe( BearerTokenProviderListener.TOPIC, object : BearerTokenProviderListener { - override fun onChange(providerId: String, newScopes: List?) { + override fun onProviderChange(providerId: String, newScopes: List?) { invalidateSdks(providerId) } diff --git a/plugins/core/jetbrains-community/src/software/aws/toolkits/jetbrains/core/AwsResourceCache.kt b/plugins/core/jetbrains-community/src/software/aws/toolkits/jetbrains/core/AwsResourceCache.kt index 07b1847b127..c7ad4da7998 100644 --- a/plugins/core/jetbrains-community/src/software/aws/toolkits/jetbrains/core/AwsResourceCache.kt +++ b/plugins/core/jetbrains-community/src/software/aws/toolkits/jetbrains/core/AwsResourceCache.kt @@ -165,7 +165,7 @@ class DefaultAwsResourceCache( subscribe( BearerTokenProviderListener.TOPIC, object : BearerTokenProviderListener { - override fun onChange(providerId: String, newScopes: List?) { + override fun onProviderChange(providerId: String, newScopes: List?) { clearByCredential(providerId) } } diff --git a/plugins/core/jetbrains-community/src/software/aws/toolkits/jetbrains/core/credentials/CredentialManager.kt b/plugins/core/jetbrains-community/src/software/aws/toolkits/jetbrains/core/credentials/CredentialManager.kt index e1e19dbac36..6c50c80597c 100644 --- a/plugins/core/jetbrains-community/src/software/aws/toolkits/jetbrains/core/credentials/CredentialManager.kt +++ b/plugins/core/jetbrains-community/src/software/aws/toolkits/jetbrains/core/credentials/CredentialManager.kt @@ -64,7 +64,7 @@ class DefaultCredentialManager : CredentialManager(), Disposable { ApplicationManager.getApplication().messageBus.connect(this).subscribe( BearerTokenProviderListener.TOPIC, object : BearerTokenProviderListener { - override fun onChange(providerId: String, newScopes: List?) { + override fun onProviderChange(providerId: String, newScopes: List?) { modifyDependentProviders(providerId) } diff --git a/plugins/core/jetbrains-community/src/software/aws/toolkits/jetbrains/core/credentials/DefaultToolkitAuthManager.kt b/plugins/core/jetbrains-community/src/software/aws/toolkits/jetbrains/core/credentials/DefaultToolkitAuthManager.kt index 83e36ad350c..1b3a1f1c520 100644 --- a/plugins/core/jetbrains-community/src/software/aws/toolkits/jetbrains/core/credentials/DefaultToolkitAuthManager.kt +++ b/plugins/core/jetbrains-community/src/software/aws/toolkits/jetbrains/core/credentials/DefaultToolkitAuthManager.kt @@ -61,7 +61,7 @@ class DefaultToolkitAuthManager : ToolkitAuthManager, PersistentStateComponent if (isDuplicate && existOldConn is Disposable) { ApplicationManager.getApplication().messageBus.syncPublisher(BearerTokenProviderListener.TOPIC) - .onChange(existOldConn.id, newConnection.scopes) + .onProviderChange(existOldConn.id, newConnection.scopes) Disposer.dispose(existOldConn) } } @@ -157,7 +157,7 @@ class DefaultToolkitAuthManager : ToolkitAuthManager, PersistentStateComponent if (isDuplicate && existOldConn is Disposable) { ApplicationManager.getApplication().messageBus.syncPublisher(BearerTokenProviderListener.TOPIC) - .onChange(existOldConn.id, newConnection.scopes) + .onProviderChange(existOldConn.id, newConnection.scopes) Disposer.dispose(existOldConn) } } diff --git a/plugins/core/jetbrains-community/src/software/aws/toolkits/jetbrains/core/credentials/sso/SsoAccessTokenProvider.kt b/plugins/core/jetbrains-community/src/software/aws/toolkits/jetbrains/core/credentials/sso/SsoAccessTokenProvider.kt index 7fd718a8220..6b3c9280201 100644 --- a/plugins/core/jetbrains-community/src/software/aws/toolkits/jetbrains/core/credentials/sso/SsoAccessTokenProvider.kt +++ b/plugins/core/jetbrains-community/src/software/aws/toolkits/jetbrains/core/credentials/sso/SsoAccessTokenProvider.kt @@ -194,7 +194,7 @@ class SsoAccessTokenProvider( @Deprecated("Device authorization grant flow is deprecated") private fun registerDAGClient(): ClientRegistration { - loadDagClientRegistration(SourceOfLoadRegistration.REGISTER_CLIENT.toString())?.let { + loadDagClientRegistration(SourceOfLoadRegistration.REGISTER_CLIENT)?.let { return it } @@ -235,7 +235,7 @@ class SsoAccessTokenProvider( } private fun registerPkceClient(): PKCEClientRegistration { - loadPkceClientRegistration(SourceOfLoadRegistration.REGISTER_CLIENT.toString())?.let { + loadPkceClientRegistration(SourceOfLoadRegistration.REGISTER_CLIENT)?.let { return it } @@ -431,8 +431,8 @@ class SsoAccessTokenProvider( stageName = RefreshCredentialStage.LOAD_REGISTRATION val registration = try { when (currentToken) { - is DeviceAuthorizationGrantToken -> loadDagClientRegistration(SourceOfLoadRegistration.REFRESH_TOKEN.toString()) - is PKCEAuthorizationGrantToken -> loadPkceClientRegistration(SourceOfLoadRegistration.REFRESH_TOKEN.toString()) + is DeviceAuthorizationGrantToken -> loadDagClientRegistration(SourceOfLoadRegistration.REFRESH_TOKEN) + is PKCEAuthorizationGrantToken -> loadPkceClientRegistration(SourceOfLoadRegistration.REFRESH_TOKEN) } } catch (e: Exception) { val message = e.message ?: "$stageName: ${e::class.java.name}" @@ -519,13 +519,13 @@ class SsoAccessTokenProvider( SAVE_TOKEN, } - private fun loadDagClientRegistration(source: String): ClientRegistration? = - cache.loadClientRegistration(dagClientRegistrationCacheKey, source)?.let { + private fun loadDagClientRegistration(source: SourceOfLoadRegistration): ClientRegistration? = + cache.loadClientRegistration(dagClientRegistrationCacheKey, source.toString())?.let { return it } - private fun loadPkceClientRegistration(source: String): PKCEClientRegistration? = - cache.loadClientRegistration(pkceClientRegistrationCacheKey, source)?.let { + private fun loadPkceClientRegistration(source: SourceOfLoadRegistration): PKCEClientRegistration? = + cache.loadClientRegistration(pkceClientRegistrationCacheKey, source.toString())?.let { return it as PKCEClientRegistration } diff --git a/plugins/core/jetbrains-community/src/software/aws/toolkits/jetbrains/core/credentials/sso/bearer/BearerTokenProvider.kt b/plugins/core/jetbrains-community/src/software/aws/toolkits/jetbrains/core/credentials/sso/bearer/BearerTokenProvider.kt index a8256f3394a..902e39a9e35 100644 --- a/plugins/core/jetbrains-community/src/software/aws/toolkits/jetbrains/core/credentials/sso/bearer/BearerTokenProvider.kt +++ b/plugins/core/jetbrains-community/src/software/aws/toolkits/jetbrains/core/credentials/sso/bearer/BearerTokenProvider.kt @@ -17,6 +17,7 @@ import software.amazon.awssdk.regions.Region import software.amazon.awssdk.services.ssooidc.SsoOidcClient import software.amazon.awssdk.services.ssooidc.SsoOidcTokenProvider import software.amazon.awssdk.services.ssooidc.internal.OnDiskTokenManager +import software.amazon.awssdk.services.ssooidc.model.InvalidGrantException import software.amazon.awssdk.services.ssooidc.model.SsoOidcException import software.amazon.awssdk.utils.SdkAutoCloseable import software.amazon.awssdk.utils.cache.CachedSupplier @@ -35,9 +36,12 @@ import software.aws.toolkits.jetbrains.core.credentials.sso.DeviceAuthorizationG import software.aws.toolkits.jetbrains.core.credentials.sso.DiskCache import software.aws.toolkits.jetbrains.core.credentials.sso.PendingAuthorization import software.aws.toolkits.jetbrains.core.credentials.sso.SsoAccessTokenProvider +import software.aws.toolkits.jetbrains.core.credentials.sso.bearer.BearerTokenProviderListener.Companion.TOPIC +import java.time.Clock import java.time.Duration -import java.time.Instant +import java.util.concurrent.atomic.AtomicBoolean import java.util.concurrent.atomic.AtomicReference +import java.util.function.Supplier internal interface BearerTokenLogoutSupport @@ -47,11 +51,6 @@ interface BearerTokenProvider : SdkTokenProvider, SdkAutoCloseable, ToolkitBeare */ fun currentToken(): AccessToken? - /** - * Not meant to be invoked outside the implementation - */ - fun refresh(): AccessToken - /** * @return The authentication state of [currentToken] */ @@ -71,11 +70,11 @@ interface BearerTokenProvider : SdkTokenProvider, SdkAutoCloseable, ToolkitBeare } companion object { - internal fun tokenExpired(accessToken: AccessToken) = Instant.now().isAfter(accessToken.expiresAt) + private fun tokenExpired(accessToken: AccessToken, clock: Clock) = clock.instant().isAfter(accessToken.expiresAt) - internal fun state(accessToken: AccessToken?) = when { + internal fun state(accessToken: AccessToken?, clock: Clock = Clock.systemUTC()) = when { accessToken == null -> BearerTokenAuthState.NOT_AUTHENTICATED - tokenExpired(accessToken) -> { + tokenExpired(accessToken, clock) -> { if (accessToken.refreshToken != null) { BearerTokenAuthState.NEEDS_REFRESH } else { @@ -94,6 +93,7 @@ class InteractiveBearerTokenProvider( val scopes: List, override val id: String, cache: DiskCache = diskCache, + private val clock: Clock = Clock.systemUTC(), ) : BearerTokenProvider, BearerTokenLogoutSupport, Disposable { override val displayName = ToolkitBearerTokenProvider.ssoDisplayName(startUrl) @@ -107,16 +107,14 @@ class InteractiveBearerTokenProvider( scopes = scopes ) - private val supplier = CachedSupplier.builder { refreshToken() }.prefetchStrategy(NonBlocking("AWS SSO bearer token refresher")).build() - internal val lastToken = AtomicReference() + private var supplier = supplier() + val pendingAuthorization: PendingAuthorization? get() = accessTokenProvider.authorization init { - lastToken.set(accessTokenProvider.loadAccessToken()) - ApplicationManager.getApplication().messageBus.connect(this).subscribe( - BearerTokenProviderListener.TOPIC, + TOPIC, object : BearerTokenProviderListener { override fun invalidate(providerId: String) { if (id == providerId) { @@ -124,7 +122,7 @@ class InteractiveBearerTokenProvider( } } - override fun onChange(providerId: String, newScopes: List?) { + override fun onProviderChange(providerId: String, newScopes: List?) { newScopes?.let { if (id == providerId && it.toSet() != scopes.toSet()) { invalidate() @@ -135,27 +133,70 @@ class InteractiveBearerTokenProvider( ) } - // we need to seed CachedSupplier with an initial value, then subsequent calls need to hit the network - private fun refreshToken(): RefreshResult { - val lastToken = lastToken.get() ?: throw NoTokenInitializedException("Token refresh started before session initialized") - val token = if (Duration.between(Instant.now(), lastToken.expiresAt) > Duration.ofMinutes(30)) { - lastToken - } else { - refresh() + private data class SupplierHolder( + val supplier: SupplierWithInitialValue, + val cachedSupplier: CachedSupplier, + ) + + private fun supplier(initialValue: AccessToken? = null) = + SupplierWithInitialValue(initialValue, accessTokenProvider).let { + SupplierHolder( + it, + CachedSupplier.builder(it).clock(clock).prefetchStrategy(NonBlocking("AWS SSO bearer token refresher")).build() + ) } - return RefreshResult.builder(token) - .staleTime(token.expiresAt.minus(DEFAULT_STALE_DURATION)) - .prefetchTime(token.expiresAt.minus(DEFAULT_PREFETCH_DURATION)) - .build() + private inner class SupplierWithInitialValue( + initial: AccessToken?, + val accessTokenProvider: SsoAccessTokenProvider, + ) : Supplier> { + private val hasCalledAtLeastOnce = AtomicBoolean(false) + private val initialValue = initial ?: accessTokenProvider.loadAccessToken() + val lastToken = AtomicReference(initialValue) + + // we need to seed CachedSupplier with an initial value, then subsequent calls need to hit the network + override fun get(): RefreshResult { + val token = if (hasCalledAtLeastOnce.getAndSet(true)) { + refresh() + } else { + // on initial call, refresh if needed + if (initialValue != null && initialValue.expiresAt.minus(DEFAULT_PREFETCH_DURATION) < clock.instant()) { + refresh() + } else { + initialValue ?: throw NoTokenInitializedException("Token provider initialized with no token") + } + } + + return RefreshResult.builder(token) + .staleTime(token.expiresAt.minus(DEFAULT_STALE_DURATION)) + .prefetchTime(token.expiresAt.minus(DEFAULT_PREFETCH_DURATION)) + .build() + } + + fun refresh(): AccessToken { + val lastToken = lastToken.get() ?: throw NoTokenInitializedException("Token refresh started before session initialized") + return try { + accessTokenProvider.refreshToken(lastToken).also { + this.lastToken.set(it) + ApplicationManager.getApplication().messageBus.syncPublisher(TOPIC).onTokenModified(id) + } + } catch (e: InvalidGrantException) { + LOG.warn { "Invalidated token due to $e" } + invalidate() + + throw e + } + } } + override fun state() = BearerTokenProvider.state(currentToken(), clock) + // how we expect consumers to obtain a token - override fun resolveToken() = supplier.get() + override fun resolveToken() = supplier.cachedSupplier.get() override fun close() { ssoOidcClient.close() - supplier.close() + supplier.cachedSupplier.close() } override fun dispose() { @@ -163,21 +204,12 @@ class InteractiveBearerTokenProvider( } // internal nonsense so we can query the token without triggering a refresh - override fun currentToken() = lastToken.get() - - /** - * Only use if you know what you're doing. - */ - override fun refresh(): AccessToken { - val lastToken = lastToken.get() ?: throw NoTokenInitializedException("Token refresh started before session initialized") - return accessTokenProvider.refreshToken(lastToken).also { - this.lastToken.set(it) - } - } + override fun currentToken() = supplier.supplier.lastToken.get() override fun invalidate() { accessTokenProvider.invalidate() - lastToken.set(null) + supplier.cachedSupplier.close() + supplier = supplier() BearerTokenProviderListener.notifyCredUpdate(id) } @@ -185,10 +217,15 @@ class InteractiveBearerTokenProvider( // we probably don't need to invalidate this, but we might as well since we need to login again anyways invalidate() accessTokenProvider.accessToken().also { - lastToken.set(it) + supplier.cachedSupplier.close() + supplier = supplier(it) BearerTokenProviderListener.notifyCredUpdate(id) } } + + companion object { + private val LOG = getLogger() + } } class NoTokenInitializedException(message: String) : Exception(message) @@ -226,10 +263,6 @@ class ProfileSdkTokenProviderWrapper(private val sessionName: String, region: St ) } - override fun refresh(): AccessToken { - error("Not yet implemented") - } - override fun close() { sdkTokenManager.close() if (ssoOidcClient.isInitialized()) { diff --git a/plugins/core/jetbrains-community/src/software/aws/toolkits/jetbrains/core/credentials/sso/bearer/BearerTokenProviderListener.kt b/plugins/core/jetbrains-community/src/software/aws/toolkits/jetbrains/core/credentials/sso/bearer/BearerTokenProviderListener.kt index 2bd8af65aa5..40f57ff02d9 100644 --- a/plugins/core/jetbrains-community/src/software/aws/toolkits/jetbrains/core/credentials/sso/bearer/BearerTokenProviderListener.kt +++ b/plugins/core/jetbrains-community/src/software/aws/toolkits/jetbrains/core/credentials/sso/bearer/BearerTokenProviderListener.kt @@ -8,8 +8,19 @@ import com.intellij.util.messages.Topic import java.util.EventListener interface BearerTokenProviderListener : EventListener { - fun onChange(providerId: String, newScopes: List? = null) {} + /** + * Called when token permissions have potentially changed, or is no longer logged in + */ + fun onProviderChange(providerId: String, newScopes: List? = null) {} + /** + * Called when token has changed but connection properties are the same + */ + fun onTokenModified(providerId: String) {} + + /** + * Called when provider is being deleted + */ fun invalidate(providerId: String) {} companion object { @@ -17,7 +28,7 @@ interface BearerTokenProviderListener : EventListener { val TOPIC = Topic.create("AWS SSO bearer token provider status change", BearerTokenProviderListener::class.java) fun notifyCredUpdate(providerId: String) { - ApplicationManager.getApplication().messageBus.syncPublisher(BearerTokenProviderListener.TOPIC).onChange(providerId) + ApplicationManager.getApplication().messageBus.syncPublisher(TOPIC).onProviderChange(providerId) } } } diff --git a/plugins/core/jetbrains-community/tst/software/aws/toolkits/jetbrains/core/credentials/sso/bearer/InteractiveBearerTokenProviderTest.kt b/plugins/core/jetbrains-community/tst/software/aws/toolkits/jetbrains/core/credentials/sso/bearer/InteractiveBearerTokenProviderTest.kt index aa218427852..490f3fb48aa 100644 --- a/plugins/core/jetbrains-community/tst/software/aws/toolkits/jetbrains/core/credentials/sso/bearer/InteractiveBearerTokenProviderTest.kt +++ b/plugins/core/jetbrains-community/tst/software/aws/toolkits/jetbrains/core/credentials/sso/bearer/InteractiveBearerTokenProviderTest.kt @@ -13,11 +13,10 @@ import org.junit.Before import org.junit.Rule import org.junit.Test import org.junit.jupiter.api.assertThrows -import org.mockito.Mockito import org.mockito.kotlin.any import org.mockito.kotlin.argThat -import org.mockito.kotlin.eq import org.mockito.kotlin.mock +import org.mockito.kotlin.reset import org.mockito.kotlin.spy import org.mockito.kotlin.times import org.mockito.kotlin.verify @@ -49,6 +48,7 @@ import software.aws.toolkits.jetbrains.core.credentials.sso.DeviceAuthorizationG import software.aws.toolkits.jetbrains.core.credentials.sso.DeviceGrantAccessTokenCacheKey import software.aws.toolkits.jetbrains.core.credentials.sso.DiskCache import software.aws.toolkits.jetbrains.core.credentials.sso.PKCEAccessTokenCacheKey +import java.time.Clock import java.time.Instant import java.time.temporal.ChronoUnit @@ -158,7 +158,7 @@ class InteractiveBearerTokenProviderTest { } @Test - fun `resolveToken does't refresh if token was retrieved recently`() { + fun `resolveToken doesn't refresh if token was retrieved recently`() { stubClientRegistration() whenever(diskCache.loadAccessToken(any())).thenReturn( DeviceAuthorizationGrantToken( @@ -173,11 +173,56 @@ class InteractiveBearerTokenProviderTest { sut.resolveToken() } + @Test + fun `resolveToken attempts to refresh token on first invoke if expired`() { + stubClientRegistration() + stubAccessToken() + whenever(diskCache.loadAccessToken(any())).thenReturn( + DeviceAuthorizationGrantToken( + startUrl = startUrl, + region = region, + accessToken = "accessToken", + refreshToken = "refreshToken", + expiresAt = Instant.now() + ) + ) + val sut = buildSut() + sut.resolveToken() + + verify(oidcClient).createToken(any()) + } + + @Test + fun `resolveToken refreshes on subsequent invokes if expired`() { + val mockClock = mock() + whenever(mockClock.instant()).thenReturn(Instant.now()) + stubClientRegistration() + stubAccessToken() + whenever(diskCache.loadAccessToken(any())).thenReturn( + DeviceAuthorizationGrantToken( + startUrl = startUrl, + region = region, + accessToken = "accessToken", + refreshToken = "refreshToken", + expiresAt = Instant.now().plus(1, ChronoUnit.HOURS) + ) + ) + val sut = buildSut(mockClock) + // current token should be valid + assertThat(sut.resolveToken().accessToken).isEqualTo("accessToken") + verify(oidcClient, times(0)).createToken(any()) + + // then if we advance the clock it should refresh + whenever(mockClock.instant()).thenReturn(Instant.now().plus(100, ChronoUnit.DAYS)) + assertThat(sut.resolveToken().accessToken).isEqualTo("access1") + verify(oidcClient, times(1)).createToken(any()) + } + @Test fun `resolveToken throws if reauthentication is needed`() { stubClientRegistration() stubAccessToken() - Mockito.reset(oidcClient) + reset(oidcClient) whenever(oidcClient.createToken(any())).thenThrow(AccessDeniedException.create("denied", null)) val sut = buildSut() @@ -195,7 +240,7 @@ class InteractiveBearerTokenProviderTest { val sut = buildSut() sut.invalidate() - verify(mockListener).onChange(sut.id) + verify(mockListener).onProviderChange(sut.id) } @Test @@ -203,10 +248,13 @@ class InteractiveBearerTokenProviderTest { stubClientRegistration() stubAccessToken() val sut = buildSut() + whenever(diskCache.loadAccessToken(any())).thenReturn(null) sut.invalidate() // initial load - verify(diskCache).loadAccessToken(any()) + // invalidate attempts to reload token from disk + verify(diskCache, times(2)).loadAccessToken(any()) + verify(diskCache).loadAccessToken(any()) verify(diskCache).invalidateClientRegistration(region) verify(diskCache).invalidateAccessToken(startUrl) @@ -222,6 +270,10 @@ class InteractiveBearerTokenProviderTest { // nothing else verifyNoMoreInteractions(diskCache) + + // should not have a token now + assertThat(sut.currentToken()?.accessToken).isNull() + assertThrows { sut.resolveToken() } } @Test @@ -230,22 +282,22 @@ class InteractiveBearerTokenProviderTest { stubAccessToken() val sut = buildSut() - assertThat(sut.currentToken()?.accessToken).isEqualTo("accessToken") + assertThat(sut.resolveToken().accessToken).isEqualTo("access1") // and now instead of trying to stub out the entire OIDC device flow, abuse the fact that we short-circuit and read from disk if available - Mockito.reset(diskCache) + reset(diskCache) whenever(diskCache.loadAccessToken(any())).thenReturn( DeviceAuthorizationGrantToken( startUrl = startUrl, region = region, - accessToken = "access1", - refreshToken = "refresh1", + accessToken = "access1234", + refreshToken = "refresh1234", expiresAt = Instant.MAX ) ) sut.reauthenticate() - assertThat(sut.currentToken()?.accessToken).isEqualTo("access1") + assertThat(sut.resolveToken().accessToken).isEqualTo("access1234") } @Test @@ -260,19 +312,20 @@ class InteractiveBearerTokenProviderTest { sut.reauthenticate() // once for invalidate, once after the token has been retrieved - verify(mockListener, times(2)).onChange(sut.id) + verify(mockListener, times(2)).onProviderChange(sut.id) } - private fun buildSut() = InteractiveBearerTokenProvider( + private fun buildSut(clock: Clock = Clock.systemUTC()) = InteractiveBearerTokenProvider( startUrl = startUrl, region = region, scopes = scopes, cache = diskCache, - id = "test" + id = "test", + clock = clock, ) private fun stubClientRegistration() { - whenever(diskCache.loadClientRegistration(any(), eq("testSource"))).thenReturn( + whenever(diskCache.loadClientRegistration(any(), any())).thenReturn( DeviceAuthorizationClientRegistration( "", "", @@ -288,7 +341,7 @@ class InteractiveBearerTokenProviderTest { region = region, accessToken = "accessToken", refreshToken = "refreshToken", - expiresAt = Instant.MIN + expiresAt = Instant.now().minus(100, ChronoUnit.DAYS), ) ) whenever(oidcClient.createToken(any())).thenReturn( diff --git a/plugins/toolkit/jetbrains-core/src/software/aws/toolkits/jetbrains/core/credentials/AwsSettingsPanel.kt b/plugins/toolkit/jetbrains-core/src/software/aws/toolkits/jetbrains/core/credentials/AwsSettingsPanel.kt index f378f83c29b..cc53b06a2ff 100644 --- a/plugins/toolkit/jetbrains-core/src/software/aws/toolkits/jetbrains/core/credentials/AwsSettingsPanel.kt +++ b/plugins/toolkit/jetbrains-core/src/software/aws/toolkits/jetbrains/core/credentials/AwsSettingsPanel.kt @@ -109,7 +109,7 @@ private class AwsSettingsPanel(private val project: Project) : ApplicationManager.getApplication().messageBus.connect(this).subscribe( BearerTokenProviderListener.TOPIC, object : BearerTokenProviderListener { - override fun onChange(providerId: String, newScopes: List?) { + override fun onProviderChange(providerId: String, newScopes: List?) { updateWidget() } diff --git a/plugins/toolkit/jetbrains-core/src/software/aws/toolkits/jetbrains/core/explorer/AbstractExplorerTreeToolWindow.kt b/plugins/toolkit/jetbrains-core/src/software/aws/toolkits/jetbrains/core/explorer/AbstractExplorerTreeToolWindow.kt index b23d4c61598..f1eb5506f5e 100644 --- a/plugins/toolkit/jetbrains-core/src/software/aws/toolkits/jetbrains/core/explorer/AbstractExplorerTreeToolWindow.kt +++ b/plugins/toolkit/jetbrains-core/src/software/aws/toolkits/jetbrains/core/explorer/AbstractExplorerTreeToolWindow.kt @@ -118,7 +118,7 @@ abstract class AbstractExplorerTreeToolWindow( ApplicationManager.getApplication().messageBus.connect(this).subscribe( BearerTokenProviderListener.TOPIC, object : BearerTokenProviderListener { - override fun onChange(providerId: String, newScopes: List?) { + override fun onProviderChange(providerId: String, newScopes: List?) { redraw() } } diff --git a/plugins/toolkit/jetbrains-core/src/software/aws/toolkits/jetbrains/core/explorer/AwsToolkitExplorerFactory.kt b/plugins/toolkit/jetbrains-core/src/software/aws/toolkits/jetbrains/core/explorer/AwsToolkitExplorerFactory.kt index ffdbf2f312b..f0993a3c66a 100644 --- a/plugins/toolkit/jetbrains-core/src/software/aws/toolkits/jetbrains/core/explorer/AwsToolkitExplorerFactory.kt +++ b/plugins/toolkit/jetbrains-core/src/software/aws/toolkits/jetbrains/core/explorer/AwsToolkitExplorerFactory.kt @@ -133,7 +133,7 @@ class AwsToolkitExplorerFactory : ToolWindowFactory, DumbAware { project.messageBus.connect(toolWindow.disposable).subscribe( BearerTokenProviderListener.TOPIC, object : BearerTokenProviderListener { - override fun onChange(providerId: String, newScopes: List?) { + override fun onProviderChange(providerId: String, newScopes: List?) { if (ToolkitConnectionManager.getInstance(project) .connectionStateForFeature(CodeCatalystConnection.getInstance()) == BearerTokenAuthState.AUTHORIZED ) { diff --git a/plugins/toolkit/jetbrains-core/src/software/aws/toolkits/jetbrains/core/gettingstarted/editor/GettingStartedPanel.kt b/plugins/toolkit/jetbrains-core/src/software/aws/toolkits/jetbrains/core/gettingstarted/editor/GettingStartedPanel.kt index 414614db6b5..1103f451171 100644 --- a/plugins/toolkit/jetbrains-core/src/software/aws/toolkits/jetbrains/core/gettingstarted/editor/GettingStartedPanel.kt +++ b/plugins/toolkit/jetbrains-core/src/software/aws/toolkits/jetbrains/core/gettingstarted/editor/GettingStartedPanel.kt @@ -91,7 +91,7 @@ class GettingStartedPanel( ApplicationManager.getApplication().messageBus.connect(this).subscribe( BearerTokenProviderListener.TOPIC, object : BearerTokenProviderListener { - override fun onChange(providerId: String, newScopes: List?) { + override fun onProviderChange(providerId: String, newScopes: List?) { connectionUpdated() } } diff --git a/plugins/toolkit/jetbrains-core/src/software/aws/toolkits/jetbrains/ui/connection/CawsLoginOverlay.kt b/plugins/toolkit/jetbrains-core/src/software/aws/toolkits/jetbrains/ui/connection/CawsLoginOverlay.kt index e106cfbcac1..0135dc1d93f 100644 --- a/plugins/toolkit/jetbrains-core/src/software/aws/toolkits/jetbrains/ui/connection/CawsLoginOverlay.kt +++ b/plugins/toolkit/jetbrains-core/src/software/aws/toolkits/jetbrains/ui/connection/CawsLoginOverlay.kt @@ -74,7 +74,7 @@ open class CawsLoginOverlay( ApplicationManager.getApplication().messageBus.connect(disposable).subscribe( BearerTokenProviderListener.TOPIC, object : BearerTokenProviderListener { - override fun onChange(providerId: String, newScopes: List?) { + override fun onProviderChange(providerId: String, newScopes: List?) { drawContent() }