Skip to content

Commit 3cad360

Browse files
Merge main into feature/remote-chat-lsp
2 parents 85ee7e7 + 3b9b6c1 commit 3cad360

File tree

28 files changed

+301
-215
lines changed

28 files changed

+301
-215
lines changed

plugins/amazonq/chat/jetbrains-community/src/software/aws/toolkits/jetbrains/services/amazonq/toolwindow/AmazonQToolWindowFactory.kt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,7 @@ class AmazonQToolWindowFactory : ToolWindowFactory, DumbAware {
8383
project.messageBus.connect(toolWindow.disposable).subscribe(
8484
BearerTokenProviderListener.TOPIC,
8585
object : BearerTokenProviderListener {
86-
override fun onChange(providerId: String, newScopes: List<String>?) {
86+
override fun onProviderChange(providerId: String, newScopes: List<String>?) {
8787
if (ToolkitConnectionManager.getInstance(project).connectionStateForFeature(QConnection.getInstance()) == BearerTokenAuthState.AUTHORIZED) {
8888
preparePanelContent(project, qPanel)
8989
}

plugins/amazonq/chat/jetbrains-community/src/software/aws/toolkits/jetbrains/services/amazonqCodeScan/CodeScanChatApp.kt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -125,7 +125,7 @@ class CodeScanChatApp(private val scope: CoroutineScope) : AmazonQApp {
125125
ApplicationManager.getApplication().messageBus.connect(this).subscribe(
126126
BearerTokenProviderListener.TOPIC,
127127
object : BearerTokenProviderListener {
128-
override fun onChange(providerId: String, newScopes: List<String>?) {
128+
override fun onProviderChange(providerId: String, newScopes: List<String>?) {
129129
val qProvider = getQTokenProvider(context.project)
130130
val isQ = qProvider?.id == providerId
131131
val isAuthorized = qProvider?.state() == BearerTokenAuthState.AUTHORIZED

plugins/amazonq/chat/jetbrains-community/tst/software/aws/toolkits/jetbrains/services/amazonq/AmazonQTestBase.kt

Lines changed: 1 addition & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -17,17 +17,14 @@ import org.mockito.kotlin.spy
1717
import org.mockito.kotlin.whenever
1818
import software.aws.toolkits.core.TokenConnectionSettings
1919
import software.aws.toolkits.core.credentials.ToolkitBearerTokenProvider
20-
import software.aws.toolkits.core.utils.test.aString
2120
import software.aws.toolkits.jetbrains.core.credentials.AwsBearerTokenConnection
2221
import software.aws.toolkits.jetbrains.core.credentials.ToolkitConnectionManager
23-
import software.aws.toolkits.jetbrains.core.credentials.sso.DeviceAuthorizationGrantToken
2422
import software.aws.toolkits.jetbrains.core.credentials.sso.bearer.BearerTokenProvider
2523
import software.aws.toolkits.jetbrains.services.amazonq.clients.AmazonQStreamingClient
2624
import software.aws.toolkits.jetbrains.utils.rules.CodeInsightTestFixtureRule
2725
import software.aws.toolkits.jetbrains.utils.rules.HeavyJavaCodeInsightTestFixtureRule
2826
import software.aws.toolkits.jetbrains.utils.rules.JavaCodeInsightTestFixtureRule
2927
import software.aws.toolkits.jetbrains.utils.rules.addModule
30-
import java.time.Instant
3128

3229
open class AmazonQTestBase(
3330
@Rule @JvmField
@@ -47,11 +44,7 @@ open class AmazonQTestBase(
4744
project = projectRule.project
4845
toolkitConnectionManager = spy(ToolkitConnectionManager.getInstance(project))
4946

50-
val accessToken = DeviceAuthorizationGrantToken(aString(), aString(), aString(), aString(), Instant.MAX, Instant.now())
51-
52-
val provider = mock<BearerTokenProvider> {
53-
doReturn(accessToken).whenever(it).refresh()
54-
}
47+
val provider = mock<BearerTokenProvider>()
5548

5649
val mockBearerProvider = mock<ToolkitBearerTokenProvider> {
5750
doReturn(provider).whenever(it).delegate

plugins/amazonq/chat/jetbrains-community/tst/software/aws/toolkits/jetbrains/services/amazonqFeatureDev/FeatureDevTestBase.kt

Lines changed: 1 addition & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -28,10 +28,8 @@ import software.amazon.awssdk.services.codewhispererruntime.model.SendTelemetryE
2828
import software.amazon.awssdk.services.codewhispererruntime.model.StartTaskAssistCodeGenerationResponse
2929
import software.aws.toolkits.core.TokenConnectionSettings
3030
import software.aws.toolkits.core.credentials.ToolkitBearerTokenProvider
31-
import software.aws.toolkits.core.utils.test.aString
3231
import software.aws.toolkits.jetbrains.core.credentials.AwsBearerTokenConnection
3332
import software.aws.toolkits.jetbrains.core.credentials.ToolkitConnectionManager
34-
import software.aws.toolkits.jetbrains.core.credentials.sso.DeviceAuthorizationGrantToken
3533
import software.aws.toolkits.jetbrains.core.credentials.sso.bearer.BearerTokenProvider
3634
import software.aws.toolkits.jetbrains.services.amazonqFeatureDev.clients.FeatureDevClient
3735
import software.aws.toolkits.jetbrains.services.amazonqFeatureDev.session.CodeGenerationStreamResult
@@ -41,7 +39,6 @@ import software.aws.toolkits.jetbrains.utils.rules.HeavyJavaCodeInsightTestFixtu
4139
import software.aws.toolkits.jetbrains.utils.rules.JavaCodeInsightTestFixtureRule
4240
import software.aws.toolkits.jetbrains.utils.rules.addModule
4341
import java.io.File
44-
import java.time.Instant
4542

4643
open class FeatureDevTestBase(
4744
@Rule @JvmField
@@ -164,11 +161,7 @@ open class FeatureDevTestBase(
164161
open fun setup() {
165162
project = projectRule.project
166163
toolkitConnectionManager = spy(ToolkitConnectionManager.getInstance(project))
167-
val accessToken = DeviceAuthorizationGrantToken(aString(), aString(), aString(), aString(), Instant.MAX, Instant.now())
168-
val provider =
169-
mock<BearerTokenProvider> {
170-
doReturn(accessToken).whenever(it).refresh()
171-
}
164+
val provider = mock<BearerTokenProvider>()
172165
val mockBearerProvider =
173166
mock<ToolkitBearerTokenProvider> {
174167
doReturn(provider).whenever(it).delegate

plugins/amazonq/codetransform/jetbrains-community/src/software/aws/toolkits/jetbrains/services/codemodernizer/CodeTransformChatApp.kt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -149,7 +149,7 @@ class CodeTransformChatApp : AmazonQApp {
149149
ApplicationManager.getApplication().messageBus.connect(this).subscribe(
150150
BearerTokenProviderListener.TOPIC,
151151
object : BearerTokenProviderListener {
152-
override fun onChange(providerId: String, newScopes: List<String>?) {
152+
override fun onProviderChange(providerId: String, newScopes: List<String>?) {
153153
val qProvider = getQTokenProvider(context.project)
154154
val isQ = qProvider?.id == providerId
155155
val isAuthorized = qProvider?.state() == BearerTokenAuthState.AUTHORIZED

plugins/amazonq/codetransform/jetbrains-community/src/software/aws/toolkits/jetbrains/services/codemodernizer/utils/CodeTransformApiUtils.kt

Lines changed: 3 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -87,11 +87,6 @@ suspend fun JobId.pollTransformationStatusAndPlan(
8787
var transformationPlan: TransformationPlan? = null
8888
var didSleepOnce = false
8989
var hasSeenTransforming = false
90-
val maxRefreshes = 10
91-
var numRefreshes = 0
92-
93-
// refresh token at start of polling since local build just prior can take a long time
94-
refreshToken(project)
9590

9691
try {
9792
waitUntil(
@@ -138,13 +133,10 @@ suspend fun JobId.pollTransformationStatusAndPlan(
138133
onStateChange(state, newStatus, transformationPlan)
139134
}
140135
state = newStatus
141-
numRefreshes = 0
142-
return@waitUntil state
143-
} catch (e: AccessDeniedException) {
144-
if (numRefreshes++ > maxRefreshes) throw e
145-
refreshToken(project)
146136
return@waitUntil state
147-
} catch (e: InvalidGrantException) {
137+
} catch (e: Exception) {
138+
if (e !is AccessDeniedException && e !is InvalidGrantException) throw e
139+
148140
CodeTransformMessageListener.instance.onReauthStarted()
149141
notifyStickyWarn(
150142
message("codemodernizer.notification.warn.expired_credentials.title"),

plugins/amazonq/codetransform/jetbrains-community/src/software/aws/toolkits/jetbrains/services/codemodernizer/utils/CodeTransformUtils.kt

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@ import com.intellij.openapi.vfs.VfsUtilCore
88
import com.intellij.openapi.vfs.VirtualFileManager
99
import software.amazon.awssdk.services.codewhispererruntime.model.TransformationLanguage
1010
import software.amazon.awssdk.services.codewhispererruntime.model.TransformationStatus
11-
import software.aws.toolkits.core.TokenConnectionSettings
1211
import software.aws.toolkits.jetbrains.core.credentials.AwsBearerTokenConnection
1312
import software.aws.toolkits.jetbrains.core.credentials.ToolkitConnectionManager
1413
import software.aws.toolkits.jetbrains.core.credentials.pinning.QConnection
@@ -43,12 +42,6 @@ val STATES_AFTER_STARTED = setOf(
4342
*STATES_AFTER_INITIAL_BUILD.toTypedArray(),
4443
)
4544

46-
fun refreshToken(project: Project) {
47-
val connection = ToolkitConnectionManager.getInstance(project).activeConnectionForFeature(QConnection.getInstance())
48-
val provider = (connection?.getConnectionSettings() as TokenConnectionSettings).tokenProvider.delegate as BearerTokenProvider
49-
provider.refresh()
50-
}
51-
5245
fun getAuthType(project: Project): CredentialSourceId? {
5346
val connection = checkBearerConnectionValidity(project, BearerTokenFeatureSet.Q)
5447
var authType: CredentialSourceId? = null

plugins/amazonq/codetransform/jetbrains-community/tst/software/aws/toolkits/jetbrains/services/codemodernizer/CodeWhispererCodeModernizerTestBase.kt

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,6 @@ import software.aws.toolkits.core.credentials.ToolkitBearerTokenProvider
4545
import software.aws.toolkits.core.utils.test.aString
4646
import software.aws.toolkits.jetbrains.core.credentials.AwsBearerTokenConnection
4747
import software.aws.toolkits.jetbrains.core.credentials.ToolkitConnectionManager
48-
import software.aws.toolkits.jetbrains.core.credentials.sso.DeviceAuthorizationGrantToken
4948
import software.aws.toolkits.jetbrains.core.credentials.sso.PKCEAuthorizationGrantToken
5049
import software.aws.toolkits.jetbrains.core.credentials.sso.bearer.BearerTokenAuthState
5150
import software.aws.toolkits.jetbrains.core.credentials.sso.bearer.BearerTokenProvider
@@ -250,11 +249,7 @@ open class CodeWhispererCodeModernizerTestBase(
250249
project = projectRule.project
251250
toolkitConnectionManager = spy(ToolkitConnectionManager.getInstance(project))
252251

253-
val accessToken = DeviceAuthorizationGrantToken(aString(), aString(), aString(), aString(), Instant.MAX, Instant.now())
254-
val provider =
255-
mock<BearerTokenProvider> {
256-
doReturn(accessToken).whenever(it).refresh()
257-
}
252+
val provider = mock<BearerTokenProvider> { }
258253
val mockBearerProvider =
259254
mock<ToolkitBearerTokenProvider> {
260255
doReturn(provider).whenever(it).delegate
@@ -340,7 +335,6 @@ open class CodeWhispererCodeModernizerTestBase(
340335
val accessToken = PKCEAuthorizationGrantToken(aString(), aString(), aString(), aString(), Instant.MAX, Instant.now())
341336

342337
val provider = mock<BearerTokenProvider> {
343-
doReturn(accessToken).whenever(it).refresh()
344338
doReturn(accessToken).whenever(it).currentToken()
345339
doReturn(authState).whenever(it).state()
346340
}

plugins/amazonq/codetransform/jetbrains-community/tst/software/aws/toolkits/jetbrains/services/codemodernizer/CodeWhispererCodeModernizerUtilsTest.kt

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,6 @@ import software.aws.toolkits.jetbrains.services.codemodernizer.utils.getTableMap
3838
import software.aws.toolkits.jetbrains.services.codemodernizer.utils.isPlanComplete
3939
import software.aws.toolkits.jetbrains.services.codemodernizer.utils.parseBuildFile
4040
import software.aws.toolkits.jetbrains.services.codemodernizer.utils.pollTransformationStatusAndPlan
41-
import software.aws.toolkits.jetbrains.services.codemodernizer.utils.refreshToken
4241
import software.aws.toolkits.jetbrains.services.codemodernizer.utils.validateCustomVersionsFile
4342
import software.aws.toolkits.jetbrains.services.codemodernizer.utils.validateSctMetadata
4443
import software.aws.toolkits.jetbrains.utils.notifyStickyWarn
@@ -90,18 +89,18 @@ class CodeWhispererCodeModernizerUtilsTest : CodeWhispererCodeModernizerTestBase
9089
}
9190

9291
@Test
93-
fun `refresh on access denied`() {
92+
fun `show re-auth notification on access denied`() {
9493
val mockAccessDeniedException = Mockito.mock(AccessDeniedException::class.java)
9594

96-
mockkStatic(::refreshToken)
97-
every { refreshToken(any()) } just runs
95+
mockkStatic(::notifyStickyWarn)
96+
every { notifyStickyWarn(any(), any(), any(), any(), any()) } just runs
9897

9998
Mockito.doThrow(
10099
mockAccessDeniedException
101100
).doReturn(
102101
exampleGetCodeMigrationResponse,
103102
exampleGetCodeMigrationResponse.replace(TransformationStatus.STARTED),
104-
exampleGetCodeMigrationResponse.replace(TransformationStatus.COMPLETED), // Should stop before this point
103+
exampleGetCodeMigrationResponse.replace(TransformationStatus.COMPLETED),
105104
).whenever(clientAdaptorSpy).getCodeModernizationJob(any())
106105

107106
Mockito.doReturn(exampleGetCodeMigrationPlanResponse)
@@ -128,7 +127,7 @@ class CodeWhispererCodeModernizerUtilsTest : CodeWhispererCodeModernizerTestBase
128127
TransformationStatus.STARTED,
129128
)
130129
assertThat(expected).isEqualTo(mutableList)
131-
io.mockk.verify { refreshToken(any()) }
130+
verify { notifyStickyWarn(message("codemodernizer.notification.warn.expired_credentials.title"), any(), any(), any(), any()) }
132131
}
133132

134133
@Test

plugins/amazonq/codewhisperer/jetbrains-community/src/software/aws/toolkits/jetbrains/services/codewhisperer/popup/QInlineCompletionProvider.kt

Lines changed: 67 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -42,10 +42,18 @@ import kotlinx.coroutines.channels.Channel
4242
import kotlinx.coroutines.flow.receiveAsFlow
4343
import kotlinx.coroutines.future.await
4444
import kotlinx.coroutines.launch
45+
import kotlinx.coroutines.withContext
4546
import migration.software.aws.toolkits.jetbrains.services.codewhisperer.explorer.CodeWhispererExplorerActionManager
47+
import org.eclipse.lsp4j.jsonrpc.ResponseErrorException
4648
import org.eclipse.lsp4j.jsonrpc.messages.Either
49+
import software.amazon.awssdk.services.ssooidc.model.InvalidGrantException
4750
import software.aws.toolkits.core.utils.debug
4851
import software.aws.toolkits.core.utils.getLogger
52+
import software.aws.toolkits.jetbrains.core.coroutines.getCoroutineBgContext
53+
import software.aws.toolkits.jetbrains.core.credentials.AwsBearerTokenConnection
54+
import software.aws.toolkits.jetbrains.core.credentials.ToolkitConnectionManager
55+
import software.aws.toolkits.jetbrains.core.credentials.pinning.QConnection
56+
import software.aws.toolkits.jetbrains.core.credentials.sso.bearer.BearerTokenProvider
4957
import software.aws.toolkits.jetbrains.services.amazonq.lsp.AmazonQLspService
5058
import software.aws.toolkits.jetbrains.services.amazonq.profile.QRegionProfileManager
5159
import software.aws.toolkits.jetbrains.services.codewhisperer.importadder.CodeWhispererImportAdder
@@ -59,8 +67,10 @@ import software.aws.toolkits.jetbrains.services.codewhisperer.service.CodeWhispe
5967
import software.aws.toolkits.jetbrains.services.codewhisperer.telemetry.CodeWhispererTelemetryService
6068
import software.aws.toolkits.jetbrains.services.codewhisperer.toolwindow.CodeWhispererCodeReferenceManager
6169
import software.aws.toolkits.jetbrains.services.codewhisperer.util.CodeWhispererConstants
70+
import software.aws.toolkits.jetbrains.services.codewhisperer.util.CodeWhispererUtil
6271
import software.aws.toolkits.jetbrains.services.codewhisperer.util.getDocumentDiagnostics
6372
import software.aws.toolkits.jetbrains.utils.isQConnected
73+
import software.aws.toolkits.jetbrains.utils.isQExpired
6474
import software.aws.toolkits.resources.message
6575
import software.aws.toolkits.telemetry.CodewhispererTriggerType
6676
import java.awt.Dimension
@@ -393,8 +403,22 @@ class QInlineCompletionProvider(private val cs: CoroutineScope) : InlineCompleti
393403

394404
override suspend fun getSuggestion(request: InlineCompletionRequest): InlineCompletionSuggestion {
395405
val editor = request.editor
396-
val document = editor.document
397406
val project = editor.project ?: return InlineCompletionSuggestion.Empty
407+
408+
// try to refresh automatically if possible, otherwise ask user to login again
409+
if (isQExpired(project)) {
410+
// consider changing to only running once a ~minute since this is relatively expensive
411+
// say the connection is un-refreshable if refresh fails for 3 times
412+
val shouldReauth = withContext(getCoroutineBgContext()) {
413+
CodeWhispererUtil.promptReAuth(project)
414+
}
415+
416+
if (shouldReauth) {
417+
return InlineCompletionSuggestion.Empty
418+
}
419+
}
420+
421+
val document = editor.document
398422
val handler = InlineCompletion.getHandlerOrNull(editor) ?: return InlineCompletionSuggestion.Empty
399423
val session = InlineCompletionSession.getOrNull(editor) ?: return InlineCompletionSuggestion.Empty
400424
val triggerSessionId = triggerSessionId++
@@ -446,31 +470,28 @@ class QInlineCompletionProvider(private val cs: CoroutineScope) : InlineCompleti
446470
}
447471

448472
try {
449-
// Launch coroutine for background pagination progress
450-
cs.launch {
451-
var nextToken: Either<String, Int>? = null
452-
do {
453-
nextToken = startPaginationInBackground(
454-
project,
455-
editor,
456-
triggerTypeInfo,
457-
triggerSessionId,
458-
nextToken,
459-
sessionContext,
460-
)
461-
} while (nextToken != null && !nextToken.left.isNullOrEmpty())
473+
var nextToken: Either<String, Int>? = null
474+
do {
475+
nextToken = startPaginationInBackground(
476+
project,
477+
editor,
478+
triggerTypeInfo,
479+
triggerSessionId,
480+
nextToken,
481+
sessionContext,
482+
)
483+
} while (nextToken != null && !nextToken.left.isNullOrEmpty())
462484

463-
// closing all channels since pagination for this session has finished
485+
// closing all channels since pagination for this session has finished
486+
logInline(triggerSessionId) {
487+
"Pagination finished, closing all channels"
488+
}
489+
sessionContext.itemContexts.forEach {
490+
it.channel.close()
491+
}
492+
if (session.context.isDisposed) {
464493
logInline(triggerSessionId) {
465-
"Pagination finished, closing all channels"
466-
}
467-
sessionContext.itemContexts.forEach {
468-
it.channel.close()
469-
}
470-
if (session.context.isDisposed) {
471-
logInline(triggerSessionId) {
472-
"Current display session already disposed by a new trigger before pagination finishes, exiting"
473-
}
494+
"Current display session already disposed by a new trigger before pagination finishes, exiting"
474495
}
475496
}
476497

@@ -575,6 +596,27 @@ class QInlineCompletionProvider(private val cs: CoroutineScope) : InlineCompleti
575596
logInline(triggerSessionId, e) {
576597
"Error during pagination"
577598
}
599+
if (e is ResponseErrorException) {
600+
// convoluted but lines up with "The bearer token included in the request is invalid"
601+
// https://github.com/aws/language-servers/blob/1f3e93024eeb22186a34f0bd560f8d552f517300/server/aws-lsp-codewhisperer/src/language-server/chat/utils.ts#L22-L23
602+
// error data is nullable
603+
if (e.responseError.data?.toString()?.contains("E_AMAZON_Q_CONNECTION_EXPIRED") == true) {
604+
// kill the session if the connection is expired
605+
val connection = ToolkitConnectionManager
606+
.getInstance(project)
607+
.activeConnectionForFeature(QConnection.getInstance()) as? AwsBearerTokenConnection
608+
val tokenProvider = connection?.let { it.getConnectionSettings().tokenProvider.delegate as? BearerTokenProvider }
609+
tokenProvider?.let {
610+
// TODO: fragile
611+
try {
612+
it.refresh()
613+
} catch (_: InvalidGrantException) {
614+
it.invalidate()
615+
CodeWhispererUtil.reconnectCodeWhisperer(project)
616+
}
617+
}
618+
}
619+
}
578620
return null
579621
}
580622
}
@@ -594,6 +636,7 @@ class QInlineCompletionProvider(private val cs: CoroutineScope) : InlineCompleti
594636
val editor = request.editor
595637
val project = editor.project ?: return false
596638

639+
// qExpired case handled in completion handler
597640
if (!isQConnected(project)) return false
598641
if (QRegionProfileManager.getInstance().hasValidConnectionButNoActiveProfile(project)) return false
599642
if (event.isManualCall()) return true

0 commit comments

Comments
 (0)