44package software.aws.toolkits.jetbrains.services.codewhisperer.credentials
55
66import com.intellij.openapi.Disposable
7+ import com.intellij.openapi.application.ApplicationManager
78import com.intellij.openapi.components.service
89import com.intellij.openapi.project.Project
910import com.intellij.util.text.nullize
@@ -40,10 +41,14 @@ import software.amazon.awssdk.services.codewhispererruntime.model.TargetCode
4041import software.amazon.awssdk.services.codewhispererruntime.model.UserIntent
4142import software.aws.toolkits.core.utils.debug
4243import software.aws.toolkits.core.utils.getLogger
44+ import software.aws.toolkits.core.utils.warn
4345import software.aws.toolkits.jetbrains.core.AwsClientManager
4446import software.aws.toolkits.jetbrains.core.awsClient
47+ import software.aws.toolkits.jetbrains.core.credentials.AwsBearerTokenConnection
48+ import software.aws.toolkits.jetbrains.core.credentials.ToolkitConnection
4549import software.aws.toolkits.jetbrains.core.credentials.ToolkitConnectionManager
46- import software.aws.toolkits.jetbrains.core.credentials.pinning.QConnection
50+ import software.aws.toolkits.jetbrains.core.credentials.ToolkitConnectionManagerListener
51+ import software.aws.toolkits.jetbrains.core.credentials.pinning.CodeWhispererConnection
4752import software.aws.toolkits.jetbrains.services.amazonq.codeWhispererUserContext
4853import software.aws.toolkits.jetbrains.services.codewhisperer.customization.CodeWhispererCustomization
4954import software.aws.toolkits.jetbrains.services.codewhisperer.explorer.CodeWhispererExplorerActionManager
@@ -62,6 +67,7 @@ import java.util.concurrent.TimeUnit
6267import kotlin.reflect.KProperty0
6368import kotlin.reflect.jvm.isAccessible
6469
70+ // TODO: move this file to package "/client"
6571// As the connection is project-level, we need to make this project-level too
6672@Deprecated(" Methods can throw a NullPointerException if callee does not check if connection is valid" )
6773interface CodeWhispererClientAdaptor : Disposable {
@@ -277,16 +283,37 @@ interface CodeWhispererClientAdaptor : Disposable {
277283open class CodeWhispererClientAdaptorImpl (override val project : Project ) : CodeWhispererClientAdaptor {
278284 private val mySigv4Client by lazy { createUnmanagedSigv4Client() }
279285
286+ @Volatile
287+ private var myBearerClient: CodeWhispererRuntimeClient ? = null
288+
280289 private val KProperty0 <* >.isLazyInitialized: Boolean
281290 get() {
282291 isAccessible = true
283292 return (getDelegate() as Lazy <* >).isInitialized()
284293 }
285294
286- fun bearerClient (): CodeWhispererRuntimeClient =
287- ToolkitConnectionManager .getInstance(project).activeConnectionForFeature(QConnection .getInstance())?.getConnectionSettings()
288- ?.awsClient<CodeWhispererRuntimeClient >()
289- ? : throw Exception (" attempt to get bearer client while there is no valid credential" )
295+ init {
296+ initClientUpdateListener()
297+ }
298+
299+ private fun initClientUpdateListener () {
300+ ApplicationManager .getApplication().messageBus.connect(this ).subscribe(
301+ ToolkitConnectionManagerListener .TOPIC ,
302+ object : ToolkitConnectionManagerListener {
303+ override fun activeConnectionChanged (newConnection : ToolkitConnection ? ) {
304+ if (newConnection is AwsBearerTokenConnection ) {
305+ myBearerClient = getBearerClient(newConnection.getConnectionSettings().providerId)
306+ }
307+ }
308+ }
309+ )
310+ }
311+
312+ private fun bearerClient (): CodeWhispererRuntimeClient {
313+ if (myBearerClient != null ) return myBearerClient as CodeWhispererRuntimeClient
314+ myBearerClient = getBearerClient()
315+ return myBearerClient as CodeWhispererRuntimeClient
316+ }
290317
291318 override fun generateCompletionsPaginator (firstRequest : GenerateCompletionsRequest ) = sequence<GenerateCompletionsResponse > {
292319 var nextToken: String? = firstRequest.nextToken()
@@ -827,6 +854,28 @@ open class CodeWhispererClientAdaptorImpl(override val project: Project) : CodeW
827854 if (this ::mySigv4Client.isLazyInitialized) {
828855 mySigv4Client.close()
829856 }
857+ myBearerClient?.close()
858+ }
859+
860+ /* *
861+ * Every different SSO/AWS Builder ID connection requires a new client which has its corresponding bearer token provider,
862+ * thus we have to create them dynamically.
863+ * Invalidate and recycle the old client first, and create a new client with the new connection.
864+ * This makes sure when we invoke CW, we always use the up-to-date connection.
865+ * In case this fails to close the client, myBearerClient is already set to null thus next time when we invoke CW,
866+ * it will go through this again which should get the current up-to-date connection. This stale client would be
867+ * unused and stay in memory for a while until eventually closed by ToolkitClientManager.
868+ */
869+ open fun getBearerClient (oldProviderIdToRemove : String = ""): CodeWhispererRuntimeClient ? {
870+ myBearerClient = null
871+
872+ val connection = ToolkitConnectionManager .getInstance(project).activeConnectionForFeature(CodeWhispererConnection .getInstance())
873+ connection as ? AwsBearerTokenConnection ? : run {
874+ LOG .warn { " $connection is not a bearer token connection" }
875+ return null
876+ }
877+
878+ return AwsClientManager .getInstance().getClient<CodeWhispererRuntimeClient >(connection.getConnectionSettings())
830879 }
831880
832881 companion object {
@@ -840,6 +889,7 @@ open class CodeWhispererClientAdaptorImpl(override val project: Project) : CodeW
840889}
841890
842891class MockCodeWhispererClientAdaptor (override val project : Project ) : CodeWhispererClientAdaptorImpl(project) {
892+ override fun getBearerClient (oldProviderIdToRemove : String ): CodeWhispererRuntimeClient = project.awsClient()
843893 override fun dispose () {}
844894}
845895
0 commit comments