Skip to content

Commit 3b9b6c1

Browse files
authored
fix(amazonq): clear out cached supplier if connection is invalidated … (#5881)
when a user reauthenticates while reusing a connection, the cached supplier can incorrectly retain the last token instead of immediately returning the new one additionally, the cached supplier can infinitely attempt token refresh since we are retaining invalid tokens instead of destroying them
1 parent 5f5ed91 commit 3b9b6c1

File tree

27 files changed

+234
-191
lines changed

27 files changed

+234
-191
lines changed

plugins/amazonq/chat/jetbrains-community/src/software/aws/toolkits/jetbrains/services/amazonq/toolwindow/AmazonQToolWindowFactory.kt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,7 @@ class AmazonQToolWindowFactory : ToolWindowFactory, DumbAware {
8383
project.messageBus.connect(toolWindow.disposable).subscribe(
8484
BearerTokenProviderListener.TOPIC,
8585
object : BearerTokenProviderListener {
86-
override fun onChange(providerId: String, newScopes: List<String>?) {
86+
override fun onProviderChange(providerId: String, newScopes: List<String>?) {
8787
if (ToolkitConnectionManager.getInstance(project).connectionStateForFeature(QConnection.getInstance()) == BearerTokenAuthState.AUTHORIZED) {
8888
preparePanelContent(project, qPanel)
8989
}

plugins/amazonq/chat/jetbrains-community/src/software/aws/toolkits/jetbrains/services/amazonqCodeScan/CodeScanChatApp.kt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -125,7 +125,7 @@ class CodeScanChatApp(private val scope: CoroutineScope) : AmazonQApp {
125125
ApplicationManager.getApplication().messageBus.connect(this).subscribe(
126126
BearerTokenProviderListener.TOPIC,
127127
object : BearerTokenProviderListener {
128-
override fun onChange(providerId: String, newScopes: List<String>?) {
128+
override fun onProviderChange(providerId: String, newScopes: List<String>?) {
129129
val qProvider = getQTokenProvider(context.project)
130130
val isQ = qProvider?.id == providerId
131131
val isAuthorized = qProvider?.state() == BearerTokenAuthState.AUTHORIZED

plugins/amazonq/chat/jetbrains-community/tst/software/aws/toolkits/jetbrains/services/amazonq/AmazonQTestBase.kt

Lines changed: 1 addition & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -17,17 +17,14 @@ import org.mockito.kotlin.spy
1717
import org.mockito.kotlin.whenever
1818
import software.aws.toolkits.core.TokenConnectionSettings
1919
import software.aws.toolkits.core.credentials.ToolkitBearerTokenProvider
20-
import software.aws.toolkits.core.utils.test.aString
2120
import software.aws.toolkits.jetbrains.core.credentials.AwsBearerTokenConnection
2221
import software.aws.toolkits.jetbrains.core.credentials.ToolkitConnectionManager
23-
import software.aws.toolkits.jetbrains.core.credentials.sso.DeviceAuthorizationGrantToken
2422
import software.aws.toolkits.jetbrains.core.credentials.sso.bearer.BearerTokenProvider
2523
import software.aws.toolkits.jetbrains.services.amazonq.clients.AmazonQStreamingClient
2624
import software.aws.toolkits.jetbrains.utils.rules.CodeInsightTestFixtureRule
2725
import software.aws.toolkits.jetbrains.utils.rules.HeavyJavaCodeInsightTestFixtureRule
2826
import software.aws.toolkits.jetbrains.utils.rules.JavaCodeInsightTestFixtureRule
2927
import software.aws.toolkits.jetbrains.utils.rules.addModule
30-
import java.time.Instant
3128

3229
open class AmazonQTestBase(
3330
@Rule @JvmField
@@ -47,11 +44,7 @@ open class AmazonQTestBase(
4744
project = projectRule.project
4845
toolkitConnectionManager = spy(ToolkitConnectionManager.getInstance(project))
4946

50-
val accessToken = DeviceAuthorizationGrantToken(aString(), aString(), aString(), aString(), Instant.MAX, Instant.now())
51-
52-
val provider = mock<BearerTokenProvider> {
53-
doReturn(accessToken).whenever(it).refresh()
54-
}
47+
val provider = mock<BearerTokenProvider>()
5548

5649
val mockBearerProvider = mock<ToolkitBearerTokenProvider> {
5750
doReturn(provider).whenever(it).delegate

plugins/amazonq/chat/jetbrains-community/tst/software/aws/toolkits/jetbrains/services/amazonqFeatureDev/FeatureDevTestBase.kt

Lines changed: 1 addition & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -28,10 +28,8 @@ import software.amazon.awssdk.services.codewhispererruntime.model.SendTelemetryE
2828
import software.amazon.awssdk.services.codewhispererruntime.model.StartTaskAssistCodeGenerationResponse
2929
import software.aws.toolkits.core.TokenConnectionSettings
3030
import software.aws.toolkits.core.credentials.ToolkitBearerTokenProvider
31-
import software.aws.toolkits.core.utils.test.aString
3231
import software.aws.toolkits.jetbrains.core.credentials.AwsBearerTokenConnection
3332
import software.aws.toolkits.jetbrains.core.credentials.ToolkitConnectionManager
34-
import software.aws.toolkits.jetbrains.core.credentials.sso.DeviceAuthorizationGrantToken
3533
import software.aws.toolkits.jetbrains.core.credentials.sso.bearer.BearerTokenProvider
3634
import software.aws.toolkits.jetbrains.services.amazonqFeatureDev.clients.FeatureDevClient
3735
import software.aws.toolkits.jetbrains.services.amazonqFeatureDev.session.CodeGenerationStreamResult
@@ -41,7 +39,6 @@ import software.aws.toolkits.jetbrains.utils.rules.HeavyJavaCodeInsightTestFixtu
4139
import software.aws.toolkits.jetbrains.utils.rules.JavaCodeInsightTestFixtureRule
4240
import software.aws.toolkits.jetbrains.utils.rules.addModule
4341
import java.io.File
44-
import java.time.Instant
4542

4643
open class FeatureDevTestBase(
4744
@Rule @JvmField
@@ -164,11 +161,7 @@ open class FeatureDevTestBase(
164161
open fun setup() {
165162
project = projectRule.project
166163
toolkitConnectionManager = spy(ToolkitConnectionManager.getInstance(project))
167-
val accessToken = DeviceAuthorizationGrantToken(aString(), aString(), aString(), aString(), Instant.MAX, Instant.now())
168-
val provider =
169-
mock<BearerTokenProvider> {
170-
doReturn(accessToken).whenever(it).refresh()
171-
}
164+
val provider = mock<BearerTokenProvider>()
172165
val mockBearerProvider =
173166
mock<ToolkitBearerTokenProvider> {
174167
doReturn(provider).whenever(it).delegate

plugins/amazonq/codetransform/jetbrains-community/src/software/aws/toolkits/jetbrains/services/codemodernizer/CodeTransformChatApp.kt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -149,7 +149,7 @@ class CodeTransformChatApp : AmazonQApp {
149149
ApplicationManager.getApplication().messageBus.connect(this).subscribe(
150150
BearerTokenProviderListener.TOPIC,
151151
object : BearerTokenProviderListener {
152-
override fun onChange(providerId: String, newScopes: List<String>?) {
152+
override fun onProviderChange(providerId: String, newScopes: List<String>?) {
153153
val qProvider = getQTokenProvider(context.project)
154154
val isQ = qProvider?.id == providerId
155155
val isAuthorized = qProvider?.state() == BearerTokenAuthState.AUTHORIZED

plugins/amazonq/codetransform/jetbrains-community/src/software/aws/toolkits/jetbrains/services/codemodernizer/utils/CodeTransformApiUtils.kt

Lines changed: 3 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -87,11 +87,6 @@ suspend fun JobId.pollTransformationStatusAndPlan(
8787
var transformationPlan: TransformationPlan? = null
8888
var didSleepOnce = false
8989
var hasSeenTransforming = false
90-
val maxRefreshes = 10
91-
var numRefreshes = 0
92-
93-
// refresh token at start of polling since local build just prior can take a long time
94-
refreshToken(project)
9590

9691
try {
9792
waitUntil(
@@ -138,13 +133,10 @@ suspend fun JobId.pollTransformationStatusAndPlan(
138133
onStateChange(state, newStatus, transformationPlan)
139134
}
140135
state = newStatus
141-
numRefreshes = 0
142-
return@waitUntil state
143-
} catch (e: AccessDeniedException) {
144-
if (numRefreshes++ > maxRefreshes) throw e
145-
refreshToken(project)
146136
return@waitUntil state
147-
} catch (e: InvalidGrantException) {
137+
} catch (e: Exception) {
138+
if (e !is AccessDeniedException && e !is InvalidGrantException) throw e
139+
148140
CodeTransformMessageListener.instance.onReauthStarted()
149141
notifyStickyWarn(
150142
message("codemodernizer.notification.warn.expired_credentials.title"),

plugins/amazonq/codetransform/jetbrains-community/src/software/aws/toolkits/jetbrains/services/codemodernizer/utils/CodeTransformUtils.kt

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@ import com.intellij.openapi.vfs.VfsUtilCore
88
import com.intellij.openapi.vfs.VirtualFileManager
99
import software.amazon.awssdk.services.codewhispererruntime.model.TransformationLanguage
1010
import software.amazon.awssdk.services.codewhispererruntime.model.TransformationStatus
11-
import software.aws.toolkits.core.TokenConnectionSettings
1211
import software.aws.toolkits.jetbrains.core.credentials.AwsBearerTokenConnection
1312
import software.aws.toolkits.jetbrains.core.credentials.ToolkitConnectionManager
1413
import software.aws.toolkits.jetbrains.core.credentials.pinning.QConnection
@@ -43,12 +42,6 @@ val STATES_AFTER_STARTED = setOf(
4342
*STATES_AFTER_INITIAL_BUILD.toTypedArray(),
4443
)
4544

46-
fun refreshToken(project: Project) {
47-
val connection = ToolkitConnectionManager.getInstance(project).activeConnectionForFeature(QConnection.getInstance())
48-
val provider = (connection?.getConnectionSettings() as TokenConnectionSettings).tokenProvider.delegate as BearerTokenProvider
49-
provider.refresh()
50-
}
51-
5245
fun getAuthType(project: Project): CredentialSourceId? {
5346
val connection = checkBearerConnectionValidity(project, BearerTokenFeatureSet.Q)
5447
var authType: CredentialSourceId? = null

plugins/amazonq/codetransform/jetbrains-community/tst/software/aws/toolkits/jetbrains/services/codemodernizer/CodeWhispererCodeModernizerTestBase.kt

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,6 @@ import software.aws.toolkits.core.credentials.ToolkitBearerTokenProvider
4545
import software.aws.toolkits.core.utils.test.aString
4646
import software.aws.toolkits.jetbrains.core.credentials.AwsBearerTokenConnection
4747
import software.aws.toolkits.jetbrains.core.credentials.ToolkitConnectionManager
48-
import software.aws.toolkits.jetbrains.core.credentials.sso.DeviceAuthorizationGrantToken
4948
import software.aws.toolkits.jetbrains.core.credentials.sso.PKCEAuthorizationGrantToken
5049
import software.aws.toolkits.jetbrains.core.credentials.sso.bearer.BearerTokenAuthState
5150
import software.aws.toolkits.jetbrains.core.credentials.sso.bearer.BearerTokenProvider
@@ -250,11 +249,7 @@ open class CodeWhispererCodeModernizerTestBase(
250249
project = projectRule.project
251250
toolkitConnectionManager = spy(ToolkitConnectionManager.getInstance(project))
252251

253-
val accessToken = DeviceAuthorizationGrantToken(aString(), aString(), aString(), aString(), Instant.MAX, Instant.now())
254-
val provider =
255-
mock<BearerTokenProvider> {
256-
doReturn(accessToken).whenever(it).refresh()
257-
}
252+
val provider = mock<BearerTokenProvider> { }
258253
val mockBearerProvider =
259254
mock<ToolkitBearerTokenProvider> {
260255
doReturn(provider).whenever(it).delegate
@@ -340,7 +335,6 @@ open class CodeWhispererCodeModernizerTestBase(
340335
val accessToken = PKCEAuthorizationGrantToken(aString(), aString(), aString(), aString(), Instant.MAX, Instant.now())
341336

342337
val provider = mock<BearerTokenProvider> {
343-
doReturn(accessToken).whenever(it).refresh()
344338
doReturn(accessToken).whenever(it).currentToken()
345339
doReturn(authState).whenever(it).state()
346340
}

plugins/amazonq/codetransform/jetbrains-community/tst/software/aws/toolkits/jetbrains/services/codemodernizer/CodeWhispererCodeModernizerUtilsTest.kt

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,6 @@ import software.aws.toolkits.jetbrains.services.codemodernizer.utils.getTableMap
3838
import software.aws.toolkits.jetbrains.services.codemodernizer.utils.isPlanComplete
3939
import software.aws.toolkits.jetbrains.services.codemodernizer.utils.parseBuildFile
4040
import software.aws.toolkits.jetbrains.services.codemodernizer.utils.pollTransformationStatusAndPlan
41-
import software.aws.toolkits.jetbrains.services.codemodernizer.utils.refreshToken
4241
import software.aws.toolkits.jetbrains.services.codemodernizer.utils.validateCustomVersionsFile
4342
import software.aws.toolkits.jetbrains.services.codemodernizer.utils.validateSctMetadata
4443
import software.aws.toolkits.jetbrains.utils.notifyStickyWarn
@@ -90,18 +89,18 @@ class CodeWhispererCodeModernizerUtilsTest : CodeWhispererCodeModernizerTestBase
9089
}
9190

9291
@Test
93-
fun `refresh on access denied`() {
92+
fun `show re-auth notification on access denied`() {
9493
val mockAccessDeniedException = Mockito.mock(AccessDeniedException::class.java)
9594

96-
mockkStatic(::refreshToken)
97-
every { refreshToken(any()) } just runs
95+
mockkStatic(::notifyStickyWarn)
96+
every { notifyStickyWarn(any(), any(), any(), any(), any()) } just runs
9897

9998
Mockito.doThrow(
10099
mockAccessDeniedException
101100
).doReturn(
102101
exampleGetCodeMigrationResponse,
103102
exampleGetCodeMigrationResponse.replace(TransformationStatus.STARTED),
104-
exampleGetCodeMigrationResponse.replace(TransformationStatus.COMPLETED), // Should stop before this point
103+
exampleGetCodeMigrationResponse.replace(TransformationStatus.COMPLETED),
105104
).whenever(clientAdaptorSpy).getCodeModernizationJob(any())
106105

107106
Mockito.doReturn(exampleGetCodeMigrationPlanResponse)
@@ -128,7 +127,7 @@ class CodeWhispererCodeModernizerUtilsTest : CodeWhispererCodeModernizerTestBase
128127
TransformationStatus.STARTED,
129128
)
130129
assertThat(expected).isEqualTo(mutableList)
131-
io.mockk.verify { refreshToken(any()) }
130+
verify { notifyStickyWarn(message("codemodernizer.notification.warn.expired_credentials.title"), any(), any(), any(), any()) }
132131
}
133132

134133
@Test

plugins/amazonq/codewhisperer/jetbrains-community/src/software/aws/toolkits/jetbrains/services/codewhisperer/status/CodeWhispererStatusBarWidget.kt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ class CodeWhispererStatusBarWidget(project: Project) :
5656
ApplicationManager.getApplication().messageBus.connect(this).subscribe(
5757
BearerTokenProviderListener.TOPIC,
5858
object : BearerTokenProviderListener {
59-
override fun onChange(providerId: String, newScopes: List<String>?) {
59+
override fun onProviderChange(providerId: String, newScopes: List<String>?) {
6060
statusBar.updateWidget(ID)
6161
}
6262
}

0 commit comments

Comments
 (0)