From 10fd021ea7210dbe952efdcc600236d898787210 Mon Sep 17 00:00:00 2001 From: Will Lo Date: Thu, 6 Feb 2025 09:10:20 -0800 Subject: [PATCH 1/6] cwsprClientAdaptor should use awsClientManager client directly --- .../credentials/CodeWhispererClientAdaptor.kt | 53 +------------------ 1 file changed, 1 insertion(+), 52 deletions(-) diff --git a/plugins/amazonq/codewhisperer/jetbrains-community/src/software/aws/toolkits/jetbrains/services/codewhisperer/credentials/CodeWhispererClientAdaptor.kt b/plugins/amazonq/codewhisperer/jetbrains-community/src/software/aws/toolkits/jetbrains/services/codewhisperer/credentials/CodeWhispererClientAdaptor.kt index 9d638fc60ae..9d4ac346156 100644 --- a/plugins/amazonq/codewhisperer/jetbrains-community/src/software/aws/toolkits/jetbrains/services/codewhisperer/credentials/CodeWhispererClientAdaptor.kt +++ b/plugins/amazonq/codewhisperer/jetbrains-community/src/software/aws/toolkits/jetbrains/services/codewhisperer/credentials/CodeWhispererClientAdaptor.kt @@ -4,7 +4,6 @@ package software.aws.toolkits.jetbrains.services.codewhisperer.credentials import com.intellij.openapi.Disposable -import com.intellij.openapi.application.ApplicationManager import com.intellij.openapi.components.service import com.intellij.openapi.project.Project import software.amazon.awssdk.auth.credentials.AnonymousCredentialsProvider @@ -40,14 +39,8 @@ import software.amazon.awssdk.services.codewhispererruntime.model.TargetCode import software.amazon.awssdk.services.codewhispererruntime.model.UserIntent import software.aws.toolkits.core.utils.debug import software.aws.toolkits.core.utils.getLogger -import software.aws.toolkits.core.utils.warn import software.aws.toolkits.jetbrains.core.AwsClientManager import software.aws.toolkits.jetbrains.core.awsClient -import software.aws.toolkits.jetbrains.core.credentials.AwsBearerTokenConnection -import software.aws.toolkits.jetbrains.core.credentials.ToolkitConnection -import software.aws.toolkits.jetbrains.core.credentials.ToolkitConnectionManager -import software.aws.toolkits.jetbrains.core.credentials.ToolkitConnectionManagerListener -import software.aws.toolkits.jetbrains.core.credentials.pinning.CodeWhispererConnection import software.aws.toolkits.jetbrains.services.amazonq.codeWhispererUserContext import software.aws.toolkits.jetbrains.services.codewhisperer.customization.CodeWhispererCustomization import software.aws.toolkits.jetbrains.services.codewhisperer.explorer.CodeWhispererExplorerActionManager @@ -66,7 +59,6 @@ import java.util.concurrent.TimeUnit import kotlin.reflect.KProperty0 import kotlin.reflect.jvm.isAccessible -// TODO: move this file to package "/client" // As the connection is project-level, we need to make this project-level too @Deprecated("Methods can throw a NullPointerException if callee does not check if connection is valid") interface CodeWhispererClientAdaptor : Disposable { @@ -295,28 +287,7 @@ open class CodeWhispererClientAdaptorImpl(override val project: Project) : CodeW return (getDelegate() as Lazy<*>).isInitialized() } - init { - initClientUpdateListener() - } - - private fun initClientUpdateListener() { - ApplicationManager.getApplication().messageBus.connect(this).subscribe( - ToolkitConnectionManagerListener.TOPIC, - object : ToolkitConnectionManagerListener { - override fun activeConnectionChanged(newConnection: ToolkitConnection?) { - if (newConnection is AwsBearerTokenConnection) { - myBearerClient = getBearerClient(newConnection.getConnectionSettings().providerId) - } - } - } - ) - } - - private fun bearerClient(): CodeWhispererRuntimeClient { - if (myBearerClient != null) return myBearerClient as CodeWhispererRuntimeClient - myBearerClient = getBearerClient() - return myBearerClient as CodeWhispererRuntimeClient - } + private fun bearerClient(): CodeWhispererRuntimeClient = project.awsClient() override fun generateCompletionsPaginator(firstRequest: GenerateCompletionsRequest) = sequence { var nextToken: String? = firstRequest.nextToken() @@ -866,27 +837,6 @@ open class CodeWhispererClientAdaptorImpl(override val project: Project) : CodeW myBearerClient?.close() } - /** - * Every different SSO/AWS Builder ID connection requires a new client which has its corresponding bearer token provider, - * thus we have to create them dynamically. - * Invalidate and recycle the old client first, and create a new client with the new connection. - * This makes sure when we invoke CW, we always use the up-to-date connection. - * In case this fails to close the client, myBearerClient is already set to null thus next time when we invoke CW, - * it will go through this again which should get the current up-to-date connection. This stale client would be - * unused and stay in memory for a while until eventually closed by ToolkitClientManager. - */ - open fun getBearerClient(oldProviderIdToRemove: String = ""): CodeWhispererRuntimeClient? { - myBearerClient = null - - val connection = ToolkitConnectionManager.getInstance(project).activeConnectionForFeature(CodeWhispererConnection.getInstance()) - connection as? AwsBearerTokenConnection ?: run { - LOG.warn { "$connection is not a bearer token connection" } - return null - } - - return AwsClientManager.getInstance().getClient(connection.getConnectionSettings()) - } - companion object { private val LOG = getLogger() private fun createUnmanagedSigv4Client(): CodeWhispererClient = AwsClientManager.getInstance().createUnmanagedClient( @@ -898,7 +848,6 @@ open class CodeWhispererClientAdaptorImpl(override val project: Project) : CodeW } class MockCodeWhispererClientAdaptor(override val project: Project) : CodeWhispererClientAdaptorImpl(project) { - override fun getBearerClient(oldProviderIdToRemove: String): CodeWhispererRuntimeClient = project.awsClient() override fun dispose() {} } From 0f7660f5ec06e20f8b3b24482e719cb6920c9436 Mon Sep 17 00:00:00 2001 From: Will Lo Date: Tue, 11 Feb 2025 17:57:41 -0800 Subject: [PATCH 2/6] connectionManager --- .../credentials/CodeWhispererClientAdaptor.kt | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/plugins/amazonq/codewhisperer/jetbrains-community/src/software/aws/toolkits/jetbrains/services/codewhisperer/credentials/CodeWhispererClientAdaptor.kt b/plugins/amazonq/codewhisperer/jetbrains-community/src/software/aws/toolkits/jetbrains/services/codewhisperer/credentials/CodeWhispererClientAdaptor.kt index 3084268e6f1..4bcebc06908 100644 --- a/plugins/amazonq/codewhisperer/jetbrains-community/src/software/aws/toolkits/jetbrains/services/codewhisperer/credentials/CodeWhispererClientAdaptor.kt +++ b/plugins/amazonq/codewhisperer/jetbrains-community/src/software/aws/toolkits/jetbrains/services/codewhisperer/credentials/CodeWhispererClientAdaptor.kt @@ -42,6 +42,8 @@ import software.aws.toolkits.core.utils.debug import software.aws.toolkits.core.utils.getLogger import software.aws.toolkits.jetbrains.core.AwsClientManager import software.aws.toolkits.jetbrains.core.awsClient +import software.aws.toolkits.jetbrains.core.credentials.ToolkitConnectionManager +import software.aws.toolkits.jetbrains.core.credentials.pinning.QConnection import software.aws.toolkits.jetbrains.services.amazonq.codeWhispererUserContext import software.aws.toolkits.jetbrains.services.codewhisperer.customization.CodeWhispererCustomization import software.aws.toolkits.jetbrains.services.codewhisperer.explorer.CodeWhispererExplorerActionManager @@ -288,7 +290,10 @@ open class CodeWhispererClientAdaptorImpl(override val project: Project) : CodeW return (getDelegate() as Lazy<*>).isInitialized() } - private fun bearerClient(): CodeWhispererRuntimeClient = project.awsClient() + private fun bearerClient(): CodeWhispererRuntimeClient = + ToolkitConnectionManager.getInstance(project).activeConnectionForFeature(QConnection.getInstance())?.getConnectionSettings() + ?.awsClient() + ?: throw Exception("attempt to get bearer client while there is no valid credential") override fun generateCompletionsPaginator(firstRequest: GenerateCompletionsRequest) = sequence { var nextToken: String? = firstRequest.nextToken() @@ -298,6 +303,7 @@ open class CodeWhispererClientAdaptorImpl(override val project: Project) : CodeW yield(response) } while (!nextToken.isNullOrEmpty()) } + override fun generateCompletions(firstRequest: GenerateCompletionsRequest): GenerateCompletionsResponse = bearerClient().generateCompletions(firstRequest) From 371ab9bd10ae728ee0b77ffa1e1d88583f5f2c62 Mon Sep 17 00:00:00 2001 From: Will Lo Date: Tue, 11 Feb 2025 20:29:10 -0800 Subject: [PATCH 3/6] test --- .../codewhisperer/CodeWhispererClientAdaptorTest.kt | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/plugins/amazonq/codewhisperer/jetbrains-community/tst/software/aws/toolkits/jetbrains/services/codewhisperer/CodeWhispererClientAdaptorTest.kt b/plugins/amazonq/codewhisperer/jetbrains-community/tst/software/aws/toolkits/jetbrains/services/codewhisperer/CodeWhispererClientAdaptorTest.kt index 86ef92f8ccc..59f9f595d0b 100644 --- a/plugins/amazonq/codewhisperer/jetbrains-community/tst/software/aws/toolkits/jetbrains/services/codewhisperer/CodeWhispererClientAdaptorTest.kt +++ b/plugins/amazonq/codewhisperer/jetbrains-community/tst/software/aws/toolkits/jetbrains/services/codewhisperer/CodeWhispererClientAdaptorTest.kt @@ -15,6 +15,7 @@ import org.junit.After import org.junit.Before import org.junit.Rule import org.junit.Test +import org.junit.jupiter.api.assertThrows import org.mockito.kotlin.any import org.mockito.kotlin.argThat import org.mockito.kotlin.argumentCaptor @@ -68,10 +69,12 @@ import software.aws.toolkits.core.TokenConnectionSettings import software.aws.toolkits.core.utils.test.aString import software.aws.toolkits.jetbrains.core.MockClientManagerRule import software.aws.toolkits.jetbrains.core.credentials.AwsBearerTokenConnection +import software.aws.toolkits.jetbrains.core.credentials.DefaultToolkitConnectionManager import software.aws.toolkits.jetbrains.core.credentials.ManagedSsoProfile import software.aws.toolkits.jetbrains.core.credentials.MockCredentialManagerRule import software.aws.toolkits.jetbrains.core.credentials.MockToolkitAuthManagerRule import software.aws.toolkits.jetbrains.core.credentials.ToolkitConnectionManager +import software.aws.toolkits.jetbrains.core.credentials.pinning.QConnection import software.aws.toolkits.jetbrains.core.credentials.sono.SONO_REGION import software.aws.toolkits.jetbrains.services.amazonq.FEATURE_EVALUATION_PRODUCT_NAME import software.aws.toolkits.jetbrains.services.codewhisperer.CodeWhispererTestUtil.metadata @@ -163,6 +166,16 @@ class CodeWhispererClientAdaptorTest { assertThat("us-east-1").isEqualTo(SONO_REGION) } + @Test + fun `should throw if there is no valid credential`() { + projectRule.project.replaceService(ToolkitConnectionManager::class.java, DefaultToolkitConnectionManager(), disposableRule.disposable) + assertThat(ToolkitConnectionManager.getInstance(projectRule.project).activeConnectionForFeature(QConnection.getInstance())).isNull() + + assertThrows("attempt to get bearer client while there is no valid credential") { + sut.listFeatureEvaluations() + } + } + @Test fun `listCustomizations`() { val sdkIterable = ListAvailableCustomizationsIterable(bearerClient, ListAvailableCustomizationsRequest.builder().build()) From 55bdab880d91bfd5e42ff4164744988cc4d83cd8 Mon Sep 17 00:00:00 2001 From: Will Lo Date: Tue, 11 Feb 2025 20:45:54 -0800 Subject: [PATCH 4/6] test failure --- .../services/codewhisperer/CodeWhispererTestBase.kt | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/plugins/amazonq/codewhisperer/jetbrains-community/tst/software/aws/toolkits/jetbrains/services/codewhisperer/CodeWhispererTestBase.kt b/plugins/amazonq/codewhisperer/jetbrains-community/tst/software/aws/toolkits/jetbrains/services/codewhisperer/CodeWhispererTestBase.kt index 6adb57070cd..62c0d6aa937 100644 --- a/plugins/amazonq/codewhisperer/jetbrains-community/tst/software/aws/toolkits/jetbrains/services/codewhisperer/CodeWhispererTestBase.kt +++ b/plugins/amazonq/codewhisperer/jetbrains-community/tst/software/aws/toolkits/jetbrains/services/codewhisperer/CodeWhispererTestBase.kt @@ -31,8 +31,13 @@ import org.mockito.kotlin.whenever import software.amazon.awssdk.services.codewhispererruntime.CodeWhispererRuntimeClient import software.amazon.awssdk.services.codewhispererruntime.model.GenerateCompletionsRequest import software.amazon.awssdk.services.codewhispererruntime.paginators.GenerateCompletionsIterable +import software.amazon.awssdk.services.ssooidc.SsoOidcClient import software.aws.toolkits.jetbrains.core.MockClientManagerRule +import software.aws.toolkits.jetbrains.core.credentials.ManagedSsoProfile import software.aws.toolkits.jetbrains.core.credentials.MockCredentialManagerRule +import software.aws.toolkits.jetbrains.core.credentials.MockToolkitAuthManagerRule +import software.aws.toolkits.jetbrains.core.credentials.ToolkitConnectionManager +import software.aws.toolkits.jetbrains.core.credentials.sono.Q_SCOPES import software.aws.toolkits.jetbrains.services.codewhisperer.CodeWhispererTestUtil.codeWhispererRecommendationActionId import software.aws.toolkits.jetbrains.services.codewhisperer.CodeWhispererTestUtil.pythonFileName import software.aws.toolkits.jetbrains.services.codewhisperer.CodeWhispererTestUtil.pythonResponse @@ -66,10 +71,11 @@ open class CodeWhispererTestBase { val mockClientManagerRule = MockClientManagerRule() val mockCredentialRule = MockCredentialManagerRule() val disposableRule = DisposableRule() + val authManagerRule = MockToolkitAuthManagerRule() @Rule @JvmField - val ruleChain = RuleChain(projectRule, mockCredentialRule, mockClientManagerRule, disposableRule) + val ruleChain = RuleChain(projectRule, mockCredentialRule, mockClientManagerRule, authManagerRule, disposableRule) protected lateinit var mockClient: CodeWhispererRuntimeClient @@ -87,6 +93,7 @@ open class CodeWhispererTestBase { @Before open fun setUp() { mockClient = mockClientManagerRule.create() + mockClientManagerRule.create() val requestCaptor = argumentCaptor() mockClient.stub { on { @@ -163,6 +170,9 @@ open class CodeWhispererTestBase { projectRule.project.replaceService(CodeWhispererClientAdaptor::class.java, clientAdaptorSpy, disposableRule.disposable) ApplicationManager.getApplication().replaceService(CodeWhispererExplorerActionManager::class.java, stateManager, disposableRule.disposable) stateManager.setAutoEnabled(false) + + val conn = authManagerRule.createConnection(ManagedSsoProfile("us-east-1", "url", Q_SCOPES)) + ToolkitConnectionManager.getInstance(projectRule.project).switchConnection(conn) } @After From 4b4d9ba7a86cfa1d311918e82df4ade463b92a06 Mon Sep 17 00:00:00 2001 From: Will Lo Date: Wed, 12 Feb 2025 08:35:59 -0800 Subject: [PATCH 5/6] patch --- .../codewhisperer/credentials/CodeWhispererClientAdaptor.kt | 4 ---- 1 file changed, 4 deletions(-) diff --git a/plugins/amazonq/codewhisperer/jetbrains-community/src/software/aws/toolkits/jetbrains/services/codewhisperer/credentials/CodeWhispererClientAdaptor.kt b/plugins/amazonq/codewhisperer/jetbrains-community/src/software/aws/toolkits/jetbrains/services/codewhisperer/credentials/CodeWhispererClientAdaptor.kt index 4bcebc06908..0d3c16acaef 100644 --- a/plugins/amazonq/codewhisperer/jetbrains-community/src/software/aws/toolkits/jetbrains/services/codewhisperer/credentials/CodeWhispererClientAdaptor.kt +++ b/plugins/amazonq/codewhisperer/jetbrains-community/src/software/aws/toolkits/jetbrains/services/codewhisperer/credentials/CodeWhispererClientAdaptor.kt @@ -281,9 +281,6 @@ interface CodeWhispererClientAdaptor : Disposable { open class CodeWhispererClientAdaptorImpl(override val project: Project) : CodeWhispererClientAdaptor { private val mySigv4Client by lazy { createUnmanagedSigv4Client() } - @Volatile - private var myBearerClient: CodeWhispererRuntimeClient? = null - private val KProperty0<*>.isLazyInitialized: Boolean get() { isAccessible = true @@ -837,7 +834,6 @@ open class CodeWhispererClientAdaptorImpl(override val project: Project) : CodeW if (this::mySigv4Client.isLazyInitialized) { mySigv4Client.close() } - myBearerClient?.close() } companion object { From eab031b6194bd6af331a3efd0bdf854426b19487 Mon Sep 17 00:00:00 2001 From: Will Lo Date: Wed, 12 Feb 2025 09:22:36 -0800 Subject: [PATCH 6/6] p --- .../credentials/CodeWhispererClientAdaptor.kt | 2 +- .../CodeWhispererClientAdaptorTest.kt | 34 ++++++++++++++++--- 2 files changed, 31 insertions(+), 5 deletions(-) diff --git a/plugins/amazonq/codewhisperer/jetbrains-community/src/software/aws/toolkits/jetbrains/services/codewhisperer/credentials/CodeWhispererClientAdaptor.kt b/plugins/amazonq/codewhisperer/jetbrains-community/src/software/aws/toolkits/jetbrains/services/codewhisperer/credentials/CodeWhispererClientAdaptor.kt index 0d3c16acaef..14588a5afbc 100644 --- a/plugins/amazonq/codewhisperer/jetbrains-community/src/software/aws/toolkits/jetbrains/services/codewhisperer/credentials/CodeWhispererClientAdaptor.kt +++ b/plugins/amazonq/codewhisperer/jetbrains-community/src/software/aws/toolkits/jetbrains/services/codewhisperer/credentials/CodeWhispererClientAdaptor.kt @@ -287,7 +287,7 @@ open class CodeWhispererClientAdaptorImpl(override val project: Project) : CodeW return (getDelegate() as Lazy<*>).isInitialized() } - private fun bearerClient(): CodeWhispererRuntimeClient = + fun bearerClient(): CodeWhispererRuntimeClient = ToolkitConnectionManager.getInstance(project).activeConnectionForFeature(QConnection.getInstance())?.getConnectionSettings() ?.awsClient() ?: throw Exception("attempt to get bearer client while there is no valid credential") diff --git a/plugins/amazonq/codewhisperer/jetbrains-community/tst/software/aws/toolkits/jetbrains/services/codewhisperer/CodeWhispererClientAdaptorTest.kt b/plugins/amazonq/codewhisperer/jetbrains-community/tst/software/aws/toolkits/jetbrains/services/codewhisperer/CodeWhispererClientAdaptorTest.kt index 59f9f595d0b..a0bec60e74e 100644 --- a/plugins/amazonq/codewhisperer/jetbrains-community/tst/software/aws/toolkits/jetbrains/services/codewhisperer/CodeWhispererClientAdaptorTest.kt +++ b/plugins/amazonq/codewhisperer/jetbrains-community/tst/software/aws/toolkits/jetbrains/services/codewhisperer/CodeWhispererClientAdaptorTest.kt @@ -74,14 +74,15 @@ import software.aws.toolkits.jetbrains.core.credentials.ManagedSsoProfile import software.aws.toolkits.jetbrains.core.credentials.MockCredentialManagerRule import software.aws.toolkits.jetbrains.core.credentials.MockToolkitAuthManagerRule import software.aws.toolkits.jetbrains.core.credentials.ToolkitConnectionManager +import software.aws.toolkits.jetbrains.core.credentials.logoutFromSsoConnection import software.aws.toolkits.jetbrains.core.credentials.pinning.QConnection +import software.aws.toolkits.jetbrains.core.credentials.sono.Q_SCOPES import software.aws.toolkits.jetbrains.core.credentials.sono.SONO_REGION import software.aws.toolkits.jetbrains.services.amazonq.FEATURE_EVALUATION_PRODUCT_NAME import software.aws.toolkits.jetbrains.services.codewhisperer.CodeWhispererTestUtil.metadata import software.aws.toolkits.jetbrains.services.codewhisperer.CodeWhispererTestUtil.pythonRequest import software.aws.toolkits.jetbrains.services.codewhisperer.CodeWhispererTestUtil.pythonResponseWithToken import software.aws.toolkits.jetbrains.services.codewhisperer.CodeWhispererTestUtil.sdkHttpResponse -import software.aws.toolkits.jetbrains.services.codewhisperer.credentials.CodeWhispererClientAdaptor import software.aws.toolkits.jetbrains.services.codewhisperer.credentials.CodeWhispererClientAdaptorImpl import software.aws.toolkits.jetbrains.services.codewhisperer.customization.CodeWhispererCustomization import software.aws.toolkits.jetbrains.services.codewhisperer.customization.CodeWhispererModelConfigurator @@ -112,7 +113,7 @@ class CodeWhispererClientAdaptorTest { private lateinit var bearerClient: CodeWhispererRuntimeClient private lateinit var ssoClient: SsoOidcClient - private lateinit var sut: CodeWhispererClientAdaptor + private lateinit var sut: CodeWhispererClientAdaptorImpl private lateinit var connectionManager: ToolkitConnectionManager private var isTelemetryEnabledDefault: Boolean = false @@ -167,13 +168,38 @@ class CodeWhispererClientAdaptorTest { } @Test - fun `should throw if there is no valid credential`() { + fun `should throw if there is no valid credential, otherwise return codewhispererRuntimeClient`() { + val connectionManager = DefaultToolkitConnectionManager() projectRule.project.replaceService(ToolkitConnectionManager::class.java, DefaultToolkitConnectionManager(), disposableRule.disposable) + assertThat(ToolkitConnectionManager.getInstance(projectRule.project).activeConnectionForFeature(QConnection.getInstance())).isNull() + assertThrows("attempt to get bearer client while there is no valid credential") { + sut.bearerClient() + } + val qConnection = authManagerRule.createConnection(ManagedSsoProfile("us-east-1", aString(), Q_SCOPES)) + connectionManager.switchConnection(qConnection) + assertThat(connectionManager.activeConnectionForFeature(QConnection.getInstance())) + .isNotNull + .isEqualTo(qConnection) + assertThat(sut.bearerClient()) + .isNotNull + .isInstanceOf(CodeWhispererRuntimeClient::class.java) + + logoutFromSsoConnection(projectRule.project, qConnection as AwsBearerTokenConnection) + assertThat(connectionManager.activeConnectionForFeature(QConnection.getInstance())).isNull() assertThrows("attempt to get bearer client while there is no valid credential") { - sut.listFeatureEvaluations() + sut.bearerClient() } + + val anotherQConnection = authManagerRule.createConnection(ManagedSsoProfile("us-east-1", aString(), Q_SCOPES)) + connectionManager.switchConnection(anotherQConnection) + assertThat(connectionManager.activeConnectionForFeature(QConnection.getInstance())) + .isNotNull + .isEqualTo(anotherQConnection) + assertThat(sut.bearerClient()) + .isNotNull + .isInstanceOf(CodeWhispererRuntimeClient::class.java) } @Test