Skip to content

Commit 2e7ee12

Browse files
committed
make queryChat suspend and called within withTimeout
1 parent a81caf6 commit 2e7ee12

File tree

3 files changed

+40
-19
lines changed

3 files changed

+40
-19
lines changed

plugins/amazonq/chat/jetbrains-community/tst/software/aws/toolkits/jetbrains/services/amazonq/workspace/context/ProjectContextProviderTest.kt

Lines changed: 33 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -214,8 +214,10 @@ class ProjectContextProviderTest {
214214
}
215215

216216
@Test
217-
fun `query should send correct encrypted request to lsp`() {
218-
sut.query("foo")
217+
fun `query should send correct encrypted request to lsp`() = runTest {
218+
sut = ProjectContextProvider(project, encoderServer, this)
219+
val r = sut.query("foo")
220+
advanceUntilIdle()
219221

220222
val request = QueryChatRequest("foo")
221223
val requestJson = mapper.writeValueAsString(request)
@@ -275,7 +277,9 @@ class ProjectContextProviderTest {
275277

276278
@Test
277279
fun `query chat should return deserialized relevantDocument`() = runTest {
280+
sut = ProjectContextProvider(project, encoderServer, this)
278281
val r = sut.query("foo")
282+
advanceUntilIdle()
279283
assertThat(r).hasSize(2)
280284
assertThat(r[0]).isEqualTo(
281285
RelevantDocument(
@@ -364,7 +368,7 @@ class ProjectContextProviderTest {
364368
.withResponseBody(
365369
Body(validQueryInlineResponse)
366370
)
367-
.withFixedDelay(51) // 10 sec
371+
.withFixedDelay(51)
368372
)
369373
)
370374

@@ -377,6 +381,29 @@ class ProjectContextProviderTest {
377381
}
378382
}
379383

384+
@Test
385+
fun `queryChat should throw if time elapsed is greather than 500ms`() = runTest {
386+
assertThrows<TimeoutCancellationException> {
387+
sut = ProjectContextProvider(project, encoderServer, this)
388+
stubFor(
389+
any(urlPathEqualTo("/query")).willReturn(
390+
aResponse()
391+
.withStatus(200)
392+
.withResponseBody(
393+
Body(validQueryChatResponse)
394+
)
395+
.withFixedDelay(501)
396+
)
397+
)
398+
399+
withContext(getCoroutineBgContext()) {
400+
sut.query("foo")
401+
}
402+
403+
advanceUntilIdle()
404+
}
405+
}
406+
380407
@Test
381408
fun `test index payload is encrypted`() = runTest {
382409
whenever(encoderServer.port).thenReturn(3000)
@@ -390,12 +417,9 @@ class ProjectContextProviderTest {
390417

391418
@Test
392419
fun `test query payload is encrypted`() = runTest {
393-
whenever(encoderServer.port).thenReturn(3000)
394-
try {
395-
sut.query("what does this project do")
396-
} catch (e: ConnectException) {
397-
// no-op
398-
}
420+
sut = ProjectContextProvider(project, encoderServer, this)
421+
sut.query("what does this project do")
422+
advanceUntilIdle()
399423
verify(encoderServer, times(1)).encrypt(any())
400424
}
401425

plugins/amazonq/shared/jetbrains-community/src/software/aws/toolkits/jetbrains/services/amazonq/project/ProjectContextController.kt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ class ProjectContextController(private val project: Project, private val cs: Cor
4747

4848
fun getProjectContextIndexComplete() = projectContextProvider.isIndexComplete.get()
4949

50-
fun query(prompt: String): List<RelevantDocument> {
50+
suspend fun query(prompt: String): List<RelevantDocument> {
5151
try {
5252
return projectContextProvider.query(prompt)
5353
} catch (e: Exception) {

plugins/amazonq/shared/jetbrains-community/src/software/aws/toolkits/jetbrains/services/amazonq/project/ProjectContextProvider.kt

Lines changed: 6 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,6 @@ import kotlinx.coroutines.runBlocking
2323
import kotlinx.coroutines.withTimeout
2424
import kotlinx.coroutines.yield
2525
import software.aws.toolkits.core.utils.debug
26-
import software.aws.toolkits.core.utils.error
2726
import software.aws.toolkits.core.utils.getLogger
2827
import software.aws.toolkits.core.utils.info
2928
import software.aws.toolkits.core.utils.warn
@@ -159,17 +158,15 @@ class ProjectContextProvider(val project: Project, private val encoderServer: En
159158
}
160159

161160
// TODO: rename queryChat
162-
fun query(prompt: String): List<RelevantDocument> {
163-
val encrypted = encryptRequest(QueryChatRequest(prompt))
164-
val response = sendMsgToLsp(LspMessage.QueryChat, encrypted)
161+
// TODO: timeout window not decided
162+
suspend fun query(prompt: String): List<RelevantDocument> = withTimeout(500L) {
163+
cs.async {
164+
val encrypted = encryptRequest(QueryChatRequest(prompt))
165+
val response = sendMsgToLsp(LspMessage.QueryChat, encrypted)
165166

166-
return try {
167167
val parsedResponse = mapper.readValue<List<Chunk>>(response.responseBody)
168168
queryResultToRelevantDocuments(parsedResponse)
169-
} catch (e: Exception) {
170-
logger.error { "error parsing query response ${e.message}" }
171-
throw e
172-
}
169+
}.await()
173170
}
174171

175172
suspend fun queryInline(query: String, filePath: String): List<InlineBm25Chunk> = withTimeout(50L) {

0 commit comments

Comments
 (0)