diff --git a/plugins/amazonq/codewhisperer/jetbrains-community/resources/META-INF/plugin-codewhisperer.xml b/plugins/amazonq/codewhisperer/jetbrains-community/resources/META-INF/plugin-codewhisperer.xml index 7b9bb7b9eca..07d8f29e713 100644 --- a/plugins/amazonq/codewhisperer/jetbrains-community/resources/META-INF/plugin-codewhisperer.xml +++ b/plugins/amazonq/codewhisperer/jetbrains-community/resources/META-INF/plugin-codewhisperer.xml @@ -33,8 +33,7 @@ serviceImplementation="software.aws.toolkits.jetbrains.services.codewhisperer.customization.DefaultCodeWhispererModelConfigurator"/> + serviceImplementation="software.aws.toolkits.jetbrains.services.codewhisperer.credentials.CodeWhispererClientAdaptorImpl"/> diff --git a/plugins/amazonq/codewhisperer/jetbrains-community/src/software/aws/toolkits/jetbrains/services/codewhisperer/credentials/CodeWhispererClientAdaptor.kt b/plugins/amazonq/codewhisperer/jetbrains-community/src/software/aws/toolkits/jetbrains/services/codewhisperer/credentials/CodeWhispererClientAdaptor.kt index 17cc419e1a6..36e60c6aa9f 100644 --- a/plugins/amazonq/codewhisperer/jetbrains-community/src/software/aws/toolkits/jetbrains/services/codewhisperer/credentials/CodeWhispererClientAdaptor.kt +++ b/plugins/amazonq/codewhisperer/jetbrains-community/src/software/aws/toolkits/jetbrains/services/codewhisperer/credentials/CodeWhispererClientAdaptor.kt @@ -3,8 +3,6 @@ package software.aws.toolkits.jetbrains.services.codewhisperer.credentials -import com.intellij.openapi.Disposable -import com.intellij.openapi.application.ApplicationManager import com.intellij.openapi.components.service import com.intellij.openapi.project.Project import com.intellij.util.text.nullize @@ -39,14 +37,9 @@ import software.amazon.awssdk.services.codewhispererruntime.model.TargetCode import software.amazon.awssdk.services.codewhispererruntime.model.UserIntent import software.aws.toolkits.core.utils.debug import software.aws.toolkits.core.utils.getLogger -import software.aws.toolkits.core.utils.warn -import software.aws.toolkits.jetbrains.core.AwsClientManager import software.aws.toolkits.jetbrains.core.awsClient -import software.aws.toolkits.jetbrains.core.credentials.AwsBearerTokenConnection -import software.aws.toolkits.jetbrains.core.credentials.ToolkitConnection import software.aws.toolkits.jetbrains.core.credentials.ToolkitConnectionManager -import software.aws.toolkits.jetbrains.core.credentials.ToolkitConnectionManagerListener -import software.aws.toolkits.jetbrains.core.credentials.pinning.CodeWhispererConnection +import software.aws.toolkits.jetbrains.core.credentials.pinning.QConnection import software.aws.toolkits.jetbrains.services.amazonq.codeWhispererUserContext import software.aws.toolkits.jetbrains.services.codewhisperer.customization.CodeWhispererCustomization import software.aws.toolkits.jetbrains.services.codewhisperer.language.CodeWhispererProgrammingLanguage @@ -62,8 +55,11 @@ import java.time.Instant import java.util.concurrent.TimeUnit // As the connection is project-level, we need to make this project-level too -@Deprecated("Methods can throw a NullPointerException if callee does not check if connection is valid") -interface CodeWhispererClientAdaptor : Disposable { +@Deprecated( + "It was needed as we were supporting two service models (sigv4 & bearer), " + + "it's no longer the case as we remove sigv4 support, should use AwsClientManager.getClient() directly" +) +interface CodeWhispererClientAdaptor { val project: Project fun generateCompletionsPaginator( @@ -261,32 +257,11 @@ interface CodeWhispererClientAdaptor : Disposable { } } -open class CodeWhispererClientAdaptorImpl(override val project: Project) : CodeWhispererClientAdaptor { - @Volatile - private var myBearerClient: CodeWhispererRuntimeClient? = null - - init { - initClientUpdateListener() - } - - private fun initClientUpdateListener() { - ApplicationManager.getApplication().messageBus.connect(this).subscribe( - ToolkitConnectionManagerListener.TOPIC, - object : ToolkitConnectionManagerListener { - override fun activeConnectionChanged(newConnection: ToolkitConnection?) { - if (newConnection is AwsBearerTokenConnection) { - myBearerClient = getBearerClient(newConnection.getConnectionSettings().providerId) - } - } - } - ) - } - - private fun bearerClient(): CodeWhispererRuntimeClient { - if (myBearerClient != null) return myBearerClient as CodeWhispererRuntimeClient - myBearerClient = getBearerClient() - return myBearerClient as CodeWhispererRuntimeClient - } +class CodeWhispererClientAdaptorImpl(override val project: Project) : CodeWhispererClientAdaptor { + fun bearerClient(): CodeWhispererRuntimeClient = + ToolkitConnectionManager.getInstance(project).activeConnectionForFeature(QConnection.getInstance())?.getConnectionSettings() + ?.awsClient() + ?: throw Exception("attempt to get bearer client while there is no valid credential") override fun generateCompletionsPaginator(firstRequest: GenerateCompletionsRequest) = sequence { var nextToken: String? = firstRequest.nextToken() @@ -809,41 +784,11 @@ open class CodeWhispererClientAdaptorImpl(override val project: Project) : CodeW requestBuilder.userContext(codeWhispererUserContext()) } - override fun dispose() { - myBearerClient?.close() - } - - /** - * Every different SSO/AWS Builder ID connection requires a new client which has its corresponding bearer token provider, - * thus we have to create them dynamically. - * Invalidate and recycle the old client first, and create a new client with the new connection. - * This makes sure when we invoke CW, we always use the up-to-date connection. - * In case this fails to close the client, myBearerClient is already set to null thus next time when we invoke CW, - * it will go through this again which should get the current up-to-date connection. This stale client would be - * unused and stay in memory for a while until eventually closed by ToolkitClientManager. - */ - open fun getBearerClient(oldProviderIdToRemove: String = ""): CodeWhispererRuntimeClient? { - myBearerClient = null - - val connection = ToolkitConnectionManager.getInstance(project).activeConnectionForFeature(CodeWhispererConnection.getInstance()) - connection as? AwsBearerTokenConnection ?: run { - LOG.warn { "$connection is not a bearer token connection" } - return null - } - - return AwsClientManager.getInstance().getClient(connection.getConnectionSettings()) - } - companion object { private val LOG = getLogger() } } -class MockCodeWhispererClientAdaptor(override val project: Project) : CodeWhispererClientAdaptorImpl(project) { - override fun getBearerClient(oldProviderIdToRemove: String): CodeWhispererRuntimeClient = project.awsClient() - override fun dispose() {} -} - private fun CodewhispererSuggestionState.toCodeWhispererSdkType() = when { this == CodewhispererSuggestionState.Accept -> SuggestionState.ACCEPT this == CodewhispererSuggestionState.Reject -> SuggestionState.REJECT diff --git a/plugins/amazonq/codewhisperer/jetbrains-community/tst/software/aws/toolkits/jetbrains/services/codewhisperer/CodeWhispererClientAdaptorTest.kt b/plugins/amazonq/codewhisperer/jetbrains-community/tst/software/aws/toolkits/jetbrains/services/codewhisperer/CodeWhispererClientAdaptorTest.kt index 268caa2ab63..9a3a455d6e1 100644 --- a/plugins/amazonq/codewhisperer/jetbrains-community/tst/software/aws/toolkits/jetbrains/services/codewhisperer/CodeWhispererClientAdaptorTest.kt +++ b/plugins/amazonq/codewhisperer/jetbrains-community/tst/software/aws/toolkits/jetbrains/services/codewhisperer/CodeWhispererClientAdaptorTest.kt @@ -4,7 +4,6 @@ package software.aws.toolkits.jetbrains.services.codewhisperer import com.intellij.openapi.application.ApplicationManager -import com.intellij.openapi.util.Disposer import com.intellij.openapi.util.SystemInfo import com.intellij.testFramework.DisposableRule import com.intellij.testFramework.RuleChain @@ -15,6 +14,7 @@ import org.junit.After import org.junit.Before import org.junit.Rule import org.junit.Test +import org.junit.jupiter.api.assertThrows import org.mockito.kotlin.any import org.mockito.kotlin.argThat import org.mockito.kotlin.argumentCaptor @@ -24,7 +24,6 @@ import org.mockito.kotlin.mock import org.mockito.kotlin.stub import org.mockito.kotlin.times import org.mockito.kotlin.verify -import org.mockito.kotlin.whenever import software.amazon.awssdk.services.codewhispererruntime.CodeWhispererRuntimeClient import software.amazon.awssdk.services.codewhispererruntime.model.ArtifactType import software.amazon.awssdk.services.codewhispererruntime.model.CodeAnalysisFindingsSchema @@ -54,7 +53,6 @@ import software.amazon.awssdk.services.codewhispererruntime.model.SuggestionStat import software.amazon.awssdk.services.codewhispererruntime.paginators.GenerateCompletionsIterable import software.amazon.awssdk.services.codewhispererruntime.paginators.ListAvailableCustomizationsIterable import software.amazon.awssdk.services.ssooidc.SsoOidcClient -import software.aws.toolkits.core.TokenConnectionSettings import software.aws.toolkits.core.utils.test.aString import software.aws.toolkits.jetbrains.core.MockClientManagerRule import software.aws.toolkits.jetbrains.core.credentials.AwsBearerTokenConnection @@ -62,13 +60,15 @@ import software.aws.toolkits.jetbrains.core.credentials.ManagedSsoProfile import software.aws.toolkits.jetbrains.core.credentials.MockCredentialManagerRule import software.aws.toolkits.jetbrains.core.credentials.MockToolkitAuthManagerRule import software.aws.toolkits.jetbrains.core.credentials.ToolkitConnectionManager +import software.aws.toolkits.jetbrains.core.credentials.logoutFromSsoConnection +import software.aws.toolkits.jetbrains.core.credentials.pinning.QConnection +import software.aws.toolkits.jetbrains.core.credentials.sono.Q_SCOPES import software.aws.toolkits.jetbrains.core.credentials.sono.SONO_REGION import software.aws.toolkits.jetbrains.services.amazonq.FEATURE_EVALUATION_PRODUCT_NAME import software.aws.toolkits.jetbrains.services.codewhisperer.CodeWhispererTestUtil.metadata import software.aws.toolkits.jetbrains.services.codewhisperer.CodeWhispererTestUtil.pythonRequest import software.aws.toolkits.jetbrains.services.codewhisperer.CodeWhispererTestUtil.pythonResponseWithToken import software.aws.toolkits.jetbrains.services.codewhisperer.CodeWhispererTestUtil.sdkHttpResponse -import software.aws.toolkits.jetbrains.services.codewhisperer.credentials.CodeWhispererClientAdaptor import software.aws.toolkits.jetbrains.services.codewhisperer.credentials.CodeWhispererClientAdaptorImpl import software.aws.toolkits.jetbrains.services.codewhisperer.customization.CodeWhispererCustomization import software.aws.toolkits.jetbrains.services.codewhisperer.customization.CodeWhispererModelConfigurator @@ -93,13 +93,12 @@ class CodeWhispererClientAdaptorTest { @Rule @JvmField - val ruleChain = RuleChain(projectRule, mockCredentialRule, mockClientManagerRule, disposableRule) + val ruleChain = RuleChain(projectRule, mockCredentialRule, mockClientManagerRule, authManagerRule, disposableRule) private lateinit var bearerClient: CodeWhispererRuntimeClient private lateinit var ssoClient: SsoOidcClient - private lateinit var sut: CodeWhispererClientAdaptor - private lateinit var connectionManager: ToolkitConnectionManager + private lateinit var sut: CodeWhispererClientAdaptorImpl private var isTelemetryEnabledDefault: Boolean = false @Before @@ -117,15 +116,8 @@ class CodeWhispererClientAdaptorTest { on { listFeatureEvaluations(any()) } doReturn listFeatureEvaluationsResponse } - val mockConnection = mock() - whenever(mockConnection.getConnectionSettings()) doReturn mock() - - connectionManager = mock { - on { - activeConnectionForFeature(any()) - } doReturn authManagerRule.createConnection(ManagedSsoProfile("us-east-1", aString(), listOf("scopes"))) as AwsBearerTokenConnection - } - projectRule.project.replaceService(ToolkitConnectionManager::class.java, connectionManager, disposableRule.disposable) + val conn = authManagerRule.createConnection(ManagedSsoProfile("us-east-1", "url", Q_SCOPES)) + ToolkitConnectionManager.getInstance(projectRule.project).switchConnection(conn) isTelemetryEnabledDefault = AwsSettings.getInstance().isTelemetryEnabled } @@ -135,16 +127,37 @@ class CodeWhispererClientAdaptorTest { AwsSettings.getInstance().isTelemetryEnabled = isTelemetryEnabledDefault } - @After - fun cleanup() { - Disposer.dispose(sut) - } - @Test fun `Sono region is us-east-1`() { assertThat("us-east-1").isEqualTo(SONO_REGION) } + @Test + fun `should throw if there is no valid credential, otherwise return codewhispererRuntimeClient`() { + val connectionManager = ToolkitConnectionManager.getInstance(projectRule.project) + + assertThat(connectionManager.activeConnectionForFeature(QConnection.getInstance())) + .isNotNull + assertThat(sut.bearerClient()) + .isNotNull + .isInstanceOf(CodeWhispererRuntimeClient::class.java) + + logoutFromSsoConnection(projectRule.project, connectionManager.activeConnectionForFeature(QConnection.getInstance()) as AwsBearerTokenConnection) + assertThat(connectionManager.activeConnectionForFeature(QConnection.getInstance())).isNull() + assertThrows("attempt to get bearer client while there is no valid credential") { + sut.bearerClient() + } + + val anotherQConnection = authManagerRule.createConnection(ManagedSsoProfile("us-east-1", aString(), Q_SCOPES)) + connectionManager.switchConnection(anotherQConnection) + assertThat(connectionManager.activeConnectionForFeature(QConnection.getInstance())) + .isNotNull + .isEqualTo(anotherQConnection) + assertThat(sut.bearerClient()) + .isNotNull + .isInstanceOf(CodeWhispererRuntimeClient::class.java) + } + @Test fun `listCustomizations`() { val sdkIterable = ListAvailableCustomizationsIterable(bearerClient, ListAvailableCustomizationsRequest.builder().build()) diff --git a/plugins/amazonq/codewhisperer/jetbrains-community/tst/software/aws/toolkits/jetbrains/services/codewhisperer/CodeWhispererSettingsTest.kt b/plugins/amazonq/codewhisperer/jetbrains-community/tst/software/aws/toolkits/jetbrains/services/codewhisperer/CodeWhispererSettingsTest.kt index 63bc70ae2b8..94474eefa6a 100644 --- a/plugins/amazonq/codewhisperer/jetbrains-community/tst/software/aws/toolkits/jetbrains/services/codewhisperer/CodeWhispererSettingsTest.kt +++ b/plugins/amazonq/codewhisperer/jetbrains-community/tst/software/aws/toolkits/jetbrains/services/codewhisperer/CodeWhispererSettingsTest.kt @@ -105,7 +105,7 @@ class CodeWhispererSettingsTest : CodeWhispererTestBase() { stateManager.loadState(CodeWhispererExploreActionState()) CodeWhispererSettings.getInstance().loadState(CodeWhispererConfiguration()) - val problemsWindow = ProblemsView.getToolWindow(projectRule.project) ?: fail("Problems window not found") + ProblemsView.getToolWindow(projectRule.project) ?: fail("Problems window not found") val codeReferenceWindow = ToolWindowManager.getInstance(projectRule.project).getToolWindow( CodeWhispererCodeReferenceToolWindowFactory.id ) ?: fail("Code Reference Log window not found") @@ -114,7 +114,6 @@ class CodeWhispererSettingsTest : CodeWhispererTestBase() { } ?: fail("CodeWhisperer status bar widget not found") runInEdtAndWait { - assertThat(problemsWindow.contentManager.contentCount).isEqualTo(0) assertThat(codeReferenceWindow.isAvailable).isFalse assertThat(statusBarWidgetFactory.isAvailable(projectRule.project)).isTrue assertThat(settingsManager.isIncludeCodeWithReference()).isFalse diff --git a/plugins/amazonq/codewhisperer/jetbrains-community/tst/software/aws/toolkits/jetbrains/services/codewhisperer/CodeWhispererTestBase.kt b/plugins/amazonq/codewhisperer/jetbrains-community/tst/software/aws/toolkits/jetbrains/services/codewhisperer/CodeWhispererTestBase.kt index aa0dd9f2270..a3e97f60c38 100644 --- a/plugins/amazonq/codewhisperer/jetbrains-community/tst/software/aws/toolkits/jetbrains/services/codewhisperer/CodeWhispererTestBase.kt +++ b/plugins/amazonq/codewhisperer/jetbrains-community/tst/software/aws/toolkits/jetbrains/services/codewhisperer/CodeWhispererTestBase.kt @@ -30,8 +30,13 @@ import org.mockito.kotlin.verify import software.amazon.awssdk.services.codewhispererruntime.CodeWhispererRuntimeClient import software.amazon.awssdk.services.codewhispererruntime.model.GenerateCompletionsRequest import software.amazon.awssdk.services.codewhispererruntime.paginators.GenerateCompletionsIterable +import software.amazon.awssdk.services.ssooidc.SsoOidcClient import software.aws.toolkits.jetbrains.core.MockClientManagerRule +import software.aws.toolkits.jetbrains.core.credentials.ManagedSsoProfile import software.aws.toolkits.jetbrains.core.credentials.MockCredentialManagerRule +import software.aws.toolkits.jetbrains.core.credentials.MockToolkitAuthManagerRule +import software.aws.toolkits.jetbrains.core.credentials.ToolkitConnectionManager +import software.aws.toolkits.jetbrains.core.credentials.sono.Q_SCOPES import software.aws.toolkits.jetbrains.services.codewhisperer.CodeWhispererTestUtil.codeWhispererRecommendationActionId import software.aws.toolkits.jetbrains.services.codewhisperer.CodeWhispererTestUtil.pythonFileName import software.aws.toolkits.jetbrains.services.codewhisperer.CodeWhispererTestUtil.pythonResponse @@ -65,10 +70,11 @@ open class CodeWhispererTestBase { val mockClientManagerRule = MockClientManagerRule() val mockCredentialRule = MockCredentialManagerRule() val disposableRule = DisposableRule() + val authManagerRule = MockToolkitAuthManagerRule() @Rule @JvmField - val ruleChain = RuleChain(projectRule, mockCredentialRule, mockClientManagerRule, disposableRule) + val ruleChain = RuleChain(projectRule, mockCredentialRule, mockClientManagerRule, authManagerRule, disposableRule) protected lateinit var mockClient: CodeWhispererRuntimeClient @@ -86,6 +92,7 @@ open class CodeWhispererTestBase { @Before open fun setUp() { mockClient = mockClientManagerRule.create() + mockClientManagerRule.create() val requestCaptor = argumentCaptor() mockClient.stub { on { @@ -159,6 +166,9 @@ open class CodeWhispererTestBase { projectRule.project.replaceService(CodeWhispererClientAdaptor::class.java, clientAdaptorSpy, disposableRule.disposable) ApplicationManager.getApplication().replaceService(CodeWhispererExplorerActionManager::class.java, stateManager, disposableRule.disposable) stateManager.setAutoEnabled(false) + + val conn = authManagerRule.createConnection(ManagedSsoProfile("us-east-1", "url", Q_SCOPES)) + ToolkitConnectionManager.getInstance(projectRule.project).switchConnection(conn) } @After