Skip to content

Commit c034541

Browse files
committed
fix(amazonq): switch to ulong to avoid overflow when input is larger than 2gb
2GB in bytes > INT_MAX so use ULong, which can handle 18 PB
1 parent d20b192 commit c034541

File tree

5 files changed

+106
-62
lines changed

5 files changed

+106
-62
lines changed
Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
{
2+
"type" : "bugfix",
3+
"description" : "Fix integer overflow when local context index input is larger than 2GB"
4+
}

plugins/amazonq/codewhisperer/jetbrains-community/src/software/aws/toolkits/jetbrains/services/codewhisperer/settings/CodeWhispererConfigurable.kt

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -133,7 +133,7 @@ class CodeWhispererConfigurable(private val project: Project) :
133133

134134
row(message("aws.settings.codewhisperer.project_context_index_thread")) {
135135
intTextField(
136-
range = IntRange(0, 50)
136+
range = CodeWhispererSettings.CONTEXT_INDEX_THREADS
137137
).bindIntText(codeWhispererSettings::getProjectContextIndexThreadCount, codeWhispererSettings::setProjectContextIndexThreadCount)
138138
.apply {
139139
connect.subscribe(
@@ -150,7 +150,7 @@ class CodeWhispererConfigurable(private val project: Project) :
150150

151151
row(message("aws.settings.codewhisperer.project_context_index_max_size")) {
152152
intTextField(
153-
range = IntRange(1, 4096)
153+
range = CodeWhispererSettings.CONTEXT_INDEX_SIZE
154154
).bindIntText(codeWhispererSettings::getProjectContextIndexMaxSize, codeWhispererSettings::setProjectContextIndexMaxSize)
155155
.apply {
156156
connect.subscribe(

plugins/amazonq/codewhisperer/jetbrains-community/tst/software/aws/toolkits/jetbrains/services/codewhisperer/CodeWhispererSettingsTest.kt

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -211,6 +211,42 @@ class CodeWhispererSettingsTest : CodeWhispererTestBase() {
211211
assertThat(actual.autoBuildSetting).hasSize(1)
212212
assertThat(actual.autoBuildSetting["project1"]).isTrue()
213213
}
214+
215+
@Test
216+
fun `context thread count is returned in range`() {
217+
val sut = CodeWhispererSettings.getInstance()
218+
219+
mapOf(
220+
1 to 1,
221+
0 to 0,
222+
-1 to 0,
223+
123 to 50,
224+
50 to 50,
225+
51 to 50,
226+
).forEach { s, expected ->
227+
sut.setProjectContextIndexThreadCount(s)
228+
assertThat(sut.getProjectContextIndexThreadCount()).isEqualTo(expected)
229+
}
230+
231+
}
232+
233+
@Test
234+
fun `context index size is returned in range`() {
235+
val sut = CodeWhispererSettings.getInstance()
236+
237+
mapOf(
238+
1 to 1,
239+
0 to 1,
240+
-1 to 1,
241+
123 to 123,
242+
2047 to 2047,
243+
4096 to 4096,
244+
4097 to 4096,
245+
).forEach { s, expected ->
246+
sut.setProjectContextIndexMaxSize(s)
247+
assertThat(sut.getProjectContextIndexMaxSize()).isEqualTo(expected)
248+
}
249+
}
214250
}
215251

216252
class CodeWhispererSettingUnitTest {

plugins/amazonq/shared/jetbrains-community/src/software/aws/toolkits/jetbrains/services/amazonq/project/ProjectContextProvider.kt

Lines changed: 55 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,7 @@ class ProjectContextProvider(val project: Project, private val encoderServer: En
7070

7171
data class FileCollectionResult(
7272
val files: List<String>,
73-
val fileSize: Int,
73+
val fileSize: Int, // in MB
7474
)
7575

7676
// TODO: move to LspMessage.kt
@@ -241,59 +241,7 @@ class ProjectContextProvider(val project: Project, private val encoderServer: En
241241
}
242242
}
243243

244-
private fun willExceedPayloadLimit(currentTotalFileSize: Long, currentFileSize: Long): Boolean {
245-
val maxSize = CodeWhispererSettings.getInstance().getProjectContextIndexMaxSize()
246-
return currentTotalFileSize.let { totalSize -> totalSize > (maxSize * 1024 * 1024 - currentFileSize) }
247-
}
248-
249-
private fun isBuildOrBin(fileName: String): Boolean {
250-
val regex = Regex("""bin|build|node_modules|venv|\.venv|env|\.idea|\.conda""", RegexOption.IGNORE_CASE)
251-
return regex.find(fileName) != null
252-
}
253-
254-
fun collectFiles(): FileCollectionResult {
255-
val collectedFiles = mutableListOf<String>()
256-
var currentTotalFileSize = 0L
257-
val allFiles = mutableListOf<VirtualFile>()
258-
259-
val projectBaseDirectories = project.getBaseDirectories()
260-
val changeListManager = ChangeListManager.getInstance(project)
261-
262-
projectBaseDirectories.forEach {
263-
VfsUtilCore.visitChildrenRecursively(
264-
it,
265-
object : VirtualFileVisitor<Unit>(NO_FOLLOW_SYMLINKS) {
266-
// TODO: refactor this along with /dev & codescan file traversing logic
267-
override fun visitFile(file: VirtualFile): Boolean {
268-
if ((file.isDirectory && isBuildOrBin(file.name)) ||
269-
!isWorkspaceSourceContent(file, projectBaseDirectories, changeListManager, additionalGlobalIgnoreRulesForStrictSources) ||
270-
(file.isFile && file.length > 10 * 1024 * 1024)
271-
) {
272-
return false
273-
}
274-
if (file.isFile) {
275-
allFiles.add(file)
276-
return false
277-
}
278-
return true
279-
}
280-
}
281-
)
282-
}
283-
284-
for (file in allFiles) {
285-
if (willExceedPayloadLimit(currentTotalFileSize, file.length)) {
286-
break
287-
}
288-
collectedFiles.add(file.path)
289-
currentTotalFileSize += file.length
290-
}
291-
292-
return FileCollectionResult(
293-
files = collectedFiles.toList(),
294-
fileSize = (currentTotalFileSize / 1024 / 1024).toInt()
295-
)
296-
}
244+
fun collectFiles(): FileCollectionResult = collectFiles(project.getBaseDirectories(), ChangeListManager.getInstance(project))
297245

298246
private fun queryResultToRelevantDocuments(queryResult: List<Chunk>): List<RelevantDocument> {
299247
val documents: MutableList<RelevantDocument> = mutableListOf()
@@ -353,5 +301,58 @@ class ProjectContextProvider(val project: Project, private val encoderServer: En
353301

354302
companion object {
355303
private val logger = getLogger<ProjectContextProvider>()
304+
305+
private fun willExceedPayloadLimit(maxSize: ULong, currentTotalFileSize: ULong, currentFileSize: Long) =
306+
currentTotalFileSize.let { totalSize -> totalSize > (maxSize - currentFileSize.toUInt()) }
307+
308+
private fun isBuildOrBin(fileName: String): Boolean {
309+
val regex = Regex("""bin|build|node_modules|venv|\.venv|env|\.idea|\.conda""", RegexOption.IGNORE_CASE)
310+
return regex.find(fileName) != null
311+
}
312+
313+
fun collectFiles(projectBaseDirectories: Set<VirtualFile>, changeListManager: ChangeListManager): FileCollectionResult {
314+
val mega = 1024u * 1024u
315+
val maxSize = CodeWhispererSettings.getInstance()
316+
.getProjectContextIndexMaxSize().toULong() * mega
317+
val tenMb = 10 * mega.toInt()
318+
val collectedFiles = mutableListOf<String>()
319+
var currentTotalFileSize = 0UL
320+
val allFiles = mutableListOf<VirtualFile>()
321+
322+
projectBaseDirectories.forEach {
323+
VfsUtilCore.visitChildrenRecursively(
324+
it,
325+
object : VirtualFileVisitor<Unit>(NO_FOLLOW_SYMLINKS) {
326+
// TODO: refactor this along with /dev & codescan file traversing logic
327+
override fun visitFile(file: VirtualFile): Boolean {
328+
if ((file.isDirectory && isBuildOrBin(file.name)) ||
329+
!isWorkspaceSourceContent(file, projectBaseDirectories, changeListManager, additionalGlobalIgnoreRulesForStrictSources) ||
330+
(file.isFile && file.length > tenMb)
331+
) {
332+
return false
333+
}
334+
if (file.isFile) {
335+
allFiles.add(file)
336+
return false
337+
}
338+
return true
339+
}
340+
}
341+
)
342+
}
343+
344+
for (file in allFiles) {
345+
if (willExceedPayloadLimit(maxSize, currentTotalFileSize, file.length)) {
346+
break
347+
}
348+
collectedFiles.add(file.path)
349+
currentTotalFileSize += file.length.toUInt()
350+
}
351+
352+
return FileCollectionResult(
353+
files = collectedFiles.toList(),
354+
fileSize = (currentTotalFileSize / 1024u / 1024u).toInt()
355+
)
356+
}
356357
}
357358
}

plugins/amazonq/shared/jetbrains-community/src/software/aws/toolkits/jetbrains/settings/CodeWhispererSettings.kt

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,7 @@ class CodeWhispererSettings : PersistentStateComponent<CodeWhispererConfiguratio
9292
fun getProjectContextIndexThreadCount(): Int = state.intValue.getOrDefault(
9393
CodeWhispererIntConfigurationType.ProjectContextIndexThreadCount,
9494
0
95-
)
95+
).coerceIn(CONTEXT_INDEX_THREADS)
9696

9797
fun setProjectContextIndexThreadCount(value: Int) {
9898
state.intValue[CodeWhispererIntConfigurationType.ProjectContextIndexThreadCount] = value
@@ -101,7 +101,7 @@ class CodeWhispererSettings : PersistentStateComponent<CodeWhispererConfiguratio
101101
fun getProjectContextIndexMaxSize(): Int = state.intValue.getOrDefault(
102102
CodeWhispererIntConfigurationType.ProjectContextIndexMaxSize,
103103
250
104-
)
104+
).coerceIn(CONTEXT_INDEX_SIZE)
105105

106106
fun setProjectContextIndexMaxSize(value: Int) {
107107
state.intValue[CodeWhispererIntConfigurationType.ProjectContextIndexMaxSize] = value
@@ -134,10 +134,6 @@ class CodeWhispererSettings : PersistentStateComponent<CodeWhispererConfiguratio
134134
state.value[CodeWhispererConfigurationType.IsTabAcceptPriorityNotificationShownOnce] = value
135135
}
136136

137-
companion object {
138-
fun getInstance(): CodeWhispererSettings = service()
139-
}
140-
141137
override fun getState(): CodeWhispererConfiguration = CodeWhispererConfiguration().apply {
142138
value.putAll(state.value)
143139
intValue.putAll(state.intValue)
@@ -155,6 +151,13 @@ class CodeWhispererSettings : PersistentStateComponent<CodeWhispererConfiguratio
155151
this.state.stringValue.putAll(state.stringValue)
156152
this.state.autoBuildSetting.putAll(state.autoBuildSetting)
157153
}
154+
155+
companion object {
156+
fun getInstance(): CodeWhispererSettings = service()
157+
158+
val CONTEXT_INDEX_SIZE = IntRange(1, 4096)
159+
val CONTEXT_INDEX_THREADS = IntRange(0, 50)
160+
}
158161
}
159162

160163
class CodeWhispererConfiguration : BaseState() {

0 commit comments

Comments
 (0)