Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
{
"type" : "feature",
"description" : "Enhance Q inline completion context fetching for better suggestion quality"
}
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ import software.aws.toolkits.jetbrains.services.amazonq.project.EncoderServer
import software.aws.toolkits.jetbrains.services.amazonq.project.IndexRequest
import software.aws.toolkits.jetbrains.services.amazonq.project.IndexUpdateMode
import software.aws.toolkits.jetbrains.services.amazonq.project.InlineBm25Chunk
import software.aws.toolkits.jetbrains.services.amazonq.project.InlineContextTarget
import software.aws.toolkits.jetbrains.services.amazonq.project.LspMessage
import software.aws.toolkits.jetbrains.services.amazonq.project.ProjectContextProvider
import software.aws.toolkits.jetbrains.services.amazonq.project.QueryChatRequest
Expand Down Expand Up @@ -237,13 +238,13 @@ class ProjectContextProviderTest {
@Test
fun `queryInline should send correct encrypted request to lsp`() = runTest {
sut = ProjectContextProvider(project, encoderServer, this)
sut.queryInline("foo", "Foo.java")
sut.queryInline("foo", "Foo.java", InlineContextTarget.CODEMAP)
advanceUntilIdle()

val request = QueryInlineCompletionRequest("foo", "Foo.java")
val request = QueryInlineCompletionRequest("foo", "Foo.java", "codemap")
val requestJson = mapper.writeValueAsString(request)

assertThat(mapper.readTree(requestJson)).isEqualTo(mapper.readTree("""{ "query": "foo", "filePath": "Foo.java" }"""))
assertThat(mapper.readTree(requestJson)).isEqualTo(mapper.readTree("""{ "query": "foo", "filePath": "Foo.java", "target": "codemap" }"""))

val encryptedRequest = encoderServer.encrypt(requestJson)
wireMock.verify(
Expand Down Expand Up @@ -315,7 +316,7 @@ class ProjectContextProviderTest {
)

assertThrows<Exception> {
sut.queryInline("foo", "filepath")
sut.queryInline("foo", "filepath", InlineContextTarget.CODEMAP)
advanceUntilIdle()
}
}
Expand All @@ -326,7 +327,7 @@ class ProjectContextProviderTest {
fun `query inline should return deserialized bm25 chunks`() = runTest {
sut = ProjectContextProvider(project, encoderServer, this)
advanceUntilIdle()
val r = sut.queryInline("foo", "filepath")
val r = sut.queryInline("foo", "filepath", InlineContextTarget.CODEMAP)
assertThat(r).hasSize(3)
assertThat(r[0]).isEqualTo(
InlineBm25Chunk(
Expand Down Expand Up @@ -374,7 +375,7 @@ class ProjectContextProviderTest {

// it won't throw if it's executed within TestDispatcher context
withContext(getCoroutineBgContext()) {
sut.queryInline("foo", "bar")
sut.queryInline("foo", "bar", InlineContextTarget.CODEMAP)
}

advanceUntilIdle()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,7 @@ object CodeWhispererConstants {
const val CHUNK_SIZE = 60
const val NUMBER_OF_LINE_IN_CHUNK = 50
const val NUMBER_OF_CHUNK_TO_FETCH = 3
const val MAX_TOTAL_LENGTH = 20480
}

object Utg {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@ import software.aws.toolkits.core.utils.debug
import software.aws.toolkits.core.utils.getLogger
import software.aws.toolkits.core.utils.info
import software.aws.toolkits.core.utils.warn
import software.aws.toolkits.jetbrains.services.amazonq.CodeWhispererFeatureConfigService
import software.aws.toolkits.jetbrains.services.amazonq.project.ProjectContextController
import software.aws.toolkits.jetbrains.services.codewhisperer.editor.CodeWhispererEditorUtil
import software.aws.toolkits.jetbrains.services.codewhisperer.language.CodeWhispererProgrammingLanguage
Expand All @@ -44,6 +43,7 @@ import java.io.DataInput
import java.io.DataOutput
import java.util.Collections
import kotlin.coroutines.coroutineContext
import kotlin.time.measureTimedValue

private val contentRootPathProvider = CopyContentRootPathProvider()

Expand Down Expand Up @@ -147,12 +147,20 @@ class DefaultCodeWhispererFileContextProvider(private val project: Project) : Fi
val latency = System.currentTimeMillis() - startFetchingTimestamp
if (it.contents.isNotEmpty()) {
val logStr = buildString {
append("Successfully fetched supplemental context with strategy ${it.strategy} with $latency ms")
append(
"""Q inline completion supplemental context:
| Strategy: ${it.strategy},
| Latency: $latency ms,
| Contents: ${it.contents.size} chunks,
| ContentLength: ${it.contentLength} chars,
| TargetFile: ${it.targetFileName},
""".trimMargin()
)
it.contents.forEachIndexed { index, chunk ->
append(
"""
|
| Chunk ${index + 1}:
| Chunk $index:
| path = ${chunk.path},
| score = ${chunk.score},
| contentLength = ${chunk.content.length}
Expand Down Expand Up @@ -219,55 +227,113 @@ class DefaultCodeWhispererFileContextProvider(private val project: Project) : Fi
val query = generateQuery(targetContext)

val contexts = withContext(coroutineContext) {
val projectContextDeferred1 = if (CodeWhispererFeatureConfigService.getInstance().getInlineCompletion()) {
async {
val t0 = System.currentTimeMillis()
val r = fetchProjectContext(query, psiFile, targetContext)
val t1 = System.currentTimeMillis()
LOG.debug {
buildString {
append("time elapse for fetching project context=${t1 - t0}ms; ")
append("numberOfChunks=${r.contents.size}; ")
append("totalLength=${r.contentLength}")
}
val projectContextDeferred1 = async {
val timedCodemapContext = measureTimedValue { fetchProjectContext(query, psiFile, targetContext) }
val codemapContext = timedCodemapContext.value
LOG.debug {
buildString {
append("time elapse for fetching project context=${timedCodemapContext.duration.inWholeMilliseconds}ms; ")
append("numberOfChunks=${codemapContext.contents.size}; ")
append("totalLength=${codemapContext.contentLength}")
}

r
}
} else {
null

codemapContext
}

val openTabsContextDeferred1 = async {
val t0 = System.currentTimeMillis()
val r = fetchOpenTabsContext(query, psiFile, targetContext)
val t1 = System.currentTimeMillis()
val timedOpentabContext = measureTimedValue { fetchOpenTabsContext(query, psiFile, targetContext) }
val opentabContext = timedOpentabContext.value
LOG.debug {
buildString {
append("time elapse for open tabs context=${t1 - t0}ms; ")
append("numberOfChunks=${r.contents.size}; ")
append("totalLength=${r.contentLength}")
append("time elapse for open tabs context=${timedOpentabContext.duration.inWholeMilliseconds}ms; ")
append("numberOfChunks=${opentabContext.contents.size}; ")
append("totalLength=${opentabContext.contentLength}")
}
}

r
opentabContext
}

if (projectContextDeferred1 != null) {
awaitAll(projectContextDeferred1, openTabsContextDeferred1)
} else {
awaitAll(openTabsContextDeferred1)
}
awaitAll(projectContextDeferred1, openTabsContextDeferred1)
}

val projectContext = contexts.find { it.strategy == CrossFileStrategy.ProjectContext }
val projectContext = contexts.find { it.strategy == CrossFileStrategy.Codemap }
val openTabsContext = contexts.find { it.strategy == CrossFileStrategy.OpenTabsBM25 }

return if (projectContext != null && projectContext.contents.isNotEmpty()) {
projectContext
} else {
openTabsContext ?: SupplementalContextInfo.emptyCrossFileContextInfo(targetContext.filename)
/**
* We're using both codemap and opentabs context
* 1. If both are present, codemap should live in the first of supplemental context list, i.e [codemap, opentabs_0, opentabs_1...] with strategy name codemap
* 2. If only one is present, return the one present with corresponding strategy name, either codemap or opentabs
* 3. If none is present, return empty list with strategy name empty
*
* Service will throw 400 error when context length is greater than 20480, drop the last chunk until the total length fits in the cap
*/
val contextBeforeTruncation = when {
projectContext == null && openTabsContext == null -> SupplementalContextInfo.emptyCrossFileContextInfo(targetContext.filename)

projectContext != null && openTabsContext != null -> {
val context1 = projectContext.contents
val context2 = openTabsContext.contents
val mergedContext = (context1 + context2).filter { it.content.isNotEmpty() }

val strategy = if (projectContext.contentLength != 0 && openTabsContext.contentLength != 0) {
CrossFileStrategy.Codemap
} else if (projectContext.contentLength != 0) {
CrossFileStrategy.Codemap
} else if (openTabsContext.contentLength != 0) {
CrossFileStrategy.OpenTabsBM25
} else {
CrossFileStrategy.Empty
}

SupplementalContextInfo(
isUtg = false,
contents = mergedContext,
targetFileName = targetContext.filename,
strategy = strategy
)
}

projectContext != null -> {
return if (projectContext.contentLength == 0) {
SupplementalContextInfo.emptyCrossFileContextInfo(targetContext.filename)
} else {
SupplementalContextInfo(
isUtg = false,
contents = projectContext.contents,
targetFileName = targetContext.filename,
strategy = CrossFileStrategy.Codemap
)
}
}

openTabsContext != null -> {
return if (openTabsContext.contentLength == 0) {
SupplementalContextInfo.emptyCrossFileContextInfo(targetContext.filename)
} else {
SupplementalContextInfo(
isUtg = false,
contents = openTabsContext.contents,
targetFileName = targetContext.filename,
strategy = CrossFileStrategy.OpenTabsBM25
)
}
}

else -> SupplementalContextInfo.emptyCrossFileContextInfo(targetContext.filename)
}

return truncateContext(contextBeforeTruncation)
}

fun truncateContext(context: SupplementalContextInfo): SupplementalContextInfo {
var c = context.contents
while (c.sumOf { it.content.length } >= CodeWhispererConstants.CrossFile.MAX_TOTAL_LENGTH) {
c = c.dropLast(1)
Comment on lines +332 to +333
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

looks nice, but maybe optimize

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yea agreeeeeeeee

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let me follow up on it

}

return context.copy(contents = c)
}

@VisibleForTesting
Expand All @@ -285,7 +351,7 @@ class DefaultCodeWhispererFileContextProvider(private val project: Project) : Fi
)
},
targetFileName = targetContext.filename,
strategy = CrossFileStrategy.ProjectContext
strategy = CrossFileStrategy.Codemap
)
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,21 +12,23 @@ enum class UtgStrategy : SupplementalContextStrategy {
;

override fun toString() = when (this) {
ByName -> "ByName"
ByContent -> "ByContent"
Empty -> "Empty"
ByName -> "byName"
ByContent -> "byContent"
Empty -> "empty"
}
}

enum class CrossFileStrategy : SupplementalContextStrategy {
OpenTabsBM25,
Empty,
ProjectContext,
Codemap,
;

override fun toString() = when (this) {
OpenTabsBM25 -> "OpenTabs_BM25"
Empty -> "Empty"
ProjectContext -> "ProjectContext"
OpenTabsBM25 -> "opentabs"
Empty -> "empty"
ProjectContext -> "projectContext"
Codemap -> "codemap"
}
}
Loading
Loading