4
4
package software.aws.toolkits.jetbrains.services.codewhisperer.codescan.sessionconfig
5
5
6
6
import com.intellij.openapi.project.Project
7
+ import com.intellij.openapi.project.guessProjectDir
7
8
import com.intellij.openapi.vfs.LocalFileSystem
9
+ import com.intellij.openapi.vfs.VfsUtil
8
10
import com.intellij.openapi.vfs.VirtualFile
9
11
import software.aws.toolkits.core.utils.createTemporaryZipFile
10
12
import software.aws.toolkits.core.utils.debug
11
13
import software.aws.toolkits.core.utils.getLogger
12
14
import software.aws.toolkits.core.utils.putNextEntry
13
15
import software.aws.toolkits.jetbrains.services.codewhisperer.codescan.fileFormatNotSupported
16
+ import software.aws.toolkits.jetbrains.services.codewhisperer.codescan.fileTooLarge
17
+ import software.aws.toolkits.jetbrains.services.codewhisperer.editor.CodeWhispererEditorUtil.codeWhispererLanguage
14
18
import software.aws.toolkits.jetbrains.services.codewhisperer.util.CodeWhispererConstants.CODE_SCAN_CREATE_PAYLOAD_TIMEOUT_IN_SECONDS
15
19
import software.aws.toolkits.jetbrains.services.codewhisperer.util.CodeWhispererConstants.TOTAL_BYTES_IN_KB
16
20
import software.aws.toolkits.jetbrains.services.codewhisperer.util.CodeWhispererConstants.TOTAL_BYTES_IN_MB
17
21
import software.aws.toolkits.telemetry.CodewhispererLanguage
18
22
import java.io.File
23
+ import java.nio.file.Files
19
24
import java.nio.file.Path
25
+ import java.time.Instant
20
26
21
- internal sealed class CodeScanSessionConfig {
22
- abstract fun createPayload (): Payload
27
+ internal sealed class CodeScanSessionConfig (
28
+ private val selectedFile : VirtualFile ,
29
+ private val project : Project
30
+ ) {
31
+ protected val projectRoot = project.guessProjectDir() ? : error(" Cannot guess base directory for project ${project.name} " )
32
+
33
+ abstract val sourceExt: String
23
34
24
35
/* *
25
36
* Timeout for the overall job - "Run Security Scan".
26
37
*/
27
38
abstract fun overallJobTimeoutInSeconds (): Long
28
39
40
+ abstract fun getPayloadLimitInBytes (): Int
41
+
42
+ open fun getImportedFiles (file : VirtualFile , includedSourceFiles : Set <String >): List <String > = listOf ()
43
+
44
+ open fun createPayload (): Payload {
45
+ // Fail fast if the selected file size is greater than the payload limit.
46
+ if (selectedFile.length > getPayloadLimitInBytes()) {
47
+ fileTooLarge(getPresentablePayloadLimit())
48
+ }
49
+
50
+ val start = Instant .now().toEpochMilli()
51
+
52
+ LOG .debug { " Creating payload. File selected as root for the context truncation: ${selectedFile.path} " }
53
+
54
+ val (includedSourceFiles, payloadSize, totalLines, _) = includeDependencies()
55
+
56
+ // Copy all the included source files to the source zip
57
+ val srcZip = zipFiles(includedSourceFiles.map { Path .of(it) })
58
+ val payloadContext = PayloadContext (
59
+ selectedFile.codeWhispererLanguage,
60
+ totalLines,
61
+ includedSourceFiles.size,
62
+ Instant .now().toEpochMilli() - start,
63
+ payloadSize,
64
+ srcZip.length()
65
+ )
66
+
67
+ return Payload (payloadContext, srcZip)
68
+ }
69
+
70
+ open fun includeDependencies (): PayloadMetadata {
71
+ val includedSourceFiles = mutableSetOf<String >()
72
+ var currentTotalFileSize = 0L
73
+ var currentTotalLines = 0L
74
+ val files = getSourceFilesUnderProjectRoot(selectedFile)
75
+ val queue = ArrayDeque <String >()
76
+
77
+ files.forEach { pivotFile ->
78
+ val filePath = pivotFile.path
79
+ queue.addLast(filePath)
80
+
81
+ // BFS
82
+ while (queue.isNotEmpty()) {
83
+ if (currentTotalFileSize.equals(getPayloadLimitInBytes())) {
84
+ return PayloadMetadata (includedSourceFiles, currentTotalFileSize, currentTotalLines)
85
+ }
86
+
87
+ val currentFilePath = queue.removeFirst()
88
+ val currentFile = File (currentFilePath).toVirtualFile()
89
+ if (includedSourceFiles.contains(currentFilePath) || currentFile == null ) continue
90
+
91
+ val currentFileSize = currentFile.length
92
+
93
+ // Ignore file if including it exceeds the payload limit.
94
+ if (currentTotalFileSize > getPayloadLimitInBytes() - currentFileSize) continue
95
+
96
+ currentTotalFileSize + = currentFileSize
97
+ currentTotalLines + = Files .lines(currentFile.toNioPath()).count()
98
+ includedSourceFiles.add(currentFilePath)
99
+
100
+ getImportedFiles(currentFile, includedSourceFiles).forEach {
101
+ if (! includedSourceFiles.contains(it)) queue.addLast(it)
102
+ }
103
+ }
104
+ }
105
+
106
+ return PayloadMetadata (includedSourceFiles, currentTotalFileSize, currentTotalLines)
107
+ }
108
+
29
109
/* *
30
110
* Timeout for creating the payload [createPayload]
31
111
*/
32
112
open fun createPayloadTimeoutInSeconds (): Long = CODE_SCAN_CREATE_PAYLOAD_TIMEOUT_IN_SECONDS
33
113
34
- abstract fun getPayloadLimitInBytes (): Int
35
-
36
114
open fun getPresentablePayloadLimit (): String = when (getPayloadLimitInBytes() >= TOTAL_BYTES_IN_MB ) {
37
115
true -> " ${getPayloadLimitInBytes() / TOTAL_BYTES_IN_MB } MB"
38
116
false -> " ${getPayloadLimitInBytes() / TOTAL_BYTES_IN_KB } KB"
@@ -45,13 +123,38 @@ internal sealed class CodeScanSessionConfig {
45
123
}
46
124
}.toFile()
47
125
126
+ /* *
127
+ * Returns all the source files for a given payload type.
128
+ */
129
+ open fun getSourceFilesUnderProjectRoot (selectedFile : VirtualFile ): List <VirtualFile > {
130
+ // Include the current selected file
131
+ val files = mutableListOf (selectedFile)
132
+ // Include other files only if the current file is in the project.
133
+ if (selectedFile.path.startsWith(projectRoot.path)) {
134
+ files.addAll(
135
+ VfsUtil .collectChildrenRecursively(projectRoot).filter {
136
+ it.path.endsWith(sourceExt) && it != selectedFile
137
+ }
138
+ )
139
+ }
140
+ return files
141
+ }
142
+
143
+ protected fun getPath (root : String , relativePath : String = ""): Path ? = try {
144
+ Path .of(root, relativePath).normalize()
145
+ } catch (e: Exception ) {
146
+ LOG .debug { " Cannot find file at path $relativePath relative to the root $root " }
147
+ null
148
+ }
149
+
48
150
protected fun File.toVirtualFile () = LocalFileSystem .getInstance().findFileByIoFile(this )
49
151
50
152
companion object {
51
153
private val LOG = getLogger<CodeScanSessionConfig >()
52
- fun create (file : VirtualFile , project : Project ): CodeScanSessionConfig = when (file.extension) {
53
- " java" -> JavaCodeScanSessionConfig (file, project)
54
- " py" -> PythonCodeScanSessionConfig (file, project)
154
+ const val FILE_SEPARATOR = ' /'
155
+ fun create (file : VirtualFile , project : Project ): CodeScanSessionConfig = when (file.codeWhispererLanguage) {
156
+ CodewhispererLanguage .Java -> JavaCodeScanSessionConfig (file, project)
157
+ CodewhispererLanguage .Python -> PythonCodeScanSessionConfig (file, project)
55
158
else -> fileFormatNotSupported(file.extension ? : " " )
56
159
}
57
160
}
@@ -73,3 +176,10 @@ data class PayloadContext(
73
176
val buildPayloadSize : Long? = null ,
74
177
val buildZipFileSize : Long? = null
75
178
)
179
+
180
+ data class PayloadMetadata (
181
+ val sourceFiles : Set <String >,
182
+ val payloadSize : Long ,
183
+ val linesScanned : Long ,
184
+ val buildPaths : Set <String > = setOf()
185
+ )
0 commit comments