diff --git a/plugins/amazonq/codewhisperer/jetbrains-community/resources/META-INF/plugin-codewhisperer.xml b/plugins/amazonq/codewhisperer/jetbrains-community/resources/META-INF/plugin-codewhisperer.xml
index 7b9bb7b9eca..07d8f29e713 100644
--- a/plugins/amazonq/codewhisperer/jetbrains-community/resources/META-INF/plugin-codewhisperer.xml
+++ b/plugins/amazonq/codewhisperer/jetbrains-community/resources/META-INF/plugin-codewhisperer.xml
@@ -33,8 +33,7 @@
serviceImplementation="software.aws.toolkits.jetbrains.services.codewhisperer.customization.DefaultCodeWhispererModelConfigurator"/>
+ serviceImplementation="software.aws.toolkits.jetbrains.services.codewhisperer.credentials.CodeWhispererClientAdaptorImpl"/>
diff --git a/plugins/amazonq/codewhisperer/jetbrains-community/src/software/aws/toolkits/jetbrains/services/codewhisperer/credentials/CodeWhispererClientAdaptor.kt b/plugins/amazonq/codewhisperer/jetbrains-community/src/software/aws/toolkits/jetbrains/services/codewhisperer/credentials/CodeWhispererClientAdaptor.kt
index 17cc419e1a6..36e60c6aa9f 100644
--- a/plugins/amazonq/codewhisperer/jetbrains-community/src/software/aws/toolkits/jetbrains/services/codewhisperer/credentials/CodeWhispererClientAdaptor.kt
+++ b/plugins/amazonq/codewhisperer/jetbrains-community/src/software/aws/toolkits/jetbrains/services/codewhisperer/credentials/CodeWhispererClientAdaptor.kt
@@ -3,8 +3,6 @@
package software.aws.toolkits.jetbrains.services.codewhisperer.credentials
-import com.intellij.openapi.Disposable
-import com.intellij.openapi.application.ApplicationManager
import com.intellij.openapi.components.service
import com.intellij.openapi.project.Project
import com.intellij.util.text.nullize
@@ -39,14 +37,9 @@ import software.amazon.awssdk.services.codewhispererruntime.model.TargetCode
import software.amazon.awssdk.services.codewhispererruntime.model.UserIntent
import software.aws.toolkits.core.utils.debug
import software.aws.toolkits.core.utils.getLogger
-import software.aws.toolkits.core.utils.warn
-import software.aws.toolkits.jetbrains.core.AwsClientManager
import software.aws.toolkits.jetbrains.core.awsClient
-import software.aws.toolkits.jetbrains.core.credentials.AwsBearerTokenConnection
-import software.aws.toolkits.jetbrains.core.credentials.ToolkitConnection
import software.aws.toolkits.jetbrains.core.credentials.ToolkitConnectionManager
-import software.aws.toolkits.jetbrains.core.credentials.ToolkitConnectionManagerListener
-import software.aws.toolkits.jetbrains.core.credentials.pinning.CodeWhispererConnection
+import software.aws.toolkits.jetbrains.core.credentials.pinning.QConnection
import software.aws.toolkits.jetbrains.services.amazonq.codeWhispererUserContext
import software.aws.toolkits.jetbrains.services.codewhisperer.customization.CodeWhispererCustomization
import software.aws.toolkits.jetbrains.services.codewhisperer.language.CodeWhispererProgrammingLanguage
@@ -62,8 +55,11 @@ import java.time.Instant
import java.util.concurrent.TimeUnit
// As the connection is project-level, we need to make this project-level too
-@Deprecated("Methods can throw a NullPointerException if callee does not check if connection is valid")
-interface CodeWhispererClientAdaptor : Disposable {
+@Deprecated(
+ "It was needed as we were supporting two service models (sigv4 & bearer), " +
+ "it's no longer the case as we remove sigv4 support, should use AwsClientManager.getClient() directly"
+)
+interface CodeWhispererClientAdaptor {
val project: Project
fun generateCompletionsPaginator(
@@ -261,32 +257,11 @@ interface CodeWhispererClientAdaptor : Disposable {
}
}
-open class CodeWhispererClientAdaptorImpl(override val project: Project) : CodeWhispererClientAdaptor {
- @Volatile
- private var myBearerClient: CodeWhispererRuntimeClient? = null
-
- init {
- initClientUpdateListener()
- }
-
- private fun initClientUpdateListener() {
- ApplicationManager.getApplication().messageBus.connect(this).subscribe(
- ToolkitConnectionManagerListener.TOPIC,
- object : ToolkitConnectionManagerListener {
- override fun activeConnectionChanged(newConnection: ToolkitConnection?) {
- if (newConnection is AwsBearerTokenConnection) {
- myBearerClient = getBearerClient(newConnection.getConnectionSettings().providerId)
- }
- }
- }
- )
- }
-
- private fun bearerClient(): CodeWhispererRuntimeClient {
- if (myBearerClient != null) return myBearerClient as CodeWhispererRuntimeClient
- myBearerClient = getBearerClient()
- return myBearerClient as CodeWhispererRuntimeClient
- }
+class CodeWhispererClientAdaptorImpl(override val project: Project) : CodeWhispererClientAdaptor {
+ fun bearerClient(): CodeWhispererRuntimeClient =
+ ToolkitConnectionManager.getInstance(project).activeConnectionForFeature(QConnection.getInstance())?.getConnectionSettings()
+ ?.awsClient()
+ ?: throw Exception("attempt to get bearer client while there is no valid credential")
override fun generateCompletionsPaginator(firstRequest: GenerateCompletionsRequest) = sequence {
var nextToken: String? = firstRequest.nextToken()
@@ -809,41 +784,11 @@ open class CodeWhispererClientAdaptorImpl(override val project: Project) : CodeW
requestBuilder.userContext(codeWhispererUserContext())
}
- override fun dispose() {
- myBearerClient?.close()
- }
-
- /**
- * Every different SSO/AWS Builder ID connection requires a new client which has its corresponding bearer token provider,
- * thus we have to create them dynamically.
- * Invalidate and recycle the old client first, and create a new client with the new connection.
- * This makes sure when we invoke CW, we always use the up-to-date connection.
- * In case this fails to close the client, myBearerClient is already set to null thus next time when we invoke CW,
- * it will go through this again which should get the current up-to-date connection. This stale client would be
- * unused and stay in memory for a while until eventually closed by ToolkitClientManager.
- */
- open fun getBearerClient(oldProviderIdToRemove: String = ""): CodeWhispererRuntimeClient? {
- myBearerClient = null
-
- val connection = ToolkitConnectionManager.getInstance(project).activeConnectionForFeature(CodeWhispererConnection.getInstance())
- connection as? AwsBearerTokenConnection ?: run {
- LOG.warn { "$connection is not a bearer token connection" }
- return null
- }
-
- return AwsClientManager.getInstance().getClient(connection.getConnectionSettings())
- }
-
companion object {
private val LOG = getLogger()
}
}
-class MockCodeWhispererClientAdaptor(override val project: Project) : CodeWhispererClientAdaptorImpl(project) {
- override fun getBearerClient(oldProviderIdToRemove: String): CodeWhispererRuntimeClient = project.awsClient()
- override fun dispose() {}
-}
-
private fun CodewhispererSuggestionState.toCodeWhispererSdkType() = when {
this == CodewhispererSuggestionState.Accept -> SuggestionState.ACCEPT
this == CodewhispererSuggestionState.Reject -> SuggestionState.REJECT
diff --git a/plugins/amazonq/codewhisperer/jetbrains-community/tst/software/aws/toolkits/jetbrains/services/codewhisperer/CodeWhispererClientAdaptorTest.kt b/plugins/amazonq/codewhisperer/jetbrains-community/tst/software/aws/toolkits/jetbrains/services/codewhisperer/CodeWhispererClientAdaptorTest.kt
index 268caa2ab63..9a3a455d6e1 100644
--- a/plugins/amazonq/codewhisperer/jetbrains-community/tst/software/aws/toolkits/jetbrains/services/codewhisperer/CodeWhispererClientAdaptorTest.kt
+++ b/plugins/amazonq/codewhisperer/jetbrains-community/tst/software/aws/toolkits/jetbrains/services/codewhisperer/CodeWhispererClientAdaptorTest.kt
@@ -4,7 +4,6 @@
package software.aws.toolkits.jetbrains.services.codewhisperer
import com.intellij.openapi.application.ApplicationManager
-import com.intellij.openapi.util.Disposer
import com.intellij.openapi.util.SystemInfo
import com.intellij.testFramework.DisposableRule
import com.intellij.testFramework.RuleChain
@@ -15,6 +14,7 @@ import org.junit.After
import org.junit.Before
import org.junit.Rule
import org.junit.Test
+import org.junit.jupiter.api.assertThrows
import org.mockito.kotlin.any
import org.mockito.kotlin.argThat
import org.mockito.kotlin.argumentCaptor
@@ -24,7 +24,6 @@ import org.mockito.kotlin.mock
import org.mockito.kotlin.stub
import org.mockito.kotlin.times
import org.mockito.kotlin.verify
-import org.mockito.kotlin.whenever
import software.amazon.awssdk.services.codewhispererruntime.CodeWhispererRuntimeClient
import software.amazon.awssdk.services.codewhispererruntime.model.ArtifactType
import software.amazon.awssdk.services.codewhispererruntime.model.CodeAnalysisFindingsSchema
@@ -54,7 +53,6 @@ import software.amazon.awssdk.services.codewhispererruntime.model.SuggestionStat
import software.amazon.awssdk.services.codewhispererruntime.paginators.GenerateCompletionsIterable
import software.amazon.awssdk.services.codewhispererruntime.paginators.ListAvailableCustomizationsIterable
import software.amazon.awssdk.services.ssooidc.SsoOidcClient
-import software.aws.toolkits.core.TokenConnectionSettings
import software.aws.toolkits.core.utils.test.aString
import software.aws.toolkits.jetbrains.core.MockClientManagerRule
import software.aws.toolkits.jetbrains.core.credentials.AwsBearerTokenConnection
@@ -62,13 +60,15 @@ import software.aws.toolkits.jetbrains.core.credentials.ManagedSsoProfile
import software.aws.toolkits.jetbrains.core.credentials.MockCredentialManagerRule
import software.aws.toolkits.jetbrains.core.credentials.MockToolkitAuthManagerRule
import software.aws.toolkits.jetbrains.core.credentials.ToolkitConnectionManager
+import software.aws.toolkits.jetbrains.core.credentials.logoutFromSsoConnection
+import software.aws.toolkits.jetbrains.core.credentials.pinning.QConnection
+import software.aws.toolkits.jetbrains.core.credentials.sono.Q_SCOPES
import software.aws.toolkits.jetbrains.core.credentials.sono.SONO_REGION
import software.aws.toolkits.jetbrains.services.amazonq.FEATURE_EVALUATION_PRODUCT_NAME
import software.aws.toolkits.jetbrains.services.codewhisperer.CodeWhispererTestUtil.metadata
import software.aws.toolkits.jetbrains.services.codewhisperer.CodeWhispererTestUtil.pythonRequest
import software.aws.toolkits.jetbrains.services.codewhisperer.CodeWhispererTestUtil.pythonResponseWithToken
import software.aws.toolkits.jetbrains.services.codewhisperer.CodeWhispererTestUtil.sdkHttpResponse
-import software.aws.toolkits.jetbrains.services.codewhisperer.credentials.CodeWhispererClientAdaptor
import software.aws.toolkits.jetbrains.services.codewhisperer.credentials.CodeWhispererClientAdaptorImpl
import software.aws.toolkits.jetbrains.services.codewhisperer.customization.CodeWhispererCustomization
import software.aws.toolkits.jetbrains.services.codewhisperer.customization.CodeWhispererModelConfigurator
@@ -93,13 +93,12 @@ class CodeWhispererClientAdaptorTest {
@Rule
@JvmField
- val ruleChain = RuleChain(projectRule, mockCredentialRule, mockClientManagerRule, disposableRule)
+ val ruleChain = RuleChain(projectRule, mockCredentialRule, mockClientManagerRule, authManagerRule, disposableRule)
private lateinit var bearerClient: CodeWhispererRuntimeClient
private lateinit var ssoClient: SsoOidcClient
- private lateinit var sut: CodeWhispererClientAdaptor
- private lateinit var connectionManager: ToolkitConnectionManager
+ private lateinit var sut: CodeWhispererClientAdaptorImpl
private var isTelemetryEnabledDefault: Boolean = false
@Before
@@ -117,15 +116,8 @@ class CodeWhispererClientAdaptorTest {
on { listFeatureEvaluations(any()) } doReturn listFeatureEvaluationsResponse
}
- val mockConnection = mock()
- whenever(mockConnection.getConnectionSettings()) doReturn mock()
-
- connectionManager = mock {
- on {
- activeConnectionForFeature(any())
- } doReturn authManagerRule.createConnection(ManagedSsoProfile("us-east-1", aString(), listOf("scopes"))) as AwsBearerTokenConnection
- }
- projectRule.project.replaceService(ToolkitConnectionManager::class.java, connectionManager, disposableRule.disposable)
+ val conn = authManagerRule.createConnection(ManagedSsoProfile("us-east-1", "url", Q_SCOPES))
+ ToolkitConnectionManager.getInstance(projectRule.project).switchConnection(conn)
isTelemetryEnabledDefault = AwsSettings.getInstance().isTelemetryEnabled
}
@@ -135,16 +127,37 @@ class CodeWhispererClientAdaptorTest {
AwsSettings.getInstance().isTelemetryEnabled = isTelemetryEnabledDefault
}
- @After
- fun cleanup() {
- Disposer.dispose(sut)
- }
-
@Test
fun `Sono region is us-east-1`() {
assertThat("us-east-1").isEqualTo(SONO_REGION)
}
+ @Test
+ fun `should throw if there is no valid credential, otherwise return codewhispererRuntimeClient`() {
+ val connectionManager = ToolkitConnectionManager.getInstance(projectRule.project)
+
+ assertThat(connectionManager.activeConnectionForFeature(QConnection.getInstance()))
+ .isNotNull
+ assertThat(sut.bearerClient())
+ .isNotNull
+ .isInstanceOf(CodeWhispererRuntimeClient::class.java)
+
+ logoutFromSsoConnection(projectRule.project, connectionManager.activeConnectionForFeature(QConnection.getInstance()) as AwsBearerTokenConnection)
+ assertThat(connectionManager.activeConnectionForFeature(QConnection.getInstance())).isNull()
+ assertThrows("attempt to get bearer client while there is no valid credential") {
+ sut.bearerClient()
+ }
+
+ val anotherQConnection = authManagerRule.createConnection(ManagedSsoProfile("us-east-1", aString(), Q_SCOPES))
+ connectionManager.switchConnection(anotherQConnection)
+ assertThat(connectionManager.activeConnectionForFeature(QConnection.getInstance()))
+ .isNotNull
+ .isEqualTo(anotherQConnection)
+ assertThat(sut.bearerClient())
+ .isNotNull
+ .isInstanceOf(CodeWhispererRuntimeClient::class.java)
+ }
+
@Test
fun `listCustomizations`() {
val sdkIterable = ListAvailableCustomizationsIterable(bearerClient, ListAvailableCustomizationsRequest.builder().build())
diff --git a/plugins/amazonq/codewhisperer/jetbrains-community/tst/software/aws/toolkits/jetbrains/services/codewhisperer/CodeWhispererSettingsTest.kt b/plugins/amazonq/codewhisperer/jetbrains-community/tst/software/aws/toolkits/jetbrains/services/codewhisperer/CodeWhispererSettingsTest.kt
index 63bc70ae2b8..94474eefa6a 100644
--- a/plugins/amazonq/codewhisperer/jetbrains-community/tst/software/aws/toolkits/jetbrains/services/codewhisperer/CodeWhispererSettingsTest.kt
+++ b/plugins/amazonq/codewhisperer/jetbrains-community/tst/software/aws/toolkits/jetbrains/services/codewhisperer/CodeWhispererSettingsTest.kt
@@ -105,7 +105,7 @@ class CodeWhispererSettingsTest : CodeWhispererTestBase() {
stateManager.loadState(CodeWhispererExploreActionState())
CodeWhispererSettings.getInstance().loadState(CodeWhispererConfiguration())
- val problemsWindow = ProblemsView.getToolWindow(projectRule.project) ?: fail("Problems window not found")
+ ProblemsView.getToolWindow(projectRule.project) ?: fail("Problems window not found")
val codeReferenceWindow = ToolWindowManager.getInstance(projectRule.project).getToolWindow(
CodeWhispererCodeReferenceToolWindowFactory.id
) ?: fail("Code Reference Log window not found")
@@ -114,7 +114,6 @@ class CodeWhispererSettingsTest : CodeWhispererTestBase() {
} ?: fail("CodeWhisperer status bar widget not found")
runInEdtAndWait {
- assertThat(problemsWindow.contentManager.contentCount).isEqualTo(0)
assertThat(codeReferenceWindow.isAvailable).isFalse
assertThat(statusBarWidgetFactory.isAvailable(projectRule.project)).isTrue
assertThat(settingsManager.isIncludeCodeWithReference()).isFalse
diff --git a/plugins/amazonq/codewhisperer/jetbrains-community/tst/software/aws/toolkits/jetbrains/services/codewhisperer/CodeWhispererTestBase.kt b/plugins/amazonq/codewhisperer/jetbrains-community/tst/software/aws/toolkits/jetbrains/services/codewhisperer/CodeWhispererTestBase.kt
index aa0dd9f2270..a3e97f60c38 100644
--- a/plugins/amazonq/codewhisperer/jetbrains-community/tst/software/aws/toolkits/jetbrains/services/codewhisperer/CodeWhispererTestBase.kt
+++ b/plugins/amazonq/codewhisperer/jetbrains-community/tst/software/aws/toolkits/jetbrains/services/codewhisperer/CodeWhispererTestBase.kt
@@ -30,8 +30,13 @@ import org.mockito.kotlin.verify
import software.amazon.awssdk.services.codewhispererruntime.CodeWhispererRuntimeClient
import software.amazon.awssdk.services.codewhispererruntime.model.GenerateCompletionsRequest
import software.amazon.awssdk.services.codewhispererruntime.paginators.GenerateCompletionsIterable
+import software.amazon.awssdk.services.ssooidc.SsoOidcClient
import software.aws.toolkits.jetbrains.core.MockClientManagerRule
+import software.aws.toolkits.jetbrains.core.credentials.ManagedSsoProfile
import software.aws.toolkits.jetbrains.core.credentials.MockCredentialManagerRule
+import software.aws.toolkits.jetbrains.core.credentials.MockToolkitAuthManagerRule
+import software.aws.toolkits.jetbrains.core.credentials.ToolkitConnectionManager
+import software.aws.toolkits.jetbrains.core.credentials.sono.Q_SCOPES
import software.aws.toolkits.jetbrains.services.codewhisperer.CodeWhispererTestUtil.codeWhispererRecommendationActionId
import software.aws.toolkits.jetbrains.services.codewhisperer.CodeWhispererTestUtil.pythonFileName
import software.aws.toolkits.jetbrains.services.codewhisperer.CodeWhispererTestUtil.pythonResponse
@@ -65,10 +70,11 @@ open class CodeWhispererTestBase {
val mockClientManagerRule = MockClientManagerRule()
val mockCredentialRule = MockCredentialManagerRule()
val disposableRule = DisposableRule()
+ val authManagerRule = MockToolkitAuthManagerRule()
@Rule
@JvmField
- val ruleChain = RuleChain(projectRule, mockCredentialRule, mockClientManagerRule, disposableRule)
+ val ruleChain = RuleChain(projectRule, mockCredentialRule, mockClientManagerRule, authManagerRule, disposableRule)
protected lateinit var mockClient: CodeWhispererRuntimeClient
@@ -86,6 +92,7 @@ open class CodeWhispererTestBase {
@Before
open fun setUp() {
mockClient = mockClientManagerRule.create()
+ mockClientManagerRule.create()
val requestCaptor = argumentCaptor()
mockClient.stub {
on {
@@ -159,6 +166,9 @@ open class CodeWhispererTestBase {
projectRule.project.replaceService(CodeWhispererClientAdaptor::class.java, clientAdaptorSpy, disposableRule.disposable)
ApplicationManager.getApplication().replaceService(CodeWhispererExplorerActionManager::class.java, stateManager, disposableRule.disposable)
stateManager.setAutoEnabled(false)
+
+ val conn = authManagerRule.createConnection(ManagedSsoProfile("us-east-1", "url", Q_SCOPES))
+ ToolkitConnectionManager.getInstance(projectRule.project).switchConnection(conn)
}
@After