Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
package software.aws.toolkits.jetbrains.services.codewhisperer.credentials

import com.intellij.openapi.Disposable
import com.intellij.openapi.application.ApplicationManager
import com.intellij.openapi.components.service
import com.intellij.openapi.project.Project
import com.intellij.util.text.nullize
Expand Down Expand Up @@ -40,10 +41,14 @@ import software.amazon.awssdk.services.codewhispererruntime.model.TargetCode
import software.amazon.awssdk.services.codewhispererruntime.model.UserIntent
import software.aws.toolkits.core.utils.debug
import software.aws.toolkits.core.utils.getLogger
import software.aws.toolkits.core.utils.warn
import software.aws.toolkits.jetbrains.core.AwsClientManager
import software.aws.toolkits.jetbrains.core.awsClient
import software.aws.toolkits.jetbrains.core.credentials.AwsBearerTokenConnection
import software.aws.toolkits.jetbrains.core.credentials.ToolkitConnection
import software.aws.toolkits.jetbrains.core.credentials.ToolkitConnectionManager
import software.aws.toolkits.jetbrains.core.credentials.pinning.QConnection
import software.aws.toolkits.jetbrains.core.credentials.ToolkitConnectionManagerListener
import software.aws.toolkits.jetbrains.core.credentials.pinning.CodeWhispererConnection
import software.aws.toolkits.jetbrains.services.amazonq.codeWhispererUserContext
import software.aws.toolkits.jetbrains.services.codewhisperer.customization.CodeWhispererCustomization
import software.aws.toolkits.jetbrains.services.codewhisperer.explorer.CodeWhispererExplorerActionManager
Expand All @@ -62,6 +67,7 @@ import java.util.concurrent.TimeUnit
import kotlin.reflect.KProperty0
import kotlin.reflect.jvm.isAccessible

// TODO: move this file to package "/client"
// As the connection is project-level, we need to make this project-level too
@Deprecated("Methods can throw a NullPointerException if callee does not check if connection is valid")
interface CodeWhispererClientAdaptor : Disposable {
Expand Down Expand Up @@ -277,16 +283,37 @@ interface CodeWhispererClientAdaptor : Disposable {
open class CodeWhispererClientAdaptorImpl(override val project: Project) : CodeWhispererClientAdaptor {
private val mySigv4Client by lazy { createUnmanagedSigv4Client() }

@Volatile
private var myBearerClient: CodeWhispererRuntimeClient? = null

private val KProperty0<*>.isLazyInitialized: Boolean
get() {
isAccessible = true
return (getDelegate() as Lazy<*>).isInitialized()
}

fun bearerClient(): CodeWhispererRuntimeClient =
ToolkitConnectionManager.getInstance(project).activeConnectionForFeature(QConnection.getInstance())?.getConnectionSettings()
?.awsClient<CodeWhispererRuntimeClient>()
?: throw Exception("attempt to get bearer client while there is no valid credential")
init {
initClientUpdateListener()
}

private fun initClientUpdateListener() {
ApplicationManager.getApplication().messageBus.connect(this).subscribe(
ToolkitConnectionManagerListener.TOPIC,
object : ToolkitConnectionManagerListener {
override fun activeConnectionChanged(newConnection: ToolkitConnection?) {
if (newConnection is AwsBearerTokenConnection) {
myBearerClient = getBearerClient(newConnection.getConnectionSettings().providerId)
}
}
}
)
}

private fun bearerClient(): CodeWhispererRuntimeClient {
if (myBearerClient != null) return myBearerClient as CodeWhispererRuntimeClient
myBearerClient = getBearerClient()
return myBearerClient as CodeWhispererRuntimeClient
}

override fun generateCompletionsPaginator(firstRequest: GenerateCompletionsRequest) = sequence<GenerateCompletionsResponse> {
var nextToken: String? = firstRequest.nextToken()
Expand Down Expand Up @@ -827,6 +854,28 @@ open class CodeWhispererClientAdaptorImpl(override val project: Project) : CodeW
if (this::mySigv4Client.isLazyInitialized) {
mySigv4Client.close()
}
myBearerClient?.close()
}

/**
* Every different SSO/AWS Builder ID connection requires a new client which has its corresponding bearer token provider,
* thus we have to create them dynamically.
* Invalidate and recycle the old client first, and create a new client with the new connection.
* This makes sure when we invoke CW, we always use the up-to-date connection.
* In case this fails to close the client, myBearerClient is already set to null thus next time when we invoke CW,
* it will go through this again which should get the current up-to-date connection. This stale client would be
* unused and stay in memory for a while until eventually closed by ToolkitClientManager.
*/
open fun getBearerClient(oldProviderIdToRemove: String = ""): CodeWhispererRuntimeClient? {
myBearerClient = null

val connection = ToolkitConnectionManager.getInstance(project).activeConnectionForFeature(CodeWhispererConnection.getInstance())
connection as? AwsBearerTokenConnection ?: run {
LOG.warn { "$connection is not a bearer token connection" }
return null
}

return AwsClientManager.getInstance().getClient<CodeWhispererRuntimeClient>(connection.getConnectionSettings())
}

companion object {
Expand All @@ -840,6 +889,7 @@ open class CodeWhispererClientAdaptorImpl(override val project: Project) : CodeW
}

class MockCodeWhispererClientAdaptor(override val project: Project) : CodeWhispererClientAdaptorImpl(project) {
override fun getBearerClient(oldProviderIdToRemove: String): CodeWhispererRuntimeClient = project.awsClient()
override fun dispose() {}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@ import org.junit.After
import org.junit.Before
import org.junit.Rule
import org.junit.Test
import org.junit.jupiter.api.assertThrows
import org.mockito.kotlin.any
import org.mockito.kotlin.argThat
import org.mockito.kotlin.argumentCaptor
Expand Down Expand Up @@ -69,20 +68,17 @@ import software.aws.toolkits.core.TokenConnectionSettings
import software.aws.toolkits.core.utils.test.aString
import software.aws.toolkits.jetbrains.core.MockClientManagerRule
import software.aws.toolkits.jetbrains.core.credentials.AwsBearerTokenConnection
import software.aws.toolkits.jetbrains.core.credentials.DefaultToolkitConnectionManager
import software.aws.toolkits.jetbrains.core.credentials.ManagedSsoProfile
import software.aws.toolkits.jetbrains.core.credentials.MockCredentialManagerRule
import software.aws.toolkits.jetbrains.core.credentials.MockToolkitAuthManagerRule
import software.aws.toolkits.jetbrains.core.credentials.ToolkitConnectionManager
import software.aws.toolkits.jetbrains.core.credentials.logoutFromSsoConnection
import software.aws.toolkits.jetbrains.core.credentials.pinning.QConnection
import software.aws.toolkits.jetbrains.core.credentials.sono.Q_SCOPES
import software.aws.toolkits.jetbrains.core.credentials.sono.SONO_REGION
import software.aws.toolkits.jetbrains.services.amazonq.FEATURE_EVALUATION_PRODUCT_NAME
import software.aws.toolkits.jetbrains.services.codewhisperer.CodeWhispererTestUtil.metadata
import software.aws.toolkits.jetbrains.services.codewhisperer.CodeWhispererTestUtil.pythonRequest
import software.aws.toolkits.jetbrains.services.codewhisperer.CodeWhispererTestUtil.pythonResponseWithToken
import software.aws.toolkits.jetbrains.services.codewhisperer.CodeWhispererTestUtil.sdkHttpResponse
import software.aws.toolkits.jetbrains.services.codewhisperer.credentials.CodeWhispererClientAdaptor
import software.aws.toolkits.jetbrains.services.codewhisperer.credentials.CodeWhispererClientAdaptorImpl
import software.aws.toolkits.jetbrains.services.codewhisperer.customization.CodeWhispererCustomization
import software.aws.toolkits.jetbrains.services.codewhisperer.customization.CodeWhispererModelConfigurator
Expand Down Expand Up @@ -113,7 +109,7 @@ class CodeWhispererClientAdaptorTest {
private lateinit var bearerClient: CodeWhispererRuntimeClient
private lateinit var ssoClient: SsoOidcClient

private lateinit var sut: CodeWhispererClientAdaptorImpl
private lateinit var sut: CodeWhispererClientAdaptor
private lateinit var connectionManager: ToolkitConnectionManager
private var isTelemetryEnabledDefault: Boolean = false

Expand Down Expand Up @@ -167,41 +163,6 @@ class CodeWhispererClientAdaptorTest {
assertThat("us-east-1").isEqualTo(SONO_REGION)
}

@Test
fun `should throw if there is no valid credential, otherwise return codewhispererRuntimeClient`() {
val connectionManager = DefaultToolkitConnectionManager()
projectRule.project.replaceService(ToolkitConnectionManager::class.java, DefaultToolkitConnectionManager(), disposableRule.disposable)

assertThat(ToolkitConnectionManager.getInstance(projectRule.project).activeConnectionForFeature(QConnection.getInstance())).isNull()
assertThrows<Exception>("attempt to get bearer client while there is no valid credential") {
sut.bearerClient()
}

val qConnection = authManagerRule.createConnection(ManagedSsoProfile("us-east-1", aString(), Q_SCOPES))
connectionManager.switchConnection(qConnection)
assertThat(connectionManager.activeConnectionForFeature(QConnection.getInstance()))
.isNotNull
.isEqualTo(qConnection)
assertThat(sut.bearerClient())
.isNotNull
.isInstanceOf(CodeWhispererRuntimeClient::class.java)

logoutFromSsoConnection(projectRule.project, qConnection as AwsBearerTokenConnection)
assertThat(connectionManager.activeConnectionForFeature(QConnection.getInstance())).isNull()
assertThrows<Exception>("attempt to get bearer client while there is no valid credential") {
sut.bearerClient()
}

val anotherQConnection = authManagerRule.createConnection(ManagedSsoProfile("us-east-1", aString(), Q_SCOPES))
connectionManager.switchConnection(anotherQConnection)
assertThat(connectionManager.activeConnectionForFeature(QConnection.getInstance()))
.isNotNull
.isEqualTo(anotherQConnection)
assertThat(sut.bearerClient())
.isNotNull
.isInstanceOf(CodeWhispererRuntimeClient::class.java)
}

@Test
fun `listCustomizations`() {
val sdkIterable = ListAvailableCustomizationsIterable(bearerClient, ListAvailableCustomizationsRequest.builder().build())
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,13 +30,8 @@ import org.mockito.kotlin.verify
import software.amazon.awssdk.services.codewhispererruntime.CodeWhispererRuntimeClient
import software.amazon.awssdk.services.codewhispererruntime.model.GenerateCompletionsRequest
import software.amazon.awssdk.services.codewhispererruntime.paginators.GenerateCompletionsIterable
import software.amazon.awssdk.services.ssooidc.SsoOidcClient
import software.aws.toolkits.jetbrains.core.MockClientManagerRule
import software.aws.toolkits.jetbrains.core.credentials.ManagedSsoProfile
import software.aws.toolkits.jetbrains.core.credentials.MockCredentialManagerRule
import software.aws.toolkits.jetbrains.core.credentials.MockToolkitAuthManagerRule
import software.aws.toolkits.jetbrains.core.credentials.ToolkitConnectionManager
import software.aws.toolkits.jetbrains.core.credentials.sono.Q_SCOPES
import software.aws.toolkits.jetbrains.services.codewhisperer.CodeWhispererTestUtil.codeWhispererRecommendationActionId
import software.aws.toolkits.jetbrains.services.codewhisperer.CodeWhispererTestUtil.pythonFileName
import software.aws.toolkits.jetbrains.services.codewhisperer.CodeWhispererTestUtil.pythonResponse
Expand Down Expand Up @@ -70,11 +65,10 @@ open class CodeWhispererTestBase {
val mockClientManagerRule = MockClientManagerRule()
val mockCredentialRule = MockCredentialManagerRule()
val disposableRule = DisposableRule()
val authManagerRule = MockToolkitAuthManagerRule()

@Rule
@JvmField
val ruleChain = RuleChain(projectRule, mockCredentialRule, mockClientManagerRule, authManagerRule, disposableRule)
val ruleChain = RuleChain(projectRule, mockCredentialRule, mockClientManagerRule, disposableRule)

protected lateinit var mockClient: CodeWhispererRuntimeClient

Expand All @@ -92,7 +86,6 @@ open class CodeWhispererTestBase {
@Before
open fun setUp() {
mockClient = mockClientManagerRule.create()
mockClientManagerRule.create<SsoOidcClient>()
val requestCaptor = argumentCaptor<GenerateCompletionsRequest>()
mockClient.stub {
on {
Expand Down Expand Up @@ -166,9 +159,6 @@ open class CodeWhispererTestBase {
projectRule.project.replaceService(CodeWhispererClientAdaptor::class.java, clientAdaptorSpy, disposableRule.disposable)
ApplicationManager.getApplication().replaceService(CodeWhispererExplorerActionManager::class.java, stateManager, disposableRule.disposable)
stateManager.setAutoEnabled(false)

val conn = authManagerRule.createConnection(ManagedSsoProfile("us-east-1", "url", Q_SCOPES))
ToolkitConnectionManager.getInstance(projectRule.project).switchConnection(conn)
}

@After
Expand Down
Loading