44package software.aws.toolkits.jetbrains.services.amazonq.profile
55
66import com.intellij.openapi.Disposable
7+ import com.intellij.openapi.application.ApplicationManager
78import com.intellij.openapi.components.BaseState
89import com.intellij.openapi.components.PersistentStateComponent
910import com.intellij.openapi.components.Service
1011import com.intellij.openapi.components.State
1112import com.intellij.openapi.components.Storage
1213import com.intellij.openapi.components.service
1314import com.intellij.openapi.project.Project
15+ import com.intellij.util.concurrency.annotations.RequiresBackgroundThread
1416import com.intellij.util.xmlb.annotations.MapAnnotation
1517import com.intellij.util.xmlb.annotations.Property
1618import software.amazon.awssdk.core.SdkClient
@@ -25,6 +27,7 @@ import software.aws.toolkits.jetbrains.core.credentials.AwsBearerTokenConnection
2527import software.aws.toolkits.jetbrains.core.credentials.ToolkitConnectionManager
2628import software.aws.toolkits.jetbrains.core.credentials.pinning.QConnection
2729import software.aws.toolkits.jetbrains.core.credentials.sono.isSono
30+ import software.aws.toolkits.jetbrains.core.credentials.sso.bearer.BearerTokenProviderListener
2831import software.aws.toolkits.jetbrains.core.region.AwsRegionProvider
2932import software.aws.toolkits.jetbrains.utils.notifyInfo
3033import software.aws.toolkits.resources.AmazonQBundle.message
@@ -40,9 +43,23 @@ class QRegionProfileManager : PersistentStateComponent<QProfileState>, Disposabl
4043
4144 // Map to store connectionId to its active profile
4245 private val connectionIdToActiveProfile = Collections .synchronizedMap<String , QRegionProfile >(mutableMapOf ())
43- private val connectionIdToProfileList = mutableMapOf<String , Int >()
46+ private val connectionIdToProfileCount = mutableMapOf<String , Int >()
47+
48+ init {
49+ ApplicationManager .getApplication().messageBus.connect(this )
50+ .subscribe(
51+ BearerTokenProviderListener .TOPIC ,
52+ object : BearerTokenProviderListener {
53+ override fun invalidate (providerId : String ) {
54+ connectionIdToActiveProfile.remove(providerId)
55+ connectionIdToProfileCount.remove(providerId)
56+ }
57+ }
58+ )
59+ }
4460
4561 // should be call on project startup to validate if profile is still active
62+ @RequiresBackgroundThread
4663 fun validateProfile (project : Project ) {
4764 val conn = getIdcConnectionOrNull(project)
4865 val selected = activeProfile(project) ? : return
@@ -78,7 +95,7 @@ class QRegionProfileManager : PersistentStateComponent<QProfileState>, Disposabl
7895 switchProfile(project, mappedProfiles.first(), intent = QProfileSwitchIntent .Update )
7996 }
8097 mappedProfiles.takeIf { it.isNotEmpty() }?.also {
81- connectionIdToProfileList [connection.id] = it.size
98+ connectionIdToProfileCount [connection.id] = it.size
8299 } ? : error(" You don't have access to the resource" )
83100 } catch (e: Exception ) {
84101 LOG .warn(e) { " Failed to list region profiles: ${e.message} " }
@@ -110,7 +127,7 @@ class QRegionProfileManager : PersistentStateComponent<QProfileState>, Disposabl
110127 Telemetry .amazonq.didSelectProfile.use { span ->
111128 span.source(intent.value)
112129 .amazonQProfileRegion(newProfile.region)
113- .profileCount(connectionIdToProfileList [conn.id])
130+ .profileCount(connectionIdToProfileCount [conn.id])
114131 .ssoRegion(conn.region)
115132 .credentialStartUrl(conn.startUrl)
116133 .result(MetricResult .Succeeded )
@@ -139,13 +156,13 @@ class QRegionProfileManager : PersistentStateComponent<QProfileState>, Disposabl
139156
140157 // for each idc connection, user should have a profile, otherwise should show the profile selection error page
141158 fun isPendingProfileSelection (project : Project ): Boolean = getIdcConnectionOrNull(project)?.let { conn ->
142- val profileCounts = connectionIdToProfileList [conn.id] ? : 0
159+ val profileCounts = connectionIdToProfileCount [conn.id] ? : 0
143160 val activeProfile = connectionIdToActiveProfile[conn.id]
144161 profileCounts == 0 || (profileCounts > 1 && activeProfile?.arn.isNullOrEmpty())
145162 } ? : false
146163
147164 fun shouldDisplayProfileInfo (project : Project ): Boolean = getIdcConnectionOrNull(project)?.let { conn ->
148- (connectionIdToProfileList [conn.id] ? : 0 ) > 1
165+ (connectionIdToProfileCount [conn.id] ? : 0 ) > 1
149166 } ? : false
150167
151168 fun getQClientSettings (project : Project ): TokenConnectionSettings {
@@ -191,16 +208,16 @@ class QRegionProfileManager : PersistentStateComponent<QProfileState>, Disposabl
191208 override fun getState (): QProfileState {
192209 val state = QProfileState ()
193210 state.connectionIdToActiveProfile.putAll(this .connectionIdToActiveProfile)
194- state.connectionIdToProfileList.putAll(this .connectionIdToProfileList )
211+ state.connectionIdToProfileList.putAll(this .connectionIdToProfileCount )
195212 return state
196213 }
197214
198215 override fun loadState (state : QProfileState ) {
199216 connectionIdToActiveProfile.clear()
200217 connectionIdToActiveProfile.putAll(state.connectionIdToActiveProfile)
201218
202- connectionIdToProfileList .clear()
203- connectionIdToProfileList .putAll(state.connectionIdToProfileList)
219+ connectionIdToProfileCount .clear()
220+ connectionIdToProfileCount .putAll(state.connectionIdToProfileList)
204221 }
205222}
206223
0 commit comments