diff --git a/plugins/amazonq/chat/jetbrains-community/src/software/aws/toolkits/jetbrains/services/amazonq/toolwindow/AmazonQToolWindowFactory.kt b/plugins/amazonq/chat/jetbrains-community/src/software/aws/toolkits/jetbrains/services/amazonq/toolwindow/AmazonQToolWindowFactory.kt index 7137d4f966c..6f5f8331307 100644 --- a/plugins/amazonq/chat/jetbrains-community/src/software/aws/toolkits/jetbrains/services/amazonq/toolwindow/AmazonQToolWindowFactory.kt +++ b/plugins/amazonq/chat/jetbrains-community/src/software/aws/toolkits/jetbrains/services/amazonq/toolwindow/AmazonQToolWindowFactory.kt @@ -85,7 +85,8 @@ class AmazonQToolWindowFactory : ToolWindowFactory, DumbAware { object : BearerTokenProviderListener { override fun onChange(providerId: String, newScopes: List?) { if (ToolkitConnectionManager.getInstance(project).connectionStateForFeature(QConnection.getInstance()) == BearerTokenAuthState.AUTHORIZED) { - prepareChatContent(project, qPanel) + AmazonQToolWindow.getInstance(project).disposeAndRecreate() + qPanel.setContent(AmazonQToolWindow.getInstance(project).component) } } } diff --git a/plugins/amazonq/codewhisperer/jetbrains-community/tst/software/aws/toolkits/jetbrains/services/codewhisperer/QRegionProfileManagerTest.kt b/plugins/amazonq/codewhisperer/jetbrains-community/tst/software/aws/toolkits/jetbrains/services/codewhisperer/QRegionProfileManagerTest.kt index 640d77e7cd1..628ab8e36c9 100644 --- a/plugins/amazonq/codewhisperer/jetbrains-community/tst/software/aws/toolkits/jetbrains/services/codewhisperer/QRegionProfileManagerTest.kt +++ b/plugins/amazonq/codewhisperer/jetbrains-community/tst/software/aws/toolkits/jetbrains/services/codewhisperer/QRegionProfileManagerTest.kt @@ -6,6 +6,7 @@ package software.aws.toolkits.jetbrains.services.codewhisperer import com.intellij.openapi.project.Project import com.intellij.testFramework.DisposableRule import com.intellij.testFramework.ProjectRule +import com.intellij.testFramework.replaceService import com.intellij.util.xmlb.XmlSerializer import org.assertj.core.api.Assertions.assertThat import org.jdom.output.XMLOutputter @@ -15,7 +16,9 @@ import org.junit.Test import org.mockito.kotlin.any import org.mockito.kotlin.doReturn import org.mockito.kotlin.mock +import org.mockito.kotlin.spy import org.mockito.kotlin.stub +import org.mockito.kotlin.whenever import software.amazon.awssdk.core.pagination.sync.SdkIterable import software.amazon.awssdk.regions.Region import software.amazon.awssdk.services.codewhispererruntime.CodeWhispererRuntimeClient @@ -34,6 +37,7 @@ import software.aws.toolkits.jetbrains.core.credentials.ToolkitConnectionManager import software.aws.toolkits.jetbrains.core.credentials.logoutFromSsoConnection import software.aws.toolkits.jetbrains.core.credentials.pinning.QConnection import software.aws.toolkits.jetbrains.core.credentials.sono.Q_SCOPES +import software.aws.toolkits.jetbrains.core.credentials.sso.bearer.BearerTokenAuthState import software.aws.toolkits.jetbrains.core.region.MockRegionProviderRule import software.aws.toolkits.jetbrains.services.amazonq.profile.QEndpoints import software.aws.toolkits.jetbrains.services.amazonq.profile.QProfileResources @@ -85,6 +89,10 @@ class QRegionProfileManagerTest { sut = QRegionProfileManager() val conn = authRule.createConnection(ManagedSsoProfile(ssoRegion = "us-east-1", startUrl = "", scopes = Q_SCOPES)) ToolkitConnectionManager.getInstance(project).switchConnection(conn) + val realManager = ToolkitConnectionManager.getInstance(project) + val managerSpy = spy(realManager) + doReturn(BearerTokenAuthState.AUTHORIZED).whenever(managerSpy).connectionStateForFeature(QConnection.getInstance()) + project.replaceService(ToolkitConnectionManager::class.java, managerSpy, disposableRule.disposable) } @Test @@ -106,7 +114,7 @@ class QRegionProfileManagerTest { logoutFromSsoConnection(project, it) } } - + ToolkitConnectionManager.getInstance(project).switchConnection(null) assertThat(ToolkitConnectionManager.getInstance(project).activeConnectionForFeature(QConnection.getInstance())).isNull() assertThat(sut.activeProfile(project)).isNull() } @@ -384,4 +392,28 @@ class QRegionProfileManagerTest { val profileList = actualState.connectionIdToProfileList["conn-123"] assertThat(profileList).isEqualTo(2) } + + @Test + fun `getIdcConnectionOrNull handles NOT_AUTH and AUTHORIZED correctly`() { + val managerSpy = ToolkitConnectionManager.getInstance(project) + doReturn(BearerTokenAuthState.NOT_AUTHENTICATED).whenever(managerSpy) + .connectionStateForFeature(QConnection.getInstance()) + + // NOT AUTHORIZED + val notAuthConn = sut.getIdcConnectionOrNull(project) + assertThat(notAuthConn).isNull() + + doReturn(BearerTokenAuthState.AUTHORIZED) + .whenever(managerSpy).connectionStateForFeature(QConnection.getInstance()) + + // AUTHORIZED + val normalConn = authRule.createConnection( + ManagedSsoProfile(ssoRegion = "us-east-1", startUrl = "", scopes = Q_SCOPES) + ) + managerSpy.switchConnection(normalConn) + + val normalConnectionResult = sut.getIdcConnectionOrNull(project) + assertThat(normalConnectionResult).isNotNull() + assertThat(normalConnectionResult).isEqualTo(normalConn) + } } diff --git a/plugins/amazonq/shared/jetbrains-community/src/software/aws/toolkits/jetbrains/services/amazonq/profile/QRegionProfileManager.kt b/plugins/amazonq/shared/jetbrains-community/src/software/aws/toolkits/jetbrains/services/amazonq/profile/QRegionProfileManager.kt index c9014d16d20..60f98c26d0f 100644 --- a/plugins/amazonq/shared/jetbrains-community/src/software/aws/toolkits/jetbrains/services/amazonq/profile/QRegionProfileManager.kt +++ b/plugins/amazonq/shared/jetbrains-community/src/software/aws/toolkits/jetbrains/services/amazonq/profile/QRegionProfileManager.kt @@ -27,6 +27,7 @@ import software.aws.toolkits.jetbrains.core.credentials.AwsBearerTokenConnection import software.aws.toolkits.jetbrains.core.credentials.ToolkitConnectionManager import software.aws.toolkits.jetbrains.core.credentials.pinning.QConnection import software.aws.toolkits.jetbrains.core.credentials.sono.isSono +import software.aws.toolkits.jetbrains.core.credentials.sso.bearer.BearerTokenAuthState import software.aws.toolkits.jetbrains.core.credentials.sso.bearer.BearerTokenProviderListener import software.aws.toolkits.jetbrains.core.region.AwsRegionProvider import software.aws.toolkits.jetbrains.utils.notifyInfo @@ -190,12 +191,16 @@ class QRegionProfileManager : PersistentStateComponent, Disposabl return client } - private fun getIdcConnectionOrNull(project: Project): AwsBearerTokenConnection? { - val connection = ToolkitConnectionManager.getInstance(project).activeConnectionForFeature(QConnection.getInstance()) - if (connection is AwsBearerTokenConnection && !connection.isSono()) { - return connection + fun getIdcConnectionOrNull(project: Project): AwsBearerTokenConnection? { + val manager = ToolkitConnectionManager.getInstance(project) + val connection = manager.activeConnectionForFeature(QConnection.getInstance()) as? AwsBearerTokenConnection + val state = manager.connectionStateForFeature(QConnection.getInstance()) + + return if (connection != null && !connection.isSono() && state == BearerTokenAuthState.AUTHORIZED) { + connection + } else { + null } - return null } companion object {