Skip to content

Commit 7b43dbe

Browse files
authored
CodeWhisperer: factor in language dialect for crossfile context fetching (#3767)
1 parent 2e82d2d commit 7b43dbe

File tree

5 files changed

+11
-1
lines changed

5 files changed

+11
-1
lines changed

jetbrains-core/src/software/aws/toolkits/jetbrains/services/codewhisperer/util/CodeWhispererFileCrawler.kt

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,7 @@ class NoOpFileCrawler : FileCrawler {
6060
abstract class CodeWhispererFileCrawler : FileCrawler {
6161
abstract val fileExtension: String
6262
abstract val testFilenamePattern: Regex
63+
abstract val dialects: Set<String>
6364

6465
override fun listFilesUnderProjectRoot(project: Project): List<VirtualFile> = project.guessProjectDir()?.let { rootDir ->
6566
VfsUtil.collectChildrenRecursively(rootDir).filter {
@@ -73,7 +74,7 @@ abstract class CodeWhispererFileCrawler : FileCrawler {
7374
val openedFiles = runReadAction {
7475
FileEditorManager.getInstance(psiFile.project).openFiles.toList().filter {
7576
it.name != psiFile.virtualFile.name &&
76-
it.extension == psiFile.virtualFile.extension &&
77+
isSameDialect(psiFile.virtualFile.extension) &&
7778
!TestSourcesFilter.isTestSources(it, psiFile.project)
7879
}
7980
}
@@ -89,6 +90,10 @@ abstract class CodeWhispererFileCrawler : FileCrawler {
8990

9091
abstract fun guessSourceFileName(tstFileName: String): String
9192

93+
private fun isSameDialect(fileExt: String?): Boolean = fileExt?.let {
94+
dialects.contains(fileExt)
95+
} ?: false
96+
9297
companion object {
9398
fun searchRelevantFileInEditors(target: PsiFile, keywordProducer: (psiFile: PsiFile) -> List<String>): VirtualFile? {
9499
val project = target.project

jetbrains-core/src/software/aws/toolkits/jetbrains/services/codewhisperer/util/JavaCodeWhispererFileCrawler.kt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ import org.jetbrains.jps.model.java.JavaModuleSourceRootTypes
2121
object JavaCodeWhispererFileCrawler : CodeWhispererFileCrawler() {
2222
override val fileExtension: String = "java"
2323
override val testFilenamePattern: Regex = """(?:Test([^/\\]+)\.java|([^/\\]+)Test\.java)$""".toRegex()
24+
override val dialects: Set<String> = setOf("java")
2425

2526
override fun guessSourceFileName(tstFileName: String): String = tstFileName.substring(0, tstFileName.length - "Test.java".length) + ".java"
2627

jetbrains-core/src/software/aws/toolkits/jetbrains/services/codewhisperer/util/JavascriptCodeWhispererFileCrawler.kt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ import com.intellij.psi.PsiFile
88

99
object JavascriptCodeWhispererFileCrawler : CodeWhispererFileCrawler() {
1010
override val fileExtension: String = "js"
11+
override val dialects: Set<String> = setOf("js", "jsx")
1112

1213
// TODO: Add implementation when UTG is enabled
1314
override val testFilenamePattern: Regex = "".toRegex()

jetbrains-core/src/software/aws/toolkits/jetbrains/services/codewhisperer/util/PythonCodeWhispererFileCrawler.kt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@ import org.jetbrains.jps.model.java.JavaModuleSourceRootTypes
1414

1515
object PythonCodeWhispererFileCrawler : CodeWhispererFileCrawler() {
1616
override val fileExtension: String = "py"
17+
override val dialects: Set<String> = setOf("py")
18+
1719
override val testFilenamePattern: Regex = Regex("""(?:test_([^/\\]+)\.py|([^/\\]+)_test\.py)${'$'}""")
1820
override suspend fun listFilesImported(psiFile: PsiFile): List<VirtualFile> = emptyList()
1921

jetbrains-core/src/software/aws/toolkits/jetbrains/services/codewhisperer/util/TypescriptCodeWhispererFileCrawler.kt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ import com.intellij.psi.PsiFile
88

99
object TypescriptCodeWhispererFileCrawler : CodeWhispererFileCrawler() {
1010
override val fileExtension: String = "ts"
11+
override val dialects: Set<String> = setOf("ts", "tsx")
1112

1213
// TODO: Add implementation when UTG is enabled
1314
override val testFilenamePattern: Regex = "".toRegex()

0 commit comments

Comments
 (0)