Skip to content

Commit 5d68556

Browse files
committed
fix potential inline completion failure due to input validation exception for supplemental context
1 parent c28033d commit 5d68556

File tree

5 files changed

+134
-2
lines changed

5 files changed

+134
-2
lines changed

plugins/amazonq/codewhisperer/jetbrains-community/src/software/aws/toolkits/jetbrains/services/codewhisperer/util/CodeWhispererConstants.kt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -188,6 +188,7 @@ object CodeWhispererConstants {
188188
const val NUMBER_OF_LINE_IN_CHUNK = 50
189189
const val NUMBER_OF_CHUNK_TO_FETCH = 3
190190
const val MAX_TOTAL_LENGTH = 20480
191+
const val MAX_LENGTH_PER_CHUNK = 10240
191192
}
192193

193194
object Utg {

plugins/amazonq/codewhisperer/jetbrains-community/src/software/aws/toolkits/jetbrains/services/codewhisperer/util/CodeWhispererFileContextProvider.kt

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ import software.aws.toolkits.core.utils.getLogger
2626
import software.aws.toolkits.core.utils.info
2727
import software.aws.toolkits.core.utils.warn
2828
import software.aws.toolkits.jetbrains.services.amazonq.project.ProjectContextController
29+
import software.aws.toolkits.jetbrains.services.amazonq.project.ProjectContextProvider
2930
import software.aws.toolkits.jetbrains.services.codewhisperer.editor.CodeWhispererEditorUtil
3031
import software.aws.toolkits.jetbrains.services.codewhisperer.language.CodeWhispererProgrammingLanguage
3132
import software.aws.toolkits.jetbrains.services.codewhisperer.language.languages.CodeWhispererJava
@@ -327,10 +328,26 @@ class DefaultCodeWhispererFileContextProvider(private val project: Project) : Fi
327328
return truncateContext(contextBeforeTruncation)
328329
}
329330

331+
/**
332+
* Requirement
333+
* - Maximum 5 supplemental context.
334+
* - Each chunk can't exceed 10240 characters
335+
* - Sum of all chunks can't exceed 20480 characters
336+
*/
330337
fun truncateContext(context: SupplementalContextInfo): SupplementalContextInfo {
331-
var c = context.contents
332-
while (c.sumOf { it.content.length } >= CodeWhispererConstants.CrossFile.MAX_TOTAL_LENGTH) {
338+
var c = context.contents.map {
339+
return@map if (it.content.length > CodeWhispererConstants.CrossFile.MAX_LENGTH_PER_CHUNK) {
340+
it.copy(content = truncateLineByLine(it.content, CodeWhispererConstants.CrossFile.MAX_LENGTH_PER_CHUNK))
341+
} else {
342+
it
343+
}
344+
}
345+
346+
var curTotalLength = c.sumOf { it.content.length }
347+
while (curTotalLength >= CodeWhispererConstants.CrossFile.MAX_TOTAL_LENGTH) {
348+
val last = c.last()
333349
c = c.dropLast(1)
350+
curTotalLength -= last.content.length
334351
}
335352

336353
return context.copy(contents = c)

plugins/amazonq/codewhisperer/jetbrains-community/src/software/aws/toolkits/jetbrains/services/codewhisperer/util/CodeWhispererUtil.kt

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -110,6 +110,28 @@ suspend fun String.toCodeChunk(path: String): List<Chunk> {
110110
}
111111
}
112112

113+
fun truncateLineByLine(input: String, l: Int): String {
114+
val maxLength = if (l > 0) l else -1 * l
115+
if (input.isEmpty()) {
116+
return ""
117+
}
118+
val shouldAddNewLineBack = input.last() == '\n'
119+
var lines = input.trim().split("\n")
120+
var curLen = input.length
121+
while (curLen > maxLength) {
122+
val last = lines.last()
123+
lines = lines.dropLast(1)
124+
curLen -= last.length + 1
125+
}
126+
127+
val r = lines.joinToString("\n")
128+
return if (shouldAddNewLineBack) {
129+
r + "\n"
130+
} else {
131+
r
132+
}
133+
}
134+
113135
fun getAuthType(project: Project): CredentialSourceId? {
114136
val connection = checkBearerConnectionValidity(project, BearerTokenFeatureSet.Q)
115137
var authType: CredentialSourceId? = null

plugins/amazonq/codewhisperer/jetbrains-community/tst/software/aws/toolkits/jetbrains/services/codewhisperer/CodeWhispererFileContextProviderTest.kt

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33

44
package software.aws.toolkits.jetbrains.services.codewhisperer
55

6+
import ai.grazie.utils.chainIfNotNull
67
import com.intellij.openapi.application.ApplicationManager
78
import com.intellij.openapi.application.readAction
89
import com.intellij.openapi.project.Project
@@ -13,6 +14,7 @@ import com.intellij.testFramework.fixtures.JavaCodeInsightTestFixture
1314
import com.intellij.testFramework.replaceService
1415
import com.intellij.testFramework.runInEdtAndGet
1516
import com.intellij.testFramework.runInEdtAndWait
17+
import fleet.util.letIfNotNull
1618
import kotlinx.coroutines.delay
1719
import kotlinx.coroutines.runBlocking
1820
import kotlinx.coroutines.test.TestScope
@@ -529,6 +531,47 @@ class CodeWhispererFileContextProviderTest {
529531
assertThat(r.targetFileName).isEqualTo("foo")
530532
}
531533

534+
@Test
535+
fun `truncate context should make context length per item fit in 10240 cap`() {
536+
val chunkA = Chunk(content = "a\n".repeat(4000), path = "a.java")
537+
val chunkB = Chunk(content = "b\n".repeat(6000), path = "b.java")
538+
val chunkC = Chunk(content = "c\n".repeat(1000), path = "c.java")
539+
val chunkD = Chunk(content = "d\n".repeat(1500), path = "d.java")
540+
541+
assertThat(chunkA.content.length).isEqualTo(8000)
542+
assertThat(chunkB.content.length).isEqualTo(12000)
543+
assertThat(chunkC.content.length).isEqualTo(2000)
544+
assertThat(chunkD.content.length).isEqualTo(3000)
545+
assertThat(chunkA.content.length + chunkB.content.length + chunkC.content.length + chunkD.content.length).isEqualTo(25000)
546+
547+
val supplementalContext = SupplementalContextInfo(
548+
isUtg = false,
549+
contents = listOf(
550+
chunkA,
551+
chunkB,
552+
chunkC,
553+
chunkD,
554+
),
555+
targetFileName = "foo",
556+
strategy = CrossFileStrategy.Codemap
557+
)
558+
559+
val r = sut.truncateContext(supplementalContext)
560+
561+
assertThat(r.contents).hasSize(3)
562+
val truncatedChunkA = r.contents[0]
563+
val truncatedChunkB = r.contents[1]
564+
val truncatedChunkC = r.contents[2]
565+
566+
assertThat(truncatedChunkA.content.length).isEqualTo(8000)
567+
assertThat(truncatedChunkB.content.length).isEqualTo(10240)
568+
assertThat(truncatedChunkC.content.length).isEqualTo(2000)
569+
570+
assertThat(r.contentLength).isEqualTo(20240)
571+
assertThat(r.strategy).isEqualTo(CrossFileStrategy.Codemap)
572+
assertThat(r.targetFileName).isEqualTo("foo")
573+
}
574+
532575
private fun setupFixture(fixture: JavaCodeInsightTestFixture): List<PsiFile> {
533576
val psiFile1 = fixture.addFileToProject("Main.java", JAVA_MAIN)
534577
val psiFile2 = fixture.addFileToProject("UtilClass.java", JAVA_UTILCLASS)

plugins/amazonq/codewhisperer/jetbrains-community/tst/software/aws/toolkits/jetbrains/services/codewhisperer/CodeWhispererUtilTest.kt

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ import software.aws.toolkits.jetbrains.services.codewhisperer.util.CodeWhisperer
2727
import software.aws.toolkits.jetbrains.services.codewhisperer.util.isWithin
2828
import software.aws.toolkits.jetbrains.services.codewhisperer.util.runIfIdcConnectionOrTelemetryEnabled
2929
import software.aws.toolkits.jetbrains.services.codewhisperer.util.toCodeChunk
30+
import software.aws.toolkits.jetbrains.services.codewhisperer.util.truncateLineByLine
3031
import software.aws.toolkits.jetbrains.settings.AwsSettings
3132
import software.aws.toolkits.jetbrains.utils.rules.JavaCodeInsightTestFixtureRule
3233
import software.aws.toolkits.telemetry.CodewhispererCompletionType
@@ -61,6 +62,54 @@ class CodeWhispererUtilTest {
6162
AwsSettings.getInstance().isTelemetryEnabled = isTelemetryEnabledDefault
6263
}
6364

65+
@Test
66+
fun `truncateLineByLine should drop the last line if max length is greater than threshold`() {
67+
val input: String = """
68+
${"a".repeat(11)}
69+
${"b".repeat(11)}
70+
${"c".repeat(11)}
71+
${"d".repeat(11)}
72+
${"e".repeat(11)}
73+
""".trimIndent()
74+
assertThat(input.length).isGreaterThan(50)
75+
val actual = truncateLineByLine(input, 50)
76+
assertThat(actual).isEqualTo(
77+
"""
78+
${"a".repeat(11)}
79+
${"b".repeat(11)}
80+
${"c".repeat(11)}
81+
${"d".repeat(11)}
82+
""".trimIndent()
83+
)
84+
85+
val input2 = "b\n".repeat(10)
86+
val actual2 = truncateLineByLine(input2, 8)
87+
assertThat(actual2.length).isEqualTo(8)
88+
}
89+
90+
@Test
91+
fun `truncateLineByLine should return empty if empty string is provided`() {
92+
val input = ""
93+
val actual = truncateLineByLine(input, 50)
94+
assertThat(actual).isEqualTo("")
95+
}
96+
97+
@Test
98+
fun `truncateLineByLine should return empty if 0 max length is provided`() {
99+
val input = "aaaaa"
100+
val actual = truncateLineByLine(input, 0)
101+
assertThat(actual).isEqualTo("")
102+
}
103+
104+
@Test
105+
fun `truncateLineByLine should return flip the value if negative max length is provided`() {
106+
val input = "aaaaa\nbbbbb"
107+
val actual = truncateLineByLine(input, -6)
108+
val expected1 = truncateLineByLine(input, 6)
109+
assertThat(actual).isEqualTo(expected1)
110+
assertThat(actual).isEqualTo("aaaaa")
111+
}
112+
64113
@Test
65114
fun `checkIfIdentityCenterLoginOrTelemetryEnabled will execute callback if the connection is IamIdentityCenter`() {
66115
val modificationTracker = SimpleModificationTracker()

0 commit comments

Comments
 (0)