Skip to content

Commit e9bfb9a

Browse files
authored
refactor(amazonq) refactor ProjectContextProvider.kt and add tests (#4988)
* refactor ProjectContextProvider and add tests * add more test * lint * format json blob * revert wrong merge conflict resolve * lint * should call updateIndex regardless index is complete or not which should be handled by LSP
1 parent 764905c commit e9bfb9a

File tree

3 files changed

+333
-100
lines changed

3 files changed

+333
-100
lines changed

plugins/amazonq/chat/jetbrains-community/tst/software/aws/toolkits/jetbrains/services/amazonq/workspace/context/ProjectContextProviderTest.kt

Lines changed: 220 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,18 +3,37 @@
33

44
package software.aws.toolkits.jetbrains.services.amazonq.workspace.context
55

6+
import com.fasterxml.jackson.module.kotlin.jacksonObjectMapper
7+
import com.github.tomakehurst.wiremock.client.WireMock.aResponse
8+
import com.github.tomakehurst.wiremock.client.WireMock.any
9+
import com.github.tomakehurst.wiremock.client.WireMock.equalTo
10+
import com.github.tomakehurst.wiremock.client.WireMock.postRequestedFor
11+
import com.github.tomakehurst.wiremock.client.WireMock.stubFor
12+
import com.github.tomakehurst.wiremock.client.WireMock.urlPathEqualTo
13+
import com.github.tomakehurst.wiremock.core.WireMockConfiguration.wireMockConfig
14+
import com.github.tomakehurst.wiremock.http.Body
15+
import com.github.tomakehurst.wiremock.junit.WireMockRule
616
import com.intellij.openapi.project.Project
717
import kotlinx.coroutines.test.TestScope
818
import kotlinx.coroutines.test.runTest
19+
import org.assertj.core.api.Assertions.assertThat
920
import org.junit.Before
1021
import org.junit.Rule
22+
import org.junit.jupiter.api.assertThrows
1123
import org.mockito.kotlin.any
12-
import org.mockito.kotlin.mock
24+
import org.mockito.kotlin.doReturn
25+
import org.mockito.kotlin.spy
26+
import org.mockito.kotlin.stub
1327
import org.mockito.kotlin.times
1428
import org.mockito.kotlin.verify
1529
import org.mockito.kotlin.whenever
1630
import software.aws.toolkits.jetbrains.services.amazonq.project.EncoderServer
31+
import software.aws.toolkits.jetbrains.services.amazonq.project.IndexRequest
32+
import software.aws.toolkits.jetbrains.services.amazonq.project.LspMessage
1733
import software.aws.toolkits.jetbrains.services.amazonq.project.ProjectContextProvider
34+
import software.aws.toolkits.jetbrains.services.amazonq.project.QueryChatRequest
35+
import software.aws.toolkits.jetbrains.services.amazonq.project.RelevantDocument
36+
import software.aws.toolkits.jetbrains.services.amazonq.project.UpdateIndexRequest
1837
import software.aws.toolkits.jetbrains.utils.rules.CodeInsightTestFixtureRule
1938
import software.aws.toolkits.jetbrains.utils.rules.JavaCodeInsightTestFixtureRule
2039
import java.net.ConnectException
@@ -25,15 +44,166 @@ class ProjectContextProviderTest {
2544
@JvmField
2645
val projectRule: CodeInsightTestFixtureRule = JavaCodeInsightTestFixtureRule()
2746

47+
@Rule
48+
@JvmField
49+
val wireMock: WireMockRule = createMockServer()
50+
2851
private val project: Project
2952
get() = projectRule.project
3053

31-
private val encoderServer: EncoderServer = mock()
54+
private lateinit var encoderServer: EncoderServer
3255
private lateinit var sut: ProjectContextProvider
3356

57+
private val mapper = jacksonObjectMapper()
58+
3459
@Before
3560
fun setup() {
61+
encoderServer = spy(EncoderServer(project))
62+
encoderServer.stub { on { port } doReturn wireMock.port() }
63+
3664
sut = ProjectContextProvider(project, encoderServer, TestScope())
65+
66+
// initialization
67+
stubFor(any(urlPathEqualTo("/initialize")).willReturn(aResponse().withStatus(200).withResponseBody(Body("initialize response"))))
68+
69+
// build index
70+
stubFor(any(urlPathEqualTo("/indexFiles")).willReturn(aResponse().withStatus(200).withResponseBody(Body("initialize response"))))
71+
72+
// update index
73+
stubFor(any(urlPathEqualTo("/updateIndex")).willReturn(aResponse().withStatus(200).withResponseBody(Body("initialize response"))))
74+
75+
// query
76+
stubFor(
77+
any(urlPathEqualTo("/query")).willReturn(
78+
aResponse()
79+
.withStatus(200)
80+
.withResponseBody(Body(validQueryChatResponse))
81+
)
82+
)
83+
84+
stubFor(
85+
any(urlPathEqualTo("/getUsage"))
86+
.willReturn(
87+
aResponse()
88+
.withStatus(200)
89+
.withResponseBody(Body(validGetUsageResponse))
90+
)
91+
)
92+
}
93+
94+
@Test
95+
fun `Lsp endpoint are correct`() {
96+
assertThat(LspMessage.Initialize.endpoint).isEqualTo("initialize")
97+
assertThat(LspMessage.Index.endpoint).isEqualTo("indexFiles")
98+
assertThat(LspMessage.QueryChat.endpoint).isEqualTo("query")
99+
assertThat(LspMessage.GetUsageMetrics.endpoint).isEqualTo("getUsage")
100+
}
101+
102+
@Test
103+
fun `index should send files within the project to lsp`() {
104+
projectRule.fixture.addFileToProject("Foo.java", "foo")
105+
projectRule.fixture.addFileToProject("Bar.java", "bar")
106+
projectRule.fixture.addFileToProject("Baz.java", "baz")
107+
108+
sut.index()
109+
110+
val request = IndexRequest(listOf("/src/Foo.java", "/src/Bar.java", "/src/Baz.java"), "/src", false)
111+
assertThat(request.filePaths).hasSize(3)
112+
assertThat(request.filePaths).satisfies({
113+
it.contains("/src/Foo.java") &&
114+
it.contains("/src/Baz.java") &&
115+
it.contains("/src/Bar.java")
116+
})
117+
118+
wireMock.verify(
119+
1,
120+
postRequestedFor(urlPathEqualTo("/indexFiles"))
121+
.withHeader("Content-Type", equalTo("text/plain"))
122+
// comment it out because order matters and will cause json string different
123+
// .withRequestBody(equalTo(encryptedRequest))
124+
)
125+
}
126+
127+
@Test
128+
fun `updateIndex should send correct encrypted request to lsp`() {
129+
sut.updateIndex("foo.java")
130+
val request = UpdateIndexRequest("foo.java")
131+
val requestJson = mapper.writeValueAsString(request)
132+
133+
assertThat(mapper.readTree(requestJson)).isEqualTo(mapper.readTree("""{ "filePath": "foo.java" }"""))
134+
135+
val encryptedRequest = encoderServer.encrypt(requestJson)
136+
137+
wireMock.verify(
138+
1,
139+
postRequestedFor(urlPathEqualTo("/updateIndex"))
140+
.withHeader("Content-Type", equalTo("text/plain"))
141+
.withRequestBody(equalTo(encryptedRequest))
142+
)
143+
}
144+
145+
@Test
146+
fun `query should send correct encrypted request to lsp`() {
147+
sut.query("foo")
148+
149+
val request = QueryChatRequest("foo")
150+
val requestJson = mapper.writeValueAsString(request)
151+
152+
assertThat(mapper.readTree(requestJson)).isEqualTo(mapper.readTree("""{ "query": "foo" }"""))
153+
154+
val encryptedRequest = encoderServer.encrypt(requestJson)
155+
156+
wireMock.verify(
157+
1,
158+
postRequestedFor(urlPathEqualTo("/query"))
159+
.withHeader("Content-Type", equalTo("text/plain"))
160+
.withRequestBody(equalTo(encryptedRequest))
161+
)
162+
}
163+
164+
@Test
165+
fun `query chat should return empty if result set non deserializable`() = runTest {
166+
stubFor(
167+
any(urlPathEqualTo("/query")).willReturn(
168+
aResponse().withStatus(200).withResponseBody(
169+
Body(
170+
"""
171+
[
172+
"foo", "bar"
173+
]
174+
""".trimIndent()
175+
)
176+
)
177+
)
178+
)
179+
180+
assertThrows<Exception> {
181+
sut.query("foo")
182+
}
183+
}
184+
185+
@Test
186+
fun `query chat should return deserialized relevantDocument`() = runTest {
187+
val r = sut.query("foo")
188+
assertThat(r).hasSize(2)
189+
assertThat(r[0]).isEqualTo(
190+
RelevantDocument(
191+
"relativeFilePath1",
192+
"context1"
193+
)
194+
)
195+
assertThat(r[1]).isEqualTo(
196+
RelevantDocument(
197+
"relativeFilePath2",
198+
"context2"
199+
)
200+
)
201+
}
202+
203+
@Test
204+
fun `get usage should return memory, cpu usage`() = runTest {
205+
val r = sut.getUsage()
206+
assertThat(r).isEqualTo(ProjectContextProvider.Usage(123, 456))
37207
}
38208

39209
@Test
@@ -57,4 +227,52 @@ class ProjectContextProviderTest {
57227
}
58228
verify(encoderServer, times(1)).encrypt(any())
59229
}
230+
231+
private fun createMockServer() = WireMockRule(wireMockConfig().dynamicPort())
60232
}
233+
234+
// language=JSON
235+
val validQueryChatResponse = """
236+
[
237+
{
238+
"filePath": "file1",
239+
"content": "content1",
240+
"id": "id1",
241+
"index": "index1",
242+
"vec": [
243+
"vec_1-1",
244+
"vec_1-2",
245+
"vec_1-3"
246+
],
247+
"context": "context1",
248+
"prev": "prev1",
249+
"next": "next1",
250+
"relativePath": "relativeFilePath1",
251+
"programmingLanguage": "language1"
252+
},
253+
{
254+
"filePath": "file2",
255+
"content": "content2",
256+
"id": "id2",
257+
"index": "index2",
258+
"vec": [
259+
"vec_2-1",
260+
"vec_2-2",
261+
"vec_2-3"
262+
],
263+
"context": "context2",
264+
"prev": "prev2",
265+
"next": "next2",
266+
"relativePath": "relativeFilePath2",
267+
"programmingLanguage": "language2"
268+
}
269+
]
270+
""".trimIndent()
271+
272+
// language=JSON
273+
val validGetUsageResponse = """
274+
{
275+
"memoryUsage":123,
276+
"cpuUsage":456
277+
}
278+
""".trimIndent()
Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
// Copyright 2024 Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
// SPDX-License-Identifier: Apache-2.0
3+
4+
package software.aws.toolkits.jetbrains.services.amazonq.project
5+
6+
sealed interface LspMessage {
7+
val endpoint: String
8+
9+
data object Initialize : LspMessage {
10+
override val endpoint: String = "initialize"
11+
}
12+
13+
data object Index : LspMessage {
14+
override val endpoint: String = "indexFiles"
15+
}
16+
17+
data object UpdateIndex : LspMessage {
18+
override val endpoint: String = "updateIndex"
19+
}
20+
21+
data object QueryChat : LspMessage {
22+
override val endpoint: String = "query"
23+
}
24+
25+
data object GetUsageMetrics : LspMessage {
26+
override val endpoint: String = "getUsage"
27+
}
28+
}
29+
30+
interface LspRequest
31+
32+
data class IndexRequest(
33+
val filePaths: List<String>,
34+
val projectRoot: String,
35+
val refresh: Boolean,
36+
) : LspRequest
37+
38+
data class UpdateIndexRequest(
39+
val filePath: String,
40+
) : LspRequest
41+
42+
data class QueryChatRequest(
43+
val query: String,
44+
) : LspRequest
45+
46+
data class LspResponse(
47+
val responseCode: Int,
48+
val responseBody: String,
49+
)

0 commit comments

Comments
 (0)