Skip to content

Commit 3621194

Browse files
authored
refactor(amazonq): improve error handling of workspace index (#5373)
* improve error handling of workspace index
1 parent 0972fcc commit 3621194

File tree

2 files changed

+33
-12
lines changed

2 files changed

+33
-12
lines changed

plugins/amazonq/chat/jetbrains-community/tst/software/aws/toolkits/jetbrains/services/amazonq/workspace/context/ProjectContextProviderTest.kt

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@ import com.intellij.openapi.application.ApplicationManager
1717
import com.intellij.openapi.project.Project
1818
import com.intellij.testFramework.DisposableRule
1919
import com.intellij.testFramework.replaceService
20+
import io.mockk.every
21+
import io.mockk.spyk
2022
import kotlinx.coroutines.ExperimentalCoroutinesApi
2123
import kotlinx.coroutines.TimeoutCancellationException
2224
import kotlinx.coroutines.test.StandardTestDispatcher
@@ -44,6 +46,7 @@ import software.aws.toolkits.jetbrains.services.amazonq.project.InlineBm25Chunk
4446
import software.aws.toolkits.jetbrains.services.amazonq.project.InlineContextTarget
4547
import software.aws.toolkits.jetbrains.services.amazonq.project.LspMessage
4648
import software.aws.toolkits.jetbrains.services.amazonq.project.ProjectContextProvider
49+
import software.aws.toolkits.jetbrains.services.amazonq.project.ProjectContextProvider.FileCollectionResult
4750
import software.aws.toolkits.jetbrains.services.amazonq.project.QueryChatRequest
4851
import software.aws.toolkits.jetbrains.services.amazonq.project.QueryInlineCompletionRequest
4952
import software.aws.toolkits.jetbrains.services.amazonq.project.RelevantDocument
@@ -82,8 +85,8 @@ class ProjectContextProviderTest {
8285
fun setup() {
8386
encoderServer = spy(EncoderServer(project))
8487
encoderServer.stub { on { port } doReturn wireMock.port() }
85-
86-
sut = ProjectContextProvider(project, encoderServer, TestScope(context = dispatcher))
88+
encoderServer.stub { on { isNodeProcessRunning() } doReturn true }
89+
sut = spyk(ProjectContextProvider(project, encoderServer, TestScope(context = dispatcher)))
8790

8891
// initialization
8992
stubFor(any(urlPathEqualTo("/initialize")).willReturn(aResponse().withStatus(200).withResponseBody(Body("initialize response"))))
@@ -143,7 +146,10 @@ class ProjectContextProviderTest {
143146
projectRule.fixture.addFileToProject("Foo.java", "foo")
144147
projectRule.fixture.addFileToProject("Bar.java", "bar")
145148
projectRule.fixture.addFileToProject("Baz.java", "baz")
146-
149+
every { sut.collectFiles() } returns FileCollectionResult(
150+
files = listOf("Foo.java", "Bar.java", "Baz.java"),
151+
fileSize = 10
152+
)
147153
sut.index()
148154

149155
val request = IndexRequest(listOf("/src/Foo.java", "/src/Bar.java", "/src/Baz.java"), "/src", "all", "")
@@ -175,7 +181,10 @@ class ProjectContextProviderTest {
175181
projectRule.fixture.addFileToProject("Foo.java", "foo")
176182
projectRule.fixture.addFileToProject("Bar.java", "bar")
177183
projectRule.fixture.addFileToProject("Baz.java", "baz")
178-
184+
every { sut.collectFiles() } returns FileCollectionResult(
185+
files = listOf("Foo.java", "Bar.java", "Baz.java"),
186+
fileSize = 10
187+
)
179188
sut.index()
180189

181190
val request = IndexRequest(listOf("/src/Foo.java", "/src/Bar.java", "/src/Baz.java"), "/src", "default", "")
@@ -408,6 +417,10 @@ class ProjectContextProviderTest {
408417
@Test
409418
fun `test index payload is encrypted`() = runTest {
410419
whenever(encoderServer.port).thenReturn(3000)
420+
every { sut.collectFiles() } returns FileCollectionResult(
421+
files = listOf("Foo.java", "Bar.java", "Baz.java"),
422+
fileSize = 10
423+
)
411424
try {
412425
sut.index()
413426
} catch (e: ConnectException) {

plugins/amazonq/shared/jetbrains-community/src/software/aws/toolkits/jetbrains/services/amazonq/project/ProjectContextProvider.kt

Lines changed: 16 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -130,14 +130,18 @@ class ProjectContextProvider(val project: Project, private val encoderServer: En
130130
private fun initEncryption(): Boolean {
131131
val request = encoderServer.getEncryptionRequest()
132132
val response = sendMsgToLsp(LspMessage.Initialize, request)
133-
return response.responseCode == 200
133+
return response?.responseCode == 200
134134
}
135135

136136
fun index(): Boolean {
137137
val projectRoot = project.basePath ?: return false
138138

139139
val indexStartTime = System.currentTimeMillis()
140140
val filesResult = collectFiles()
141+
if (filesResult.files.isEmpty()) {
142+
logger.warn { "No file found in workspace" }
143+
return false
144+
}
141145
var duration = (System.currentTimeMillis() - indexStartTime).toDouble()
142146
logger.debug { "time elapsed to collect project context files: ${duration}ms, collected ${filesResult.files.size} files" }
143147

@@ -149,12 +153,13 @@ class ProjectContextProvider(val project: Project, private val encoderServer: En
149153
logger.debug { "project context index time: ${duration}ms" }
150154

151155
val startUrl = getStartUrl(project)
152-
if (response.responseCode == 200) {
156+
if (response?.responseCode == 200) {
153157
val usage = getUsage()
154158
recordIndexWorkspace(duration, filesResult.files.size, filesResult.fileSize, true, usage?.memoryUsage, usage?.cpuUsage, startUrl)
155159
logger.debug { "project context index finished for ${project.name}" }
156160
return true
157161
} else {
162+
logger.debug { "project context index failed" }
158163
recordIndexWorkspace(duration, filesResult.files.size, filesResult.fileSize, false, null, null, startUrl)
159164
return false
160165
}
@@ -164,8 +169,7 @@ class ProjectContextProvider(val project: Project, private val encoderServer: En
164169
suspend fun query(prompt: String, timeout: Long?): List<RelevantDocument> = withTimeout(timeout ?: CHAT_EXPLICIT_PROJECT_CONTEXT_TIMEOUT) {
165170
cs.async {
166171
val encrypted = encryptRequest(QueryChatRequest(prompt))
167-
val response = sendMsgToLsp(LspMessage.QueryChat, encrypted)
168-
172+
val response = sendMsgToLsp(LspMessage.QueryChat, encrypted) ?: return@async emptyList()
169173
val parsedResponse = mapper.readValue<List<Chunk>>(response.responseBody)
170174
queryResultToRelevantDocuments(parsedResponse)
171175
}.await()
@@ -174,13 +178,13 @@ class ProjectContextProvider(val project: Project, private val encoderServer: En
174178
suspend fun queryInline(query: String, filePath: String, target: InlineContextTarget): List<InlineBm25Chunk> = withTimeout(SUPPLEMENTAL_CONTEXT_TIMEOUT) {
175179
cs.async {
176180
val encrypted = encryptRequest(QueryInlineCompletionRequest(query, filePath, target.toString()))
177-
val r = sendMsgToLsp(LspMessage.QueryInlineCompletion, encrypted)
181+
val r = sendMsgToLsp(LspMessage.QueryInlineCompletion, encrypted) ?: return@async emptyList()
178182
return@async mapper.readValue<List<InlineBm25Chunk>>(r.responseBody)
179183
}.await()
180184
}
181185

182186
fun getUsage(): Usage? {
183-
val response = sendMsgToLsp(LspMessage.GetUsageMetrics, request = null)
187+
val response = sendMsgToLsp(LspMessage.GetUsageMetrics, request = null) ?: return null
184188
return try {
185189
val parsedResponse = mapper.readValue<Usage>(response.responseBody)
186190
parsedResponse
@@ -246,7 +250,7 @@ class ProjectContextProvider(val project: Project, private val encoderServer: En
246250
return regex.find(fileName) != null
247251
}
248252

249-
private fun collectFiles(): FileCollectionResult {
253+
fun collectFiles(): FileCollectionResult {
250254
val collectedFiles = mutableListOf<String>()
251255
var currentTotalFileSize = 0L
252256
val featureDevSessionContext = FeatureDevSessionContext(project)
@@ -306,9 +310,13 @@ class ProjectContextProvider(val project: Project, private val encoderServer: En
306310
return encoderServer.encrypt(payloadJson)
307311
}
308312

309-
private fun sendMsgToLsp(msgType: LspMessage, request: String?): LspResponse {
313+
private fun sendMsgToLsp(msgType: LspMessage, request: String?): LspResponse? {
310314
logger.info { "sending message: ${msgType.endpoint} to lsp on port ${encoderServer.port}" }
311315
val url = URL("http://localhost:${encoderServer.port}/${msgType.endpoint}")
316+
if (!encoderServer.isNodeProcessRunning()) {
317+
logger.warn { "language server is not running" }
318+
return null
319+
}
312320
// use 1h as timeout for index, 5 seconds for other APIs
313321
val timeoutMs = if (msgType is LspMessage.Index) 60.minutes.inWholeMilliseconds.toInt() else 5000
314322
return with(url.openConnection() as HttpURLConnection) {

0 commit comments

Comments
 (0)