diff --git a/plugins/amazonq/codewhisperer/jetbrains-community/src/software/aws/toolkits/jetbrains/services/codewhisperer/credentials/CodeWhispererClientAdaptor.kt b/plugins/amazonq/codewhisperer/jetbrains-community/src/software/aws/toolkits/jetbrains/services/codewhisperer/credentials/CodeWhispererClientAdaptor.kt index c72f480355a..e2e55b86430 100644 --- a/plugins/amazonq/codewhisperer/jetbrains-community/src/software/aws/toolkits/jetbrains/services/codewhisperer/credentials/CodeWhispererClientAdaptor.kt +++ b/plugins/amazonq/codewhisperer/jetbrains-community/src/software/aws/toolkits/jetbrains/services/codewhisperer/credentials/CodeWhispererClientAdaptor.kt @@ -4,6 +4,7 @@ package software.aws.toolkits.jetbrains.services.codewhisperer.credentials import com.intellij.openapi.Disposable +import com.intellij.openapi.application.ApplicationManager import com.intellij.openapi.components.service import com.intellij.openapi.project.Project import com.intellij.util.text.nullize @@ -40,10 +41,14 @@ import software.amazon.awssdk.services.codewhispererruntime.model.TargetCode import software.amazon.awssdk.services.codewhispererruntime.model.UserIntent import software.aws.toolkits.core.utils.debug import software.aws.toolkits.core.utils.getLogger +import software.aws.toolkits.core.utils.warn import software.aws.toolkits.jetbrains.core.AwsClientManager import software.aws.toolkits.jetbrains.core.awsClient +import software.aws.toolkits.jetbrains.core.credentials.AwsBearerTokenConnection +import software.aws.toolkits.jetbrains.core.credentials.ToolkitConnection import software.aws.toolkits.jetbrains.core.credentials.ToolkitConnectionManager -import software.aws.toolkits.jetbrains.core.credentials.pinning.QConnection +import software.aws.toolkits.jetbrains.core.credentials.ToolkitConnectionManagerListener +import software.aws.toolkits.jetbrains.core.credentials.pinning.CodeWhispererConnection import software.aws.toolkits.jetbrains.services.amazonq.codeWhispererUserContext import software.aws.toolkits.jetbrains.services.codewhisperer.customization.CodeWhispererCustomization import software.aws.toolkits.jetbrains.services.codewhisperer.explorer.CodeWhispererExplorerActionManager @@ -62,6 +67,7 @@ import java.util.concurrent.TimeUnit import kotlin.reflect.KProperty0 import kotlin.reflect.jvm.isAccessible +// TODO: move this file to package "/client" // As the connection is project-level, we need to make this project-level too @Deprecated("Methods can throw a NullPointerException if callee does not check if connection is valid") interface CodeWhispererClientAdaptor : Disposable { @@ -277,16 +283,37 @@ interface CodeWhispererClientAdaptor : Disposable { open class CodeWhispererClientAdaptorImpl(override val project: Project) : CodeWhispererClientAdaptor { private val mySigv4Client by lazy { createUnmanagedSigv4Client() } + @Volatile + private var myBearerClient: CodeWhispererRuntimeClient? = null + private val KProperty0<*>.isLazyInitialized: Boolean get() { isAccessible = true return (getDelegate() as Lazy<*>).isInitialized() } - fun bearerClient(): CodeWhispererRuntimeClient = - ToolkitConnectionManager.getInstance(project).activeConnectionForFeature(QConnection.getInstance())?.getConnectionSettings() - ?.awsClient() - ?: throw Exception("attempt to get bearer client while there is no valid credential") + init { + initClientUpdateListener() + } + + private fun initClientUpdateListener() { + ApplicationManager.getApplication().messageBus.connect(this).subscribe( + ToolkitConnectionManagerListener.TOPIC, + object : ToolkitConnectionManagerListener { + override fun activeConnectionChanged(newConnection: ToolkitConnection?) { + if (newConnection is AwsBearerTokenConnection) { + myBearerClient = getBearerClient(newConnection.getConnectionSettings().providerId) + } + } + } + ) + } + + private fun bearerClient(): CodeWhispererRuntimeClient { + if (myBearerClient != null) return myBearerClient as CodeWhispererRuntimeClient + myBearerClient = getBearerClient() + return myBearerClient as CodeWhispererRuntimeClient + } override fun generateCompletionsPaginator(firstRequest: GenerateCompletionsRequest) = sequence { var nextToken: String? = firstRequest.nextToken() @@ -827,6 +854,28 @@ open class CodeWhispererClientAdaptorImpl(override val project: Project) : CodeW if (this::mySigv4Client.isLazyInitialized) { mySigv4Client.close() } + myBearerClient?.close() + } + + /** + * Every different SSO/AWS Builder ID connection requires a new client which has its corresponding bearer token provider, + * thus we have to create them dynamically. + * Invalidate and recycle the old client first, and create a new client with the new connection. + * This makes sure when we invoke CW, we always use the up-to-date connection. + * In case this fails to close the client, myBearerClient is already set to null thus next time when we invoke CW, + * it will go through this again which should get the current up-to-date connection. This stale client would be + * unused and stay in memory for a while until eventually closed by ToolkitClientManager. + */ + open fun getBearerClient(oldProviderIdToRemove: String = ""): CodeWhispererRuntimeClient? { + myBearerClient = null + + val connection = ToolkitConnectionManager.getInstance(project).activeConnectionForFeature(CodeWhispererConnection.getInstance()) + connection as? AwsBearerTokenConnection ?: run { + LOG.warn { "$connection is not a bearer token connection" } + return null + } + + return AwsClientManager.getInstance().getClient(connection.getConnectionSettings()) } companion object { @@ -840,6 +889,7 @@ open class CodeWhispererClientAdaptorImpl(override val project: Project) : CodeW } class MockCodeWhispererClientAdaptor(override val project: Project) : CodeWhispererClientAdaptorImpl(project) { + override fun getBearerClient(oldProviderIdToRemove: String): CodeWhispererRuntimeClient = project.awsClient() override fun dispose() {} } diff --git a/plugins/amazonq/codewhisperer/jetbrains-community/tst/software/aws/toolkits/jetbrains/services/codewhisperer/CodeWhispererClientAdaptorTest.kt b/plugins/amazonq/codewhisperer/jetbrains-community/tst/software/aws/toolkits/jetbrains/services/codewhisperer/CodeWhispererClientAdaptorTest.kt index a0bec60e74e..86ef92f8ccc 100644 --- a/plugins/amazonq/codewhisperer/jetbrains-community/tst/software/aws/toolkits/jetbrains/services/codewhisperer/CodeWhispererClientAdaptorTest.kt +++ b/plugins/amazonq/codewhisperer/jetbrains-community/tst/software/aws/toolkits/jetbrains/services/codewhisperer/CodeWhispererClientAdaptorTest.kt @@ -15,7 +15,6 @@ import org.junit.After import org.junit.Before import org.junit.Rule import org.junit.Test -import org.junit.jupiter.api.assertThrows import org.mockito.kotlin.any import org.mockito.kotlin.argThat import org.mockito.kotlin.argumentCaptor @@ -69,20 +68,17 @@ import software.aws.toolkits.core.TokenConnectionSettings import software.aws.toolkits.core.utils.test.aString import software.aws.toolkits.jetbrains.core.MockClientManagerRule import software.aws.toolkits.jetbrains.core.credentials.AwsBearerTokenConnection -import software.aws.toolkits.jetbrains.core.credentials.DefaultToolkitConnectionManager import software.aws.toolkits.jetbrains.core.credentials.ManagedSsoProfile import software.aws.toolkits.jetbrains.core.credentials.MockCredentialManagerRule import software.aws.toolkits.jetbrains.core.credentials.MockToolkitAuthManagerRule import software.aws.toolkits.jetbrains.core.credentials.ToolkitConnectionManager -import software.aws.toolkits.jetbrains.core.credentials.logoutFromSsoConnection -import software.aws.toolkits.jetbrains.core.credentials.pinning.QConnection -import software.aws.toolkits.jetbrains.core.credentials.sono.Q_SCOPES import software.aws.toolkits.jetbrains.core.credentials.sono.SONO_REGION import software.aws.toolkits.jetbrains.services.amazonq.FEATURE_EVALUATION_PRODUCT_NAME import software.aws.toolkits.jetbrains.services.codewhisperer.CodeWhispererTestUtil.metadata import software.aws.toolkits.jetbrains.services.codewhisperer.CodeWhispererTestUtil.pythonRequest import software.aws.toolkits.jetbrains.services.codewhisperer.CodeWhispererTestUtil.pythonResponseWithToken import software.aws.toolkits.jetbrains.services.codewhisperer.CodeWhispererTestUtil.sdkHttpResponse +import software.aws.toolkits.jetbrains.services.codewhisperer.credentials.CodeWhispererClientAdaptor import software.aws.toolkits.jetbrains.services.codewhisperer.credentials.CodeWhispererClientAdaptorImpl import software.aws.toolkits.jetbrains.services.codewhisperer.customization.CodeWhispererCustomization import software.aws.toolkits.jetbrains.services.codewhisperer.customization.CodeWhispererModelConfigurator @@ -113,7 +109,7 @@ class CodeWhispererClientAdaptorTest { private lateinit var bearerClient: CodeWhispererRuntimeClient private lateinit var ssoClient: SsoOidcClient - private lateinit var sut: CodeWhispererClientAdaptorImpl + private lateinit var sut: CodeWhispererClientAdaptor private lateinit var connectionManager: ToolkitConnectionManager private var isTelemetryEnabledDefault: Boolean = false @@ -167,41 +163,6 @@ class CodeWhispererClientAdaptorTest { assertThat("us-east-1").isEqualTo(SONO_REGION) } - @Test - fun `should throw if there is no valid credential, otherwise return codewhispererRuntimeClient`() { - val connectionManager = DefaultToolkitConnectionManager() - projectRule.project.replaceService(ToolkitConnectionManager::class.java, DefaultToolkitConnectionManager(), disposableRule.disposable) - - assertThat(ToolkitConnectionManager.getInstance(projectRule.project).activeConnectionForFeature(QConnection.getInstance())).isNull() - assertThrows("attempt to get bearer client while there is no valid credential") { - sut.bearerClient() - } - - val qConnection = authManagerRule.createConnection(ManagedSsoProfile("us-east-1", aString(), Q_SCOPES)) - connectionManager.switchConnection(qConnection) - assertThat(connectionManager.activeConnectionForFeature(QConnection.getInstance())) - .isNotNull - .isEqualTo(qConnection) - assertThat(sut.bearerClient()) - .isNotNull - .isInstanceOf(CodeWhispererRuntimeClient::class.java) - - logoutFromSsoConnection(projectRule.project, qConnection as AwsBearerTokenConnection) - assertThat(connectionManager.activeConnectionForFeature(QConnection.getInstance())).isNull() - assertThrows("attempt to get bearer client while there is no valid credential") { - sut.bearerClient() - } - - val anotherQConnection = authManagerRule.createConnection(ManagedSsoProfile("us-east-1", aString(), Q_SCOPES)) - connectionManager.switchConnection(anotherQConnection) - assertThat(connectionManager.activeConnectionForFeature(QConnection.getInstance())) - .isNotNull - .isEqualTo(anotherQConnection) - assertThat(sut.bearerClient()) - .isNotNull - .isInstanceOf(CodeWhispererRuntimeClient::class.java) - } - @Test fun `listCustomizations`() { val sdkIterable = ListAvailableCustomizationsIterable(bearerClient, ListAvailableCustomizationsRequest.builder().build()) diff --git a/plugins/amazonq/codewhisperer/jetbrains-community/tst/software/aws/toolkits/jetbrains/services/codewhisperer/CodeWhispererTestBase.kt b/plugins/amazonq/codewhisperer/jetbrains-community/tst/software/aws/toolkits/jetbrains/services/codewhisperer/CodeWhispererTestBase.kt index a3e97f60c38..aa0dd9f2270 100644 --- a/plugins/amazonq/codewhisperer/jetbrains-community/tst/software/aws/toolkits/jetbrains/services/codewhisperer/CodeWhispererTestBase.kt +++ b/plugins/amazonq/codewhisperer/jetbrains-community/tst/software/aws/toolkits/jetbrains/services/codewhisperer/CodeWhispererTestBase.kt @@ -30,13 +30,8 @@ import org.mockito.kotlin.verify import software.amazon.awssdk.services.codewhispererruntime.CodeWhispererRuntimeClient import software.amazon.awssdk.services.codewhispererruntime.model.GenerateCompletionsRequest import software.amazon.awssdk.services.codewhispererruntime.paginators.GenerateCompletionsIterable -import software.amazon.awssdk.services.ssooidc.SsoOidcClient import software.aws.toolkits.jetbrains.core.MockClientManagerRule -import software.aws.toolkits.jetbrains.core.credentials.ManagedSsoProfile import software.aws.toolkits.jetbrains.core.credentials.MockCredentialManagerRule -import software.aws.toolkits.jetbrains.core.credentials.MockToolkitAuthManagerRule -import software.aws.toolkits.jetbrains.core.credentials.ToolkitConnectionManager -import software.aws.toolkits.jetbrains.core.credentials.sono.Q_SCOPES import software.aws.toolkits.jetbrains.services.codewhisperer.CodeWhispererTestUtil.codeWhispererRecommendationActionId import software.aws.toolkits.jetbrains.services.codewhisperer.CodeWhispererTestUtil.pythonFileName import software.aws.toolkits.jetbrains.services.codewhisperer.CodeWhispererTestUtil.pythonResponse @@ -70,11 +65,10 @@ open class CodeWhispererTestBase { val mockClientManagerRule = MockClientManagerRule() val mockCredentialRule = MockCredentialManagerRule() val disposableRule = DisposableRule() - val authManagerRule = MockToolkitAuthManagerRule() @Rule @JvmField - val ruleChain = RuleChain(projectRule, mockCredentialRule, mockClientManagerRule, authManagerRule, disposableRule) + val ruleChain = RuleChain(projectRule, mockCredentialRule, mockClientManagerRule, disposableRule) protected lateinit var mockClient: CodeWhispererRuntimeClient @@ -92,7 +86,6 @@ open class CodeWhispererTestBase { @Before open fun setUp() { mockClient = mockClientManagerRule.create() - mockClientManagerRule.create() val requestCaptor = argumentCaptor() mockClient.stub { on { @@ -166,9 +159,6 @@ open class CodeWhispererTestBase { projectRule.project.replaceService(CodeWhispererClientAdaptor::class.java, clientAdaptorSpy, disposableRule.disposable) ApplicationManager.getApplication().replaceService(CodeWhispererExplorerActionManager::class.java, stateManager, disposableRule.disposable) stateManager.setAutoEnabled(false) - - val conn = authManagerRule.createConnection(ManagedSsoProfile("us-east-1", "url", Q_SCOPES)) - ToolkitConnectionManager.getInstance(projectRule.project).switchConnection(conn) } @After