Skip to content

Commit 5ae453f

Browse files
committed
fix(amazonq): clear out cached supplier if connection is invalidated / clear tokens on InvalidGrantException
when a user reauthenticates while reusing a connection, the cached supplier can incorrectly retain the last token instead of immediately returning the new one additionally, the cached supplier can infinitely attempt token refresh since we are retaining invalid tokens instead of destroying them
1 parent 5335ed8 commit 5ae453f

File tree

20 files changed

+115
-75
lines changed

20 files changed

+115
-75
lines changed

plugins/amazonq/chat/jetbrains-community/src/software/aws/toolkits/jetbrains/services/amazonq/toolwindow/AmazonQToolWindowFactory.kt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,7 @@ class AmazonQToolWindowFactory : ToolWindowFactory, DumbAware {
8383
project.messageBus.connect(toolWindow.disposable).subscribe(
8484
BearerTokenProviderListener.TOPIC,
8585
object : BearerTokenProviderListener {
86-
override fun onChange(providerId: String, newScopes: List<String>?) {
86+
override fun onProviderChange(providerId: String, newScopes: List<String>?) {
8787
if (ToolkitConnectionManager.getInstance(project).connectionStateForFeature(QConnection.getInstance()) == BearerTokenAuthState.AUTHORIZED) {
8888
preparePanelContent(project, qPanel)
8989
}

plugins/amazonq/chat/jetbrains-community/src/software/aws/toolkits/jetbrains/services/amazonqCodeScan/CodeScanChatApp.kt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -125,7 +125,7 @@ class CodeScanChatApp(private val scope: CoroutineScope) : AmazonQApp {
125125
ApplicationManager.getApplication().messageBus.connect(this).subscribe(
126126
BearerTokenProviderListener.TOPIC,
127127
object : BearerTokenProviderListener {
128-
override fun onChange(providerId: String, newScopes: List<String>?) {
128+
override fun onProviderChange(providerId: String, newScopes: List<String>?) {
129129
val qProvider = getQTokenProvider(context.project)
130130
val isQ = qProvider?.id == providerId
131131
val isAuthorized = qProvider?.state() == BearerTokenAuthState.AUTHORIZED

plugins/amazonq/codetransform/jetbrains-community/src/software/aws/toolkits/jetbrains/services/codemodernizer/CodeTransformChatApp.kt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -149,7 +149,7 @@ class CodeTransformChatApp : AmazonQApp {
149149
ApplicationManager.getApplication().messageBus.connect(this).subscribe(
150150
BearerTokenProviderListener.TOPIC,
151151
object : BearerTokenProviderListener {
152-
override fun onChange(providerId: String, newScopes: List<String>?) {
152+
override fun onProviderChange(providerId: String, newScopes: List<String>?) {
153153
val qProvider = getQTokenProvider(context.project)
154154
val isQ = qProvider?.id == providerId
155155
val isAuthorized = qProvider?.state() == BearerTokenAuthState.AUTHORIZED

plugins/amazonq/codewhisperer/jetbrains-community/src/software/aws/toolkits/jetbrains/services/codewhisperer/status/CodeWhispererStatusBarWidget.kt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ class CodeWhispererStatusBarWidget(project: Project) :
5656
ApplicationManager.getApplication().messageBus.connect(this).subscribe(
5757
BearerTokenProviderListener.TOPIC,
5858
object : BearerTokenProviderListener {
59-
override fun onChange(providerId: String, newScopes: List<String>?) {
59+
override fun onProviderChange(providerId: String, newScopes: List<String>?) {
6060
statusBar.updateWidget(ID)
6161
}
6262
}

plugins/amazonq/shared/jetbrains-community/src/software/aws/toolkits/jetbrains/services/amazonq/lsp/auth/AuthCredentialsService.kt

Lines changed: 0 additions & 13 deletions
This file was deleted.

plugins/amazonq/shared/jetbrains-community/src/software/aws/toolkits/jetbrains/services/amazonq/lsp/auth/DefaultAuthCredentialsService.kt

Lines changed: 13 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -43,8 +43,7 @@ class DefaultAuthCredentialsService(
4343
private val project: Project,
4444
private val encryptionManager: JwtEncryptionManager,
4545
private val cs: CoroutineScope,
46-
) : AuthCredentialsService,
47-
BearerTokenProviderListener,
46+
) : BearerTokenProviderListener,
4847
ToolkitConnectionManagerListener,
4948
QRegionProfileSelectedListener,
5049
Disposable {
@@ -71,6 +70,7 @@ class DefaultAuthCredentialsService(
7170
startPeriodicTokenSync()
7271
}
7372

73+
// TODO: we really only need a single application-wide instance of this
7474
private fun startPeriodicTokenSync() {
7575
tokenSyncTask = scheduler.scheduleWithFixedDelay(
7676
{
@@ -89,14 +89,10 @@ class DefaultAuthCredentialsService(
8989
if (tokenProvider.state() == BearerTokenAuthState.NEEDS_REFRESH) {
9090
try {
9191
tokenProvider.resolveToken()
92-
// Now that the token is refreshed, update it in Flare
93-
updateTokenFromActiveConnection()
9492
} catch (e: Exception) {
9593
LOG.warn(e) { "Failed to refresh bearer token" }
9694
}
9795
}
98-
} else {
99-
updateTokenFromActiveConnection()
10096
}
10197
}
10298
} catch (e: Exception) {
@@ -109,7 +105,7 @@ class DefaultAuthCredentialsService(
109105
)
110106
}
111107

112-
override fun updateTokenCredentials(connection: ToolkitConnection, encrypted: Boolean): CompletableFuture<ResponseMessage> {
108+
fun updateTokenCredentials(connection: ToolkitConnection, encrypted: Boolean): CompletableFuture<ResponseMessage> {
113109
val payload = try {
114110
createUpdateCredentialsPayload(connection, encrypted)
115111
} catch (e: Exception) {
@@ -129,18 +125,26 @@ class DefaultAuthCredentialsService(
129125
}.asCompletableFuture()
130126
}
131127

132-
override fun deleteTokenCredentials() {
128+
fun deleteTokenCredentials() {
133129
cs.launch {
134130
AmazonQLspService.executeAsyncIfRunning(project) { server ->
135131
server.deleteTokenCredentials()
136132
}
137133
}
138134
}
139135

140-
override fun onChange(providerId: String, newScopes: List<String>?) {
136+
override fun onProviderChange(providerId: String, newScopes: List<String>?) {
141137
updateTokenFromActiveConnection()
142138
}
143139

140+
override fun onTokenModified(providerId: String) {
141+
updateTokenFromActiveConnection()
142+
}
143+
144+
override fun invalidate(providerId: String) {
145+
deleteTokenCredentials()
146+
}
147+
144148
override fun activeConnectionChanged(newConnection: ToolkitConnection?) {
145149
val qConnection = ToolkitConnectionManager.getInstance(project)
146150
.activeConnectionForFeature(QConnection.getInstance())
@@ -161,9 +165,6 @@ class DefaultAuthCredentialsService(
161165
private fun updateTokenFromConnection(connection: ToolkitConnection): CompletableFuture<ResponseMessage> =
162166
updateTokenCredentials(connection, true)
163167

164-
override fun invalidate(providerId: String) {
165-
deleteTokenCredentials()
166-
}
167168

168169
private fun createUpdateCredentialsPayload(connection: ToolkitConnection, encrypted: Boolean): UpdateCredentialsPayload {
169170
val token = (connection.getConnectionSettings() as? TokenConnectionSettings)

plugins/amazonq/shared/jetbrains-community/src/software/aws/toolkits/jetbrains/services/amazonq/profile/QRegionProfileManager.kt

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -114,7 +114,12 @@ class QRegionProfileManager : PersistentStateComponent<QProfileState>, Disposabl
114114
connectionIdToProfileCount[connection.id] = it.size
115115
} ?: error("You don't have access to the resource")
116116
} catch (e: Exception) {
117-
LOG.warn(e) { "Failed to list region profiles: ${e.message}" }
117+
if (e is AccessDeniedException) {
118+
LOG.warn { "Failed to list region profiles: ${e.message}" }
119+
} else {
120+
LOG.warn(e) { "Failed to list region profiles" }
121+
}
122+
118123
throw e
119124
}
120125
}

plugins/amazonq/shared/jetbrains-community/tst/software/aws/toolkits/jetbrains/services/amazonq/lsp/auth/DefaultAuthCredentialsServiceTest.kt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -180,7 +180,7 @@ class DefaultAuthCredentialsServiceTest {
180180
sut = DefaultAuthCredentialsService(project, mockEncryptionManager, this)
181181
setupMockConnectionManager("updated-token")
182182

183-
sut.onChange("providerId", listOf("new-scope"))
183+
sut.onProviderChange("providerId", listOf("new-scope"))
184184

185185
advanceUntilIdle()
186186
verify(exactly = 1) { mockLanguageServer.updateTokenCredentials(any()) }

plugins/core/jetbrains-community/src/software/aws/toolkits/jetbrains/core/AwsClientManager.kt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ open class AwsClientManager : ToolkitClientManager(), Disposable {
5151
busConnection.subscribe(
5252
BearerTokenProviderListener.TOPIC,
5353
object : BearerTokenProviderListener {
54-
override fun onChange(providerId: String, newScopes: List<String>?) {
54+
override fun onProviderChange(providerId: String, newScopes: List<String>?) {
5555
invalidateSdks(providerId)
5656
}
5757

plugins/core/jetbrains-community/src/software/aws/toolkits/jetbrains/core/AwsResourceCache.kt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -165,7 +165,7 @@ class DefaultAwsResourceCache(
165165
subscribe(
166166
BearerTokenProviderListener.TOPIC,
167167
object : BearerTokenProviderListener {
168-
override fun onChange(providerId: String, newScopes: List<String>?) {
168+
override fun onProviderChange(providerId: String, newScopes: List<String>?) {
169169
clearByCredential(providerId)
170170
}
171171
}

0 commit comments

Comments
 (0)