diff --git a/.changes/next-release/bugfix-e06391bd-f5ae-4c2a-a102-580d30f3be8d.json b/.changes/next-release/bugfix-e06391bd-f5ae-4c2a-a102-580d30f3be8d.json new file mode 100644 index 00000000000..8c4f599db4f --- /dev/null +++ b/.changes/next-release/bugfix-e06391bd-f5ae-4c2a-a102-580d30f3be8d.json @@ -0,0 +1,4 @@ +{ + "type" : "bugfix", + "description" : "Amazon Q: Attempt to reduce thread pool contention locking IDE caused by `@workspace` making a large number of requests" +} diff --git a/detekt-rules/detekt.yml b/detekt-rules/detekt.yml index 09fdafb5a9f..9429c533592 100644 --- a/detekt-rules/detekt.yml +++ b/detekt-rules/detekt.yml @@ -34,10 +34,8 @@ coroutines: active: true GlobalCoroutineUsage: active: true - RedundantSuspendModifier: - active: true - SleepInsteadOfDelay: - active: true + InjectDispatcher: + active: false SuspendFunWithFlowReturnType: active: true diff --git a/plugins/amazonq/chat/jetbrains-community/tst/software/aws/toolkits/jetbrains/services/amazonq/workspace/context/ProjectContextProviderTest.kt b/plugins/amazonq/chat/jetbrains-community/tst/software/aws/toolkits/jetbrains/services/amazonq/workspace/context/ProjectContextProviderTest.kt index 8cdf146948b..399bd114b57 100644 --- a/plugins/amazonq/chat/jetbrains-community/tst/software/aws/toolkits/jetbrains/services/amazonq/workspace/context/ProjectContextProviderTest.kt +++ b/plugins/amazonq/chat/jetbrains-community/tst/software/aws/toolkits/jetbrains/services/amazonq/workspace/context/ProjectContextProviderTest.kt @@ -1,6 +1,6 @@ // Copyright 2024 Amazon.com, Inc. or its affiliates. All Rights Reserved. // SPDX-License-Identifier: Apache-2.0 - +@file:Suppress("BannedImports") package software.aws.toolkits.jetbrains.services.amazonq.workspace.context import com.fasterxml.jackson.module.kotlin.jacksonObjectMapper @@ -19,6 +19,7 @@ import com.intellij.testFramework.DisposableRule import com.intellij.testFramework.replaceService import io.mockk.every import io.mockk.spyk +import kotlinx.coroutines.Dispatchers import kotlinx.coroutines.ExperimentalCoroutinesApi import kotlinx.coroutines.TimeoutCancellationException import kotlinx.coroutines.test.StandardTestDispatcher @@ -136,7 +137,7 @@ class ProjectContextProviderTest { } @Test - fun `index should send files within the project to lsp - vector index enabled`() { + fun `index should send files within the project to lsp - vector index enabled`() = runTest { ApplicationManager.getApplication().replaceService( CodeWhispererSettings::class.java, mock { on { isProjectContextEnabled() } doReturn true }, @@ -171,7 +172,7 @@ class ProjectContextProviderTest { } @Test - fun `index should send files within the project to lsp - vector index disabled`() { + fun `index should send files within the project to lsp - vector index disabled`() = runTest { ApplicationManager.getApplication().replaceService( CodeWhispererSettings::class.java, mock { on { isProjectContextEnabled() } doReturn false }, @@ -225,43 +226,49 @@ class ProjectContextProviderTest { @Test fun `query should send correct encrypted request to lsp`() = runTest { - sut = ProjectContextProvider(project, encoderServer, this) - val r = sut.query("foo", null) - advanceUntilIdle() + // use real time + withContext(Dispatchers.Default.limitedParallelism(1)) { + sut = ProjectContextProvider(project, encoderServer, this) + val r = sut.query("foo", null) + advanceUntilIdle() - val request = QueryChatRequest("foo") - val requestJson = mapper.writeValueAsString(request) + val request = QueryChatRequest("foo") + val requestJson = mapper.writeValueAsString(request) - assertThat(mapper.readTree(requestJson)).isEqualTo(mapper.readTree("""{ "query": "foo" }""")) + assertThat(mapper.readTree(requestJson)).isEqualTo(mapper.readTree("""{ "query": "foo" }""")) - val encryptedRequest = encoderServer.encrypt(requestJson) + val encryptedRequest = encoderServer.encrypt(requestJson) - wireMock.verify( - 1, - postRequestedFor(urlPathEqualTo("/query")) - .withHeader("Content-Type", equalTo("text/plain")) - .withRequestBody(equalTo(encryptedRequest)) - ) + wireMock.verify( + 1, + postRequestedFor(urlPathEqualTo("/query")) + .withHeader("Content-Type", equalTo("text/plain")) + .withRequestBody(equalTo(encryptedRequest)) + ) + } } @Test fun `queryInline should send correct encrypted request to lsp`() = runTest { - sut = ProjectContextProvider(project, encoderServer, this) - sut.queryInline("foo", "Foo.java", InlineContextTarget.CODEMAP) - advanceUntilIdle() + // use real time + withContext(Dispatchers.Default.limitedParallelism(1)) { + sut = ProjectContextProvider(project, encoderServer, this) + sut.queryInline("foo", "Foo.java", InlineContextTarget.CODEMAP) + advanceUntilIdle() - val request = QueryInlineCompletionRequest("foo", "Foo.java", "codemap") - val requestJson = mapper.writeValueAsString(request) + val request = QueryInlineCompletionRequest("foo", "Foo.java", "codemap") + val requestJson = mapper.writeValueAsString(request) - assertThat(mapper.readTree(requestJson)).isEqualTo(mapper.readTree("""{ "query": "foo", "filePath": "Foo.java", "target": "codemap" }""")) + assertThat(mapper.readTree(requestJson)).isEqualTo(mapper.readTree("""{ "query": "foo", "filePath": "Foo.java", "target": "codemap" }""")) - val encryptedRequest = encoderServer.encrypt(requestJson) - wireMock.verify( - 1, - postRequestedFor(urlPathEqualTo("/queryInlineProjectContext")) - .withHeader("Content-Type", equalTo("text/plain")) - .withRequestBody(equalTo(encryptedRequest)) - ) + val encryptedRequest = encoderServer.encrypt(requestJson) + wireMock.verify( + 1, + postRequestedFor(urlPathEqualTo("/queryInlineProjectContext")) + .withHeader("Content-Type", equalTo("text/plain")) + .withRequestBody(equalTo(encryptedRequest)) + ) + } } @Test @@ -287,78 +294,84 @@ class ProjectContextProviderTest { @Test fun `query chat should return deserialized relevantDocument`() = runTest { - sut = ProjectContextProvider(project, encoderServer, this) - val r = sut.query("foo", null) - advanceUntilIdle() - assertThat(r).hasSize(2) - assertThat(r[0]).isEqualTo( - RelevantDocument( - "relativeFilePath1", - "context1" + // use real time + withContext(Dispatchers.Default.limitedParallelism(1)) { + sut = ProjectContextProvider(project, encoderServer, this) + val r = sut.query("foo", null) + advanceUntilIdle() + assertThat(r).hasSize(2) + assertThat(r[0]).isEqualTo( + RelevantDocument( + "relativeFilePath1", + "context1" + ) ) - ) - assertThat(r[1]).isEqualTo( - RelevantDocument( - "relativeFilePath2", - "context2" + assertThat(r[1]).isEqualTo( + RelevantDocument( + "relativeFilePath2", + "context2" + ) ) - ) + } } @Test - fun `query inline should throw if resultset not deserializable`() { - assertThrows { - runTest { - sut = ProjectContextProvider(project, encoderServer, this) - stubFor( - any(urlPathEqualTo("/queryInlineProjectContext")).willReturn( - aResponse().withStatus(200).withResponseBody( - Body( - """ + fun `query inline should throw if resultset not deserializable`() = + runTest { + sut = ProjectContextProvider(project, encoderServer, this) + stubFor( + any(urlPathEqualTo("/queryInlineProjectContext")).willReturn( + aResponse().withStatus(200).withResponseBody( + Body( + """ [ "foo", "bar" ] - """.trimIndent() - ) + """.trimIndent() ) ) ) + ) - assertThrows { + assertThrows { + withContext(getCoroutineBgContext()) { sut.queryInline("foo", "filepath", InlineContextTarget.CODEMAP) - advanceUntilIdle() } + + advanceUntilIdle() } } - } @Test fun `query inline should return deserialized bm25 chunks`() = runTest { - sut = ProjectContextProvider(project, encoderServer, this) - advanceUntilIdle() - val r = sut.queryInline("foo", "filepath", InlineContextTarget.CODEMAP) - assertThat(r).hasSize(3) - assertThat(r[0]).isEqualTo( - InlineBm25Chunk( - "content1", - "file1", - 0.1 + // use real time + withContext(Dispatchers.Default.limitedParallelism(1)) { + sut = ProjectContextProvider(project, encoderServer, this) + advanceUntilIdle() + val r = sut.queryInline("foo", "filepath", InlineContextTarget.CODEMAP) + assertThat(r).hasSize(3) + assertThat(r[0]).isEqualTo( + InlineBm25Chunk( + "content1", + "file1", + 0.1 + ) ) - ) - assertThat(r[1]).isEqualTo( - InlineBm25Chunk( - "content2", - "file2", - 0.2 + assertThat(r[1]).isEqualTo( + InlineBm25Chunk( + "content2", + "file2", + 0.2 + ) ) - ) - assertThat(r[2]).isEqualTo( - InlineBm25Chunk( - "content3", - "file3", - 0.3 + assertThat(r[2]).isEqualTo( + InlineBm25Chunk( + "content3", + "file3", + 0.3 + ) ) - ) + } } @Test @@ -431,10 +444,13 @@ class ProjectContextProviderTest { @Test fun `test query payload is encrypted`() = runTest { - sut = ProjectContextProvider(project, encoderServer, this) - sut.query("what does this project do", null) - advanceUntilIdle() - verify(encoderServer, times(1)).encrypt(any()) + // use real time + withContext(Dispatchers.Default.limitedParallelism(1)) { + sut = ProjectContextProvider(project, encoderServer, this) + sut.query("what does this project do", null) + advanceUntilIdle() + verify(encoderServer, times(1)).encrypt(any()) + } } private fun createMockServer() = WireMockRule(wireMockConfig().dynamicPort()) @@ -442,67 +458,67 @@ class ProjectContextProviderTest { // language=JSON val validQueryInlineResponse = """ - [ - { - "content": "content1", - "filePath": "file1", - "score": 0.1 - }, - { - "content": "content2", - "filePath": "file2", - "score": 0.2 - }, - { - "content": "content3", - "filePath": "file3", - "score": 0.3 - } - ] + [ + { + "content": "content1", + "filePath": "file1", + "score": 0.1 + }, + { + "content": "content2", + "filePath": "file2", + "score": 0.2 + }, + { + "content": "content3", + "filePath": "file3", + "score": 0.3 + } + ] """.trimIndent() // language=JSON val validQueryChatResponse = """ - [ - { - "filePath": "file1", - "content": "content1", - "id": "id1", - "index": "index1", - "vec": [ - "vec_1-1", - "vec_1-2", - "vec_1-3" - ], - "context": "context1", - "prev": "prev1", - "next": "next1", - "relativePath": "relativeFilePath1", - "programmingLanguage": "language1" - }, - { - "filePath": "file2", - "content": "content2", - "id": "id2", - "index": "index2", - "vec": [ - "vec_2-1", - "vec_2-2", - "vec_2-3" - ], - "context": "context2", - "prev": "prev2", - "next": "next2", - "relativePath": "relativeFilePath2", - "programmingLanguage": "language2" - } - ] + [ + { + "filePath": "file1", + "content": "content1", + "id": "id1", + "index": "index1", + "vec": [ + "vec_1-1", + "vec_1-2", + "vec_1-3" + ], + "context": "context1", + "prev": "prev1", + "next": "next1", + "relativePath": "relativeFilePath1", + "programmingLanguage": "language1" + }, + { + "filePath": "file2", + "content": "content2", + "id": "id2", + "index": "index2", + "vec": [ + "vec_2-1", + "vec_2-2", + "vec_2-3" + ], + "context": "context2", + "prev": "prev2", + "next": "next2", + "relativePath": "relativeFilePath2", + "programmingLanguage": "language2" + } + ] """.trimIndent() // language=JSON val validGetUsageResponse = """ - { - "memoryUsage":123, - "cpuUsage":456 - } + { + "memoryUsage":123, + "cpuUsage":456 + } """.trimIndent() diff --git a/plugins/amazonq/shared/jetbrains-community/src/software/aws/toolkits/jetbrains/services/amazonq/project/ProjectContextController.kt b/plugins/amazonq/shared/jetbrains-community/src/software/aws/toolkits/jetbrains/services/amazonq/project/ProjectContextController.kt index dabe282f2fd..4c08dcd82e5 100644 --- a/plugins/amazonq/shared/jetbrains-community/src/software/aws/toolkits/jetbrains/services/amazonq/project/ProjectContextController.kt +++ b/plugins/amazonq/shared/jetbrains-community/src/software/aws/toolkits/jetbrains/services/amazonq/project/ProjectContextController.kt @@ -20,6 +20,7 @@ import kotlinx.coroutines.TimeoutCancellationException import kotlinx.coroutines.launch import software.aws.toolkits.core.utils.getLogger import software.aws.toolkits.core.utils.warn +import software.aws.toolkits.jetbrains.core.coroutines.ioDispatcher import software.aws.toolkits.jetbrains.utils.pluginAwareExecuteOnPooledThread import java.util.concurrent.TimeoutException @@ -28,7 +29,7 @@ class ProjectContextController(private val project: Project, private val cs: Cor // TODO: Ideally we should inject dependencies via constructor for easier testing, refer to how [TelemetryService] inject publisher and batcher private val encoderServer: EncoderServer = EncoderServer(project) private val projectContextProvider: ProjectContextProvider = ProjectContextProvider(project, encoderServer, cs) - val initJob: Job = cs.launch { + val initJob: Job = cs.launch(ioDispatcher(1)) { encoderServer.downloadArtifactsAndStartServer() } diff --git a/plugins/amazonq/shared/jetbrains-community/src/software/aws/toolkits/jetbrains/services/amazonq/project/ProjectContextProvider.kt b/plugins/amazonq/shared/jetbrains-community/src/software/aws/toolkits/jetbrains/services/amazonq/project/ProjectContextProvider.kt index c944ff1ac52..cb9d8b9362a 100644 --- a/plugins/amazonq/shared/jetbrains-community/src/software/aws/toolkits/jetbrains/services/amazonq/project/ProjectContextProvider.kt +++ b/plugins/amazonq/shared/jetbrains-community/src/software/aws/toolkits/jetbrains/services/amazonq/project/ProjectContextProvider.kt @@ -16,15 +16,18 @@ import com.intellij.openapi.vfs.VfsUtilCore import com.intellij.openapi.vfs.VirtualFile import com.intellij.openapi.vfs.VirtualFileVisitor import com.intellij.openapi.vfs.isFile +import com.intellij.util.concurrency.annotations.RequiresBackgroundThread import kotlinx.coroutines.CoroutineScope -import kotlinx.coroutines.async import kotlinx.coroutines.delay import kotlinx.coroutines.launch +import kotlinx.coroutines.runBlocking +import kotlinx.coroutines.withContext import kotlinx.coroutines.withTimeout import software.aws.toolkits.core.utils.debug import software.aws.toolkits.core.utils.getLogger import software.aws.toolkits.core.utils.info import software.aws.toolkits.core.utils.warn +import software.aws.toolkits.jetbrains.core.coroutines.ioDispatcher import software.aws.toolkits.jetbrains.services.amazonq.CHAT_EXPLICIT_PROJECT_CONTEXT_TIMEOUT import software.aws.toolkits.jetbrains.services.amazonq.SUPPLEMENTAL_CONTEXT_TIMEOUT import software.aws.toolkits.jetbrains.services.cwc.controller.chat.telemetry.getStartUrl @@ -32,7 +35,7 @@ import software.aws.toolkits.jetbrains.settings.CodeWhispererSettings import software.aws.toolkits.telemetry.AmazonqTelemetry import java.io.OutputStreamWriter import java.net.HttpURLConnection -import java.net.URL +import java.net.URI import java.util.concurrent.atomic.AtomicBoolean import java.util.concurrent.atomic.AtomicInteger import kotlin.time.Duration.Companion.minutes @@ -42,6 +45,9 @@ class ProjectContextProvider(val project: Project, private val encoderServer: En val isIndexComplete = AtomicBoolean(false) private val mapper = jacksonObjectMapper() + // max number of requests that can be ongoing to an given server instance, excluding index() + private val ioDispatcher = ioDispatcher(20) + init { cs.launch { if (ApplicationManager.getApplication().isUnitTestMode) { @@ -101,38 +107,36 @@ class ProjectContextProvider(val project: Project, private val encoderServer: En val programmingLanguage: String? = null, ) - private fun initAndIndex() { - cs.launch { - while (retryCount.get() < 5) { - try { - logger.info { "project context: about to init key" } - val isInitSuccess = initEncryption() - if (isInitSuccess) { - logger.info { "project context index starting" } - delay(300) - val isIndexSuccess = index() - if (isIndexSuccess) isIndexComplete.set(true) - return@launch - } - } catch (e: Exception) { - if (e.stackTraceToString().contains("Connection refused")) { - retryCount.incrementAndGet() - delay(10000) - } else { - return@launch - } + private suspend fun initAndIndex() { + while (retryCount.get() < 5) { + try { + logger.info { "project context: about to init key" } + val isInitSuccess = initEncryption() + if (isInitSuccess) { + logger.info { "project context index starting" } + delay(300) + val isIndexSuccess = index() + if (isIndexSuccess) isIndexComplete.set(true) + return + } + } catch (e: Exception) { + if (e.stackTraceToString().contains("Connection refused")) { + retryCount.incrementAndGet() + delay(10000) + } else { + return } } } } - private fun initEncryption(): Boolean { + private suspend fun initEncryption(): Boolean { val request = encoderServer.getEncryptionRequest() val response = sendMsgToLsp(LspMessage.Initialize, request) return response?.responseCode == 200 } - fun index(): Boolean { + suspend fun index(): Boolean { val projectRoot = project.basePath ?: return false val indexStartTime = System.currentTimeMillis() @@ -166,23 +170,20 @@ class ProjectContextProvider(val project: Project, private val encoderServer: En // TODO: rename queryChat suspend fun query(prompt: String, timeout: Long?): List = withTimeout(timeout ?: CHAT_EXPLICIT_PROJECT_CONTEXT_TIMEOUT) { - cs.async { - val encrypted = encryptRequest(QueryChatRequest(prompt)) - val response = sendMsgToLsp(LspMessage.QueryChat, encrypted) ?: return@async emptyList() - val parsedResponse = mapper.readValue>(response.responseBody) - queryResultToRelevantDocuments(parsedResponse) - }.await() + val encrypted = encryptRequest(QueryChatRequest(prompt)) + val response = sendMsgToLsp(LspMessage.QueryChat, encrypted) ?: return@withTimeout emptyList() + val parsedResponse = mapper.readValue>(response.responseBody) + + return@withTimeout queryResultToRelevantDocuments(parsedResponse) } suspend fun queryInline(query: String, filePath: String, target: InlineContextTarget): List = withTimeout(SUPPLEMENTAL_CONTEXT_TIMEOUT) { - cs.async { - val encrypted = encryptRequest(QueryInlineCompletionRequest(query, filePath, target.toString())) - val r = sendMsgToLsp(LspMessage.QueryInlineCompletion, encrypted) ?: return@async emptyList() - return@async mapper.readValue>(r.responseBody) - }.await() + val encrypted = encryptRequest(QueryInlineCompletionRequest(query, filePath, target.toString())) + val r = sendMsgToLsp(LspMessage.QueryInlineCompletion, encrypted) ?: return@withTimeout emptyList() + return@withTimeout mapper.readValue>(r.responseBody) } - fun getUsage(): Usage? { + suspend fun getUsage(): Usage? { val response = sendMsgToLsp(LspMessage.GetUsageMetrics, request = null) ?: return null return try { val parsedResponse = mapper.readValue(response.responseBody) @@ -193,9 +194,10 @@ class ProjectContextProvider(val project: Project, private val encoderServer: En } } + @RequiresBackgroundThread fun updateIndex(filePaths: List, mode: IndexUpdateMode) { val encrypted = encryptRequest(UpdateIndexRequest(filePaths, mode.command)) - sendMsgToLsp(LspMessage.UpdateIndex, encrypted) + runBlocking { sendMsgToLsp(LspMessage.UpdateIndex, encrypted) } } private fun recordIndexWorkspace( @@ -312,31 +314,36 @@ class ProjectContextProvider(val project: Project, private val encoderServer: En return encoderServer.encrypt(payloadJson) } - private fun sendMsgToLsp(msgType: LspMessage, request: String?): LspResponse? { + private suspend fun sendMsgToLsp(msgType: LspMessage, request: String?): LspResponse? { logger.info { "sending message: ${msgType.endpoint} to lsp on port ${encoderServer.port}" } - val url = URL("http://localhost:${encoderServer.port}/${msgType.endpoint}") + val url = URI("http://127.0.0.1:${encoderServer.port}/${msgType.endpoint}").toURL() if (!encoderServer.isNodeProcessRunning()) { logger.warn { "language server for ${project.name} is not running" } return null } // use 1h as timeout for index, 5 seconds for other APIs val timeoutMs = if (msgType is LspMessage.Index) 60.minutes.inWholeMilliseconds.toInt() else 5000 - return with(url.openConnection() as HttpURLConnection) { - setConnectionProperties(this) - setConnectionTimeout(this, timeoutMs) - request?.let { r -> - setConnectionRequest(this, r) - } - val responseCode = this.responseCode - logger.info { "receiving response for $msgType with responseCode $responseCode" } + // dedicate single thread to index operation because it can be long running + val dispatcher = if (msgType is LspMessage.Index) ioDispatcher(1) else ioDispatcher + + return withContext(dispatcher) { + with(url.openConnection() as HttpURLConnection) { + setConnectionProperties(this) + setConnectionTimeout(this, timeoutMs) + request?.let { r -> + setConnectionRequest(this, r) + } + val responseCode = this.responseCode + logger.info { "receiving response for $msgType with responseCode $responseCode" } - val responseBody = if (responseCode == 200) { - this.inputStream.bufferedReader().use { reader -> reader.readText() } - } else { - "" - } + val responseBody = if (responseCode == 200) { + this.inputStream.bufferedReader().use { reader -> reader.readText() } + } else { + "" + } - LspResponse(responseCode, responseBody) + LspResponse(responseCode, responseBody) + } } } diff --git a/plugins/core/jetbrains-community/src/software/aws/toolkits/jetbrains/core/coroutines/contexts.kt b/plugins/core/jetbrains-community/src/software/aws/toolkits/jetbrains/core/coroutines/contexts.kt index ba982a94244..762f487fe32 100644 --- a/plugins/core/jetbrains-community/src/software/aws/toolkits/jetbrains/core/coroutines/contexts.kt +++ b/plugins/core/jetbrains-community/src/software/aws/toolkits/jetbrains/core/coroutines/contexts.kt @@ -32,3 +32,6 @@ fun getCoroutineUiContext(): CoroutineContext = EdtCoroutineDispatcher fun getCoroutineBgContext(): CoroutineContext = AppExecutorUtil.getAppExecutorService().asCoroutineDispatcher() val EDT = Dispatchers.EDT + +// parallelism should be defined https://youtrack.jetbrains.com/issue/KTOR-6462 +fun ioDispatcher(limitedParallelism: Int) = Dispatchers.IO.limitedParallelism(limitedParallelism)