Skip to content

Commit 2e82d2d

Browse files
authored
CodeWhisperer: enhanced Java file fetching (#3766)
1 parent a556a6c commit 2e82d2d

File tree

13 files changed

+127
-52
lines changed

13 files changed

+127
-52
lines changed
Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
{
2+
"type" : "feature",
3+
"description" : "CodeWhisperer: Improve Java suggestion quality with enhanced file context fetching"
4+
}

jetbrains-core/src/software/aws/toolkits/jetbrains/services/codewhisperer/language/languages/CodeWhispererJavaScript.kt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,8 @@ class CodeWhispererJavaScript private constructor() : CodeWhispererProgrammingLa
1919

2020
override fun isAllClassifier(): Boolean = true
2121

22+
override fun isSupplementalContextSupported() = true
23+
2224
companion object {
2325
const val ID = "javascript"
2426

jetbrains-core/src/software/aws/toolkits/jetbrains/services/codewhisperer/language/languages/CodeWhispererPython.kt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,8 @@ class CodeWhispererPython private constructor() : CodeWhispererProgrammingLangua
2323

2424
override fun isUTGSupported() = true
2525

26+
override fun isSupplementalContextSupported() = true
27+
2628
companion object {
2729
const val ID = "python"
2830

jetbrains-core/src/software/aws/toolkits/jetbrains/services/codewhisperer/language/languages/CodeWhispererTypeScript.kt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@ class CodeWhispererTypeScript private constructor() : CodeWhispererProgrammingLa
1717

1818
override fun isAllClassifier(): Boolean = true
1919

20+
override fun isSupplementalContextSupported() = true
21+
2022
companion object {
2123
const val ID = "typescript"
2224

jetbrains-core/src/software/aws/toolkits/jetbrains/services/codewhisperer/service/CodeWhispererService.kt

Lines changed: 18 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -511,36 +511,27 @@ class CodeWhispererService {
511511
// 2. supplemental context
512512
val startFetchingTimestamp = System.currentTimeMillis()
513513
val isTstFile = FileContextProvider.getInstance(project).isTestFile(psiFile)
514-
val supplementalContext = if (CodeWhispererUserGroupSettings.getInstance().getUserGroup() == CodeWhispererUserGroup.CrossFile) {
515-
runBlocking {
516-
try {
517-
withTimeout(SUPPLEMENTAL_CONTEXT_TIMEOUT) {
518-
FileContextProvider.getInstance(project).extractSupplementalFileContext(psiFile, fileContext)
519-
}
520-
} catch (e: Exception) {
521-
if (e is TimeoutCancellationException) {
522-
LOG.debug {
523-
"Supplemental context fetch timed out in ${System.currentTimeMillis() - startFetchingTimestamp}ms"
524-
}
525-
SupplementalContextInfo(
526-
isUtg = isTstFile,
527-
contents = emptyList(),
528-
latency = System.currentTimeMillis() - startFetchingTimestamp,
529-
targetFileName = fileContext.filename
530-
)
531-
} else {
532-
LOG.debug { "Run into unexpected error when fetching supplemental context, error: ${e.message}" }
533-
null
514+
val supplementalContext = runBlocking {
515+
try {
516+
withTimeout(SUPPLEMENTAL_CONTEXT_TIMEOUT) {
517+
FileContextProvider.getInstance(project).extractSupplementalFileContext(psiFile, fileContext)
518+
}
519+
} catch (e: Exception) {
520+
if (e is TimeoutCancellationException) {
521+
LOG.debug {
522+
"Supplemental context fetch timed out in ${System.currentTimeMillis() - startFetchingTimestamp}ms"
534523
}
524+
SupplementalContextInfo(
525+
isUtg = isTstFile,
526+
contents = emptyList(),
527+
latency = System.currentTimeMillis() - startFetchingTimestamp,
528+
targetFileName = fileContext.filename
529+
)
530+
} else {
531+
LOG.debug { "Run into unexpected error when fetching supplemental context, error: ${e.message}" }
532+
null
535533
}
536534
}
537-
} else {
538-
SupplementalContextInfo(
539-
isUtg = isTstFile,
540-
contents = emptyList(),
541-
latency = 0,
542-
targetFileName = fileContext.filename
543-
)
544535
}
545536

546537
// 3. caret position

jetbrains-core/src/software/aws/toolkits/jetbrains/services/codewhisperer/util/CodeWhispererConstants.kt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,7 @@ object CodeWhispererConstants {
7777
}
7878

7979
object CrossFile {
80-
const val CHUNK_SIZE = 3000
80+
const val CHUNK_SIZE = 60
8181
}
8282

8383
object Utg {

jetbrains-core/src/software/aws/toolkits/jetbrains/services/codewhisperer/util/CodeWhispererFileContextProvider.kt

Lines changed: 24 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,11 +24,15 @@ import software.aws.toolkits.core.utils.warn
2424
import software.aws.toolkits.jetbrains.services.codewhisperer.editor.CodeWhispererEditorUtil
2525
import software.aws.toolkits.jetbrains.services.codewhisperer.language.CodeWhispererProgrammingLanguage
2626
import software.aws.toolkits.jetbrains.services.codewhisperer.language.languages.CodeWhispererJava
27+
import software.aws.toolkits.jetbrains.services.codewhisperer.language.languages.CodeWhispererJavaScript
2728
import software.aws.toolkits.jetbrains.services.codewhisperer.language.languages.CodeWhispererPython
29+
import software.aws.toolkits.jetbrains.services.codewhisperer.language.languages.CodeWhispererTypeScript
2830
import software.aws.toolkits.jetbrains.services.codewhisperer.language.programmingLanguage
2931
import software.aws.toolkits.jetbrains.services.codewhisperer.model.Chunk
3032
import software.aws.toolkits.jetbrains.services.codewhisperer.model.FileContextInfo
3133
import software.aws.toolkits.jetbrains.services.codewhisperer.model.SupplementalContextInfo
34+
import software.aws.toolkits.jetbrains.services.codewhisperer.service.CodeWhispererUserGroup
35+
import software.aws.toolkits.jetbrains.services.codewhisperer.service.CodeWhispererUserGroupSettings
3236
import java.io.DataInput
3337
import java.io.DataOutput
3438
import java.util.Collections
@@ -47,6 +51,8 @@ private val codewhispererCodeChunksIndex = GistManager.getInstance()
4751
private fun getFileCrawlerForLanguage(programmingLanguage: CodeWhispererProgrammingLanguage) = when (programmingLanguage) {
4852
is CodeWhispererJava -> JavaCodeWhispererFileCrawler
4953
is CodeWhispererPython -> PythonCodeWhispererFileCrawler
54+
is CodeWhispererJavaScript -> JavascriptCodeWhispererFileCrawler
55+
is CodeWhispererTypeScript -> TypescriptCodeWhispererFileCrawler
5056
else -> NoOpFileCrawler()
5157
}
5258

@@ -108,11 +114,27 @@ class DefaultCodeWhispererFileContextProvider(private val project: Project) : Fi
108114
override suspend fun extractSupplementalFileContext(psiFile: PsiFile, targetContext: FileContextInfo): SupplementalContextInfo? {
109115
val startFetchingTimestamp = System.currentTimeMillis()
110116
val isTst = isTestFile(psiFile)
117+
val userGroup = CodeWhispererUserGroupSettings.getInstance().getUserGroup()
118+
val language = targetContext.programmingLanguage
111119

112120
val chunks = if (isTst && targetContext.programmingLanguage.isUTGSupported()) {
113-
extractSupplementalFileContextForTst(psiFile, targetContext)
121+
if (userGroup == CodeWhispererUserGroup.CrossFile) {
122+
extractSupplementalFileContextForTst(psiFile, targetContext)
123+
} else {
124+
emptyList()
125+
}
114126
} else if (!isTst && targetContext.programmingLanguage.isSupplementalContextSupported()) {
115-
extractSupplementalFileContextForSrc(psiFile, targetContext)
127+
when (language) {
128+
is CodeWhispererJava -> extractSupplementalFileContextForSrc(psiFile, targetContext)
129+
130+
is CodeWhispererPython, is CodeWhispererJavaScript, is CodeWhispererTypeScript -> if (userGroup == CodeWhispererUserGroup.CrossFile) {
131+
extractSupplementalFileContextForSrc(psiFile, targetContext)
132+
} else {
133+
emptyList()
134+
}
135+
136+
else -> emptyList()
137+
}
116138
} else {
117139
LOG.debug { "${if (isTst) "UTG" else "CrossFile"} not supported for ${targetContext.programmingLanguage.languageId}" }
118140
null

jetbrains-core/src/software/aws/toolkits/jetbrains/services/codewhisperer/util/CodeWhispererFileCrawler.kt

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ import com.intellij.openapi.application.runReadAction
77
import com.intellij.openapi.fileEditor.FileEditorManager
88
import com.intellij.openapi.project.Project
99
import com.intellij.openapi.project.guessProjectDir
10+
import com.intellij.openapi.roots.TestSourcesFilter
1011
import com.intellij.openapi.vfs.VfsUtil
1112
import com.intellij.openapi.vfs.VirtualFile
1213
import com.intellij.psi.PsiFile
@@ -39,6 +40,9 @@ interface FileCrawler {
3940
*/
4041
fun findFocalFileForTest(psiFile: PsiFile): VirtualFile?
4142

43+
/**
44+
* List files opened in the editors and sorted by file distance @see [CodeWhispererFileCrawler.getFileDistance]
45+
*/
4246
fun listRelevantFilesInEditors(psiFile: PsiFile): List<VirtualFile>
4347
}
4448

@@ -63,6 +67,26 @@ abstract class CodeWhispererFileCrawler : FileCrawler {
6367
}
6468
}.orEmpty()
6569

70+
override fun listRelevantFilesInEditors(psiFile: PsiFile): List<VirtualFile> {
71+
val targetFile = psiFile.virtualFile
72+
73+
val openedFiles = runReadAction {
74+
FileEditorManager.getInstance(psiFile.project).openFiles.toList().filter {
75+
it.name != psiFile.virtualFile.name &&
76+
it.extension == psiFile.virtualFile.extension &&
77+
!TestSourcesFilter.isTestSources(it, psiFile.project)
78+
}
79+
}
80+
81+
val fileToFileDistanceList = runReadAction {
82+
openedFiles.map {
83+
return@map it to CodeWhispererFileCrawler.getFileDistance(targetFile = targetFile, candidateFile = it)
84+
}
85+
}
86+
87+
return fileToFileDistanceList.sortedBy { it.second }.map { it.first }
88+
}
89+
6690
abstract fun guessSourceFileName(tstFileName: String): String
6791

6892
companion object {

jetbrains-core/src/software/aws/toolkits/jetbrains/services/codewhisperer/util/JavaCodeWhispererFileCrawler.kt

Lines changed: 1 addition & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@ import com.intellij.openapi.application.runReadAction
77
import com.intellij.openapi.fileEditor.FileEditorManager
88
import com.intellij.openapi.module.ModuleUtilCore
99
import com.intellij.openapi.project.rootManager
10-
import com.intellij.openapi.roots.TestSourcesFilter
1110
import com.intellij.openapi.vfs.VfsUtil
1211
import com.intellij.openapi.vfs.VirtualFile
1312
import com.intellij.psi.PsiClassOwner
@@ -20,27 +19,11 @@ import org.jetbrains.jps.model.java.JavaModuleSourceRootTypes
2019

2120
// version1: Utilize PSI import elements to resolve imported files
2221
object JavaCodeWhispererFileCrawler : CodeWhispererFileCrawler() {
23-
override val fileExtension: String = ".java"
22+
override val fileExtension: String = "java"
2423
override val testFilenamePattern: Regex = """(?:Test([^/\\]+)\.java|([^/\\]+)Test\.java)$""".toRegex()
2524

2625
override fun guessSourceFileName(tstFileName: String): String = tstFileName.substring(0, tstFileName.length - "Test.java".length) + ".java"
2726

28-
override fun listRelevantFilesInEditors(psiFile: PsiFile): List<VirtualFile> {
29-
val targetFile = psiFile.virtualFile
30-
31-
val openedFiles = FileEditorManager.getInstance(psiFile.project).openFiles.toList().filter {
32-
it.name != psiFile.virtualFile.name &&
33-
it.extension == psiFile.virtualFile.extension &&
34-
!TestSourcesFilter.isTestSources(it, psiFile.project)
35-
}
36-
37-
val fileToFileDistanceList = openedFiles.map {
38-
return@map it to CodeWhispererFileCrawler.getFileDistance(targetFile = targetFile, candidateFile = it)
39-
}
40-
41-
return fileToFileDistanceList.sortedBy { it.second }.map { it.first }
42-
}
43-
4427
override suspend fun listFilesImported(psiFile: PsiFile): List<VirtualFile> {
4528
if (psiFile !is PsiJavaFile) return emptyList()
4629
val result = mutableListOf<VirtualFile>()
Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
// Copyright 2023 Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
// SPDX-License-Identifier: Apache-2.0
3+
4+
package software.aws.toolkits.jetbrains.services.codewhisperer.util
5+
6+
import com.intellij.openapi.vfs.VirtualFile
7+
import com.intellij.psi.PsiFile
8+
9+
object JavascriptCodeWhispererFileCrawler : CodeWhispererFileCrawler() {
10+
override val fileExtension: String = "js"
11+
12+
// TODO: Add implementation when UTG is enabled
13+
override val testFilenamePattern: Regex = "".toRegex()
14+
15+
// TODO: Add implementation when UTG is enabled
16+
override fun guessSourceFileName(tstFileName: String): String = ""
17+
18+
override suspend fun listFilesImported(psiFile: PsiFile): List<VirtualFile> = emptyList()
19+
20+
override fun listFilesWithinSamePackage(psiFile: PsiFile): List<VirtualFile> = emptyList()
21+
22+
override fun findFocalFileForTest(psiFile: PsiFile): VirtualFile? = null
23+
}

0 commit comments

Comments
 (0)