Skip to content
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
package software.aws.toolkits.jetbrains.services.codewhisperer.credentials

import com.intellij.openapi.Disposable
import com.intellij.openapi.application.ApplicationManager
import com.intellij.openapi.components.service
import com.intellij.openapi.project.Project
import com.intellij.util.text.nullize
Expand Down Expand Up @@ -41,14 +40,10 @@ import software.amazon.awssdk.services.codewhispererruntime.model.TargetCode
import software.amazon.awssdk.services.codewhispererruntime.model.UserIntent
import software.aws.toolkits.core.utils.debug
import software.aws.toolkits.core.utils.getLogger
import software.aws.toolkits.core.utils.warn
import software.aws.toolkits.jetbrains.core.AwsClientManager
import software.aws.toolkits.jetbrains.core.awsClient
import software.aws.toolkits.jetbrains.core.credentials.AwsBearerTokenConnection
import software.aws.toolkits.jetbrains.core.credentials.ToolkitConnection
import software.aws.toolkits.jetbrains.core.credentials.ToolkitConnectionManager
import software.aws.toolkits.jetbrains.core.credentials.ToolkitConnectionManagerListener
import software.aws.toolkits.jetbrains.core.credentials.pinning.CodeWhispererConnection
import software.aws.toolkits.jetbrains.core.credentials.pinning.QConnection
import software.aws.toolkits.jetbrains.services.amazonq.codeWhispererUserContext
import software.aws.toolkits.jetbrains.services.codewhisperer.customization.CodeWhispererCustomization
import software.aws.toolkits.jetbrains.services.codewhisperer.explorer.CodeWhispererExplorerActionManager
Expand All @@ -67,7 +62,6 @@ import java.util.concurrent.TimeUnit
import kotlin.reflect.KProperty0
import kotlin.reflect.jvm.isAccessible

// TODO: move this file to package "/client"
// As the connection is project-level, we need to make this project-level too
@Deprecated("Methods can throw a NullPointerException if callee does not check if connection is valid")
interface CodeWhispererClientAdaptor : Disposable {
Expand Down Expand Up @@ -283,37 +277,16 @@ interface CodeWhispererClientAdaptor : Disposable {
open class CodeWhispererClientAdaptorImpl(override val project: Project) : CodeWhispererClientAdaptor {
private val mySigv4Client by lazy { createUnmanagedSigv4Client() }

@Volatile
private var myBearerClient: CodeWhispererRuntimeClient? = null

private val KProperty0<*>.isLazyInitialized: Boolean
get() {
isAccessible = true
return (getDelegate() as Lazy<*>).isInitialized()
}

init {
initClientUpdateListener()
}

private fun initClientUpdateListener() {
ApplicationManager.getApplication().messageBus.connect(this).subscribe(
ToolkitConnectionManagerListener.TOPIC,
object : ToolkitConnectionManagerListener {
override fun activeConnectionChanged(newConnection: ToolkitConnection?) {
if (newConnection is AwsBearerTokenConnection) {
myBearerClient = getBearerClient(newConnection.getConnectionSettings().providerId)
}
}
}
)
}

private fun bearerClient(): CodeWhispererRuntimeClient {
if (myBearerClient != null) return myBearerClient as CodeWhispererRuntimeClient
myBearerClient = getBearerClient()
return myBearerClient as CodeWhispererRuntimeClient
}
fun bearerClient(): CodeWhispererRuntimeClient =
ToolkitConnectionManager.getInstance(project).activeConnectionForFeature(QConnection.getInstance())?.getConnectionSettings()
?.awsClient<CodeWhispererRuntimeClient>()
?: throw Exception("attempt to get bearer client while there is no valid credential")

override fun generateCompletionsPaginator(firstRequest: GenerateCompletionsRequest) = sequence<GenerateCompletionsResponse> {
var nextToken: String? = firstRequest.nextToken()
Expand Down Expand Up @@ -854,28 +827,6 @@ open class CodeWhispererClientAdaptorImpl(override val project: Project) : CodeW
if (this::mySigv4Client.isLazyInitialized) {
mySigv4Client.close()
}
myBearerClient?.close()
}

/**
* Every different SSO/AWS Builder ID connection requires a new client which has its corresponding bearer token provider,
* thus we have to create them dynamically.
* Invalidate and recycle the old client first, and create a new client with the new connection.
* This makes sure when we invoke CW, we always use the up-to-date connection.
* In case this fails to close the client, myBearerClient is already set to null thus next time when we invoke CW,
* it will go through this again which should get the current up-to-date connection. This stale client would be
* unused and stay in memory for a while until eventually closed by ToolkitClientManager.
*/
open fun getBearerClient(oldProviderIdToRemove: String = ""): CodeWhispererRuntimeClient? {
myBearerClient = null

val connection = ToolkitConnectionManager.getInstance(project).activeConnectionForFeature(CodeWhispererConnection.getInstance())
connection as? AwsBearerTokenConnection ?: run {
LOG.warn { "$connection is not a bearer token connection" }
return null
}

return AwsClientManager.getInstance().getClient<CodeWhispererRuntimeClient>(connection.getConnectionSettings())
}

companion object {
Expand All @@ -889,7 +840,6 @@ open class CodeWhispererClientAdaptorImpl(override val project: Project) : CodeW
}

class MockCodeWhispererClientAdaptor(override val project: Project) : CodeWhispererClientAdaptorImpl(project) {
override fun getBearerClient(oldProviderIdToRemove: String): CodeWhispererRuntimeClient = project.awsClient()
override fun dispose() {}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ import org.junit.After
import org.junit.Before
import org.junit.Rule
import org.junit.Test
import org.junit.jupiter.api.assertThrows
import org.mockito.kotlin.any
import org.mockito.kotlin.argThat
import org.mockito.kotlin.argumentCaptor
Expand Down Expand Up @@ -68,17 +69,20 @@ import software.aws.toolkits.core.TokenConnectionSettings
import software.aws.toolkits.core.utils.test.aString
import software.aws.toolkits.jetbrains.core.MockClientManagerRule
import software.aws.toolkits.jetbrains.core.credentials.AwsBearerTokenConnection
import software.aws.toolkits.jetbrains.core.credentials.DefaultToolkitConnectionManager
import software.aws.toolkits.jetbrains.core.credentials.ManagedSsoProfile
import software.aws.toolkits.jetbrains.core.credentials.MockCredentialManagerRule
import software.aws.toolkits.jetbrains.core.credentials.MockToolkitAuthManagerRule
import software.aws.toolkits.jetbrains.core.credentials.ToolkitConnectionManager
import software.aws.toolkits.jetbrains.core.credentials.logoutFromSsoConnection
import software.aws.toolkits.jetbrains.core.credentials.pinning.QConnection
import software.aws.toolkits.jetbrains.core.credentials.sono.Q_SCOPES
import software.aws.toolkits.jetbrains.core.credentials.sono.SONO_REGION
import software.aws.toolkits.jetbrains.services.amazonq.FEATURE_EVALUATION_PRODUCT_NAME
import software.aws.toolkits.jetbrains.services.codewhisperer.CodeWhispererTestUtil.metadata
import software.aws.toolkits.jetbrains.services.codewhisperer.CodeWhispererTestUtil.pythonRequest
import software.aws.toolkits.jetbrains.services.codewhisperer.CodeWhispererTestUtil.pythonResponseWithToken
import software.aws.toolkits.jetbrains.services.codewhisperer.CodeWhispererTestUtil.sdkHttpResponse
import software.aws.toolkits.jetbrains.services.codewhisperer.credentials.CodeWhispererClientAdaptor
import software.aws.toolkits.jetbrains.services.codewhisperer.credentials.CodeWhispererClientAdaptorImpl
import software.aws.toolkits.jetbrains.services.codewhisperer.customization.CodeWhispererCustomization
import software.aws.toolkits.jetbrains.services.codewhisperer.customization.CodeWhispererModelConfigurator
Expand Down Expand Up @@ -109,7 +113,7 @@ class CodeWhispererClientAdaptorTest {
private lateinit var bearerClient: CodeWhispererRuntimeClient
private lateinit var ssoClient: SsoOidcClient

private lateinit var sut: CodeWhispererClientAdaptor
private lateinit var sut: CodeWhispererClientAdaptorImpl
private lateinit var connectionManager: ToolkitConnectionManager
private var isTelemetryEnabledDefault: Boolean = false

Expand Down Expand Up @@ -163,6 +167,41 @@ class CodeWhispererClientAdaptorTest {
assertThat("us-east-1").isEqualTo(SONO_REGION)
}

@Test
fun `should throw if there is no valid credential, otherwise return codewhispererRuntimeClient`() {
val connectionManager = DefaultToolkitConnectionManager()
projectRule.project.replaceService(ToolkitConnectionManager::class.java, DefaultToolkitConnectionManager(), disposableRule.disposable)

assertThat(ToolkitConnectionManager.getInstance(projectRule.project).activeConnectionForFeature(QConnection.getInstance())).isNull()
assertThrows<Exception>("attempt to get bearer client while there is no valid credential") {
sut.bearerClient()
}

val qConnection = authManagerRule.createConnection(ManagedSsoProfile("us-east-1", aString(), Q_SCOPES))
connectionManager.switchConnection(qConnection)
assertThat(connectionManager.activeConnectionForFeature(QConnection.getInstance()))
.isNotNull
.isEqualTo(qConnection)
assertThat(sut.bearerClient())
.isNotNull
.isInstanceOf(CodeWhispererRuntimeClient::class.java)

logoutFromSsoConnection(projectRule.project, qConnection as AwsBearerTokenConnection)
assertThat(connectionManager.activeConnectionForFeature(QConnection.getInstance())).isNull()
assertThrows<Exception>("attempt to get bearer client while there is no valid credential") {
sut.bearerClient()
}

val anotherQConnection = authManagerRule.createConnection(ManagedSsoProfile("us-east-1", aString(), Q_SCOPES))
connectionManager.switchConnection(anotherQConnection)
assertThat(connectionManager.activeConnectionForFeature(QConnection.getInstance()))
.isNotNull
.isEqualTo(anotherQConnection)
assertThat(sut.bearerClient())
.isNotNull
.isInstanceOf(CodeWhispererRuntimeClient::class.java)
}

@Test
fun `listCustomizations`() {
val sdkIterable = ListAvailableCustomizationsIterable(bearerClient, ListAvailableCustomizationsRequest.builder().build())
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,13 @@ import org.mockito.kotlin.verify
import software.amazon.awssdk.services.codewhispererruntime.CodeWhispererRuntimeClient
import software.amazon.awssdk.services.codewhispererruntime.model.GenerateCompletionsRequest
import software.amazon.awssdk.services.codewhispererruntime.paginators.GenerateCompletionsIterable
import software.amazon.awssdk.services.ssooidc.SsoOidcClient
import software.aws.toolkits.jetbrains.core.MockClientManagerRule
import software.aws.toolkits.jetbrains.core.credentials.ManagedSsoProfile
import software.aws.toolkits.jetbrains.core.credentials.MockCredentialManagerRule
import software.aws.toolkits.jetbrains.core.credentials.MockToolkitAuthManagerRule
import software.aws.toolkits.jetbrains.core.credentials.ToolkitConnectionManager
import software.aws.toolkits.jetbrains.core.credentials.sono.Q_SCOPES
import software.aws.toolkits.jetbrains.services.codewhisperer.CodeWhispererTestUtil.codeWhispererRecommendationActionId
import software.aws.toolkits.jetbrains.services.codewhisperer.CodeWhispererTestUtil.pythonFileName
import software.aws.toolkits.jetbrains.services.codewhisperer.CodeWhispererTestUtil.pythonResponse
Expand Down Expand Up @@ -65,10 +70,11 @@ open class CodeWhispererTestBase {
val mockClientManagerRule = MockClientManagerRule()
val mockCredentialRule = MockCredentialManagerRule()
val disposableRule = DisposableRule()
val authManagerRule = MockToolkitAuthManagerRule()

@Rule
@JvmField
val ruleChain = RuleChain(projectRule, mockCredentialRule, mockClientManagerRule, disposableRule)
val ruleChain = RuleChain(projectRule, mockCredentialRule, mockClientManagerRule, authManagerRule, disposableRule)

protected lateinit var mockClient: CodeWhispererRuntimeClient

Expand All @@ -86,6 +92,7 @@ open class CodeWhispererTestBase {
@Before
open fun setUp() {
mockClient = mockClientManagerRule.create()
mockClientManagerRule.create<SsoOidcClient>()
val requestCaptor = argumentCaptor<GenerateCompletionsRequest>()
mockClient.stub {
on {
Expand Down Expand Up @@ -159,6 +166,9 @@ open class CodeWhispererTestBase {
projectRule.project.replaceService(CodeWhispererClientAdaptor::class.java, clientAdaptorSpy, disposableRule.disposable)
ApplicationManager.getApplication().replaceService(CodeWhispererExplorerActionManager::class.java, stateManager, disposableRule.disposable)
stateManager.setAutoEnabled(false)

val conn = authManagerRule.createConnection(ManagedSsoProfile("us-east-1", "url", Q_SCOPES))
ToolkitConnectionManager.getInstance(projectRule.project).switchConnection(conn)
}

@After
Expand Down
Loading