diff --git a/plugins/amazonq/chat/jetbrains-community/src/software/aws/toolkits/jetbrains/services/amazonq/webview/BrowserConnector.kt b/plugins/amazonq/chat/jetbrains-community/src/software/aws/toolkits/jetbrains/services/amazonq/webview/BrowserConnector.kt index 6241eefedf..98b16e0029 100644 --- a/plugins/amazonq/chat/jetbrains-community/src/software/aws/toolkits/jetbrains/services/amazonq/webview/BrowserConnector.kt +++ b/plugins/amazonq/chat/jetbrains-community/src/software/aws/toolkits/jetbrains/services/amazonq/webview/BrowserConnector.kt @@ -20,13 +20,19 @@ import kotlinx.coroutines.channels.awaitClose import kotlinx.coroutines.coroutineScope import kotlinx.coroutines.flow.Flow import kotlinx.coroutines.flow.callbackFlow +import kotlinx.coroutines.flow.catch import kotlinx.coroutines.flow.distinctUntilChanged +import kotlinx.coroutines.flow.first import kotlinx.coroutines.flow.launchIn import kotlinx.coroutines.flow.merge import kotlinx.coroutines.flow.onEach +import kotlinx.coroutines.flow.timeout +import kotlinx.coroutines.flow.toList import kotlinx.coroutines.launch +import kotlinx.coroutines.runBlocking import org.cef.browser.CefBrowser import org.eclipse.lsp4j.TextDocumentIdentifier +import org.eclipse.lsp4j.jsonrpc.JsonRpcException import org.eclipse.lsp4j.jsonrpc.ResponseErrorException import org.eclipse.lsp4j.jsonrpc.messages.ResponseErrorCode import software.aws.toolkits.core.utils.error @@ -37,6 +43,7 @@ import software.aws.toolkits.jetbrains.services.amazonq.apps.AppConnection import software.aws.toolkits.jetbrains.services.amazonq.commands.MessageSerializer import software.aws.toolkits.jetbrains.services.amazonq.lsp.AmazonQChatServer import software.aws.toolkits.jetbrains.services.amazonq.lsp.AmazonQLspService +import software.aws.toolkits.jetbrains.services.amazonq.lsp.AmazonQServerInstanceFacade import software.aws.toolkits.jetbrains.services.amazonq.lsp.JsonRpcMethod import software.aws.toolkits.jetbrains.services.amazonq.lsp.JsonRpcNotification import software.aws.toolkits.jetbrains.services.amazonq.lsp.JsonRpcRequest @@ -114,6 +121,7 @@ import software.aws.toolkits.telemetry.Telemetry import java.util.concurrent.CompletableFuture import java.util.concurrent.CompletionException import java.util.function.Function +import kotlin.time.Duration.Companion.milliseconds class BrowserConnector( private val serializer: MessageSerializer = MessageSerializer.getInstance(), @@ -237,43 +245,20 @@ class BrowserConnector( val chatParams: ObjectNode = (node.params as ObjectNode) .setAll(serializedEnrichmentParams) - val tabId = requestFromUi.params.tabId - val partialResultToken = chatCommunicationManager.addPartialChatMessage(tabId) - chatCommunicationManager.registerPartialResultToken(partialResultToken) - - var encryptionManager: JwtEncryptionManager? = null - val result = AmazonQLspService.executeAsyncIfRunning(project) { server -> - encryptionManager = this.encryptionManager - - val encryptedParams = EncryptedChatParams(this.encryptionManager.encrypt(chatParams), partialResultToken) - rawEndpoint.request(SEND_CHAT_COMMAND_PROMPT, encryptedParams) as CompletableFuture - } ?: (CompletableFuture.failedFuture(IllegalStateException("LSP Server not running"))) - - // We assume there is only one outgoing request per tab because the input is - // blocked when there is an outgoing request - chatCommunicationManager.setInflightRequestForTab(tabId, result) - showResult(result, partialResultToken, tabId, encryptionManager, browser) + doChatRequest(requestFromUi.params.tabId, browser) { serverFacade, partialResultToken -> + val encryptedParams = EncryptedChatParams(serverFacade.encryptionManager.encrypt(chatParams), partialResultToken) + (serverFacade.rawEndpoint.request(SEND_CHAT_COMMAND_PROMPT, encryptedParams) as CompletableFuture) + } } CHAT_QUICK_ACTION -> { - val requestFromUi = serializer.deserializeChatMessages(node) - val tabId = requestFromUi.params.tabId val quickActionParams = node.params ?: error("empty payload") - val partialResultToken = chatCommunicationManager.addPartialChatMessage(tabId) - chatCommunicationManager.registerPartialResultToken(partialResultToken) - var encryptionManager: JwtEncryptionManager? = null - val result = AmazonQLspService.executeAsyncIfRunning(project) { server -> - encryptionManager = this.encryptionManager - - val encryptedParams = EncryptedQuickActionChatParams(this.encryptionManager.encrypt(quickActionParams), partialResultToken) - rawEndpoint.request(CHAT_QUICK_ACTION, encryptedParams) as CompletableFuture - } ?: (CompletableFuture.failedFuture(IllegalStateException("LSP Server not running"))) - - // We assume there is only one outgoing request per tab because the input is - // blocked when there is an outgoing request - chatCommunicationManager.setInflightRequestForTab(tabId, result) + val requestFromUi = serializer.deserializeChatMessages(node) - showResult(result, partialResultToken, tabId, encryptionManager, browser) + doChatRequest(requestFromUi.params.tabId, browser) { serverFacade, partialResultToken -> + val encryptedParams = EncryptedQuickActionChatParams(serverFacade.encryptionManager.encrypt(quickActionParams), partialResultToken) + serverFacade.rawEndpoint.request(CHAT_QUICK_ACTION, encryptedParams) as CompletableFuture + } } CHAT_LIST_CONVERSATIONS -> { @@ -465,7 +450,6 @@ class BrowserConnector( AUTH_FOLLOW_UP_CLICKED -> { val message = serializer.deserializeChatMessages(node) chatCommunicationManager.handleAuthFollowUpClicked( - project, message.params ) } @@ -564,18 +548,44 @@ class BrowserConnector( } } - private fun showResult( - result: CompletableFuture, - partialResultToken: String, + private suspend fun doChatRequest( tabId: String, - encryptionManager: JwtEncryptionManager?, browser: Browser, + action: (AmazonQServerInstanceFacade, String) -> CompletableFuture, ) { + val partialResultToken = chatCommunicationManager.addPartialChatMessage(tabId) + chatCommunicationManager.registerPartialResultToken(partialResultToken) + var encryptionManager: JwtEncryptionManager? = null + val result = AmazonQLspService.executeAsyncIfRunning(project) { _ -> + // jank + encryptionManager = this@executeAsyncIfRunning.encryptionManager + action(this, partialResultToken) + .handle { result, ex -> + if (ex == null) { + return@handle result + } + + if (JsonRpcException.indicatesStreamClosed(ex)) { + // the flow buffer will never complete so insert some arbitrary timeout until we figure out how to end the flow + // after the error stream is closed and drained + val errorStream = runBlocking { this@executeAsyncIfRunning.errorStream.timeout(500.milliseconds).catch { }.toList() } + throw IllegalStateException("LSP execution error. See logs for more details: ${errorStream.joinToString(separator = "")}", ex.cause) + } + + throw ex + } + } ?: (CompletableFuture.failedFuture(IllegalStateException("LSP failed to start. See logs for more details: ${AmazonQLspService.getInstance(project).instanceFlow.first().errorStream.timeout(500.milliseconds).catch { }.toList().joinToString(separator = "")}"))) + + // We assume there is only one outgoing request per tab because the input is + // blocked when there is an outgoing request + chatCommunicationManager.setInflightRequestForTab(tabId, result) + result.whenComplete { value, error -> try { if (error != null) { throw error } + chatCommunicationManager.removePartialChatMessage(partialResultToken) val messageToChat = ChatCommunicationManager.convertToJsonToSendToChat( SEND_CHAT_COMMAND_PROMPT, @@ -585,7 +595,7 @@ class BrowserConnector( ) browser.postChat(messageToChat) chatCommunicationManager.removeInflightRequestForTab(tabId) - } catch (e: CancellationException) { + } catch (_: CancellationException) { LOG.warn { "Cancelled chat generation" } try { chatAsyncResultManager.createRequestId(partialResultToken) diff --git a/plugins/amazonq/shared/jetbrains-community/src/software/aws/toolkits/jetbrains/services/amazonq/lsp/AmazonQLanguageClientImpl.kt b/plugins/amazonq/shared/jetbrains-community/src/software/aws/toolkits/jetbrains/services/amazonq/lsp/AmazonQLanguageClientImpl.kt index 4b10514248..937bf7c2a6 100644 --- a/plugins/amazonq/shared/jetbrains-community/src/software/aws/toolkits/jetbrains/services/amazonq/lsp/AmazonQLanguageClientImpl.kt +++ b/plugins/amazonq/shared/jetbrains-community/src/software/aws/toolkits/jetbrains/services/amazonq/lsp/AmazonQLanguageClientImpl.kt @@ -84,7 +84,7 @@ import java.util.concurrent.TimeUnit /** * Concrete implementation of [AmazonQLanguageClient] to handle messages sent from server */ -class AmazonQLanguageClientImpl(private val project: Project) : AmazonQLanguageClient { +class AmazonQLanguageClientImpl(private val project: Project, private val instanceFacade: AmazonQServerInstanceFacade) : AmazonQLanguageClient { private val chatManager get() = ChatCommunicationManager.getInstance(project) private fun handleTelemetryMap(telemetryMap: Map<*, *>) { @@ -399,7 +399,7 @@ class AmazonQLanguageClientImpl(private val project: Project) : AmazonQLanguageC override fun notifyProgress(params: ProgressParams?) { if (params == null) return try { - chatManager.handlePartialResultProgressNotification(project, params) + chatManager.handlePartialResultProgressNotification(instanceFacade.encryptionManager, params) } catch (e: Exception) { LOG.error(e) { "Cannot handle partial chat" } } diff --git a/plugins/amazonq/shared/jetbrains-community/src/software/aws/toolkits/jetbrains/services/amazonq/lsp/AmazonQLspService.kt b/plugins/amazonq/shared/jetbrains-community/src/software/aws/toolkits/jetbrains/services/amazonq/lsp/AmazonQLspService.kt index c4a57f2312..5c36da8479 100644 --- a/plugins/amazonq/shared/jetbrains-community/src/software/aws/toolkits/jetbrains/services/amazonq/lsp/AmazonQLspService.kt +++ b/plugins/amazonq/shared/jetbrains-community/src/software/aws/toolkits/jetbrains/services/amazonq/lsp/AmazonQLspService.kt @@ -31,17 +31,23 @@ import com.intellij.util.net.ssl.CertificateManager import kotlinx.coroutines.CoroutineScope import kotlinx.coroutines.Deferred import kotlinx.coroutines.Job +import kotlinx.coroutines.TimeoutCancellationException import kotlinx.coroutines.async import kotlinx.coroutines.channels.BufferOverflow import kotlinx.coroutines.delay +import kotlinx.coroutines.flow.Flow import kotlinx.coroutines.flow.MutableSharedFlow import kotlinx.coroutines.flow.asSharedFlow +import kotlinx.coroutines.flow.catch import kotlinx.coroutines.flow.map +import kotlinx.coroutines.flow.timeout +import kotlinx.coroutines.flow.toList import kotlinx.coroutines.future.asCompletableFuture import kotlinx.coroutines.future.await import kotlinx.coroutines.isActive import kotlinx.coroutines.launch import kotlinx.coroutines.runBlocking +import kotlinx.coroutines.selects.select import kotlinx.coroutines.sync.Mutex import kotlinx.coroutines.sync.withLock import kotlinx.coroutines.withContext @@ -55,6 +61,7 @@ import org.eclipse.lsp4j.InitializedParams import org.eclipse.lsp4j.SynchronizationCapabilities import org.eclipse.lsp4j.TextDocumentClientCapabilities import org.eclipse.lsp4j.WorkspaceClientCapabilities +import org.eclipse.lsp4j.jsonrpc.JsonRpcException import org.eclipse.lsp4j.jsonrpc.Launcher import org.eclipse.lsp4j.jsonrpc.MessageConsumer import org.eclipse.lsp4j.jsonrpc.RemoteEndpoint @@ -101,6 +108,7 @@ import java.nio.file.Files import java.nio.file.Path import java.util.concurrent.Future import java.util.concurrent.TimeUnit +import kotlin.time.Duration.Companion.milliseconds import kotlin.time.Duration.Companion.seconds // https://github.com/redhat-developer/lsp4ij/blob/main/src/main/java/com/redhat/devtools/lsp4ij/server/LSPProcessListener.java @@ -109,6 +117,7 @@ internal class LSPProcessListener : ProcessListener { private val outputStream = PipedOutputStream() private val outputStreamWriter = OutputStreamWriter(outputStream, StandardCharsets.UTF_8) val inputStream = PipedInputStream(outputStream) + val errorStream = MutableSharedFlow(replay = 50, onBufferOverflow = BufferOverflow.DROP_OLDEST) override fun onTextAvailable(event: ProcessEvent, outputType: Key<*>) { if (ProcessOutputType.isStdout(outputType)) { @@ -120,6 +129,7 @@ internal class LSPProcessListener : ProcessListener { } } else if (ProcessOutputType.isStderr(outputType)) { LOG.warn { "LSP process stderr: ${event.text}" } + errorStream.tryEmit(event.text) } else if (outputType == ProcessOutputType.SYSTEM) { LOG.info { "LSP system events: ${event.text}" } } @@ -156,19 +166,13 @@ class AmazonQLspService @VisibleForTesting constructor( constructor(project: Project, cs: CoroutineScope) : this(DefaultAmazonQServerInstanceStarter, project, cs) private val _flowInstance = MutableSharedFlow(replay = 1, onBufferOverflow = BufferOverflow.DROP_OLDEST) - val instanceFlow = _flowInstance.asSharedFlow().map { it.languageServer } + val instanceFlow = _flowInstance.asSharedFlow() private var instance: Deferred - - val encryptionManager - get() = instance.getCompleted().encryptionManager private val heartbeatJob: Job private val restartTimestamps = ArrayDeque() private val restartMutex = Mutex() // Separate mutex for restart tracking - val rawEndpoint - get() = instance.getCompleted().rawEndpoint - // dont allow lsp commands if server is restarting private val mutex = Mutex(false) @@ -176,16 +180,17 @@ class AmazonQLspService @VisibleForTesting constructor( // manage lifecycle RAII-like so we can restart at arbitrary time // and suppress IDE error if server fails to start var attempts = 0 - while (attempts < 3) { + var lastInstance: AmazonQServerInstanceFacade? = null + while (isActive && attempts < 3 && checkForRemainingRestartAttempts()) { try { // no timeout; start() can download which may take long time - val instance = starter.start(project, cs).also { + lastInstance = starter.start(project, cs).also { Disposer.register(this@AmazonQLspService, it) } - // wait for handshake to complete - instance.initializeResult.join() - return@async instance.also { + lastInstance.waitForInitializeOrThrow(this) + + return@async lastInstance.also { _flowInstance.emit(it) } } catch (e: Exception) { @@ -194,7 +199,12 @@ class AmazonQLspService @VisibleForTesting constructor( attempts++ } - error("Failed to start LSP server in 3 attempts") + // does not capture all failure + lastInstance?.let { + _flowInstance.tryEmit(it) + } + + lastInstance ?: error("LSP failed all attempts to start") } init { @@ -294,20 +304,20 @@ class AmazonQLspService @VisibleForTesting constructor( suspend fun execute(runnable: suspend AmazonQLspService.(AmazonQLanguageServer) -> T): T { val lsp = withTimeout(5.seconds) { val holder = mutex.withLock { instance }.await() - holder.initializeResult.join() + holder.waitForInitializeOrThrow(this) holder.languageServer } return runnable(lsp) } - suspend fun executeIfRunning(runnable: suspend AmazonQLspService.(AmazonQLanguageServer) -> T): T? = withContext(dispatcher) { + suspend fun executeIfRunning(runnable: suspend AmazonQServerInstanceFacade.(AmazonQLanguageServer) -> T): T? = withContext(dispatcher) { val lsp = try { withTimeout(5.seconds) { val holder = mutex.withLock { instance }.await() - holder.initializeResult.join() + holder.waitForInitializeOrThrow(this) - holder.languageServer + holder } } catch (_: Exception) { LOG.debug { "LSP not running" } @@ -315,10 +325,10 @@ class AmazonQLspService @VisibleForTesting constructor( null } - lsp?.let { runnable(it) } + lsp?.let { runnable(it, it.languageServer) } } - fun syncExecuteIfRunning(runnable: suspend AmazonQLspService.(AmazonQLanguageServer) -> T): T? = + fun syncExecuteIfRunning(runnable: suspend AmazonQServerInstanceFacade.(AmazonQLanguageServer) -> T): T? = runBlocking(dispatcher) { executeIfRunning(runnable) } @@ -331,13 +341,14 @@ class AmazonQLspService @VisibleForTesting constructor( private const val RESTART_WINDOW_MS = 3 * 60 * 1000 fun getInstance(project: Project) = project.service() - suspend fun executeAsyncIfRunning(project: Project, runnable: suspend AmazonQLspService.(AmazonQLanguageServer) -> T): T? = + suspend fun executeAsyncIfRunning(project: Project, runnable: suspend AmazonQServerInstanceFacade.(AmazonQLanguageServer) -> T): T? = project.serviceIfCreated()?.executeIfRunning(runnable) } } interface AmazonQServerInstanceFacade : Disposable { val launcher: Launcher + val errorStream: Flow @Suppress("ForbiddenVoid") val launcherFuture: Future @@ -349,11 +360,18 @@ interface AmazonQServerInstanceFacade : Disposable { val rawEndpoint: RemoteEndpoint get() = launcher.remoteEndpoint + + suspend fun waitForInitializeOrThrow(scope: CoroutineScope) = + select { + initializeResult.onAwait { it } + scope.async { launcherFuture.get() }.onAwait { error(errorStream.timeout(500.milliseconds).catch { }.toList().joinToString(separator = "")) } + } } private class AmazonQServerInstance(private val project: Project, private val cs: CoroutineScope) : Disposable, AmazonQServerInstanceFacade { override val encryptionManager = JwtEncryptionManager() override val launcher: Launcher + override val errorStream: Flow @Suppress("ForbiddenVoid") override val launcherFuture: Future @@ -502,6 +520,7 @@ private class AmazonQServerInstance(private val project: Project, private val cs launcherHandler = KillableColoredProcessHandler.Silent(cmd) val inputWrapper = LSPProcessListener() + errorStream = inputWrapper.errorStream launcherHandler.addProcessListener(inputWrapper) launcherHandler.startNotify() @@ -539,11 +558,21 @@ private class AmazonQServerInstance(private val project: Project, private val cs AwsServerCapabilitiesProvider.getInstance(project).setAwsServerCapabilities(result.getAwsServerCapabilities()) } - // required - consumer?.consume(message) + try { + // required + consumer?.consume(message) + } catch (e: JsonRpcException) { + // suppress stream error if notification, else bubble up for correct error propagation + if (JsonRpcException.indicatesStreamClosed(e) && message is NotificationMessage) { + LOG.warn { "Failed to send notification message (${message.method}): ${e.cause}" } + LOG.debug(e) { "Failed to send notification message (${message.method})." } + } else { + throw e + } + } } } - .setLocalService(AmazonQLanguageClientImpl(project)) + .setLocalService(AmazonQLanguageClientImpl(project, this)) .setRemoteInterface(AmazonQLanguageServer::class.java) .configureGson { // otherwise Gson treats all numbers as double which causes deser issues diff --git a/plugins/amazonq/shared/jetbrains-community/src/software/aws/toolkits/jetbrains/services/amazonq/lsp/flareChat/ChatCommunicationManager.kt b/plugins/amazonq/shared/jetbrains-community/src/software/aws/toolkits/jetbrains/services/amazonq/lsp/flareChat/ChatCommunicationManager.kt index 57f7aef7b4..b0e2efd922 100644 --- a/plugins/amazonq/shared/jetbrains-community/src/software/aws/toolkits/jetbrains/services/amazonq/lsp/flareChat/ChatCommunicationManager.kt +++ b/plugins/amazonq/shared/jetbrains-community/src/software/aws/toolkits/jetbrains/services/amazonq/lsp/flareChat/ChatCommunicationManager.kt @@ -19,7 +19,7 @@ import software.aws.toolkits.core.utils.warn import software.aws.toolkits.jetbrains.core.credentials.ToolkitConnectionManager import software.aws.toolkits.jetbrains.core.credentials.pinning.QConnection import software.aws.toolkits.jetbrains.core.credentials.reauthConnectionIfNeeded -import software.aws.toolkits.jetbrains.services.amazonq.lsp.AmazonQLspService +import software.aws.toolkits.jetbrains.services.amazonq.lsp.encryption.JwtEncryptionManager import software.aws.toolkits.jetbrains.services.amazonq.lsp.flareChat.ProgressNotificationUtils.getObject import software.aws.toolkits.jetbrains.services.amazonq.lsp.model.aws.LSPAny import software.aws.toolkits.jetbrains.services.amazonq.lsp.model.aws.chat.AuthFollowUpClickedParams @@ -132,7 +132,7 @@ class ChatCommunicationManager(private val project: Project, private val cs: Cor finalResultProcessed[partialResultToken] = false } - fun handlePartialResultProgressNotification(project: Project, params: ProgressParams) { + fun handlePartialResultProgressNotification(encryptionManager: JwtEncryptionManager, params: ProgressParams) { val token = ProgressNotificationUtils.getToken(params) val tabId = getPartialChatMessage(token) if (tabId.isNullOrEmpty()) { @@ -146,7 +146,7 @@ class ChatCommunicationManager(private val project: Project, private val cs: Cor val encryptedPartialChatResult = getObject(params, String::class.java) if (encryptedPartialChatResult != null) { - val partialChatResult = AmazonQLspService.getInstance(project).encryptionManager.decrypt(encryptedPartialChatResult) + val partialChatResult = encryptionManager.decrypt(encryptedPartialChatResult) // Special case: check for stop message before proceeding val partialResultMap = tryOrNull { @@ -234,7 +234,7 @@ class ChatCommunicationManager(private val project: Project, private val cs: Cor """.trimIndent() } - fun handleAuthFollowUpClicked(project: Project, params: AuthFollowUpClickedParams) { + fun handleAuthFollowUpClicked(params: AuthFollowUpClickedParams) { val incomingType = params.authFollowupType val connectionManager = ToolkitConnectionManager.getInstance(project) try {