diff --git a/plugins/amazonq/chat/jetbrains-community/src/software/aws/toolkits/jetbrains/services/cwc/controller/ChatController.kt b/plugins/amazonq/chat/jetbrains-community/src/software/aws/toolkits/jetbrains/services/cwc/controller/ChatController.kt index b2cb29d8de5..af3a400a455 100644 --- a/plugins/amazonq/chat/jetbrains-community/src/software/aws/toolkits/jetbrains/services/cwc/controller/ChatController.kt +++ b/plugins/amazonq/chat/jetbrains-community/src/software/aws/toolkits/jetbrains/services/cwc/controller/ChatController.kt @@ -36,6 +36,7 @@ import software.aws.toolkits.core.utils.getLogger import software.aws.toolkits.core.utils.info import software.aws.toolkits.core.utils.warn import software.aws.toolkits.jetbrains.core.coroutines.EDT +import software.aws.toolkits.jetbrains.services.amazonq.CHAT_IMPLICIT_PROJECT_CONTEXT_TIMEOUT import software.aws.toolkits.jetbrains.services.amazonq.CodeWhispererFeatureConfigService import software.aws.toolkits.jetbrains.services.amazonq.apps.AmazonQAppInitContext import software.aws.toolkits.jetbrains.services.amazonq.auth.AuthController @@ -137,7 +138,7 @@ class ChatController private constructor( shouldUseWorkspaceContext = true prompt = prompt.replace("@workspace", "") val projectContextController = ProjectContextController.getInstance(context.project) - queryResult = projectContextController.query(prompt) + queryResult = projectContextController.query(prompt, timeout = null) if (!projectContextController.getProjectContextIndexComplete()) shouldAddIndexInProgressMessage = true logger.info { "project context relevant document count: ${queryResult.size}" } } else { @@ -145,7 +146,7 @@ class ChatController private constructor( } } else if (CodeWhispererSettings.getInstance().isProjectContextEnabled() && isDataCollectionGroup) { val projectContextController = ProjectContextController.getInstance(context.project) - queryResult = projectContextController.query(prompt) + queryResult = projectContextController.query(prompt, timeout = CHAT_IMPLICIT_PROJECT_CONTEXT_TIMEOUT) } handleChat( diff --git a/plugins/amazonq/chat/jetbrains-community/tst/software/aws/toolkits/jetbrains/services/amazonq/workspace/context/ProjectContextProviderTest.kt b/plugins/amazonq/chat/jetbrains-community/tst/software/aws/toolkits/jetbrains/services/amazonq/workspace/context/ProjectContextProviderTest.kt index abde58fe26e..8853bb4a68c 100644 --- a/plugins/amazonq/chat/jetbrains-community/tst/software/aws/toolkits/jetbrains/services/amazonq/workspace/context/ProjectContextProviderTest.kt +++ b/plugins/amazonq/chat/jetbrains-community/tst/software/aws/toolkits/jetbrains/services/amazonq/workspace/context/ProjectContextProviderTest.kt @@ -214,8 +214,10 @@ class ProjectContextProviderTest { } @Test - fun `query should send correct encrypted request to lsp`() { - sut.query("foo") + fun `query should send correct encrypted request to lsp`() = runTest { + sut = ProjectContextProvider(project, encoderServer, this) + val r = sut.query("foo", null) + advanceUntilIdle() val request = QueryChatRequest("foo") val requestJson = mapper.writeValueAsString(request) @@ -269,13 +271,15 @@ class ProjectContextProviderTest { ) assertThrows { - sut.query("foo") + sut.query("foo", null) } } @Test fun `query chat should return deserialized relevantDocument`() = runTest { - val r = sut.query("foo") + sut = ProjectContextProvider(project, encoderServer, this) + val r = sut.query("foo", null) + advanceUntilIdle() assertThat(r).hasSize(2) assertThat(r[0]).isEqualTo( RelevantDocument( @@ -377,6 +381,29 @@ class ProjectContextProviderTest { } } + @Test + fun `queryChat should throw if time elapsed is greather than 500ms`() = runTest { + assertThrows { + sut = ProjectContextProvider(project, encoderServer, this) + stubFor( + any(urlPathEqualTo("/query")).willReturn( + aResponse() + .withStatus(200) + .withResponseBody( + Body(validQueryChatResponse) + ) + .withFixedDelay(501) + ) + ) + + withContext(getCoroutineBgContext()) { + sut.query("foo", timeout = 500L) + } + + advanceUntilIdle() + } + } + @Test fun `test index payload is encrypted`() = runTest { whenever(encoderServer.port).thenReturn(3000) @@ -390,12 +417,9 @@ class ProjectContextProviderTest { @Test fun `test query payload is encrypted`() = runTest { - whenever(encoderServer.port).thenReturn(3000) - try { - sut.query("what does this project do") - } catch (e: ConnectException) { - // no-op - } + sut = ProjectContextProvider(project, encoderServer, this) + sut.query("what does this project do", null) + advanceUntilIdle() verify(encoderServer, times(1)).encrypt(any()) } diff --git a/plugins/amazonq/shared/jetbrains-community/src/software/aws/toolkits/jetbrains/services/amazonq/Constants.kt b/plugins/amazonq/shared/jetbrains-community/src/software/aws/toolkits/jetbrains/services/amazonq/Constants.kt index 45341767295..cda847954e6 100644 --- a/plugins/amazonq/shared/jetbrains-community/src/software/aws/toolkits/jetbrains/services/amazonq/Constants.kt +++ b/plugins/amazonq/shared/jetbrains-community/src/software/aws/toolkits/jetbrains/services/amazonq/Constants.kt @@ -3,6 +3,8 @@ package software.aws.toolkits.jetbrains.services.amazonq +import kotlin.time.Duration.Companion.minutes + const val APPLICATION_ZIP = "application/zip" const val SERVER_SIDE_ENCRYPTION = "x-amz-server-side-encryption" const val AWS_KMS = "aws:kms" @@ -39,3 +41,7 @@ const val CODE_TRANSFORM_PREREQUISITES = const val FEATURE_EVALUATION_PRODUCT_NAME = "CodeWhisperer" const val SUPPLEMENTAL_CONTEXT_TIMEOUT = 100L + +val CHAT_EXPLICIT_PROJECT_CONTEXT_TIMEOUT = 5.minutes.inWholeMilliseconds + +const val CHAT_IMPLICIT_PROJECT_CONTEXT_TIMEOUT = 500L // 500ms diff --git a/plugins/amazonq/shared/jetbrains-community/src/software/aws/toolkits/jetbrains/services/amazonq/project/ProjectContextController.kt b/plugins/amazonq/shared/jetbrains-community/src/software/aws/toolkits/jetbrains/services/amazonq/project/ProjectContextController.kt index add579965ac..7e98c61b74d 100644 --- a/plugins/amazonq/shared/jetbrains-community/src/software/aws/toolkits/jetbrains/services/amazonq/project/ProjectContextController.kt +++ b/plugins/amazonq/shared/jetbrains-community/src/software/aws/toolkits/jetbrains/services/amazonq/project/ProjectContextController.kt @@ -47,9 +47,9 @@ class ProjectContextController(private val project: Project, private val cs: Cor fun getProjectContextIndexComplete() = projectContextProvider.isIndexComplete.get() - fun query(prompt: String): List { + suspend fun query(prompt: String, timeout: Long?): List { try { - return projectContextProvider.query(prompt) + return projectContextProvider.query(prompt, timeout) } catch (e: Exception) { logger.warn { "error while querying for project context $e.message" } return emptyList() diff --git a/plugins/amazonq/shared/jetbrains-community/src/software/aws/toolkits/jetbrains/services/amazonq/project/ProjectContextProvider.kt b/plugins/amazonq/shared/jetbrains-community/src/software/aws/toolkits/jetbrains/services/amazonq/project/ProjectContextProvider.kt index 9e15336279b..1f657837c4e 100644 --- a/plugins/amazonq/shared/jetbrains-community/src/software/aws/toolkits/jetbrains/services/amazonq/project/ProjectContextProvider.kt +++ b/plugins/amazonq/shared/jetbrains-community/src/software/aws/toolkits/jetbrains/services/amazonq/project/ProjectContextProvider.kt @@ -22,10 +22,10 @@ import kotlinx.coroutines.launch import kotlinx.coroutines.runBlocking import kotlinx.coroutines.withTimeout import software.aws.toolkits.core.utils.debug -import software.aws.toolkits.core.utils.error import software.aws.toolkits.core.utils.getLogger import software.aws.toolkits.core.utils.info import software.aws.toolkits.core.utils.warn +import software.aws.toolkits.jetbrains.services.amazonq.CHAT_EXPLICIT_PROJECT_CONTEXT_TIMEOUT import software.aws.toolkits.jetbrains.services.amazonq.FeatureDevSessionContext import software.aws.toolkits.jetbrains.services.amazonq.SUPPLEMENTAL_CONTEXT_TIMEOUT import software.aws.toolkits.jetbrains.services.cwc.controller.chat.telemetry.getStartUrl @@ -161,17 +161,14 @@ class ProjectContextProvider(val project: Project, private val encoderServer: En } // TODO: rename queryChat - fun query(prompt: String): List { - val encrypted = encryptRequest(QueryChatRequest(prompt)) - val response = sendMsgToLsp(LspMessage.QueryChat, encrypted) + suspend fun query(prompt: String, timeout: Long?): List = withTimeout(timeout ?: CHAT_EXPLICIT_PROJECT_CONTEXT_TIMEOUT) { + cs.async { + val encrypted = encryptRequest(QueryChatRequest(prompt)) + val response = sendMsgToLsp(LspMessage.QueryChat, encrypted) - return try { val parsedResponse = mapper.readValue>(response.responseBody) queryResultToRelevantDocuments(parsedResponse) - } catch (e: Exception) { - logger.error { "error parsing query response ${e.message}" } - throw e - } + }.await() } suspend fun queryInline(query: String, filePath: String): List = withTimeout(SUPPLEMENTAL_CONTEXT_TIMEOUT) {