Skip to content

Commit 080462d

Browse files
committed
handle race condition
1 parent d798928 commit 080462d

File tree

3 files changed

+160
-13
lines changed

3 files changed

+160
-13
lines changed

plugins/amazonq/chat/jetbrains-community/src/software/aws/toolkits/jetbrains/services/amazonq/webview/BrowserConnector.kt

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,7 @@ import software.aws.toolkits.telemetry.Telemetry
9898
import java.util.concurrent.CompletableFuture
9999
import java.util.concurrent.CompletionException
100100
import java.util.function.Function
101+
import software.aws.toolkits.jetbrains.services.amazonq.lsp.flareChat.ChatAsyncResultManager
101102

102103
class BrowserConnector(
103104
private val serializer: MessageSerializer = MessageSerializer.getInstance(),
@@ -106,6 +107,7 @@ class BrowserConnector(
106107
) {
107108
var uiReady = CompletableDeferred<Boolean>()
108109
private val chatCommunicationManager = ChatCommunicationManager.getInstance(project)
110+
private val chatAsyncResultManager = ChatAsyncResultManager.getInstance(project)
109111

110112
suspend fun connect(
111113
browser: Browser,
@@ -227,6 +229,8 @@ class BrowserConnector(
227229

228230
val tabId = requestFromUi.params.tabId
229231
val partialResultToken = chatCommunicationManager.addPartialChatMessage(tabId)
232+
chatCommunicationManager.registerPartialResultToken(partialResultToken)
233+
230234

231235
var encryptionManager: JwtEncryptionManager? = null
232236
val result = AmazonQLspService.executeIfRunning(project) { server ->
@@ -247,6 +251,7 @@ class BrowserConnector(
247251
val tabId = requestFromUi.params.tabId
248252
val quickActionParams = node.params ?: error("empty payload")
249253
val partialResultToken = chatCommunicationManager.addPartialChatMessage(tabId)
254+
chatCommunicationManager.registerPartialResultToken(partialResultToken)
250255
var encryptionManager: JwtEncryptionManager? = null
251256
val result = AmazonQLspService.executeIfRunning(project) { server ->
252257
encryptionManager = this.encryptionManager
@@ -478,7 +483,17 @@ class BrowserConnector(
478483
chatCommunicationManager.removeInflightRequestForTab(tabId)
479484
} catch (e: CancellationException) {
480485
LOG.warn { "Cancelled chat generation" }
481-
handleCancellation(tabId, partialResultToken, browser)
486+
try{
487+
chatAsyncResultManager.createRequestId(partialResultToken)
488+
chatAsyncResultManager.getResult(partialResultToken)
489+
handleCancellation(tabId, partialResultToken, browser)
490+
} catch (ex: Exception) {
491+
LOG.warn(ex) { "An error occurred while processing cancellation" }
492+
} finally {
493+
chatAsyncResultManager.removeRequestId(partialResultToken)
494+
chatCommunicationManager.removePartialResultLock(partialResultToken)
495+
chatCommunicationManager.removeFinalResultProcessed(partialResultToken)
496+
}
482497
} catch (e: Exception) {
483498
LOG.warn(e) { "Failed to send chat message" }
484499
browser.postChat(chatCommunicationManager.getErrorUiMessage(tabId, e, partialResultToken))
@@ -487,9 +502,6 @@ class BrowserConnector(
487502
}
488503

489504
private fun handleCancellation(tabId: String, partialResultToken: String, browser: Browser) {
490-
chatCommunicationManager.removePartialChatMessage(partialResultToken)
491-
chatCommunicationManager.removeInflightRequestForTab(tabId)
492-
493505
// Send a message to hide the stop button without showing an error
494506
val cancelMessage = chatCommunicationManager.getCancellationUiMessage(tabId)
495507
browser.postChat(cancelMessage)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
// Copyright 2025 Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
// SPDX-License-Identifier: Apache-2.0
3+
4+
package software.aws.toolkits.jetbrains.services.amazonq.lsp.flareChat
5+
6+
import com.intellij.openapi.components.Service
7+
import com.intellij.openapi.components.service
8+
import com.intellij.openapi.project.Project
9+
import java.util.concurrent.CompletableFuture
10+
import java.util.concurrent.ConcurrentHashMap
11+
import java.util.concurrent.TimeUnit
12+
import java.util.concurrent.TimeoutException
13+
14+
/**
15+
* Manages asynchronous results for chat operations, particularly handling the coordination
16+
* between partial results and final results during cancellation.
17+
*/
18+
@Service(Service.Level.PROJECT)
19+
class ChatAsyncResultManager() {
20+
private val results = ConcurrentHashMap<String, CompletableFuture<Any>>()
21+
private val completedResults = ConcurrentHashMap<String, Any>()
22+
private val timeout = 30L
23+
private val timeUnit = TimeUnit.SECONDS
24+
25+
fun createRequestId(requestId: String) {
26+
if (!completedResults.containsKey(requestId)) {
27+
results[requestId] = CompletableFuture()
28+
}
29+
}
30+
31+
fun removeRequestId(requestId: String) {
32+
val future = results.remove(requestId)
33+
if (future != null && !future.isDone) {
34+
future.cancel(true)
35+
}
36+
completedResults.remove(requestId)
37+
}
38+
39+
fun setResult(requestId: String, result: Any) {
40+
val future = results[requestId]
41+
if (future != null) {
42+
future.complete(result)
43+
results.remove(requestId)
44+
}
45+
completedResults[requestId] = result
46+
}
47+
48+
fun getResult(requestId: String): Any? {
49+
return getResult(requestId, timeout, timeUnit)
50+
}
51+
52+
private fun getResult(requestId: String, timeout: Long, unit: TimeUnit): Any? {
53+
val completedResult = completedResults[requestId]
54+
if (completedResult != null) {
55+
return completedResult
56+
}
57+
58+
val future = results[requestId] ?: throw IllegalArgumentException("Request ID not found: $requestId")
59+
60+
try {
61+
val result = future.get(timeout, unit)
62+
completedResults[requestId] = result
63+
results.remove(requestId)
64+
return result
65+
} catch (e: TimeoutException) {
66+
future.cancel(true)
67+
results.remove(requestId)
68+
throw TimeoutException("Operation timed out for requestId: $requestId")
69+
}
70+
}
71+
72+
companion object {
73+
fun getInstance(project: Project) = project.service<ChatAsyncResultManager>()
74+
}
75+
}

plugins/amazonq/shared/jetbrains-community/src/software/aws/toolkits/jetbrains/services/amazonq/lsp/flareChat/ChatCommunicationManager.kt

Lines changed: 69 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,8 @@ class ChatCommunicationManager(private val cs: CoroutineScope) {
4343
private val inflightRequestByTabId = ConcurrentHashMap<String, CompletableFuture<String>>()
4444
private val pendingSerializedChatRequests = ConcurrentHashMap<String, CompletableFuture<GetSerializedChatResult>>()
4545
private val pendingTabRequests = ConcurrentHashMap<String, CompletableFuture<LSPAny>>()
46+
private val partialResultLocks = ConcurrentHashMap<String, Any>()
47+
private val finalResultProcessed = ConcurrentHashMap<String, Boolean>()
4648

4749
fun setUiReady() {
4850
uiReady.complete(true)
@@ -97,6 +99,28 @@ class ChatCommunicationManager(private val cs: CoroutineScope) {
9799
fun removeTabOpenRequest(requestId: String) =
98100
pendingTabRequests.remove(requestId)
99101

102+
fun setPartialResultLock(token: String, lock: Any) {
103+
partialResultLocks[token] = lock
104+
}
105+
106+
fun removePartialResultLock(token: String) {
107+
partialResultLocks.remove(token)
108+
}
109+
110+
fun setFinalResultProcessed(token: String, processed: Boolean) {
111+
finalResultProcessed[token] = processed
112+
}
113+
114+
fun removeFinalResultProcessed(token: String) {
115+
finalResultProcessed.remove(token)
116+
}
117+
118+
fun registerPartialResultToken(partialResultToken: String) {
119+
val lock = Any()
120+
partialResultLocks[partialResultToken] = lock
121+
finalResultProcessed[partialResultToken] = false
122+
}
123+
100124
fun handlePartialResultProgressNotification(project: Project, params: ProgressParams) {
101125
val token = ProgressNotificationUtils.getToken(params)
102126
val tabId = getPartialChatMessage(token)
@@ -112,13 +136,49 @@ class ChatCommunicationManager(private val cs: CoroutineScope) {
112136
val encryptedPartialChatResult = getObject(params, String::class.java)
113137
if (encryptedPartialChatResult != null) {
114138
val partialChatResult = AmazonQLspService.getInstance(project).encryptionManager.decrypt(encryptedPartialChatResult)
115-
val uiMessage = convertToJsonToSendToChat(
116-
command = SEND_CHAT_COMMAND_PROMPT,
117-
tabId = tabId,
118-
params = partialChatResult,
119-
isPartialResult = true
120-
)
121-
AsyncChatUiListener.notifyPartialMessageUpdate(uiMessage)
139+
140+
// Special case: check for stop message before proceeding
141+
val partialResultMap = tryOrNull {
142+
Gson().fromJson(partialChatResult, Map::class.java)
143+
}
144+
145+
if (partialResultMap != null) {
146+
@Suppress("UNCHECKED_CAST")
147+
val additionalMessages = partialResultMap["additionalMessages"] as? List<Map<String, Any>>
148+
if (additionalMessages != null) {
149+
for (message in additionalMessages) {
150+
val messageId = message["messageId"] as? String
151+
if (messageId != null && messageId.startsWith("stopped")) {
152+
// Process stop messages immediately
153+
val uiMessage = convertToJsonToSendToChat(
154+
command = SEND_CHAT_COMMAND_PROMPT,
155+
tabId = tabId,
156+
params = partialChatResult,
157+
isPartialResult = true
158+
)
159+
AsyncChatUiListener.notifyPartialMessageUpdate(uiMessage)
160+
finalResultProcessed[token] = true
161+
ChatAsyncResultManager.getInstance(project).setResult(token, partialResultMap)
162+
return
163+
}
164+
}
165+
}
166+
}
167+
168+
// Normal processing for non-stop messages
169+
val lock = partialResultLocks[token] ?: return
170+
synchronized(lock) {
171+
if (finalResultProcessed[token] == true || partialResultLocks[token] == null) {
172+
return@synchronized
173+
}
174+
val uiMessage = convertToJsonToSendToChat(
175+
command = SEND_CHAT_COMMAND_PROMPT,
176+
tabId = tabId,
177+
params = partialChatResult,
178+
isPartialResult = true
179+
)
180+
AsyncChatUiListener.notifyPartialMessageUpdate(uiMessage)
181+
}
122182
}
123183
}
124184

@@ -147,12 +207,12 @@ class ChatCommunicationManager(private val cs: CoroutineScope) {
147207
""".trimIndent()
148208
return uiMessage
149209
}
150-
210+
151211
fun getCancellationUiMessage(tabId: String): String {
152212
// Create a minimal error params with empty error message to hide the stop button
153213
// without showing an actual error message to the user
154214
val errorParams = Gson().toJson(ErrorParams(tabId, null, "", "")).toString()
155-
215+
156216
return """
157217
{
158218
"command":"$CHAT_ERROR_PARAMS",

0 commit comments

Comments
 (0)