diff --git a/plugins/amazonq/codewhisperer/jetbrains-community/src/software/aws/toolkits/jetbrains/services/codewhisperer/credentials/CodeWhispererClientAdaptor.kt b/plugins/amazonq/codewhisperer/jetbrains-community/src/software/aws/toolkits/jetbrains/services/codewhisperer/credentials/CodeWhispererClientAdaptor.kt index e2e55b86430..c72f480355a 100644 --- a/plugins/amazonq/codewhisperer/jetbrains-community/src/software/aws/toolkits/jetbrains/services/codewhisperer/credentials/CodeWhispererClientAdaptor.kt +++ b/plugins/amazonq/codewhisperer/jetbrains-community/src/software/aws/toolkits/jetbrains/services/codewhisperer/credentials/CodeWhispererClientAdaptor.kt @@ -4,7 +4,6 @@ package software.aws.toolkits.jetbrains.services.codewhisperer.credentials import com.intellij.openapi.Disposable -import com.intellij.openapi.application.ApplicationManager import com.intellij.openapi.components.service import com.intellij.openapi.project.Project import com.intellij.util.text.nullize @@ -41,14 +40,10 @@ import software.amazon.awssdk.services.codewhispererruntime.model.TargetCode import software.amazon.awssdk.services.codewhispererruntime.model.UserIntent import software.aws.toolkits.core.utils.debug import software.aws.toolkits.core.utils.getLogger -import software.aws.toolkits.core.utils.warn import software.aws.toolkits.jetbrains.core.AwsClientManager import software.aws.toolkits.jetbrains.core.awsClient -import software.aws.toolkits.jetbrains.core.credentials.AwsBearerTokenConnection -import software.aws.toolkits.jetbrains.core.credentials.ToolkitConnection import software.aws.toolkits.jetbrains.core.credentials.ToolkitConnectionManager -import software.aws.toolkits.jetbrains.core.credentials.ToolkitConnectionManagerListener -import software.aws.toolkits.jetbrains.core.credentials.pinning.CodeWhispererConnection +import software.aws.toolkits.jetbrains.core.credentials.pinning.QConnection import software.aws.toolkits.jetbrains.services.amazonq.codeWhispererUserContext import software.aws.toolkits.jetbrains.services.codewhisperer.customization.CodeWhispererCustomization import software.aws.toolkits.jetbrains.services.codewhisperer.explorer.CodeWhispererExplorerActionManager @@ -67,7 +62,6 @@ import java.util.concurrent.TimeUnit import kotlin.reflect.KProperty0 import kotlin.reflect.jvm.isAccessible -// TODO: move this file to package "/client" // As the connection is project-level, we need to make this project-level too @Deprecated("Methods can throw a NullPointerException if callee does not check if connection is valid") interface CodeWhispererClientAdaptor : Disposable { @@ -283,37 +277,16 @@ interface CodeWhispererClientAdaptor : Disposable { open class CodeWhispererClientAdaptorImpl(override val project: Project) : CodeWhispererClientAdaptor { private val mySigv4Client by lazy { createUnmanagedSigv4Client() } - @Volatile - private var myBearerClient: CodeWhispererRuntimeClient? = null - private val KProperty0<*>.isLazyInitialized: Boolean get() { isAccessible = true return (getDelegate() as Lazy<*>).isInitialized() } - init { - initClientUpdateListener() - } - - private fun initClientUpdateListener() { - ApplicationManager.getApplication().messageBus.connect(this).subscribe( - ToolkitConnectionManagerListener.TOPIC, - object : ToolkitConnectionManagerListener { - override fun activeConnectionChanged(newConnection: ToolkitConnection?) { - if (newConnection is AwsBearerTokenConnection) { - myBearerClient = getBearerClient(newConnection.getConnectionSettings().providerId) - } - } - } - ) - } - - private fun bearerClient(): CodeWhispererRuntimeClient { - if (myBearerClient != null) return myBearerClient as CodeWhispererRuntimeClient - myBearerClient = getBearerClient() - return myBearerClient as CodeWhispererRuntimeClient - } + fun bearerClient(): CodeWhispererRuntimeClient = + ToolkitConnectionManager.getInstance(project).activeConnectionForFeature(QConnection.getInstance())?.getConnectionSettings() + ?.awsClient() + ?: throw Exception("attempt to get bearer client while there is no valid credential") override fun generateCompletionsPaginator(firstRequest: GenerateCompletionsRequest) = sequence { var nextToken: String? = firstRequest.nextToken() @@ -854,28 +827,6 @@ open class CodeWhispererClientAdaptorImpl(override val project: Project) : CodeW if (this::mySigv4Client.isLazyInitialized) { mySigv4Client.close() } - myBearerClient?.close() - } - - /** - * Every different SSO/AWS Builder ID connection requires a new client which has its corresponding bearer token provider, - * thus we have to create them dynamically. - * Invalidate and recycle the old client first, and create a new client with the new connection. - * This makes sure when we invoke CW, we always use the up-to-date connection. - * In case this fails to close the client, myBearerClient is already set to null thus next time when we invoke CW, - * it will go through this again which should get the current up-to-date connection. This stale client would be - * unused and stay in memory for a while until eventually closed by ToolkitClientManager. - */ - open fun getBearerClient(oldProviderIdToRemove: String = ""): CodeWhispererRuntimeClient? { - myBearerClient = null - - val connection = ToolkitConnectionManager.getInstance(project).activeConnectionForFeature(CodeWhispererConnection.getInstance()) - connection as? AwsBearerTokenConnection ?: run { - LOG.warn { "$connection is not a bearer token connection" } - return null - } - - return AwsClientManager.getInstance().getClient(connection.getConnectionSettings()) } companion object { @@ -889,7 +840,6 @@ open class CodeWhispererClientAdaptorImpl(override val project: Project) : CodeW } class MockCodeWhispererClientAdaptor(override val project: Project) : CodeWhispererClientAdaptorImpl(project) { - override fun getBearerClient(oldProviderIdToRemove: String): CodeWhispererRuntimeClient = project.awsClient() override fun dispose() {} } diff --git a/plugins/amazonq/codewhisperer/jetbrains-community/tst/software/aws/toolkits/jetbrains/services/codewhisperer/CodeWhispererClientAdaptorTest.kt b/plugins/amazonq/codewhisperer/jetbrains-community/tst/software/aws/toolkits/jetbrains/services/codewhisperer/CodeWhispererClientAdaptorTest.kt index 86ef92f8ccc..a0bec60e74e 100644 --- a/plugins/amazonq/codewhisperer/jetbrains-community/tst/software/aws/toolkits/jetbrains/services/codewhisperer/CodeWhispererClientAdaptorTest.kt +++ b/plugins/amazonq/codewhisperer/jetbrains-community/tst/software/aws/toolkits/jetbrains/services/codewhisperer/CodeWhispererClientAdaptorTest.kt @@ -15,6 +15,7 @@ import org.junit.After import org.junit.Before import org.junit.Rule import org.junit.Test +import org.junit.jupiter.api.assertThrows import org.mockito.kotlin.any import org.mockito.kotlin.argThat import org.mockito.kotlin.argumentCaptor @@ -68,17 +69,20 @@ import software.aws.toolkits.core.TokenConnectionSettings import software.aws.toolkits.core.utils.test.aString import software.aws.toolkits.jetbrains.core.MockClientManagerRule import software.aws.toolkits.jetbrains.core.credentials.AwsBearerTokenConnection +import software.aws.toolkits.jetbrains.core.credentials.DefaultToolkitConnectionManager import software.aws.toolkits.jetbrains.core.credentials.ManagedSsoProfile import software.aws.toolkits.jetbrains.core.credentials.MockCredentialManagerRule import software.aws.toolkits.jetbrains.core.credentials.MockToolkitAuthManagerRule import software.aws.toolkits.jetbrains.core.credentials.ToolkitConnectionManager +import software.aws.toolkits.jetbrains.core.credentials.logoutFromSsoConnection +import software.aws.toolkits.jetbrains.core.credentials.pinning.QConnection +import software.aws.toolkits.jetbrains.core.credentials.sono.Q_SCOPES import software.aws.toolkits.jetbrains.core.credentials.sono.SONO_REGION import software.aws.toolkits.jetbrains.services.amazonq.FEATURE_EVALUATION_PRODUCT_NAME import software.aws.toolkits.jetbrains.services.codewhisperer.CodeWhispererTestUtil.metadata import software.aws.toolkits.jetbrains.services.codewhisperer.CodeWhispererTestUtil.pythonRequest import software.aws.toolkits.jetbrains.services.codewhisperer.CodeWhispererTestUtil.pythonResponseWithToken import software.aws.toolkits.jetbrains.services.codewhisperer.CodeWhispererTestUtil.sdkHttpResponse -import software.aws.toolkits.jetbrains.services.codewhisperer.credentials.CodeWhispererClientAdaptor import software.aws.toolkits.jetbrains.services.codewhisperer.credentials.CodeWhispererClientAdaptorImpl import software.aws.toolkits.jetbrains.services.codewhisperer.customization.CodeWhispererCustomization import software.aws.toolkits.jetbrains.services.codewhisperer.customization.CodeWhispererModelConfigurator @@ -109,7 +113,7 @@ class CodeWhispererClientAdaptorTest { private lateinit var bearerClient: CodeWhispererRuntimeClient private lateinit var ssoClient: SsoOidcClient - private lateinit var sut: CodeWhispererClientAdaptor + private lateinit var sut: CodeWhispererClientAdaptorImpl private lateinit var connectionManager: ToolkitConnectionManager private var isTelemetryEnabledDefault: Boolean = false @@ -163,6 +167,41 @@ class CodeWhispererClientAdaptorTest { assertThat("us-east-1").isEqualTo(SONO_REGION) } + @Test + fun `should throw if there is no valid credential, otherwise return codewhispererRuntimeClient`() { + val connectionManager = DefaultToolkitConnectionManager() + projectRule.project.replaceService(ToolkitConnectionManager::class.java, DefaultToolkitConnectionManager(), disposableRule.disposable) + + assertThat(ToolkitConnectionManager.getInstance(projectRule.project).activeConnectionForFeature(QConnection.getInstance())).isNull() + assertThrows("attempt to get bearer client while there is no valid credential") { + sut.bearerClient() + } + + val qConnection = authManagerRule.createConnection(ManagedSsoProfile("us-east-1", aString(), Q_SCOPES)) + connectionManager.switchConnection(qConnection) + assertThat(connectionManager.activeConnectionForFeature(QConnection.getInstance())) + .isNotNull + .isEqualTo(qConnection) + assertThat(sut.bearerClient()) + .isNotNull + .isInstanceOf(CodeWhispererRuntimeClient::class.java) + + logoutFromSsoConnection(projectRule.project, qConnection as AwsBearerTokenConnection) + assertThat(connectionManager.activeConnectionForFeature(QConnection.getInstance())).isNull() + assertThrows("attempt to get bearer client while there is no valid credential") { + sut.bearerClient() + } + + val anotherQConnection = authManagerRule.createConnection(ManagedSsoProfile("us-east-1", aString(), Q_SCOPES)) + connectionManager.switchConnection(anotherQConnection) + assertThat(connectionManager.activeConnectionForFeature(QConnection.getInstance())) + .isNotNull + .isEqualTo(anotherQConnection) + assertThat(sut.bearerClient()) + .isNotNull + .isInstanceOf(CodeWhispererRuntimeClient::class.java) + } + @Test fun `listCustomizations`() { val sdkIterable = ListAvailableCustomizationsIterable(bearerClient, ListAvailableCustomizationsRequest.builder().build()) diff --git a/plugins/amazonq/codewhisperer/jetbrains-community/tst/software/aws/toolkits/jetbrains/services/codewhisperer/CodeWhispererTestBase.kt b/plugins/amazonq/codewhisperer/jetbrains-community/tst/software/aws/toolkits/jetbrains/services/codewhisperer/CodeWhispererTestBase.kt index aa0dd9f2270..a3e97f60c38 100644 --- a/plugins/amazonq/codewhisperer/jetbrains-community/tst/software/aws/toolkits/jetbrains/services/codewhisperer/CodeWhispererTestBase.kt +++ b/plugins/amazonq/codewhisperer/jetbrains-community/tst/software/aws/toolkits/jetbrains/services/codewhisperer/CodeWhispererTestBase.kt @@ -30,8 +30,13 @@ import org.mockito.kotlin.verify import software.amazon.awssdk.services.codewhispererruntime.CodeWhispererRuntimeClient import software.amazon.awssdk.services.codewhispererruntime.model.GenerateCompletionsRequest import software.amazon.awssdk.services.codewhispererruntime.paginators.GenerateCompletionsIterable +import software.amazon.awssdk.services.ssooidc.SsoOidcClient import software.aws.toolkits.jetbrains.core.MockClientManagerRule +import software.aws.toolkits.jetbrains.core.credentials.ManagedSsoProfile import software.aws.toolkits.jetbrains.core.credentials.MockCredentialManagerRule +import software.aws.toolkits.jetbrains.core.credentials.MockToolkitAuthManagerRule +import software.aws.toolkits.jetbrains.core.credentials.ToolkitConnectionManager +import software.aws.toolkits.jetbrains.core.credentials.sono.Q_SCOPES import software.aws.toolkits.jetbrains.services.codewhisperer.CodeWhispererTestUtil.codeWhispererRecommendationActionId import software.aws.toolkits.jetbrains.services.codewhisperer.CodeWhispererTestUtil.pythonFileName import software.aws.toolkits.jetbrains.services.codewhisperer.CodeWhispererTestUtil.pythonResponse @@ -65,10 +70,11 @@ open class CodeWhispererTestBase { val mockClientManagerRule = MockClientManagerRule() val mockCredentialRule = MockCredentialManagerRule() val disposableRule = DisposableRule() + val authManagerRule = MockToolkitAuthManagerRule() @Rule @JvmField - val ruleChain = RuleChain(projectRule, mockCredentialRule, mockClientManagerRule, disposableRule) + val ruleChain = RuleChain(projectRule, mockCredentialRule, mockClientManagerRule, authManagerRule, disposableRule) protected lateinit var mockClient: CodeWhispererRuntimeClient @@ -86,6 +92,7 @@ open class CodeWhispererTestBase { @Before open fun setUp() { mockClient = mockClientManagerRule.create() + mockClientManagerRule.create() val requestCaptor = argumentCaptor() mockClient.stub { on { @@ -159,6 +166,9 @@ open class CodeWhispererTestBase { projectRule.project.replaceService(CodeWhispererClientAdaptor::class.java, clientAdaptorSpy, disposableRule.disposable) ApplicationManager.getApplication().replaceService(CodeWhispererExplorerActionManager::class.java, stateManager, disposableRule.disposable) stateManager.setAutoEnabled(false) + + val conn = authManagerRule.createConnection(ManagedSsoProfile("us-east-1", "url", Q_SCOPES)) + ToolkitConnectionManager.getInstance(projectRule.project).switchConnection(conn) } @After