Skip to content

Commit 41965f8

Browse files
authored
config(amazonq): Add project context to inline completion (#4976)
1 parent 4eb0b30 commit 41965f8

File tree

12 files changed

+619
-110
lines changed

12 files changed

+619
-110
lines changed
Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
{
2+
"type" : "bugfix",
3+
"description" : "Update `@workspace` index when adding or deleting a file"
4+
}

plugins/amazonq/chat/jetbrains-community/tst/software/aws/toolkits/jetbrains/services/amazonq/workspace/context/ProjectContextProviderTest.kt

Lines changed: 204 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -13,37 +13,56 @@ import com.github.tomakehurst.wiremock.client.WireMock.urlPathEqualTo
1313
import com.github.tomakehurst.wiremock.core.WireMockConfiguration.wireMockConfig
1414
import com.github.tomakehurst.wiremock.http.Body
1515
import com.github.tomakehurst.wiremock.junit.WireMockRule
16+
import com.intellij.openapi.application.ApplicationManager
1617
import com.intellij.openapi.project.Project
18+
import com.intellij.testFramework.DisposableRule
19+
import com.intellij.testFramework.replaceService
20+
import kotlinx.coroutines.ExperimentalCoroutinesApi
21+
import kotlinx.coroutines.TimeoutCancellationException
22+
import kotlinx.coroutines.test.StandardTestDispatcher
1723
import kotlinx.coroutines.test.TestScope
24+
import kotlinx.coroutines.test.advanceUntilIdle
1825
import kotlinx.coroutines.test.runTest
26+
import kotlinx.coroutines.withContext
1927
import org.assertj.core.api.Assertions.assertThat
2028
import org.junit.Before
2129
import org.junit.Rule
2230
import org.junit.jupiter.api.assertThrows
2331
import org.mockito.kotlin.any
2432
import org.mockito.kotlin.doReturn
33+
import org.mockito.kotlin.mock
2534
import org.mockito.kotlin.spy
2635
import org.mockito.kotlin.stub
2736
import org.mockito.kotlin.times
2837
import org.mockito.kotlin.verify
2938
import org.mockito.kotlin.whenever
39+
import software.aws.toolkits.jetbrains.core.coroutines.getCoroutineBgContext
3040
import software.aws.toolkits.jetbrains.services.amazonq.project.EncoderServer
3141
import software.aws.toolkits.jetbrains.services.amazonq.project.IndexRequest
42+
import software.aws.toolkits.jetbrains.services.amazonq.project.IndexUpdateMode
43+
import software.aws.toolkits.jetbrains.services.amazonq.project.InlineBm25Chunk
3244
import software.aws.toolkits.jetbrains.services.amazonq.project.LspMessage
3345
import software.aws.toolkits.jetbrains.services.amazonq.project.ProjectContextProvider
3446
import software.aws.toolkits.jetbrains.services.amazonq.project.QueryChatRequest
47+
import software.aws.toolkits.jetbrains.services.amazonq.project.QueryInlineCompletionRequest
3548
import software.aws.toolkits.jetbrains.services.amazonq.project.RelevantDocument
3649
import software.aws.toolkits.jetbrains.services.amazonq.project.UpdateIndexRequest
50+
import software.aws.toolkits.jetbrains.settings.CodeWhispererSettings
3751
import software.aws.toolkits.jetbrains.utils.rules.CodeInsightTestFixtureRule
3852
import software.aws.toolkits.jetbrains.utils.rules.JavaCodeInsightTestFixtureRule
3953
import java.net.ConnectException
4054
import kotlin.test.Test
4155

56+
@OptIn(ExperimentalCoroutinesApi::class)
4257
class ProjectContextProviderTest {
4358
@Rule
4459
@JvmField
4560
val projectRule: CodeInsightTestFixtureRule = JavaCodeInsightTestFixtureRule()
4661

62+
@Rule
63+
@JvmField
64+
val disposableRule: DisposableRule = DisposableRule()
65+
4766
@Rule
4867
@JvmField
4968
val wireMock: WireMockRule = createMockServer()
@@ -56,21 +75,23 @@ class ProjectContextProviderTest {
5675

5776
private val mapper = jacksonObjectMapper()
5877

78+
private val dispatcher = StandardTestDispatcher()
79+
5980
@Before
6081
fun setup() {
6182
encoderServer = spy(EncoderServer(project))
6283
encoderServer.stub { on { port } doReturn wireMock.port() }
6384

64-
sut = ProjectContextProvider(project, encoderServer, TestScope())
85+
sut = ProjectContextProvider(project, encoderServer, TestScope(context = dispatcher))
6586

6687
// initialization
6788
stubFor(any(urlPathEqualTo("/initialize")).willReturn(aResponse().withStatus(200).withResponseBody(Body("initialize response"))))
6889

6990
// build index
70-
stubFor(any(urlPathEqualTo("/indexFiles")).willReturn(aResponse().withStatus(200).withResponseBody(Body("initialize response"))))
91+
stubFor(any(urlPathEqualTo("/buildIndex")).willReturn(aResponse().withStatus(200).withResponseBody(Body("initialize response"))))
7192

7293
// update index
73-
stubFor(any(urlPathEqualTo("/updateIndex")).willReturn(aResponse().withStatus(200).withResponseBody(Body("initialize response"))))
94+
stubFor(any(urlPathEqualTo("/updateIndexV2")).willReturn(aResponse().withStatus(200).withResponseBody(Body("initialize response"))))
7495

7596
// query
7697
stubFor(
@@ -80,6 +101,15 @@ class ProjectContextProviderTest {
80101
.withResponseBody(Body(validQueryChatResponse))
81102
)
82103
)
104+
stubFor(
105+
any(urlPathEqualTo("/queryInlineProjectContext")).willReturn(
106+
aResponse()
107+
.withStatus(200)
108+
.withResponseBody(
109+
Body(validQueryInlineResponse)
110+
)
111+
)
112+
)
83113

84114
stubFor(
85115
any(urlPathEqualTo("/getUsage"))
@@ -92,32 +122,73 @@ class ProjectContextProviderTest {
92122
}
93123

94124
@Test
95-
fun `Lsp endpoint are correct`() {
125+
fun `Lsp endpoint correctness`() {
96126
assertThat(LspMessage.Initialize.endpoint).isEqualTo("initialize")
97-
assertThat(LspMessage.Index.endpoint).isEqualTo("indexFiles")
127+
assertThat(LspMessage.Index.endpoint).isEqualTo("buildIndex")
128+
assertThat(LspMessage.UpdateIndex.endpoint).isEqualTo("updateIndexV2")
98129
assertThat(LspMessage.QueryChat.endpoint).isEqualTo("query")
130+
assertThat(LspMessage.QueryInlineCompletion.endpoint).isEqualTo("queryInlineProjectContext")
99131
assertThat(LspMessage.GetUsageMetrics.endpoint).isEqualTo("getUsage")
100132
}
101133

102134
@Test
103-
fun `index should send files within the project to lsp`() {
135+
fun `index should send files within the project to lsp - vector index enabled`() {
136+
ApplicationManager.getApplication().replaceService(
137+
CodeWhispererSettings::class.java,
138+
mock { on { isProjectContextEnabled() } doReturn true },
139+
disposableRule.disposable
140+
)
141+
142+
projectRule.fixture.addFileToProject("Foo.java", "foo")
143+
projectRule.fixture.addFileToProject("Bar.java", "bar")
144+
projectRule.fixture.addFileToProject("Baz.java", "baz")
145+
146+
sut.index()
147+
148+
val request = IndexRequest(listOf("/src/Foo.java", "/src/Bar.java", "/src/Baz.java"), "/src", "all", "")
149+
assertThat(request.filePaths).hasSize(3)
150+
assertThat(request.filePaths).satisfies({
151+
it.contains("/src/Foo.java") &&
152+
it.contains("/src/Baz.java") &&
153+
it.contains("/src/Bar.java")
154+
})
155+
assertThat(request.config).isEqualTo("all")
156+
157+
wireMock.verify(
158+
1,
159+
postRequestedFor(urlPathEqualTo("/buildIndex"))
160+
.withHeader("Content-Type", equalTo("text/plain"))
161+
// comment it out because order matters and will cause json string different
162+
// .withRequestBody(equalTo(encryptedRequest))
163+
)
164+
}
165+
166+
@Test
167+
fun `index should send files within the project to lsp - vector index disabled`() {
168+
ApplicationManager.getApplication().replaceService(
169+
CodeWhispererSettings::class.java,
170+
mock { on { isProjectContextEnabled() } doReturn false },
171+
disposableRule.disposable
172+
)
173+
104174
projectRule.fixture.addFileToProject("Foo.java", "foo")
105175
projectRule.fixture.addFileToProject("Bar.java", "bar")
106176
projectRule.fixture.addFileToProject("Baz.java", "baz")
107177

108178
sut.index()
109179

110-
val request = IndexRequest(listOf("/src/Foo.java", "/src/Bar.java", "/src/Baz.java"), "/src", false)
180+
val request = IndexRequest(listOf("/src/Foo.java", "/src/Bar.java", "/src/Baz.java"), "/src", "default", "")
111181
assertThat(request.filePaths).hasSize(3)
112182
assertThat(request.filePaths).satisfies({
113183
it.contains("/src/Foo.java") &&
114184
it.contains("/src/Baz.java") &&
115185
it.contains("/src/Bar.java")
116186
})
187+
assertThat(request.config).isEqualTo("default")
117188

118189
wireMock.verify(
119190
1,
120-
postRequestedFor(urlPathEqualTo("/indexFiles"))
191+
postRequestedFor(urlPathEqualTo("/buildIndex"))
121192
.withHeader("Content-Type", equalTo("text/plain"))
122193
// comment it out because order matters and will cause json string different
123194
// .withRequestBody(equalTo(encryptedRequest))
@@ -126,17 +197,17 @@ class ProjectContextProviderTest {
126197

127198
@Test
128199
fun `updateIndex should send correct encrypted request to lsp`() {
129-
sut.updateIndex("foo.java")
130-
val request = UpdateIndexRequest("foo.java")
200+
sut.updateIndex(listOf("foo.java"), IndexUpdateMode.UPDATE)
201+
val request = UpdateIndexRequest(listOf("foo.java"), IndexUpdateMode.UPDATE.command)
131202
val requestJson = mapper.writeValueAsString(request)
132203

133-
assertThat(mapper.readTree(requestJson)).isEqualTo(mapper.readTree("""{ "filePath": "foo.java" }"""))
204+
assertThat(mapper.readTree(requestJson)).isEqualTo(mapper.readTree("""{ "filePaths": ["foo.java"], "mode": "update" }"""))
134205

135206
val encryptedRequest = encoderServer.encrypt(requestJson)
136207

137208
wireMock.verify(
138209
1,
139-
postRequestedFor(urlPathEqualTo("/updateIndex"))
210+
postRequestedFor(urlPathEqualTo("/updateIndexV2"))
140211
.withHeader("Content-Type", equalTo("text/plain"))
141212
.withRequestBody(equalTo(encryptedRequest))
142213
)
@@ -161,6 +232,26 @@ class ProjectContextProviderTest {
161232
)
162233
}
163234

235+
@Test
236+
fun `queryInline should send correct encrypted request to lsp`() = runTest {
237+
sut = ProjectContextProvider(project, encoderServer, this)
238+
sut.queryInline("foo", "Foo.java")
239+
advanceUntilIdle()
240+
241+
val request = QueryInlineCompletionRequest("foo", "Foo.java")
242+
val requestJson = mapper.writeValueAsString(request)
243+
244+
assertThat(mapper.readTree(requestJson)).isEqualTo(mapper.readTree("""{ "query": "foo", "filePath": "Foo.java" }"""))
245+
246+
val encryptedRequest = encoderServer.encrypt(requestJson)
247+
wireMock.verify(
248+
1,
249+
postRequestedFor(urlPathEqualTo("/queryInlineProjectContext"))
250+
.withHeader("Content-Type", equalTo("text/plain"))
251+
.withRequestBody(equalTo(encryptedRequest))
252+
)
253+
}
254+
164255
@Test
165256
fun `query chat should return empty if result set non deserializable`() = runTest {
166257
stubFor(
@@ -200,12 +291,92 @@ class ProjectContextProviderTest {
200291
)
201292
}
202293

294+
@Test
295+
fun `query inline should throw if resultset not deserializable`() {
296+
assertThrows<Exception> {
297+
runTest {
298+
sut = ProjectContextProvider(project, encoderServer, this)
299+
stubFor(
300+
any(urlPathEqualTo("/queryInlineProjectContext")).willReturn(
301+
aResponse().withStatus(200).withResponseBody(
302+
Body(
303+
"""
304+
[
305+
"foo", "bar"
306+
]
307+
""".trimIndent()
308+
)
309+
)
310+
)
311+
)
312+
313+
assertThrows<Exception> {
314+
sut.queryInline("foo", "filepath")
315+
advanceUntilIdle()
316+
}
317+
}
318+
}
319+
}
320+
321+
@Test
322+
fun `query inline should return deserialized bm25 chunks`() = runTest {
323+
sut = ProjectContextProvider(project, encoderServer, this)
324+
advanceUntilIdle()
325+
val r = sut.queryInline("foo", "filepath")
326+
assertThat(r).hasSize(3)
327+
assertThat(r[0]).isEqualTo(
328+
InlineBm25Chunk(
329+
"content1",
330+
"file1",
331+
0.1
332+
)
333+
)
334+
assertThat(r[1]).isEqualTo(
335+
InlineBm25Chunk(
336+
"content2",
337+
"file2",
338+
0.2
339+
)
340+
)
341+
assertThat(r[2]).isEqualTo(
342+
InlineBm25Chunk(
343+
"content3",
344+
"file3",
345+
0.3
346+
)
347+
)
348+
}
349+
203350
@Test
204351
fun `get usage should return memory, cpu usage`() = runTest {
205352
val r = sut.getUsage()
206353
assertThat(r).isEqualTo(ProjectContextProvider.Usage(123, 456))
207354
}
208355

356+
@Test
357+
fun `queryInline should throw if time elapsed is greater than 50ms`() = runTest {
358+
assertThrows<TimeoutCancellationException> {
359+
sut = ProjectContextProvider(project, encoderServer, this)
360+
stubFor(
361+
any(urlPathEqualTo("/queryInlineProjectContext")).willReturn(
362+
aResponse()
363+
.withStatus(200)
364+
.withResponseBody(
365+
Body(validQueryInlineResponse)
366+
)
367+
.withFixedDelay(51) // 10 sec
368+
)
369+
)
370+
371+
// it won't throw if it's executed within TestDispatcher context
372+
withContext(getCoroutineBgContext()) {
373+
sut.queryInline("foo", "bar")
374+
}
375+
376+
advanceUntilIdle()
377+
}
378+
}
379+
209380
@Test
210381
fun `test index payload is encrypted`() = runTest {
211382
whenever(encoderServer.port).thenReturn(3000)
@@ -231,6 +402,27 @@ class ProjectContextProviderTest {
231402
private fun createMockServer() = WireMockRule(wireMockConfig().dynamicPort())
232403
}
233404

405+
// language=JSON
406+
val validQueryInlineResponse = """
407+
[
408+
{
409+
"content": "content1",
410+
"filePath": "file1",
411+
"score": 0.1
412+
},
413+
{
414+
"content": "content2",
415+
"filePath": "file2",
416+
"score": 0.2
417+
},
418+
{
419+
"content": "content3",
420+
"filePath": "file3",
421+
"score": 0.3
422+
}
423+
]
424+
""".trimIndent()
425+
234426
// language=JSON
235427
val validQueryChatResponse = """
236428
[

plugins/amazonq/codewhisperer/jetbrains-community/src/software/aws/toolkits/jetbrains/services/codewhisperer/util/CodeWhispererConstants.kt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ object CodeWhispererConstants {
3131
const val POPUP_DELAY_CHECK_INTERVAL: Long = 25
3232
const val IDLE_TIME_CHECK_INTERVAL: Long = 25
3333
const val SUPPLEMENTAL_CONTEXT_TIMEOUT = 50L
34+
const val SUPPLEMETAL_CONTEXT_BUFFER = 10L
3435

3536
val AWSTemplateKeyWordsRegex = Regex("(AWSTemplateFormatVersion|Resources|AWS::|Description)")
3637
val AWSTemplateCaseInsensitiveKeyWordsRegex = Regex("(cloudformation|cfn|template|description)")

0 commit comments

Comments
 (0)