Skip to content

Commit 19742b8

Browse files
authored
crossfile/utg patch (#3768)
1 parent eaa7258 commit 19742b8

File tree

7 files changed

+55
-44
lines changed

7 files changed

+55
-44
lines changed

jetbrains-core/src/software/aws/toolkits/jetbrains/services/codewhisperer/language/languages/CodeWhispererJsx.kt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,8 @@ class CodeWhispererJsx private constructor() : CodeWhispererProgrammingLanguage(
1919

2020
override fun isAllClassifier(): Boolean = true
2121

22+
override fun isSupplementalContextSupported() = true
23+
2224
companion object {
2325
const val ID = "jsx"
2426

jetbrains-core/src/software/aws/toolkits/jetbrains/services/codewhisperer/language/languages/CodeWhispererTsx.kt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,8 @@ class CodeWhispererTsx private constructor() : CodeWhispererProgrammingLanguage(
1919

2020
override fun isAllClassifier(): Boolean = true
2121

22+
override fun isSupplementalContextSupported() = true
23+
2224
companion object {
2325
const val ID = "tsx"
2426

jetbrains-core/src/software/aws/toolkits/jetbrains/services/codewhisperer/util/CodeWhispererFileContextProvider.kt

Lines changed: 33 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,9 @@ import software.aws.toolkits.jetbrains.services.codewhisperer.editor.CodeWhisper
2525
import software.aws.toolkits.jetbrains.services.codewhisperer.language.CodeWhispererProgrammingLanguage
2626
import software.aws.toolkits.jetbrains.services.codewhisperer.language.languages.CodeWhispererJava
2727
import software.aws.toolkits.jetbrains.services.codewhisperer.language.languages.CodeWhispererJavaScript
28+
import software.aws.toolkits.jetbrains.services.codewhisperer.language.languages.CodeWhispererJsx
2829
import software.aws.toolkits.jetbrains.services.codewhisperer.language.languages.CodeWhispererPython
30+
import software.aws.toolkits.jetbrains.services.codewhisperer.language.languages.CodeWhispererTsx
2931
import software.aws.toolkits.jetbrains.services.codewhisperer.language.languages.CodeWhispererTypeScript
3032
import software.aws.toolkits.jetbrains.services.codewhisperer.language.programmingLanguage
3133
import software.aws.toolkits.jetbrains.services.codewhisperer.model.Chunk
@@ -51,8 +53,8 @@ private val codewhispererCodeChunksIndex = GistManager.getInstance()
5153
private fun getFileCrawlerForLanguage(programmingLanguage: CodeWhispererProgrammingLanguage) = when (programmingLanguage) {
5254
is CodeWhispererJava -> JavaCodeWhispererFileCrawler
5355
is CodeWhispererPython -> PythonCodeWhispererFileCrawler
54-
is CodeWhispererJavaScript -> JavascriptCodeWhispererFileCrawler
55-
is CodeWhispererTypeScript -> TypescriptCodeWhispererFileCrawler
56+
is CodeWhispererJavaScript, is CodeWhispererJsx -> JavascriptCodeWhispererFileCrawler
57+
is CodeWhispererTypeScript, is CodeWhispererTsx -> TypescriptCodeWhispererFileCrawler
5658
else -> NoOpFileCrawler()
5759
}
5860

@@ -127,11 +129,12 @@ class DefaultCodeWhispererFileContextProvider(private val project: Project) : Fi
127129
when (language) {
128130
is CodeWhispererJava -> extractSupplementalFileContextForSrc(psiFile, targetContext)
129131

130-
is CodeWhispererPython, is CodeWhispererJavaScript, is CodeWhispererTypeScript -> if (userGroup == CodeWhispererUserGroup.CrossFile) {
131-
extractSupplementalFileContextForSrc(psiFile, targetContext)
132-
} else {
133-
emptyList()
134-
}
132+
is CodeWhispererPython, is CodeWhispererJavaScript, is CodeWhispererTypeScript, is CodeWhispererJsx, is CodeWhispererTsx ->
133+
if (userGroup == CodeWhispererUserGroup.CrossFile) {
134+
extractSupplementalFileContextForSrc(psiFile, targetContext)
135+
} else {
136+
emptyList()
137+
}
135138

136139
else -> emptyList()
137140
}
@@ -195,10 +198,12 @@ class DefaultCodeWhispererFileContextProvider(private val project: Project) : Fi
195198
return chunks.take(CodeWhispererConstants.CrossFile.CHUNK_SIZE)
196199
}
197200

198-
override fun isTestFile(psiFile: PsiFile) = when (psiFile.programmingLanguage()) {
199-
is CodeWhispererJava -> TestSourcesFilter.isTestSources(psiFile.virtualFile, project)
200-
is CodeWhispererPython -> PythonCodeWhispererFileCrawler.testFilenamePattern.matches(psiFile.name)
201-
else -> true
201+
override fun isTestFile(psiFile: PsiFile): Boolean {
202+
val path = runReadAction { contentRootPathProvider.getPathToElement(project, psiFile.virtualFile, null) ?: psiFile.virtualFile.path }
203+
return TestSourcesFilter.isTestSources(psiFile.virtualFile, project) ||
204+
path.contains("""test/""") ||
205+
path.contains("""tst/""") ||
206+
path.contains("""tests/""")
202207
}
203208

204209
@VisibleForTesting
@@ -255,23 +260,25 @@ class DefaultCodeWhispererFileContextProvider(private val project: Project) : Fi
255260
val focalFile = getFileCrawlerForLanguage(targetContext.programmingLanguage).findFocalFileForTest(psiFile)
256261

257262
return focalFile?.let { file ->
258-
val relativePath = contentRootPathProvider.getPathToElement(project, file, null) ?: file.path
259-
val content = file.content()
263+
runReadAction {
264+
val relativePath = contentRootPathProvider.getPathToElement(project, file, null) ?: file.path
265+
val content = file.content()
260266

261-
if (content.isBlank()) {
262-
emptyList()
263-
} else {
264-
listOf(
265-
Chunk(
266-
content = CodeWhispererConstants.Utg.UTG_PREFIX + file.content().let {
267-
it.substring(
268-
0,
269-
minOf(it.length, CodeWhispererConstants.Utg.UTG_SEGMENT_SIZE)
270-
)
271-
},
272-
path = relativePath
267+
if (content.isBlank()) {
268+
emptyList()
269+
} else {
270+
listOf(
271+
Chunk(
272+
content = CodeWhispererConstants.Utg.UTG_PREFIX + file.content().let {
273+
it.substring(
274+
0,
275+
minOf(it.length, CodeWhispererConstants.Utg.UTG_SEGMENT_SIZE)
276+
)
277+
},
278+
path = relativePath
279+
)
273280
)
274-
)
281+
}
275282
}
276283
}.orEmpty()
277284
}

jetbrains-core/src/software/aws/toolkits/jetbrains/services/codewhisperer/util/CodeWhispererFileCrawler.kt

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@ import com.intellij.openapi.vfs.VfsUtil
1212
import com.intellij.openapi.vfs.VirtualFile
1313
import com.intellij.psi.PsiFile
1414
import com.intellij.psi.PsiManager
15+
import software.aws.toolkits.jetbrains.services.codewhisperer.language.languages.CodeWhispererJava
16+
import software.aws.toolkits.jetbrains.services.codewhisperer.language.programmingLanguage
1517

1618
/**
1719
* An interface define how do we parse and fetch files provided a psi file or project
@@ -70,12 +72,19 @@ abstract class CodeWhispererFileCrawler : FileCrawler {
7072

7173
override fun listRelevantFilesInEditors(psiFile: PsiFile): List<VirtualFile> {
7274
val targetFile = psiFile.virtualFile
75+
val language = psiFile.programmingLanguage()
76+
77+
val isTestFilePredicate: (file: VirtualFile, project: Project) -> Boolean = if (language is CodeWhispererJava) {
78+
{ file, project -> TestSourcesFilter.isTestSources(file, project) }
79+
} else {
80+
{ file, _ -> testFilenamePattern.matches(file.name) }
81+
}
7382

7483
val openedFiles = runReadAction {
7584
FileEditorManager.getInstance(psiFile.project).openFiles.toList().filter {
7685
it.name != psiFile.virtualFile.name &&
7786
isSameDialect(psiFile.virtualFile.extension) &&
78-
!TestSourcesFilter.isTestSources(it, psiFile.project)
87+
!isTestFilePredicate(it, psiFile.project)
7988
}
8089
}
8190

jetbrains-core/src/software/aws/toolkits/jetbrains/services/codewhisperer/util/JavascriptCodeWhispererFileCrawler.kt

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,7 @@ object JavascriptCodeWhispererFileCrawler : CodeWhispererFileCrawler() {
1010
override val fileExtension: String = "js"
1111
override val dialects: Set<String> = setOf("js", "jsx")
1212

13-
// TODO: Add implementation when UTG is enabled
14-
override val testFilenamePattern: Regex = "".toRegex()
13+
override val testFilenamePattern: Regex = """^.*\.test\.(js|jsx)${'$'}""".toRegex()
1514

1615
// TODO: Add implementation when UTG is enabled
1716
override fun guessSourceFileName(tstFileName: String): String = ""

jetbrains-core/src/software/aws/toolkits/jetbrains/services/codewhisperer/util/PythonCodeWhispererFileCrawler.kt

Lines changed: 6 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -4,13 +4,9 @@
44
package software.aws.toolkits.jetbrains.services.codewhisperer.util
55

66
import com.intellij.openapi.application.runReadAction
7-
import com.intellij.openapi.module.ModuleUtilCore
8-
import com.intellij.openapi.project.rootManager
9-
import com.intellij.openapi.vfs.VfsUtil
107
import com.intellij.openapi.vfs.VirtualFile
118
import com.intellij.psi.PsiFile
129
import com.jetbrains.python.psi.PyFile
13-
import org.jetbrains.jps.model.java.JavaModuleSourceRootTypes
1410

1511
object PythonCodeWhispererFileCrawler : CodeWhispererFileCrawler() {
1612
override val fileExtension: String = "py"
@@ -36,15 +32,12 @@ object PythonCodeWhispererFileCrawler : CodeWhispererFileCrawler() {
3632

3733
override fun findFocalFileForTest(psiFile: PsiFile): VirtualFile? = findSourceFileByName(psiFile) ?: findRelevantFileFromEditors(psiFile)
3834

39-
private fun findSourceFileByName(psiFile: PsiFile): VirtualFile? {
40-
val module = ModuleUtilCore.findModuleForFile(psiFile)
41-
42-
return module?.rootManager?.getSourceRoots(JavaModuleSourceRootTypes.PRODUCTION)?.let { srcRoot ->
43-
srcRoot
44-
.map { root -> VfsUtil.collectChildrenRecursively(root) }
45-
.flatten()
46-
.find { !it.isDirectory && it.isWritable && it.name.contains(guessSourceFileName(psiFile.name)) }
47-
}
35+
private fun findSourceFileByName(psiFile: PsiFile): VirtualFile? = super.listFilesUnderProjectRoot(psiFile.project).find {
36+
!it.isDirectory &&
37+
it.isWritable &&
38+
it.name != psiFile.virtualFile.name &&
39+
// TODO: should we use strict equal instead?
40+
it.name.contains(guessSourceFileName(psiFile.name))
4841
}
4942

5043
/**

jetbrains-core/src/software/aws/toolkits/jetbrains/services/codewhisperer/util/TypescriptCodeWhispererFileCrawler.kt

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,7 @@ object TypescriptCodeWhispererFileCrawler : CodeWhispererFileCrawler() {
1010
override val fileExtension: String = "ts"
1111
override val dialects: Set<String> = setOf("ts", "tsx")
1212

13-
// TODO: Add implementation when UTG is enabled
14-
override val testFilenamePattern: Regex = "".toRegex()
13+
override val testFilenamePattern: Regex = """^.*\.test\.(ts|tsx)${'$'}""".toRegex()
1514

1615
// TODO: Add implementation when UTG is enabled
1716
override fun guessSourceFileName(tstFileName: String): String = ""

0 commit comments

Comments
 (0)