diff --git a/plugins/amazonq/chat/jetbrains-community/tst/software/aws/toolkits/jetbrains/services/amazonq/workspace/context/ProjectContextProviderTest.kt b/plugins/amazonq/chat/jetbrains-community/tst/software/aws/toolkits/jetbrains/services/amazonq/workspace/context/ProjectContextProviderTest.kt index cb1dd20ca3b..68b39fc0091 100644 --- a/plugins/amazonq/chat/jetbrains-community/tst/software/aws/toolkits/jetbrains/services/amazonq/workspace/context/ProjectContextProviderTest.kt +++ b/plugins/amazonq/chat/jetbrains-community/tst/software/aws/toolkits/jetbrains/services/amazonq/workspace/context/ProjectContextProviderTest.kt @@ -3,18 +3,37 @@ package software.aws.toolkits.jetbrains.services.amazonq.workspace.context +import com.fasterxml.jackson.module.kotlin.jacksonObjectMapper +import com.github.tomakehurst.wiremock.client.WireMock.aResponse +import com.github.tomakehurst.wiremock.client.WireMock.any +import com.github.tomakehurst.wiremock.client.WireMock.equalTo +import com.github.tomakehurst.wiremock.client.WireMock.postRequestedFor +import com.github.tomakehurst.wiremock.client.WireMock.stubFor +import com.github.tomakehurst.wiremock.client.WireMock.urlPathEqualTo +import com.github.tomakehurst.wiremock.core.WireMockConfiguration.wireMockConfig +import com.github.tomakehurst.wiremock.http.Body +import com.github.tomakehurst.wiremock.junit.WireMockRule import com.intellij.openapi.project.Project import kotlinx.coroutines.test.TestScope import kotlinx.coroutines.test.runTest +import org.assertj.core.api.Assertions.assertThat import org.junit.Before import org.junit.Rule +import org.junit.jupiter.api.assertThrows import org.mockito.kotlin.any -import org.mockito.kotlin.mock +import org.mockito.kotlin.doReturn +import org.mockito.kotlin.spy +import org.mockito.kotlin.stub import org.mockito.kotlin.times import org.mockito.kotlin.verify import org.mockito.kotlin.whenever import software.aws.toolkits.jetbrains.services.amazonq.project.EncoderServer +import software.aws.toolkits.jetbrains.services.amazonq.project.IndexRequest +import software.aws.toolkits.jetbrains.services.amazonq.project.LspMessage import software.aws.toolkits.jetbrains.services.amazonq.project.ProjectContextProvider +import software.aws.toolkits.jetbrains.services.amazonq.project.QueryChatRequest +import software.aws.toolkits.jetbrains.services.amazonq.project.RelevantDocument +import software.aws.toolkits.jetbrains.services.amazonq.project.UpdateIndexRequest import software.aws.toolkits.jetbrains.utils.rules.CodeInsightTestFixtureRule import software.aws.toolkits.jetbrains.utils.rules.JavaCodeInsightTestFixtureRule import java.net.ConnectException @@ -25,15 +44,166 @@ class ProjectContextProviderTest { @JvmField val projectRule: CodeInsightTestFixtureRule = JavaCodeInsightTestFixtureRule() + @Rule + @JvmField + val wireMock: WireMockRule = createMockServer() + private val project: Project get() = projectRule.project - private val encoderServer: EncoderServer = mock() + private lateinit var encoderServer: EncoderServer private lateinit var sut: ProjectContextProvider + private val mapper = jacksonObjectMapper() + @Before fun setup() { + encoderServer = spy(EncoderServer(project)) + encoderServer.stub { on { port } doReturn wireMock.port() } + sut = ProjectContextProvider(project, encoderServer, TestScope()) + + // initialization + stubFor(any(urlPathEqualTo("/initialize")).willReturn(aResponse().withStatus(200).withResponseBody(Body("initialize response")))) + + // build index + stubFor(any(urlPathEqualTo("/indexFiles")).willReturn(aResponse().withStatus(200).withResponseBody(Body("initialize response")))) + + // update index + stubFor(any(urlPathEqualTo("/updateIndex")).willReturn(aResponse().withStatus(200).withResponseBody(Body("initialize response")))) + + // query + stubFor( + any(urlPathEqualTo("/query")).willReturn( + aResponse() + .withStatus(200) + .withResponseBody(Body(validQueryChatResponse)) + ) + ) + + stubFor( + any(urlPathEqualTo("/getUsage")) + .willReturn( + aResponse() + .withStatus(200) + .withResponseBody(Body(validGetUsageResponse)) + ) + ) + } + + @Test + fun `Lsp endpoint are correct`() { + assertThat(LspMessage.Initialize.endpoint).isEqualTo("initialize") + assertThat(LspMessage.Index.endpoint).isEqualTo("indexFiles") + assertThat(LspMessage.QueryChat.endpoint).isEqualTo("query") + assertThat(LspMessage.GetUsageMetrics.endpoint).isEqualTo("getUsage") + } + + @Test + fun `index should send files within the project to lsp`() { + projectRule.fixture.addFileToProject("Foo.java", "foo") + projectRule.fixture.addFileToProject("Bar.java", "bar") + projectRule.fixture.addFileToProject("Baz.java", "baz") + + sut.index() + + val request = IndexRequest(listOf("/src/Foo.java", "/src/Bar.java", "/src/Baz.java"), "/src", false) + assertThat(request.filePaths).hasSize(3) + assertThat(request.filePaths).satisfies({ + it.contains("/src/Foo.java") && + it.contains("/src/Baz.java") && + it.contains("/src/Bar.java") + }) + + wireMock.verify( + 1, + postRequestedFor(urlPathEqualTo("/indexFiles")) + .withHeader("Content-Type", equalTo("text/plain")) + // comment it out because order matters and will cause json string different +// .withRequestBody(equalTo(encryptedRequest)) + ) + } + + @Test + fun `updateIndex should send correct encrypted request to lsp`() { + sut.updateIndex("foo.java") + val request = UpdateIndexRequest("foo.java") + val requestJson = mapper.writeValueAsString(request) + + assertThat(mapper.readTree(requestJson)).isEqualTo(mapper.readTree("""{ "filePath": "foo.java" }""")) + + val encryptedRequest = encoderServer.encrypt(requestJson) + + wireMock.verify( + 1, + postRequestedFor(urlPathEqualTo("/updateIndex")) + .withHeader("Content-Type", equalTo("text/plain")) + .withRequestBody(equalTo(encryptedRequest)) + ) + } + + @Test + fun `query should send correct encrypted request to lsp`() { + sut.query("foo") + + val request = QueryChatRequest("foo") + val requestJson = mapper.writeValueAsString(request) + + assertThat(mapper.readTree(requestJson)).isEqualTo(mapper.readTree("""{ "query": "foo" }""")) + + val encryptedRequest = encoderServer.encrypt(requestJson) + + wireMock.verify( + 1, + postRequestedFor(urlPathEqualTo("/query")) + .withHeader("Content-Type", equalTo("text/plain")) + .withRequestBody(equalTo(encryptedRequest)) + ) + } + + @Test + fun `query chat should return empty if result set non deserializable`() = runTest { + stubFor( + any(urlPathEqualTo("/query")).willReturn( + aResponse().withStatus(200).withResponseBody( + Body( + """ + [ + "foo", "bar" + ] + """.trimIndent() + ) + ) + ) + ) + + assertThrows { + sut.query("foo") + } + } + + @Test + fun `query chat should return deserialized relevantDocument`() = runTest { + val r = sut.query("foo") + assertThat(r).hasSize(2) + assertThat(r[0]).isEqualTo( + RelevantDocument( + "relativeFilePath1", + "context1" + ) + ) + assertThat(r[1]).isEqualTo( + RelevantDocument( + "relativeFilePath2", + "context2" + ) + ) + } + + @Test + fun `get usage should return memory, cpu usage`() = runTest { + val r = sut.getUsage() + assertThat(r).isEqualTo(ProjectContextProvider.Usage(123, 456)) } @Test @@ -57,4 +227,52 @@ class ProjectContextProviderTest { } verify(encoderServer, times(1)).encrypt(any()) } + + private fun createMockServer() = WireMockRule(wireMockConfig().dynamicPort()) } + +// language=JSON +val validQueryChatResponse = """ + [ + { + "filePath": "file1", + "content": "content1", + "id": "id1", + "index": "index1", + "vec": [ + "vec_1-1", + "vec_1-2", + "vec_1-3" + ], + "context": "context1", + "prev": "prev1", + "next": "next1", + "relativePath": "relativeFilePath1", + "programmingLanguage": "language1" + }, + { + "filePath": "file2", + "content": "content2", + "id": "id2", + "index": "index2", + "vec": [ + "vec_2-1", + "vec_2-2", + "vec_2-3" + ], + "context": "context2", + "prev": "prev2", + "next": "next2", + "relativePath": "relativeFilePath2", + "programmingLanguage": "language2" + } + ] +""".trimIndent() + +// language=JSON +val validGetUsageResponse = """ + { + "memoryUsage":123, + "cpuUsage":456 + } +""".trimIndent() diff --git a/plugins/amazonq/shared/jetbrains-community/src/software/aws/toolkits/jetbrains/services/amazonq/project/LspMessage.kt b/plugins/amazonq/shared/jetbrains-community/src/software/aws/toolkits/jetbrains/services/amazonq/project/LspMessage.kt new file mode 100644 index 00000000000..4c44e48da0e --- /dev/null +++ b/plugins/amazonq/shared/jetbrains-community/src/software/aws/toolkits/jetbrains/services/amazonq/project/LspMessage.kt @@ -0,0 +1,49 @@ +// Copyright 2024 Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +package software.aws.toolkits.jetbrains.services.amazonq.project + +sealed interface LspMessage { + val endpoint: String + + data object Initialize : LspMessage { + override val endpoint: String = "initialize" + } + + data object Index : LspMessage { + override val endpoint: String = "indexFiles" + } + + data object UpdateIndex : LspMessage { + override val endpoint: String = "updateIndex" + } + + data object QueryChat : LspMessage { + override val endpoint: String = "query" + } + + data object GetUsageMetrics : LspMessage { + override val endpoint: String = "getUsage" + } +} + +interface LspRequest + +data class IndexRequest( + val filePaths: List, + val projectRoot: String, + val refresh: Boolean, +) : LspRequest + +data class UpdateIndexRequest( + val filePath: String, +) : LspRequest + +data class QueryChatRequest( + val query: String, +) : LspRequest + +data class LspResponse( + val responseCode: Int, + val responseBody: String, +) diff --git a/plugins/amazonq/shared/jetbrains-community/src/software/aws/toolkits/jetbrains/services/amazonq/project/ProjectContextProvider.kt b/plugins/amazonq/shared/jetbrains-community/src/software/aws/toolkits/jetbrains/services/amazonq/project/ProjectContextProvider.kt index ab46c1154e8..ccfa4b70966 100644 --- a/plugins/amazonq/shared/jetbrains-community/src/software/aws/toolkits/jetbrains/services/amazonq/project/ProjectContextProvider.kt +++ b/plugins/amazonq/shared/jetbrains-community/src/software/aws/toolkits/jetbrains/services/amazonq/project/ProjectContextProvider.kt @@ -20,6 +20,7 @@ import kotlinx.coroutines.launch import kotlinx.coroutines.runBlocking import kotlinx.coroutines.yield import software.aws.toolkits.core.utils.debug +import software.aws.toolkits.core.utils.error import software.aws.toolkits.core.utils.getLogger import software.aws.toolkits.core.utils.info import software.aws.toolkits.core.utils.warn @@ -53,25 +54,12 @@ class ProjectContextProvider(val project: Project, private val encoderServer: En } } } - data class IndexRequestPayload( - val filePaths: List, - val projectRoot: String, - val refresh: Boolean, - ) data class FileCollectionResult( val files: List, val fileSize: Int, ) - data class QueryRequestPayload( - val query: String, - ) - - data class UpdateIndexRequestPayload( - val filePath: String, - ) - data class Usage( @JsonIgnoreProperties(ignoreUnknown = true) @JsonProperty("memoryUsage") @@ -130,37 +118,27 @@ class ProjectContextProvider(val project: Project, private val encoderServer: En } private fun initEncryption(): Boolean { - logger.info { "project context: init key for ${project.name} on port ${encoderServer.port}" } - val url = URL("http://localhost:${encoderServer.port}/initialize") - val payload = encoderServer.getEncryptionRequest() - val connection = url.openConnection() as HttpURLConnection - setConnectionProperties(connection) - setConnectionRequest(connection, payload) - logger.info { "project context initialize response code: ${connection.responseCode} for ${project.name}" } - return connection.responseCode == 200 + val request = encoderServer.getEncryptionRequest() + val response = sendMsgToLsp(LspMessage.Initialize, request) + return response.responseCode == 200 } fun index(): Boolean { - logger.info { "project context: indexing ${project.name} on port ${encoderServer.port}" } + val projectRoot = project.basePath ?: return false + val indexStartTime = System.currentTimeMillis() - val url = URL("http://localhost:${encoderServer.port}/indexFiles") val filesResult = collectFiles() var duration = (System.currentTimeMillis() - indexStartTime).toDouble() - logger.debug { "project context file collection time: ${duration}ms" } - logger.debug { "list of files collected: ${filesResult.files.joinToString("\n")}" } - val projectRoot = project.basePath ?: return false - val payload = IndexRequestPayload(filesResult.files, projectRoot, false) - val payloadJson = mapper.writeValueAsString(payload) - val encrypted = encoderServer.encrypt(payloadJson) - - val connection = url.openConnection() as HttpURLConnection - setConnectionProperties(connection) - setConnectionRequest(connection, encrypted) - logger.info { "project context index response code: ${connection.responseCode} for ${project.name}" } + logger.debug { "time elapsed to collect project context files: ${duration}ms, collected ${filesResult.files.size} files" } + + val encrypted = encryptRequest(IndexRequest(filesResult.files, projectRoot, false)) + val response = sendMsgToLsp(LspMessage.Index, encrypted) + duration = (System.currentTimeMillis() - indexStartTime).toDouble() - val startUrl = getStartUrl(project) logger.debug { "project context index time: ${duration}ms" } - if (connection.responseCode == 200) { + + val startUrl = getStartUrl(project) + if (response.responseCode == 200) { val usage = getUsage() recordIndexWorkspace(duration, filesResult.files.size, filesResult.fileSize, true, usage?.memoryUsage, usage?.cpuUsage, startUrl) logger.debug { "project context index finished for ${project.name}" } @@ -172,34 +150,34 @@ class ProjectContextProvider(val project: Project, private val encoderServer: En } fun query(prompt: String): List { - logger.info { "project context: querying ${project.name} on port ${encoderServer.port}" } - val url = URL("http://localhost:${encoderServer.port}/query") - val payload = QueryRequestPayload(prompt) - val payloadJson = mapper.writeValueAsString(payload) - val encrypted = encoderServer.encrypt(payloadJson) - - val connection = url.openConnection() as HttpURLConnection - setConnectionProperties(connection) - setConnectionTimeout(connection) - setConnectionRequest(connection, encrypted) - - val responseCode = connection.responseCode - logger.info { "project context query response code: $responseCode for $prompt" } - val responseBody = if (responseCode == 200) { - connection.inputStream.bufferedReader().use { reader -> reader.readText() } - } else { - "" + val encrypted = encryptRequest(QueryChatRequest(prompt)) + val response = sendMsgToLsp(LspMessage.QueryChat, encrypted) + + return try { + val parsedResponse = mapper.readValue>(response.responseBody) + queryResultToRelevantDocuments(parsedResponse) + } catch (e: Exception) { + logger.error { "error parsing query response ${e.message}" } + throw e } - connection.disconnect() - try { - val parsedResponse = mapper.readValue>(responseBody) - return queryResultToRelevantDocuments(parsedResponse) + } + + fun getUsage(): Usage? { + val response = sendMsgToLsp(LspMessage.GetUsageMetrics, request = null) + return try { + val parsedResponse = mapper.readValue(response.responseBody) + parsedResponse } catch (e: Exception) { logger.warn { "error parsing query response ${e.message}" } - return emptyList() + null } } + fun updateIndex(filePath: String) { + val encrypted = encryptRequest(UpdateIndexRequest(filePath)) + sendMsgToLsp(LspMessage.UpdateIndex, encrypted) + } + private fun recordIndexWorkspace( duration: Double, fileCount: Int = 0, @@ -221,46 +199,6 @@ class ProjectContextProvider(val project: Project, private val encoderServer: En ) } - private fun getUsage(): Usage? { - logger.info { "project context: getting usage for ${project.name} on port ${encoderServer.port}" } - val url = URL("http://localhost:${encoderServer.port}/getUsage") - val connection = url.openConnection() as HttpURLConnection - setConnectionProperties(connection) - val responseCode = connection.responseCode - - logger.info { "project context getUsage response code: $responseCode for ${project.name} " } - val responseBody = if (responseCode == 200) { - connection.inputStream.bufferedReader().use { reader -> reader.readText() } - } else { - "" - } - connection.disconnect() - try { - val parsedResponse = mapper.readValue(responseBody) - return parsedResponse - } catch (e: Exception) { - logger.warn { "error parsing query response ${e.message}" } - return null - } - } - - fun updateIndex(filePath: String) { - if (!isIndexComplete.get()) return - logger.info { "project context: updating index for $filePath on port ${encoderServer.port}" } - val url = URL("http://localhost:${encoderServer.port}/updateIndex") - val payload = UpdateIndexRequestPayload(filePath) - val payloadJson = mapper.writeValueAsString(payload) - val encrypted = encoderServer.encrypt(payloadJson) - with(url.openConnection() as HttpURLConnection) { - setConnectionProperties(this) - setConnectionTimeout(this) - setConnectionRequest(this, encrypted) - val responseCode = responseCode - logger.debug { "project context update index response code: $responseCode for $filePath" } - return - } - } - private fun setConnectionTimeout(connection: HttpURLConnection) { connection.connectTimeout = 5000 // 5 seconds connection.readTimeout = 5000 // 5 second @@ -346,6 +284,34 @@ class ProjectContextProvider(val project: Project, private val encoderServer: En return documents } + private fun encryptRequest(r: LspRequest): String { + val payloadJson = mapper.writeValueAsString(r) + return encoderServer.encrypt(payloadJson) + } + + private fun sendMsgToLsp(msgType: LspMessage, request: String?): LspResponse { + logger.info { "sending message: ${msgType.endpoint} to lsp on port ${encoderServer.port}" } + val url = URL("http://localhost:${encoderServer.port}/${msgType.endpoint}") + + return with(url.openConnection() as HttpURLConnection) { + setConnectionProperties(this) + setConnectionTimeout(this) + request?.let { r -> + setConnectionRequest(this, r) + } + val responseCode = this.responseCode + logger.info { "receiving response for $msgType with responseCode $responseCode" } + + val responseBody = if (responseCode == 200) { + this.inputStream.bufferedReader().use { reader -> reader.readText() } + } else { + "" + } + + LspResponse(responseCode, responseBody) + } + } + override fun dispose() { retryCount.set(0) }