Skip to content

Commit 0e756b7

Browse files
authored
restructure and cleanup utg code path (#3774)
1 parent cca7b45 commit 0e756b7

File tree

7 files changed

+440
-168
lines changed

7 files changed

+440
-168
lines changed

jetbrains-core/src/software/aws/toolkits/jetbrains/services/codewhisperer/util/CodeWhispererFileContextProvider.kt

Lines changed: 7 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@ import com.intellij.openapi.application.runReadAction
88
import com.intellij.openapi.components.service
99
import com.intellij.openapi.editor.Editor
1010
import com.intellij.openapi.project.Project
11-
import com.intellij.openapi.roots.TestSourcesFilter
1211
import com.intellij.openapi.vfs.VirtualFile
1312
import com.intellij.psi.PsiFile
1413
import com.intellij.util.gist.GistManager
@@ -45,7 +44,7 @@ private val codewhispererCodeChunksIndex = GistManager.getInstance()
4544
.newPsiFileGist("psi to code chunk index", 0, CodeWhispererCodeChunkExternalizer) { psiFile ->
4645
runBlocking {
4746
val fileCrawler = getFileCrawlerForLanguage(psiFile.programmingLanguage())
48-
val fileProducers = listOf<suspend (PsiFile) -> List<VirtualFile>> { psiFile -> fileCrawler.listRelevantFilesInEditors(psiFile) }
47+
val fileProducers = listOf<suspend (PsiFile) -> List<VirtualFile>> { psiFile -> fileCrawler.listCrossFileCandidate(psiFile) }
4948
FileContextProvider.getInstance(psiFile.project).extractCodeChunksFromFiles(psiFile, fileProducers)
5049
}
5150
}
@@ -96,6 +95,10 @@ interface FileContextProvider {
9695

9796
suspend fun extractCodeChunksFromFiles(psiFile: PsiFile, fileProducers: List<suspend (PsiFile) -> List<VirtualFile>>): List<Chunk>
9897

98+
/**
99+
* It will actually delegate to invoke corresponding [CodeWhispererFileCrawler.isTestFile] for each language
100+
* as different languages have their own naming conventions.
101+
*/
99102
fun isTestFile(psiFile: PsiFile): Boolean
100103

101104
companion object {
@@ -198,13 +201,7 @@ class DefaultCodeWhispererFileContextProvider(private val project: Project) : Fi
198201
return chunks.take(CodeWhispererConstants.CrossFile.CHUNK_SIZE)
199202
}
200203

201-
override fun isTestFile(psiFile: PsiFile): Boolean {
202-
val path = runReadAction { contentRootPathProvider.getPathToElement(project, psiFile.virtualFile, null) ?: psiFile.virtualFile.path }
203-
return TestSourcesFilter.isTestSources(psiFile.virtualFile, project) ||
204-
path.contains("""test/""") ||
205-
path.contains("""tst/""") ||
206-
path.contains("""tests/""")
207-
}
204+
override fun isTestFile(psiFile: PsiFile) = getFileCrawlerForLanguage(psiFile.programmingLanguage()).isTestFile(psiFile.virtualFile, psiFile.project)
208205

209206
@VisibleForTesting
210207
suspend fun extractSupplementalFileContextForSrc(psiFile: PsiFile, targetContext: FileContextInfo): List<Chunk> {
@@ -257,7 +254,7 @@ class DefaultCodeWhispererFileContextProvider(private val project: Project) : Fi
257254
fun extractSupplementalFileContextForTst(psiFile: PsiFile, targetContext: FileContextInfo): List<Chunk> {
258255
if (!targetContext.programmingLanguage.isUTGSupported()) return emptyList()
259256

260-
val focalFile = getFileCrawlerForLanguage(targetContext.programmingLanguage).findFocalFileForTest(psiFile)
257+
val focalFile = getFileCrawlerForLanguage(targetContext.programmingLanguage).listUtgCandidate(psiFile)
261258

262259
return focalFile?.let { file ->
263260
runReadAction {

jetbrains-core/src/software/aws/toolkits/jetbrains/services/codewhisperer/util/CodeWhispererFileCrawler.kt

Lines changed: 66 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,7 @@ import com.intellij.openapi.vfs.VfsUtil
1212
import com.intellij.openapi.vfs.VirtualFile
1313
import com.intellij.psi.PsiFile
1414
import com.intellij.psi.PsiManager
15-
import software.aws.toolkits.jetbrains.services.codewhisperer.language.languages.CodeWhispererJava
16-
import software.aws.toolkits.jetbrains.services.codewhisperer.language.programmingLanguage
15+
import software.aws.toolkits.core.utils.tryOrNull
1716

1817
/**
1918
* An interface define how do we parse and fetch files provided a psi file or project
@@ -40,51 +39,84 @@ interface FileCrawler {
4039
* @param psiFile psi of the test file we are searching with, e.g. MainTest.java
4140
* @return its source file e.g. Main.java, main.py or most relevant file if any
4241
*/
43-
fun findFocalFileForTest(psiFile: PsiFile): VirtualFile?
42+
fun listUtgCandidate(psiFile: PsiFile): VirtualFile?
4443

4544
/**
4645
* List files opened in the editors and sorted by file distance @see [CodeWhispererFileCrawler.getFileDistance]
46+
* @return opened files and satisfy the following conditions
47+
* (1) not the input file
48+
* (2) with the same file extension as the input file has
49+
* (3) non-test file which will be determined by [FileCrawler.isTestFile]
50+
* (4) writable file
4751
*/
48-
fun listRelevantFilesInEditors(psiFile: PsiFile): List<VirtualFile>
52+
fun listCrossFileCandidate(psiFile: PsiFile): List<VirtualFile>
53+
54+
/**
55+
* Determine if the file given is test file or not based on its path and file name
56+
*/
57+
fun isTestFile(virtualFile: VirtualFile, project: Project): Boolean
4958
}
5059

5160
class NoOpFileCrawler : FileCrawler {
5261
override suspend fun listFilesImported(psiFile: PsiFile): List<VirtualFile> = emptyList()
5362

5463
override fun listFilesUnderProjectRoot(project: Project): List<VirtualFile> = emptyList()
55-
override fun findFocalFileForTest(psiFile: PsiFile): VirtualFile? = null
64+
override fun listUtgCandidate(psiFile: PsiFile): VirtualFile? = null
5665

5766
override fun listFilesWithinSamePackage(psiFile: PsiFile): List<VirtualFile> = emptyList()
5867

59-
override fun listRelevantFilesInEditors(psiFile: PsiFile): List<VirtualFile> = emptyList()
68+
override fun listCrossFileCandidate(psiFile: PsiFile): List<VirtualFile> = emptyList()
69+
70+
override fun isTestFile(virtualFile: VirtualFile, project: Project): Boolean = false
6071
}
6172

6273
abstract class CodeWhispererFileCrawler : FileCrawler {
6374
abstract val fileExtension: String
64-
abstract val testFilenamePattern: Regex
6575
abstract val dialects: Set<String>
76+
abstract val testFileNamingPatterns: List<Regex>
77+
78+
override fun isTestFile(virtualFile: VirtualFile, project: Project): Boolean {
79+
val filePath = virtualFile.path
80+
81+
// if file path itself explicitly explains the file is under test sources
82+
if (TestSourcesFilter.isTestSources(virtualFile, project) ||
83+
filePath.contains("""test/""", ignoreCase = true) ||
84+
filePath.contains("""tst/""", ignoreCase = true) ||
85+
filePath.contains("""tests/""", ignoreCase = true)
86+
) {
87+
return true
88+
}
89+
90+
// no explicit clue from the file path, use regexes based on naming conventions
91+
return testFileNamingPatterns.any { it.matches(virtualFile.name) }
92+
}
6693

6794
override fun listFilesUnderProjectRoot(project: Project): List<VirtualFile> = project.guessProjectDir()?.let { rootDir ->
6895
VfsUtil.collectChildrenRecursively(rootDir).filter {
96+
// TODO: need to handle cases js vs. jsx, ts vs. tsx when we enable js/ts utg since we likely have different file extensions
6997
it.path.endsWith(fileExtension)
7098
}
7199
}.orEmpty()
72100

73-
override fun listRelevantFilesInEditors(psiFile: PsiFile): List<VirtualFile> {
74-
val targetFile = psiFile.virtualFile
75-
val language = psiFile.programmingLanguage()
101+
override fun listFilesWithinSamePackage(psiFile: PsiFile): List<VirtualFile> = runReadAction {
102+
psiFile.containingDirectory?.files?.mapNotNull {
103+
// exclude target file
104+
if (it != psiFile) {
105+
it.virtualFile
106+
} else {
107+
null
108+
}
109+
}.orEmpty()
110+
}
76111

77-
val isTestFilePredicate: (file: VirtualFile, project: Project) -> Boolean = if (language is CodeWhispererJava) {
78-
{ file, project -> TestSourcesFilter.isTestSources(file, project) }
79-
} else {
80-
{ file, _ -> testFilenamePattern.matches(file.name) }
81-
}
112+
override fun listCrossFileCandidate(psiFile: PsiFile): List<VirtualFile> {
113+
val targetFile = psiFile.virtualFile
82114

83115
val openedFiles = runReadAction {
84116
FileEditorManager.getInstance(psiFile.project).openFiles.toList().filter {
85117
it.name != psiFile.virtualFile.name &&
86118
isSameDialect(it.extension) &&
87-
!isTestFilePredicate(it, psiFile.project)
119+
!isTestFile(it, psiFile.project)
88120
}
89121
}
90122

@@ -97,7 +129,24 @@ abstract class CodeWhispererFileCrawler : FileCrawler {
97129
return fileToFileDistanceList.sortedBy { it.second }.map { it.first }
98130
}
99131

100-
abstract fun guessSourceFileName(tstFileName: String): String
132+
override fun listUtgCandidate(psiFile: PsiFile): VirtualFile? = findSourceFileByName(psiFile) ?: findSourceFileByContent(psiFile)
133+
134+
abstract fun findSourceFileByName(psiFile: PsiFile): VirtualFile?
135+
136+
abstract fun findSourceFileByContent(psiFile: PsiFile): VirtualFile?
137+
138+
// TODO: may need to update when we enable JS/TS UTG, since we have to factor in .jsx/.tsx combinations
139+
fun guessSourceFileName(tstFileName: String): String? {
140+
val srcFileName = tryOrNull {
141+
testFileNamingPatterns.firstNotNullOf { regex ->
142+
regex.find(tstFileName)?.groupValues?.let { groupValues ->
143+
groupValues.get(1) + groupValues.get(2)
144+
}
145+
}
146+
}
147+
148+
return srcFileName
149+
}
101150

102151
private fun isSameDialect(fileExt: String?): Boolean = fileExt?.let {
103152
dialects.contains(fileExt)

jetbrains-core/src/software/aws/toolkits/jetbrains/services/codewhisperer/util/JavaCodeWhispererFileCrawler.kt

Lines changed: 14 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -17,13 +17,13 @@ import com.intellij.psi.search.GlobalSearchScope
1717
import kotlinx.coroutines.yield
1818
import org.jetbrains.jps.model.java.JavaModuleSourceRootTypes
1919

20-
// version1: Utilize PSI import elements to resolve imported files
2120
object JavaCodeWhispererFileCrawler : CodeWhispererFileCrawler() {
2221
override val fileExtension: String = "java"
23-
override val testFilenamePattern: Regex = """(?:Test([^/\\]+)\.java|([^/\\]+)Test\.java)$""".toRegex()
2422
override val dialects: Set<String> = setOf("java")
25-
26-
override fun guessSourceFileName(tstFileName: String): String = tstFileName.substring(0, tstFileName.length - "Test.java".length) + ".java"
23+
override val testFileNamingPatterns = listOf(
24+
Regex("""^(.+)Test(\.java)$"""),
25+
Regex("""^(.+)Tests(\.java)$""")
26+
)
2727

2828
override suspend fun listFilesImported(psiFile: PsiFile): List<VirtualFile> {
2929
if (psiFile !is PsiJavaFile) return emptyList()
@@ -67,35 +67,23 @@ object JavaCodeWhispererFileCrawler : CodeWhispererFileCrawler() {
6767
return result
6868
}
6969

70-
override fun listFilesWithinSamePackage(targetFile: PsiFile): List<VirtualFile> = runReadAction {
71-
targetFile.containingDirectory?.files?.mapNotNull {
72-
// exclude target file
73-
if (it != targetFile) {
74-
it.virtualFile
75-
} else {
76-
null
77-
}
78-
}.orEmpty()
79-
}
80-
81-
override fun findFocalFileForTest(psiFile: PsiFile): VirtualFile? = findSourceFileByName(psiFile) ?: findRelevantFileFromEditors(psiFile)
82-
8370
// psiFile = "MainTest.java", targetFileName = "Main.java"
84-
private fun findSourceFileByName(psiFile: PsiFile): VirtualFile? {
85-
val module = ModuleUtilCore.findModuleForFile(psiFile)
71+
override fun findSourceFileByName(psiFile: PsiFile): VirtualFile? =
72+
guessSourceFileName(psiFile.virtualFile.name)?.let { srcName ->
73+
val module = ModuleUtilCore.findModuleForFile(psiFile)
8674

87-
return module?.rootManager?.getSourceRoots(JavaModuleSourceRootTypes.PRODUCTION)?.let { srcRoot ->
88-
srcRoot
89-
.map { root -> VfsUtil.collectChildrenRecursively(root) }
90-
.flatten()
91-
.find { !it.isDirectory && it.isWritable && it.name.contains(guessSourceFileName(psiFile.name)) }
75+
module?.rootManager?.getSourceRoots(JavaModuleSourceRootTypes.PRODUCTION)?.let { srcRoot ->
76+
srcRoot
77+
.map { root -> VfsUtil.collectChildrenRecursively(root) }
78+
.flatten()
79+
.find { !it.isDirectory && it.isWritable && it.name == srcName }
80+
}
9281
}
93-
}
9482

9583
/**
9684
* check files in editors and pick one which has most substring matches to the target
9785
*/
98-
private fun findRelevantFileFromEditors(psiFile: PsiFile): VirtualFile? = searchRelevantFileInEditors(psiFile) { myPsiFile ->
86+
override fun findSourceFileByContent(psiFile: PsiFile): VirtualFile? = searchRelevantFileInEditors(psiFile) { myPsiFile ->
9987
if (myPsiFile !is PsiClassOwner) {
10088
return@searchRelevantFileInEditors emptyList()
10189
}

jetbrains-core/src/software/aws/toolkits/jetbrains/services/codewhisperer/util/JavascriptCodeWhispererFileCrawler.kt

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -9,15 +9,14 @@ import com.intellij.psi.PsiFile
99
object JavascriptCodeWhispererFileCrawler : CodeWhispererFileCrawler() {
1010
override val fileExtension: String = "js"
1111
override val dialects: Set<String> = setOf("js", "jsx")
12-
13-
override val testFilenamePattern: Regex = """^.*\.test\.(js|jsx)${'$'}""".toRegex()
14-
15-
// TODO: Add implementation when UTG is enabled
16-
override fun guessSourceFileName(tstFileName: String): String = ""
12+
override val testFileNamingPatterns: List<Regex> = listOf(
13+
Regex("""^(.+)\.(?i:t)est(\.js|\.jsx)$"""),
14+
Regex("""^(.+)\.(?i:s)pec(\.js|\.jsx)$""")
15+
)
1716

1817
override suspend fun listFilesImported(psiFile: PsiFile): List<VirtualFile> = emptyList()
1918

20-
override fun listFilesWithinSamePackage(psiFile: PsiFile): List<VirtualFile> = emptyList()
19+
override fun findSourceFileByName(psiFile: PsiFile): VirtualFile? = null
2120

22-
override fun findFocalFileForTest(psiFile: PsiFile): VirtualFile? = null
21+
override fun findSourceFileByContent(psiFile: PsiFile): VirtualFile? = null
2322
}

jetbrains-core/src/software/aws/toolkits/jetbrains/services/codewhisperer/util/PythonCodeWhispererFileCrawler.kt

Lines changed: 7 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -11,39 +11,24 @@ import com.jetbrains.python.psi.PyFile
1111
object PythonCodeWhispererFileCrawler : CodeWhispererFileCrawler() {
1212
override val fileExtension: String = "py"
1313
override val dialects: Set<String> = setOf("py")
14+
override val testFileNamingPatterns: List<Regex> = listOf(
15+
Regex("""^test_(.+)(\.py)$"""),
16+
Regex("""^(.+)_test(\.py)$""")
17+
)
1418

15-
override val testFilenamePattern: Regex = Regex("""(?:test_([^/\\]+)\.py|([^/\\]+)_test\.py)${'$'}""")
1619
override suspend fun listFilesImported(psiFile: PsiFile): List<VirtualFile> = emptyList()
1720

18-
override fun guessSourceFileName(tstFileName: String): String {
19-
assert(testFilenamePattern.matches(tstFileName))
20-
return tstFileName.substring(5)
21-
}
22-
23-
override fun listFilesWithinSamePackage(psiFile: PsiFile): List<VirtualFile> {
24-
val targetPackagePath = psiFile.virtualFile.let {
25-
it.path.removeSuffix(it.name)
26-
}
27-
return listFilesUnderProjectRoot(psiFile.project).filter {
28-
val packagePath = it.path.removeSuffix(it.name)
29-
targetPackagePath == packagePath && it != psiFile.virtualFile
30-
}
31-
}
32-
33-
override fun findFocalFileForTest(psiFile: PsiFile): VirtualFile? = findSourceFileByName(psiFile) ?: findRelevantFileFromEditors(psiFile)
34-
35-
private fun findSourceFileByName(psiFile: PsiFile): VirtualFile? = super.listFilesUnderProjectRoot(psiFile.project).find {
21+
override fun findSourceFileByName(psiFile: PsiFile): VirtualFile? = super.listFilesUnderProjectRoot(psiFile.project).find {
3622
!it.isDirectory &&
3723
it.isWritable &&
3824
it.name != psiFile.virtualFile.name &&
39-
// TODO: should we use strict equal instead?
40-
it.name.contains(guessSourceFileName(psiFile.name))
25+
it.name == guessSourceFileName(psiFile.name)
4126
}
4227

4328
/**
4429
* check files in editors and pick one which has most substring matches to the target
4530
*/
46-
private fun findRelevantFileFromEditors(psiFile: PsiFile): VirtualFile? = searchRelevantFileInEditors(psiFile) { myPsiFile ->
31+
override fun findSourceFileByContent(psiFile: PsiFile): VirtualFile? = searchRelevantFileInEditors(psiFile) { myPsiFile ->
4732
if (myPsiFile !is PyFile) {
4833
return@searchRelevantFileInEditors emptyList()
4934
}

jetbrains-core/src/software/aws/toolkits/jetbrains/services/codewhisperer/util/TypescriptCodeWhispererFileCrawler.kt

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -9,15 +9,14 @@ import com.intellij.psi.PsiFile
99
object TypescriptCodeWhispererFileCrawler : CodeWhispererFileCrawler() {
1010
override val fileExtension: String = "ts"
1111
override val dialects: Set<String> = setOf("ts", "tsx")
12-
13-
override val testFilenamePattern: Regex = """^.*\.test\.(ts|tsx)${'$'}""".toRegex()
14-
15-
// TODO: Add implementation when UTG is enabled
16-
override fun guessSourceFileName(tstFileName: String): String = ""
12+
override val testFileNamingPatterns: List<Regex> = listOf(
13+
Regex("""^(.+)\.(?i:t)est(\.ts|\.tsx)$"""),
14+
Regex("""^(.+)\.(?i:s)pec(\.ts|\.tsx)$""")
15+
)
1716

1817
override suspend fun listFilesImported(psiFile: PsiFile): List<VirtualFile> = emptyList()
1918

20-
override fun listFilesWithinSamePackage(psiFile: PsiFile): List<VirtualFile> = emptyList()
19+
override fun findSourceFileByName(psiFile: PsiFile): VirtualFile? = null
2120

22-
override fun findFocalFileForTest(psiFile: PsiFile): VirtualFile? = null
21+
override fun findSourceFileByContent(psiFile: PsiFile): VirtualFile? = null
2322
}

0 commit comments

Comments
 (0)