Skip to content

Commit dc8bc7a

Browse files
fix(amazonq): Stop button disappears after being clicked (#5765)
Fixed a race condition between two streams receiving cancellation information. Previously, if the CancellationException from the sendChatPrompt stream was processed before the "You stopped your work" message from the partialResult stream, the UI would not properly display the stop message. The fix prioritizes stop messages by checking for them before acquiring any locks, ensuring they're processed immediately when detected. This allows the cancellation to complete properly regardless of which stream receives information first. Key changes: * Added special case handling for stop messages before lock acquisition * Immediately mark final result as processed when stop message is detected * Set result in ChatAsyncResultManager to coordinate between streams This ensures a consistent user experience when cancelling operations, with the stop message always being displayed properly.
1 parent 04aba5b commit dc8bc7a

File tree

3 files changed

+170
-8
lines changed

3 files changed

+170
-8
lines changed

plugins/amazonq/chat/jetbrains-community/src/software/aws/toolkits/jetbrains/services/amazonq/webview/BrowserConnector.kt

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@ import software.aws.toolkits.jetbrains.services.amazonq.lsp.JsonRpcNotification
4141
import software.aws.toolkits.jetbrains.services.amazonq.lsp.JsonRpcRequest
4242
import software.aws.toolkits.jetbrains.services.amazonq.lsp.encryption.JwtEncryptionManager
4343
import software.aws.toolkits.jetbrains.services.amazonq.lsp.flareChat.AwsServerCapabilitiesProvider
44+
import software.aws.toolkits.jetbrains.services.amazonq.lsp.flareChat.ChatAsyncResultManager
4445
import software.aws.toolkits.jetbrains.services.amazonq.lsp.flareChat.ChatCommunicationManager
4546
import software.aws.toolkits.jetbrains.services.amazonq.lsp.flareChat.FlareUiMessage
4647
import software.aws.toolkits.jetbrains.services.amazonq.lsp.model.aws.chat.AUTH_FOLLOW_UP_CLICKED
@@ -106,6 +107,7 @@ class BrowserConnector(
106107
) {
107108
var uiReady = CompletableDeferred<Boolean>()
108109
private val chatCommunicationManager = ChatCommunicationManager.getInstance(project)
110+
private val chatAsyncResultManager = ChatAsyncResultManager.getInstance(project)
109111

110112
suspend fun connect(
111113
browser: Browser,
@@ -227,6 +229,7 @@ class BrowserConnector(
227229

228230
val tabId = requestFromUi.params.tabId
229231
val partialResultToken = chatCommunicationManager.addPartialChatMessage(tabId)
232+
chatCommunicationManager.registerPartialResultToken(partialResultToken)
230233

231234
var encryptionManager: JwtEncryptionManager? = null
232235
val result = AmazonQLspService.executeIfRunning(project) { server ->
@@ -247,6 +250,7 @@ class BrowserConnector(
247250
val tabId = requestFromUi.params.tabId
248251
val quickActionParams = node.params ?: error("empty payload")
249252
val partialResultToken = chatCommunicationManager.addPartialChatMessage(tabId)
253+
chatCommunicationManager.registerPartialResultToken(partialResultToken)
250254
var encryptionManager: JwtEncryptionManager? = null
251255
val result = AmazonQLspService.executeIfRunning(project) { server ->
252256
encryptionManager = this.encryptionManager
@@ -476,15 +480,32 @@ class BrowserConnector(
476480
)
477481
browser.postChat(messageToChat)
478482
chatCommunicationManager.removeInflightRequestForTab(tabId)
479-
} catch (_: CancellationException) {
483+
} catch (e: CancellationException) {
480484
LOG.warn { "Cancelled chat generation" }
485+
try {
486+
chatAsyncResultManager.createRequestId(partialResultToken)
487+
chatAsyncResultManager.getResult(partialResultToken)
488+
handleCancellation(tabId, browser)
489+
} catch (ex: Exception) {
490+
LOG.warn(ex) { "An error occurred while processing cancellation" }
491+
} finally {
492+
chatAsyncResultManager.removeRequestId(partialResultToken)
493+
chatCommunicationManager.removePartialResultLock(partialResultToken)
494+
chatCommunicationManager.removeFinalResultProcessed(partialResultToken)
495+
}
481496
} catch (e: Exception) {
482497
LOG.warn(e) { "Failed to send chat message" }
483498
browser.postChat(chatCommunicationManager.getErrorUiMessage(tabId, e, partialResultToken))
484499
}
485500
}
486501
}
487502

503+
private fun handleCancellation(tabId: String, browser: Browser) {
504+
// Send a message to hide the stop button without showing an error
505+
val cancelMessage = chatCommunicationManager.getCancellationUiMessage(tabId)
506+
browser.postChat(cancelMessage)
507+
}
508+
488509
private fun cancelInflightRequests(tabId: String) {
489510
chatCommunicationManager.getInflightRequestForTab(tabId)?.let { request ->
490511
request.cancel(true)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
// Copyright 2025 Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
// SPDX-License-Identifier: Apache-2.0
3+
4+
package software.aws.toolkits.jetbrains.services.amazonq.lsp.flareChat
5+
6+
import com.intellij.openapi.components.Service
7+
import com.intellij.openapi.components.service
8+
import com.intellij.openapi.project.Project
9+
import java.util.concurrent.CompletableFuture
10+
import java.util.concurrent.ConcurrentHashMap
11+
import java.util.concurrent.TimeUnit
12+
import java.util.concurrent.TimeoutException
13+
14+
/**
15+
* Manages asynchronous results for chat operations, particularly handling the coordination
16+
* between partial results and final results during cancellation.
17+
*/
18+
@Service(Service.Level.PROJECT)
19+
class ChatAsyncResultManager {
20+
private val results = ConcurrentHashMap<String, CompletableFuture<Any>>()
21+
private val completedResults = ConcurrentHashMap<String, Any>()
22+
private val timeout = 30L
23+
private val timeUnit = TimeUnit.SECONDS
24+
25+
fun createRequestId(requestId: String) {
26+
if (!completedResults.containsKey(requestId)) {
27+
results[requestId] = CompletableFuture()
28+
}
29+
}
30+
31+
fun removeRequestId(requestId: String) {
32+
val future = results.remove(requestId)
33+
if (future != null && !future.isDone) {
34+
future.cancel(true)
35+
}
36+
completedResults.remove(requestId)
37+
}
38+
39+
fun setResult(requestId: String, result: Any) {
40+
val future = results[requestId]
41+
if (future != null) {
42+
future.complete(result)
43+
results.remove(requestId)
44+
}
45+
completedResults[requestId] = result
46+
}
47+
48+
fun getResult(requestId: String): Any? =
49+
getResult(requestId, timeout, timeUnit)
50+
51+
private fun getResult(requestId: String, timeout: Long, unit: TimeUnit): Any? {
52+
val completedResult = completedResults[requestId]
53+
if (completedResult != null) {
54+
return completedResult
55+
}
56+
57+
val future = results[requestId] ?: throw IllegalArgumentException("Request ID not found: $requestId")
58+
59+
try {
60+
val result = future.get(timeout, unit)
61+
completedResults[requestId] = result
62+
results.remove(requestId)
63+
return result
64+
} catch (e: TimeoutException) {
65+
future.cancel(true)
66+
results.remove(requestId)
67+
throw TimeoutException("Operation timed out for requestId: $requestId")
68+
}
69+
}
70+
71+
companion object {
72+
fun getInstance(project: Project) = project.service<ChatAsyncResultManager>()
73+
}
74+
}

plugins/amazonq/shared/jetbrains-community/src/software/aws/toolkits/jetbrains/services/amazonq/lsp/flareChat/ChatCommunicationManager.kt

Lines changed: 74 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,8 @@ class ChatCommunicationManager(private val cs: CoroutineScope) {
4343
private val inflightRequestByTabId = ConcurrentHashMap<String, CompletableFuture<String>>()
4444
private val pendingSerializedChatRequests = ConcurrentHashMap<String, CompletableFuture<GetSerializedChatResult>>()
4545
private val pendingTabRequests = ConcurrentHashMap<String, CompletableFuture<LSPAny>>()
46+
private val partialResultLocks = ConcurrentHashMap<String, Any>()
47+
private val finalResultProcessed = ConcurrentHashMap<String, Boolean>()
4648

4749
fun setUiReady() {
4850
uiReady.complete(true)
@@ -97,6 +99,20 @@ class ChatCommunicationManager(private val cs: CoroutineScope) {
9799
fun removeTabOpenRequest(requestId: String) =
98100
pendingTabRequests.remove(requestId)
99101

102+
fun removePartialResultLock(token: String) {
103+
partialResultLocks.remove(token)
104+
}
105+
106+
fun removeFinalResultProcessed(token: String) {
107+
finalResultProcessed.remove(token)
108+
}
109+
110+
fun registerPartialResultToken(partialResultToken: String) {
111+
val lock = Any()
112+
partialResultLocks[partialResultToken] = lock
113+
finalResultProcessed[partialResultToken] = false
114+
}
115+
100116
fun handlePartialResultProgressNotification(project: Project, params: ProgressParams) {
101117
val token = ProgressNotificationUtils.getToken(params)
102118
val tabId = getPartialChatMessage(token)
@@ -112,13 +128,49 @@ class ChatCommunicationManager(private val cs: CoroutineScope) {
112128
val encryptedPartialChatResult = getObject(params, String::class.java)
113129
if (encryptedPartialChatResult != null) {
114130
val partialChatResult = AmazonQLspService.getInstance(project).encryptionManager.decrypt(encryptedPartialChatResult)
115-
val uiMessage = convertToJsonToSendToChat(
116-
command = SEND_CHAT_COMMAND_PROMPT,
117-
tabId = tabId,
118-
params = partialChatResult,
119-
isPartialResult = true
120-
)
121-
AsyncChatUiListener.notifyPartialMessageUpdate(uiMessage)
131+
132+
// Special case: check for stop message before proceeding
133+
val partialResultMap = tryOrNull {
134+
Gson().fromJson(partialChatResult, Map::class.java)
135+
}
136+
137+
if (partialResultMap != null) {
138+
@Suppress("UNCHECKED_CAST")
139+
val additionalMessages = partialResultMap["additionalMessages"] as? List<Map<String, Any>>
140+
if (additionalMessages != null) {
141+
for (message in additionalMessages) {
142+
val messageId = message["messageId"] as? String
143+
if (messageId != null && messageId.startsWith("stopped")) {
144+
// Process stop messages immediately
145+
val uiMessage = convertToJsonToSendToChat(
146+
command = SEND_CHAT_COMMAND_PROMPT,
147+
tabId = tabId,
148+
params = partialChatResult,
149+
isPartialResult = true
150+
)
151+
AsyncChatUiListener.notifyPartialMessageUpdate(uiMessage)
152+
finalResultProcessed[token] = true
153+
ChatAsyncResultManager.getInstance(project).setResult(token, partialResultMap)
154+
return
155+
}
156+
}
157+
}
158+
}
159+
160+
// Normal processing for non-stop messages
161+
val lock = partialResultLocks[token] ?: return
162+
synchronized(lock) {
163+
if (finalResultProcessed[token] == true || partialResultLocks[token] == null) {
164+
return@synchronized
165+
}
166+
val uiMessage = convertToJsonToSendToChat(
167+
command = SEND_CHAT_COMMAND_PROMPT,
168+
tabId = tabId,
169+
params = partialChatResult,
170+
isPartialResult = true
171+
)
172+
AsyncChatUiListener.notifyPartialMessageUpdate(uiMessage)
173+
}
122174
}
123175
}
124176

@@ -148,6 +200,21 @@ class ChatCommunicationManager(private val cs: CoroutineScope) {
148200
return uiMessage
149201
}
150202

203+
fun getCancellationUiMessage(tabId: String): String {
204+
// Create a minimal error params with empty error message to hide the stop button
205+
// without showing an actual error message to the user
206+
val errorParams = Gson().toJson(ErrorParams(tabId, null, "", "")).toString()
207+
208+
return """
209+
{
210+
"command":"$CHAT_ERROR_PARAMS",
211+
"tabId": "$tabId",
212+
"params": $errorParams,
213+
"isPartialResult": false
214+
}
215+
""".trimIndent()
216+
}
217+
151218
fun handleAuthFollowUpClicked(project: Project, params: AuthFollowUpClickedParams) {
152219
val incomingType = params.authFollowupType
153220
val connectionManager = ToolkitConnectionManager.getInstance(project)

0 commit comments

Comments
 (0)