Skip to content
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
package software.aws.toolkits.jetbrains.services.codewhisperer.credentials

import com.intellij.openapi.Disposable
import com.intellij.openapi.application.ApplicationManager
import com.intellij.openapi.components.service
import com.intellij.openapi.project.Project
import com.intellij.util.text.nullize
Expand Down Expand Up @@ -41,14 +40,8 @@ import software.amazon.awssdk.services.codewhispererruntime.model.TargetCode
import software.amazon.awssdk.services.codewhispererruntime.model.UserIntent
import software.aws.toolkits.core.utils.debug
import software.aws.toolkits.core.utils.getLogger
import software.aws.toolkits.core.utils.warn
import software.aws.toolkits.jetbrains.core.AwsClientManager
import software.aws.toolkits.jetbrains.core.awsClient
import software.aws.toolkits.jetbrains.core.credentials.AwsBearerTokenConnection
import software.aws.toolkits.jetbrains.core.credentials.ToolkitConnection
import software.aws.toolkits.jetbrains.core.credentials.ToolkitConnectionManager
import software.aws.toolkits.jetbrains.core.credentials.ToolkitConnectionManagerListener
import software.aws.toolkits.jetbrains.core.credentials.pinning.CodeWhispererConnection
import software.aws.toolkits.jetbrains.services.amazonq.codeWhispererUserContext
import software.aws.toolkits.jetbrains.services.codewhisperer.customization.CodeWhispererCustomization
import software.aws.toolkits.jetbrains.services.codewhisperer.explorer.CodeWhispererExplorerActionManager
Expand All @@ -67,7 +60,6 @@ import java.util.concurrent.TimeUnit
import kotlin.reflect.KProperty0
import kotlin.reflect.jvm.isAccessible

// TODO: move this file to package "/client"
// As the connection is project-level, we need to make this project-level too
@Deprecated("Methods can throw a NullPointerException if callee does not check if connection is valid")
interface CodeWhispererClientAdaptor : Disposable {
Expand Down Expand Up @@ -296,28 +288,7 @@ open class CodeWhispererClientAdaptorImpl(override val project: Project) : CodeW
return (getDelegate() as Lazy<*>).isInitialized()
}

init {
initClientUpdateListener()
}

private fun initClientUpdateListener() {
ApplicationManager.getApplication().messageBus.connect(this).subscribe(
ToolkitConnectionManagerListener.TOPIC,
object : ToolkitConnectionManagerListener {
override fun activeConnectionChanged(newConnection: ToolkitConnection?) {
if (newConnection is AwsBearerTokenConnection) {
myBearerClient = getBearerClient(newConnection.getConnectionSettings().providerId)
}
}
}
)
}

private fun bearerClient(): CodeWhispererRuntimeClient {
if (myBearerClient != null) return myBearerClient as CodeWhispererRuntimeClient
myBearerClient = getBearerClient()
return myBearerClient as CodeWhispererRuntimeClient
}
private fun bearerClient(): CodeWhispererRuntimeClient = project.awsClient()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i think for safety this still needs to go through connection manager

Copy link
Contributor Author

@Will-ShaoHua Will-ShaoHua Feb 11, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sgtm


override fun generateCompletionsPaginator(firstRequest: GenerateCompletionsRequest) = sequence<GenerateCompletionsResponse> {
var nextToken: String? = firstRequest.nextToken()
Expand Down Expand Up @@ -863,27 +834,6 @@ open class CodeWhispererClientAdaptorImpl(override val project: Project) : CodeW
myBearerClient?.close()
}

/**
* Every different SSO/AWS Builder ID connection requires a new client which has its corresponding bearer token provider,
* thus we have to create them dynamically.
* Invalidate and recycle the old client first, and create a new client with the new connection.
* This makes sure when we invoke CW, we always use the up-to-date connection.
* In case this fails to close the client, myBearerClient is already set to null thus next time when we invoke CW,
* it will go through this again which should get the current up-to-date connection. This stale client would be
* unused and stay in memory for a while until eventually closed by ToolkitClientManager.
*/
open fun getBearerClient(oldProviderIdToRemove: String = ""): CodeWhispererRuntimeClient? {
myBearerClient = null

val connection = ToolkitConnectionManager.getInstance(project).activeConnectionForFeature(CodeWhispererConnection.getInstance())
connection as? AwsBearerTokenConnection ?: run {
LOG.warn { "$connection is not a bearer token connection" }
return null
}

return AwsClientManager.getInstance().getClient<CodeWhispererRuntimeClient>(connection.getConnectionSettings())
}

companion object {
private val LOG = getLogger<CodeWhispererClientAdaptorImpl>()
private fun createUnmanagedSigv4Client(): CodeWhispererClient = AwsClientManager.getInstance().createUnmanagedClient(
Expand All @@ -895,7 +845,6 @@ open class CodeWhispererClientAdaptorImpl(override val project: Project) : CodeW
}

class MockCodeWhispererClientAdaptor(override val project: Project) : CodeWhispererClientAdaptorImpl(project) {
override fun getBearerClient(oldProviderIdToRemove: String): CodeWhispererRuntimeClient = project.awsClient()
override fun dispose() {}
}

Expand Down
Loading