Skip to content

Commit 5aeb5ce

Browse files
committed
patch
1 parent 575a120 commit 5aeb5ce

File tree

4 files changed

+72
-2
lines changed

4 files changed

+72
-2
lines changed

plugins/amazonq/codewhisperer/jetbrains-community/src/software/aws/toolkits/jetbrains/services/codewhisperer/util/CodeWhispererFileContextProvider.kt

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ import software.aws.toolkits.core.utils.debug
2222
import software.aws.toolkits.core.utils.getLogger
2323
import software.aws.toolkits.core.utils.info
2424
import software.aws.toolkits.core.utils.warn
25+
import software.aws.toolkits.jetbrains.services.amazonq.project.ProjectContextController
2526
import software.aws.toolkits.jetbrains.services.codewhisperer.editor.CodeWhispererEditorUtil
2627
import software.aws.toolkits.jetbrains.services.codewhisperer.language.CodeWhispererProgrammingLanguage
2728
import software.aws.toolkits.jetbrains.services.codewhisperer.language.languages.CodeWhispererJava
@@ -214,6 +215,8 @@ class DefaultCodeWhispererFileContextProvider(private val project: Project) : Fi
214215

215216
// takeLast(11) will extract 10 lines (exclusing current line) of left context as the query parameter
216217
val query = targetContext.caretContext.leftFileContext.split("\n").takeLast(11).joinToString("\n")
218+
val bm25 = ProjectContextController.getInstance(project).queryBM25(query, targetContext.filename)
219+
println("---------------------------------${bm25} --------------------------------------------")
217220

218221
// step 1: prepare data
219222
val first60Chunks: List<Chunk> = try {
Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
// Copyright 2024 Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
// SPDX-License-Identifier: Apache-2.0
3+
4+
package software.aws.toolkits.jetbrains.services.amazonq.project
5+
6+
data class BM25Chunk(
7+
val content: String,
8+
val filePath: String,
9+
val score: Double
10+
)

plugins/amazonq/shared/jetbrains-community/src/software/aws/toolkits/jetbrains/services/amazonq/project/ProjectContextController.kt

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,15 @@ class ProjectContextController(private val project: Project, private val cs: Cor
3737
}
3838
}
3939

40+
fun queryBM25(prompt: String, filePath: String): Any {
41+
try {
42+
return projectContextProvider.queryBM25(prompt, filePath)
43+
} catch (e: Exception) {
44+
logger.warn { "error while querying for project context $e.message" }
45+
return emptyList<Any>()
46+
}
47+
}
48+
4049
fun updateIndex(filePath: String) {
4150
try {
4251
return projectContextProvider.updateIndex(filePath)

plugins/amazonq/shared/jetbrains-community/src/software/aws/toolkits/jetbrains/services/amazonq/project/ProjectContextProvider.kt

Lines changed: 50 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ import kotlinx.coroutines.launch
2020
import kotlinx.coroutines.runBlocking
2121
import kotlinx.coroutines.yield
2222
import software.aws.toolkits.core.utils.debug
23+
import software.aws.toolkits.core.utils.error
2324
import software.aws.toolkits.core.utils.getLogger
2425
import software.aws.toolkits.core.utils.info
2526
import software.aws.toolkits.core.utils.warn
@@ -59,6 +60,13 @@ class ProjectContextProvider(val project: Project, private val encoderServer: En
5960
val refresh: Boolean,
6061
)
6162

63+
data class IndexRequestV2(
64+
val filePaths: List<String>,
65+
val projectRoot: String,
66+
val config: String,
67+
val language: String = ""
68+
)
69+
6270
data class FileCollectionResult(
6371
val files: List<String>,
6472
val fileSize: Int,
@@ -68,10 +76,20 @@ class ProjectContextProvider(val project: Project, private val encoderServer: En
6876
val query: String,
6977
)
7078

79+
data class QueryRequestV2(
80+
val query: String,
81+
val filePath: String,
82+
)
83+
7184
data class UpdateIndexRequestPayload(
7285
val filePath: String,
7386
)
7487

88+
data class UpdateIndexRequestV2(
89+
val filePath: String,
90+
val content: String,
91+
)
92+
7593
data class Usage(
7694
@JsonIgnoreProperties(ignoreUnknown = true)
7795
@JsonProperty("memoryUsage")
@@ -118,6 +136,7 @@ class ProjectContextProvider(val project: Project, private val encoderServer: En
118136
return@launch
119137
}
120138
} catch (e: Exception) {
139+
logger.error { "project context index error: message=${e.message} stack=${e.stackTraceToString()}"}
121140
if (e.stackTraceToString().contains("Connection refused")) {
122141
retryCount.incrementAndGet()
123142
delay(10000)
@@ -143,13 +162,14 @@ class ProjectContextProvider(val project: Project, private val encoderServer: En
143162
fun index(): Boolean {
144163
logger.info { "project context: indexing ${project.name} on port ${encoderServer.port}" }
145164
val indexStartTime = System.currentTimeMillis()
146-
val url = URL("http://localhost:${encoderServer.port}/indexFiles")
165+
val url = URL("http://localhost:${encoderServer.port}/buildIndex")
147166
val filesResult = collectFiles()
148167
var duration = (System.currentTimeMillis() - indexStartTime).toDouble()
149168
logger.debug { "project context file collection time: ${duration}ms" }
150169
logger.debug { "list of files collected: ${filesResult.files.joinToString("\n")}" }
151170
val projectRoot = project.guessProjectDir()?.path ?: return false
152-
val payload = IndexRequestPayload(filesResult.files, projectRoot, false)
171+
// val payload = IndexRequestPayload(filesResult.files, projectRoot, false)
172+
val payload = IndexRequestV2(filesResult.files, projectRoot, "all", "")
153173
val payloadJson = mapper.writeValueAsString(payload)
154174
val encrypted = encoderServer.encrypt(payloadJson)
155175

@@ -200,6 +220,34 @@ class ProjectContextProvider(val project: Project, private val encoderServer: En
200220
}
201221
}
202222

223+
fun queryBM25(prompt: String, filePath: String): List<BM25Chunk> {
224+
logger.info { "project context: querying ${project.name} on port ${encoderServer.port}" }
225+
val url = URL("http://localhost:${encoderServer.port}/queryInlineProjectContext")
226+
val payload = QueryRequestV2(prompt, filePath)
227+
val payloadJson = mapper.writeValueAsString(payload)
228+
val encrypted = encoderServer.encrypt(payloadJson)
229+
230+
val connection = url.openConnection() as HttpURLConnection
231+
setConnectionProperties(connection)
232+
setConnectionTimeout(connection)
233+
setConnectionRequest(connection, encrypted)
234+
235+
val responseCode = connection.responseCode
236+
logger.info { "project context query response code: $responseCode for $prompt" }
237+
val responseBody = if (responseCode == 200) {
238+
connection.inputStream.bufferedReader().use { reader -> reader.readText() }
239+
} else {
240+
""
241+
}
242+
connection.disconnect()
243+
try {
244+
return mapper.readValue<List<BM25Chunk>>(responseBody)
245+
} catch (e: Exception) {
246+
logger.warn { "error parsing query response ${e.message}" }
247+
throw e
248+
}
249+
}
250+
203251
private fun recordIndexWorkspace(
204252
duration: Double,
205253
fileCount: Int = 0,

0 commit comments

Comments
 (0)