Skip to content

Commit 3840926

Browse files
committed
timeout LSP query inline for 50ms
1 parent a59204c commit 3840926

File tree

7 files changed

+366
-119
lines changed

7 files changed

+366
-119
lines changed

plugins/amazonq/chat/jetbrains-community/tst/software/aws/toolkits/jetbrains/services/amazonq/workspace/context/ProjectContextProviderTest.kt

Lines changed: 97 additions & 75 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ import org.junit.Rule
1919
import org.junit.jupiter.api.assertThrows
2020
import org.mockito.kotlin.any
2121
import org.mockito.kotlin.doReturn
22-
import org.mockito.kotlin.mock
22+
import org.mockito.kotlin.spy
2323
import org.mockito.kotlin.stub
2424
import org.mockito.kotlin.times
2525
import org.mockito.kotlin.verify
@@ -32,6 +32,7 @@ import software.aws.toolkits.jetbrains.services.amazonq.project.RelevantDocument
3232
import software.aws.toolkits.jetbrains.utils.rules.CodeInsightTestFixtureRule
3333
import software.aws.toolkits.jetbrains.utils.rules.JavaCodeInsightTestFixtureRule
3434
import java.net.ConnectException
35+
import java.util.concurrent.TimeoutException
3536
import kotlin.test.Test
3637

3738
class ProjectContextProviderTest {
@@ -46,13 +47,16 @@ class ProjectContextProviderTest {
4647
private val project: Project
4748
get() = projectRule.project
4849

49-
private val encoderServer: EncoderServer = mock()
50+
private lateinit var encoderServer: EncoderServer
5051
private lateinit var sut: ProjectContextProvider
5152

5253
@Before
5354
fun setup() {
54-
sut = ProjectContextProvider(project, encoderServer, TestScope())
55+
encoderServer = spy(EncoderServer(project))
5556
encoderServer.stub { on { port } doReturn wireMock.port() }
57+
58+
sut = ProjectContextProvider(project, encoderServer, TestScope())
59+
5660
// initialization
5761
stubFor(any(urlPathEqualTo("/initialize")).willReturn(aResponse().withStatus(200).withResponseBody(Body("initialize response"))))
5862

@@ -67,73 +71,18 @@ class ProjectContextProviderTest {
6771
// query
6872
stubFor(
6973
any(urlPathEqualTo("/query")).willReturn(
70-
aResponse().withStatus(200).withResponseBody(
71-
Body(
72-
"""
73-
[
74-
{
75-
"filePath": "file1",
76-
"content": "content1",
77-
"id": "id1",
78-
"index": "index1",
79-
"vec": [
80-
"vec_1-1",
81-
"vec_1-2",
82-
"vec_1-3"
83-
],
84-
"context": "context1",
85-
"prev": "prev1",
86-
"next": "next1",
87-
"relativePath": "relativeFilePath1",
88-
"programmingLanguage": "language1"
89-
},
90-
{
91-
"filePath": "file2",
92-
"content": "content2",
93-
"id": "id2",
94-
"index": "index2",
95-
"vec": [
96-
"vec_2-1",
97-
"vec_2-2",
98-
"vec_2-3"
99-
],
100-
"context": "context2",
101-
"prev": "prev2",
102-
"next": "next2",
103-
"relativePath": "relativeFilePath2",
104-
"programmingLanguage": "language2"
105-
}
106-
]
107-
""".trimIndent()
108-
)
109-
)
74+
aResponse()
75+
.withStatus(200)
76+
.withResponseBody(Body(validQueryChatResponse))
11077
)
11178
)
11279
stubFor(
11380
any(urlPathEqualTo("/queryInlineProjectContext")).willReturn(
114-
aResponse().withStatus(200).withResponseBody(
115-
Body(
116-
"""
117-
[
118-
{
119-
"content": "content1",
120-
"filePath": "file1",
121-
"score": 0.1
122-
},
123-
{
124-
"content": "content2",
125-
"filePath": "file2",
126-
"score": 0.2
127-
},
128-
{
129-
"content": "content3",
130-
"filePath": "file3",
131-
"score": 0.3
132-
}
133-
]
134-
""".trimIndent()
81+
aResponse()
82+
.withStatus(200)
83+
.withResponseBody(
84+
Body(validQueryInlineResponse)
13585
)
136-
)
13786
)
13887
)
13988

@@ -142,16 +91,7 @@ class ProjectContextProviderTest {
14291
.willReturn(
14392
aResponse()
14493
.withStatus(200)
145-
.withResponseBody(
146-
Body(
147-
"""
148-
{
149-
"memoryUsage":123,
150-
"cpuUsage":456
151-
}
152-
""".trimIndent()
153-
)
154-
)
94+
.withResponseBody(Body(validGetUsageResponse))
15595
)
15696
)
15797
}
@@ -259,6 +199,24 @@ class ProjectContextProviderTest {
259199
assertThat(r).isEqualTo(ProjectContextProvider.Usage(123, 456))
260200
}
261201

202+
@Test
203+
fun `should return empty if timeout with 50ms`() {
204+
stubFor(
205+
any(urlPathEqualTo("/queryInlineProjectContext")).willReturn(
206+
aResponse()
207+
.withStatus(200)
208+
.withResponseBody(
209+
Body(validQueryInlineResponse)
210+
)
211+
.withFixedDelay(51) // 10 sec
212+
)
213+
)
214+
215+
assertThrows<TimeoutException> {
216+
sut.queryInline("foo", "bar")
217+
}
218+
}
219+
262220
@Test
263221
fun `test index payload is encrypted`() = runTest {
264222
whenever(encoderServer.port).thenReturn(3000)
@@ -283,3 +241,67 @@ class ProjectContextProviderTest {
283241

284242
private fun createMockServer() = WireMockRule(wireMockConfig().dynamicPort())
285243
}
244+
245+
val validQueryInlineResponse = """
246+
[
247+
{
248+
"content": "content1",
249+
"filePath": "file1",
250+
"score": 0.1
251+
},
252+
{
253+
"content": "content2",
254+
"filePath": "file2",
255+
"score": 0.2
256+
},
257+
{
258+
"content": "content3",
259+
"filePath": "file3",
260+
"score": 0.3
261+
}
262+
]
263+
""".trimIndent()
264+
265+
val validQueryChatResponse = """
266+
[
267+
{
268+
"filePath": "file1",
269+
"content": "content1",
270+
"id": "id1",
271+
"index": "index1",
272+
"vec": [
273+
"vec_1-1",
274+
"vec_1-2",
275+
"vec_1-3"
276+
],
277+
"context": "context1",
278+
"prev": "prev1",
279+
"next": "next1",
280+
"relativePath": "relativeFilePath1",
281+
"programmingLanguage": "language1"
282+
},
283+
{
284+
"filePath": "file2",
285+
"content": "content2",
286+
"id": "id2",
287+
"index": "index2",
288+
"vec": [
289+
"vec_2-1",
290+
"vec_2-2",
291+
"vec_2-3"
292+
],
293+
"context": "context2",
294+
"prev": "prev2",
295+
"next": "next2",
296+
"relativePath": "relativeFilePath2",
297+
"programmingLanguage": "language2"
298+
}
299+
]
300+
""".trimIndent()
301+
302+
val validGetUsageResponse = """
303+
{
304+
"memoryUsage":123,
305+
"cpuUsage":456
306+
}
307+
""".trimIndent()

plugins/amazonq/codewhisperer/jetbrains-community/src/software/aws/toolkits/jetbrains/services/codewhisperer/util/CodeWhispererFileContextProvider.kt

Lines changed: 42 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,8 @@ import software.aws.toolkits.core.utils.debug
2222
import software.aws.toolkits.core.utils.getLogger
2323
import software.aws.toolkits.core.utils.info
2424
import software.aws.toolkits.core.utils.warn
25+
import software.aws.toolkits.jetbrains.services.amazonq.CodeWhispererFeatureConfigService
26+
import software.aws.toolkits.jetbrains.services.amazonq.project.ProjectContextController
2527
import software.aws.toolkits.jetbrains.services.codewhisperer.editor.CodeWhispererEditorUtil
2628
import software.aws.toolkits.jetbrains.services.codewhisperer.language.CodeWhispererProgrammingLanguage
2729
import software.aws.toolkits.jetbrains.services.codewhisperer.language.languages.CodeWhispererJava
@@ -206,33 +208,49 @@ class DefaultCodeWhispererFileContextProvider(private val project: Project) : Fi
206208

207209
override fun isTestFile(psiFile: PsiFile) = psiFile.programmingLanguage().fileCrawler.isTestFile(psiFile.virtualFile, psiFile.project)
208210

209-
@VisibleForTesting
210211
suspend fun extractSupplementalFileContextForSrc(psiFile: PsiFile, targetContext: FileContextInfo): SupplementalContextInfo {
211212
if (!targetContext.programmingLanguage.isSupplementalContextSupported()) {
212213
return SupplementalContextInfo.emptyCrossFileContextInfo(targetContext.filename)
213214
}
214215

215-
// takeLast(11) will extract 10 lines (exclusing current line) of left context as the query parameter
216-
val query = targetContext.caretContext.leftFileContext.split("\n").takeLast(11).joinToString("\n")
217-
218-
// TODO: uncomment
219-
// if (CodeWhispererFeatureConfigService.getInstance().getInlineCompletion()) {
220-
// val response = ProjectContextController.getInstance(project).queryInline(query, psiFile.virtualFile?.path ?: "").filter { it.content.isNotBlank() }
221-
// return SupplementalContextInfo(
222-
// isUtg = false,
223-
// contents = response.map {
224-
// Chunk(
225-
// content = it.content,
226-
// path = it.filePath,
227-
// nextChunk = it.content,
228-
// score = it.score
229-
// )
230-
// },
231-
// targetFileName = targetContext.filename,
232-
// strategy = CrossFileStrategy.ProjectContext
233-
// )
234-
// }
216+
val query = generateQuery(targetContext)
217+
218+
val projectContext = if (CodeWhispererFeatureConfigService.getInstance().getInlineCompletion()) {
219+
fetchProjectContext(query, psiFile, targetContext)
220+
} else {
221+
null
222+
}
223+
224+
val openTabsContext = fetchOpentabsContext(query, psiFile, targetContext)
225+
226+
return if (projectContext == null || projectContext.contents.isEmpty()) {
227+
openTabsContext
228+
} else {
229+
projectContext
230+
}
231+
}
232+
233+
@VisibleForTesting
234+
suspend fun fetchProjectContext(query: String, psiFile: PsiFile, targetContext: FileContextInfo): SupplementalContextInfo {
235+
val response = ProjectContextController.getInstance(project).queryInline(query, psiFile.virtualFile?.path ?: "")
235236

237+
return SupplementalContextInfo(
238+
isUtg = false,
239+
contents = response.map {
240+
Chunk(
241+
content = it.content,
242+
path = it.filePath,
243+
nextChunk = it.content,
244+
score = it.score
245+
)
246+
},
247+
targetFileName = targetContext.filename,
248+
strategy = CrossFileStrategy.ProjectContext
249+
)
250+
}
251+
252+
@VisibleForTesting
253+
suspend fun fetchOpentabsContext(query: String, psiFile: PsiFile, targetContext: FileContextInfo): SupplementalContextInfo {
236254
// step 1: prepare data
237255
val first60Chunks: List<Chunk> = try {
238256
runReadAction { codewhispererCodeChunksIndex.getFileData(psiFile) }
@@ -323,6 +341,9 @@ class DefaultCodeWhispererFileContextProvider(private val project: Project) : Fi
323341
}
324342
}
325343

344+
// takeLast(11) will extract 10 lines (exclusing current line) of left context as the query parameter
345+
fun generateQuery(fileContext: FileContextInfo) = fileContext.caretContext.leftFileContext.split("\n").takeLast(11).joinToString("\n")
346+
326347
companion object {
327348
private val LOG = getLogger<DefaultCodeWhispererFileContextProvider>()
328349

plugins/amazonq/codewhisperer/jetbrains-community/src/software/aws/toolkits/jetbrains/services/codewhisperer/util/CodeWhispererFileCrawler.kt

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -84,11 +84,11 @@ abstract class CodeWhispererFileCrawler : FileCrawler {
8484
}.orEmpty()
8585

8686
override fun listCrossFileCandidate(target: PsiFile): List<VirtualFile> {
87-
val targetFile = target.virtualFile
87+
val targetFile = target.viewProvider.virtualFile
8888

8989
val openedFiles = runReadAction {
9090
FileEditorManager.getInstance(target.project).openFiles.toList().filter {
91-
it.name != target.virtualFile.name &&
91+
it.name != targetFile.name &&
9292
isSameDialect(it.extension) &&
9393
!isTestFile(it, target.project)
9494
}

0 commit comments

Comments
 (0)