Skip to content

Commit aed4ab7

Browse files
Will-ShaoHualeigaolrli
authored
config(amazonq): make queryChat suspend and called within withTimeout (#5020)
Co-authored-by: Lei Gao <[email protected]> Co-authored-by: Richard Li <[email protected]>
1 parent 662a251 commit aed4ab7

File tree

5 files changed

+51
-23
lines changed

5 files changed

+51
-23
lines changed

plugins/amazonq/chat/jetbrains-community/src/software/aws/toolkits/jetbrains/services/cwc/controller/ChatController.kt

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ import software.aws.toolkits.core.utils.getLogger
3636
import software.aws.toolkits.core.utils.info
3737
import software.aws.toolkits.core.utils.warn
3838
import software.aws.toolkits.jetbrains.core.coroutines.EDT
39+
import software.aws.toolkits.jetbrains.services.amazonq.CHAT_IMPLICIT_PROJECT_CONTEXT_TIMEOUT
3940
import software.aws.toolkits.jetbrains.services.amazonq.CodeWhispererFeatureConfigService
4041
import software.aws.toolkits.jetbrains.services.amazonq.apps.AmazonQAppInitContext
4142
import software.aws.toolkits.jetbrains.services.amazonq.auth.AuthController
@@ -137,15 +138,15 @@ class ChatController private constructor(
137138
shouldUseWorkspaceContext = true
138139
prompt = prompt.replace("@workspace", "")
139140
val projectContextController = ProjectContextController.getInstance(context.project)
140-
queryResult = projectContextController.query(prompt)
141+
queryResult = projectContextController.query(prompt, timeout = null)
141142
if (!projectContextController.getProjectContextIndexComplete()) shouldAddIndexInProgressMessage = true
142143
logger.info { "project context relevant document count: ${queryResult.size}" }
143144
} else {
144145
sendOpenSettingsMessage(message.tabId)
145146
}
146147
} else if (CodeWhispererSettings.getInstance().isProjectContextEnabled() && isDataCollectionGroup) {
147148
val projectContextController = ProjectContextController.getInstance(context.project)
148-
queryResult = projectContextController.query(prompt)
149+
queryResult = projectContextController.query(prompt, timeout = CHAT_IMPLICIT_PROJECT_CONTEXT_TIMEOUT)
149150
}
150151

151152
handleChat(

plugins/amazonq/chat/jetbrains-community/tst/software/aws/toolkits/jetbrains/services/amazonq/workspace/context/ProjectContextProviderTest.kt

Lines changed: 34 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -214,8 +214,10 @@ class ProjectContextProviderTest {
214214
}
215215

216216
@Test
217-
fun `query should send correct encrypted request to lsp`() {
218-
sut.query("foo")
217+
fun `query should send correct encrypted request to lsp`() = runTest {
218+
sut = ProjectContextProvider(project, encoderServer, this)
219+
val r = sut.query("foo", null)
220+
advanceUntilIdle()
219221

220222
val request = QueryChatRequest("foo")
221223
val requestJson = mapper.writeValueAsString(request)
@@ -269,13 +271,15 @@ class ProjectContextProviderTest {
269271
)
270272

271273
assertThrows<Exception> {
272-
sut.query("foo")
274+
sut.query("foo", null)
273275
}
274276
}
275277

276278
@Test
277279
fun `query chat should return deserialized relevantDocument`() = runTest {
278-
val r = sut.query("foo")
280+
sut = ProjectContextProvider(project, encoderServer, this)
281+
val r = sut.query("foo", null)
282+
advanceUntilIdle()
279283
assertThat(r).hasSize(2)
280284
assertThat(r[0]).isEqualTo(
281285
RelevantDocument(
@@ -377,6 +381,29 @@ class ProjectContextProviderTest {
377381
}
378382
}
379383

384+
@Test
385+
fun `queryChat should throw if time elapsed is greather than 500ms`() = runTest {
386+
assertThrows<TimeoutCancellationException> {
387+
sut = ProjectContextProvider(project, encoderServer, this)
388+
stubFor(
389+
any(urlPathEqualTo("/query")).willReturn(
390+
aResponse()
391+
.withStatus(200)
392+
.withResponseBody(
393+
Body(validQueryChatResponse)
394+
)
395+
.withFixedDelay(501)
396+
)
397+
)
398+
399+
withContext(getCoroutineBgContext()) {
400+
sut.query("foo", timeout = 500L)
401+
}
402+
403+
advanceUntilIdle()
404+
}
405+
}
406+
380407
@Test
381408
fun `test index payload is encrypted`() = runTest {
382409
whenever(encoderServer.port).thenReturn(3000)
@@ -390,12 +417,9 @@ class ProjectContextProviderTest {
390417

391418
@Test
392419
fun `test query payload is encrypted`() = runTest {
393-
whenever(encoderServer.port).thenReturn(3000)
394-
try {
395-
sut.query("what does this project do")
396-
} catch (e: ConnectException) {
397-
// no-op
398-
}
420+
sut = ProjectContextProvider(project, encoderServer, this)
421+
sut.query("what does this project do", null)
422+
advanceUntilIdle()
399423
verify(encoderServer, times(1)).encrypt(any())
400424
}
401425

plugins/amazonq/shared/jetbrains-community/src/software/aws/toolkits/jetbrains/services/amazonq/Constants.kt

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@
33

44
package software.aws.toolkits.jetbrains.services.amazonq
55

6+
import kotlin.time.Duration.Companion.minutes
7+
68
const val APPLICATION_ZIP = "application/zip"
79
const val SERVER_SIDE_ENCRYPTION = "x-amz-server-side-encryption"
810
const val AWS_KMS = "aws:kms"
@@ -39,3 +41,7 @@ const val CODE_TRANSFORM_PREREQUISITES =
3941
const val FEATURE_EVALUATION_PRODUCT_NAME = "CodeWhisperer"
4042

4143
const val SUPPLEMENTAL_CONTEXT_TIMEOUT = 100L
44+
45+
val CHAT_EXPLICIT_PROJECT_CONTEXT_TIMEOUT = 5.minutes.inWholeMilliseconds
46+
47+
const val CHAT_IMPLICIT_PROJECT_CONTEXT_TIMEOUT = 500L // 500ms

plugins/amazonq/shared/jetbrains-community/src/software/aws/toolkits/jetbrains/services/amazonq/project/ProjectContextController.kt

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -47,9 +47,9 @@ class ProjectContextController(private val project: Project, private val cs: Cor
4747

4848
fun getProjectContextIndexComplete() = projectContextProvider.isIndexComplete.get()
4949

50-
fun query(prompt: String): List<RelevantDocument> {
50+
suspend fun query(prompt: String, timeout: Long?): List<RelevantDocument> {
5151
try {
52-
return projectContextProvider.query(prompt)
52+
return projectContextProvider.query(prompt, timeout)
5353
} catch (e: Exception) {
5454
logger.warn { "error while querying for project context $e.message" }
5555
return emptyList()

plugins/amazonq/shared/jetbrains-community/src/software/aws/toolkits/jetbrains/services/amazonq/project/ProjectContextProvider.kt

Lines changed: 6 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -22,10 +22,10 @@ import kotlinx.coroutines.launch
2222
import kotlinx.coroutines.runBlocking
2323
import kotlinx.coroutines.withTimeout
2424
import software.aws.toolkits.core.utils.debug
25-
import software.aws.toolkits.core.utils.error
2625
import software.aws.toolkits.core.utils.getLogger
2726
import software.aws.toolkits.core.utils.info
2827
import software.aws.toolkits.core.utils.warn
28+
import software.aws.toolkits.jetbrains.services.amazonq.CHAT_EXPLICIT_PROJECT_CONTEXT_TIMEOUT
2929
import software.aws.toolkits.jetbrains.services.amazonq.FeatureDevSessionContext
3030
import software.aws.toolkits.jetbrains.services.amazonq.SUPPLEMENTAL_CONTEXT_TIMEOUT
3131
import software.aws.toolkits.jetbrains.services.cwc.controller.chat.telemetry.getStartUrl
@@ -161,17 +161,14 @@ class ProjectContextProvider(val project: Project, private val encoderServer: En
161161
}
162162

163163
// TODO: rename queryChat
164-
fun query(prompt: String): List<RelevantDocument> {
165-
val encrypted = encryptRequest(QueryChatRequest(prompt))
166-
val response = sendMsgToLsp(LspMessage.QueryChat, encrypted)
164+
suspend fun query(prompt: String, timeout: Long?): List<RelevantDocument> = withTimeout(timeout ?: CHAT_EXPLICIT_PROJECT_CONTEXT_TIMEOUT) {
165+
cs.async {
166+
val encrypted = encryptRequest(QueryChatRequest(prompt))
167+
val response = sendMsgToLsp(LspMessage.QueryChat, encrypted)
167168

168-
return try {
169169
val parsedResponse = mapper.readValue<List<Chunk>>(response.responseBody)
170170
queryResultToRelevantDocuments(parsedResponse)
171-
} catch (e: Exception) {
172-
logger.error { "error parsing query response ${e.message}" }
173-
throw e
174-
}
171+
}.await()
175172
}
176173

177174
suspend fun queryInline(query: String, filePath: String): List<InlineBm25Chunk> = withTimeout(SUPPLEMENTAL_CONTEXT_TIMEOUT) {

0 commit comments

Comments
 (0)