diff --git a/plugins/amazonq/codewhisperer/jetbrains-community/src/software/aws/toolkits/jetbrains/services/codewhisperer/popup/QInlineCompletionProvider.kt b/plugins/amazonq/codewhisperer/jetbrains-community/src/software/aws/toolkits/jetbrains/services/codewhisperer/popup/QInlineCompletionProvider.kt index 82719b8dbd6..24c5f0d8c84 100644 --- a/plugins/amazonq/codewhisperer/jetbrains-community/src/software/aws/toolkits/jetbrains/services/codewhisperer/popup/QInlineCompletionProvider.kt +++ b/plugins/amazonq/codewhisperer/jetbrains-community/src/software/aws/toolkits/jetbrains/services/codewhisperer/popup/QInlineCompletionProvider.kt @@ -42,10 +42,18 @@ import kotlinx.coroutines.channels.Channel import kotlinx.coroutines.flow.receiveAsFlow import kotlinx.coroutines.future.await import kotlinx.coroutines.launch +import kotlinx.coroutines.withContext import migration.software.aws.toolkits.jetbrains.services.codewhisperer.explorer.CodeWhispererExplorerActionManager +import org.eclipse.lsp4j.jsonrpc.ResponseErrorException import org.eclipse.lsp4j.jsonrpc.messages.Either +import software.amazon.awssdk.services.ssooidc.model.InvalidGrantException import software.aws.toolkits.core.utils.debug import software.aws.toolkits.core.utils.getLogger +import software.aws.toolkits.jetbrains.core.coroutines.getCoroutineBgContext +import software.aws.toolkits.jetbrains.core.credentials.AwsBearerTokenConnection +import software.aws.toolkits.jetbrains.core.credentials.ToolkitConnectionManager +import software.aws.toolkits.jetbrains.core.credentials.pinning.QConnection +import software.aws.toolkits.jetbrains.core.credentials.sso.bearer.BearerTokenProvider import software.aws.toolkits.jetbrains.services.amazonq.lsp.AmazonQLspService import software.aws.toolkits.jetbrains.services.amazonq.profile.QRegionProfileManager import software.aws.toolkits.jetbrains.services.codewhisperer.importadder.CodeWhispererImportAdder @@ -59,8 +67,10 @@ import software.aws.toolkits.jetbrains.services.codewhisperer.service.CodeWhispe import software.aws.toolkits.jetbrains.services.codewhisperer.telemetry.CodeWhispererTelemetryService import software.aws.toolkits.jetbrains.services.codewhisperer.toolwindow.CodeWhispererCodeReferenceManager import software.aws.toolkits.jetbrains.services.codewhisperer.util.CodeWhispererConstants +import software.aws.toolkits.jetbrains.services.codewhisperer.util.CodeWhispererUtil import software.aws.toolkits.jetbrains.services.codewhisperer.util.getDocumentDiagnostics import software.aws.toolkits.jetbrains.utils.isQConnected +import software.aws.toolkits.jetbrains.utils.isQExpired import software.aws.toolkits.resources.message import software.aws.toolkits.telemetry.CodewhispererTriggerType import java.awt.Dimension @@ -393,8 +403,22 @@ class QInlineCompletionProvider(private val cs: CoroutineScope) : InlineCompleti override suspend fun getSuggestion(request: InlineCompletionRequest): InlineCompletionSuggestion { val editor = request.editor - val document = editor.document val project = editor.project ?: return InlineCompletionSuggestion.Empty + + // try to refresh automatically if possible, otherwise ask user to login again + if (isQExpired(project)) { + // consider changing to only running once a ~minute since this is relatively expensive + // say the connection is un-refreshable if refresh fails for 3 times + val shouldReauth = withContext(getCoroutineBgContext()) { + CodeWhispererUtil.promptReAuth(project) + } + + if (shouldReauth) { + return InlineCompletionSuggestion.Empty + } + } + + val document = editor.document val handler = InlineCompletion.getHandlerOrNull(editor) ?: return InlineCompletionSuggestion.Empty val session = InlineCompletionSession.getOrNull(editor) ?: return InlineCompletionSuggestion.Empty val triggerSessionId = triggerSessionId++ @@ -446,31 +470,28 @@ class QInlineCompletionProvider(private val cs: CoroutineScope) : InlineCompleti } try { - // Launch coroutine for background pagination progress - cs.launch { - var nextToken: Either? = null - do { - nextToken = startPaginationInBackground( - project, - editor, - triggerTypeInfo, - triggerSessionId, - nextToken, - sessionContext, - ) - } while (nextToken != null && !nextToken.left.isNullOrEmpty()) + var nextToken: Either? = null + do { + nextToken = startPaginationInBackground( + project, + editor, + triggerTypeInfo, + triggerSessionId, + nextToken, + sessionContext, + ) + } while (nextToken != null && !nextToken.left.isNullOrEmpty()) - // closing all channels since pagination for this session has finished + // closing all channels since pagination for this session has finished + logInline(triggerSessionId) { + "Pagination finished, closing all channels" + } + sessionContext.itemContexts.forEach { + it.channel.close() + } + if (session.context.isDisposed) { logInline(triggerSessionId) { - "Pagination finished, closing all channels" - } - sessionContext.itemContexts.forEach { - it.channel.close() - } - if (session.context.isDisposed) { - logInline(triggerSessionId) { - "Current display session already disposed by a new trigger before pagination finishes, exiting" - } + "Current display session already disposed by a new trigger before pagination finishes, exiting" } } @@ -575,6 +596,27 @@ class QInlineCompletionProvider(private val cs: CoroutineScope) : InlineCompleti logInline(triggerSessionId, e) { "Error during pagination" } + if (e is ResponseErrorException) { + // convoluted but lines up with "The bearer token included in the request is invalid" + // https://github.com/aws/language-servers/blob/1f3e93024eeb22186a34f0bd560f8d552f517300/server/aws-lsp-codewhisperer/src/language-server/chat/utils.ts#L22-L23 + // error data is nullable + if (e.responseError.data?.toString()?.contains("E_AMAZON_Q_CONNECTION_EXPIRED") == true) { + // kill the session if the connection is expired + val connection = ToolkitConnectionManager + .getInstance(project) + .activeConnectionForFeature(QConnection.getInstance()) as? AwsBearerTokenConnection + val tokenProvider = connection?.let { it.getConnectionSettings().tokenProvider.delegate as? BearerTokenProvider } + tokenProvider?.let { + // TODO: fragile + try { + it.refresh() + } catch (_: InvalidGrantException) { + it.invalidate() + CodeWhispererUtil.reconnectCodeWhisperer(project) + } + } + } + } return null } } @@ -594,6 +636,7 @@ class QInlineCompletionProvider(private val cs: CoroutineScope) : InlineCompleti val editor = request.editor val project = editor.project ?: return false + // qExpired case handled in completion handler if (!isQConnected(project)) return false if (QRegionProfileManager.getInstance().hasValidConnectionButNoActiveProfile(project)) return false if (event.isManualCall()) return true