Skip to content

Commit f6d894d

Browse files
authored
[CodeWhisperer] Add code scan unit and integ tests (#3246)
* Add code scan unit and integ tests * Remove duplicated code.
1 parent 7989107 commit f6d894d

File tree

13 files changed

+1200
-182
lines changed

13 files changed

+1200
-182
lines changed

jetbrains-core/src/software/aws/toolkits/jetbrains/services/codewhisperer/codescan/CodeWhispererCodeScanManager.kt

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ import kotlinx.coroutines.TimeoutCancellationException
3232
import kotlinx.coroutines.launch
3333
import kotlinx.coroutines.time.withTimeout
3434
import kotlinx.coroutines.withContext
35+
import org.jetbrains.annotations.TestOnly
3536
import software.amazon.awssdk.services.codewhisperer.model.CodeWhispererException
3637
import software.aws.toolkits.core.utils.WaiterTimeoutException
3738
import software.aws.toolkits.core.utils.debug
@@ -347,6 +348,12 @@ internal class CodeWhispererCodeScanManager(val project: Project) {
347348
}
348349
}
349350

351+
@TestOnly
352+
suspend fun testRenderResponseOnUIThread(issues: List<CodeWhispererCodeScanIssue>) {
353+
assert(ApplicationManager.getApplication().isUnitTestMode)
354+
renderResponseOnUIThread(issues)
355+
}
356+
350357
companion object {
351358
private val LOG = getLogger<CodeWhispererCodeScanManager>()
352359
fun getInstance(project: Project): CodeWhispererCodeScanManager = project.service()

jetbrains-core/src/software/aws/toolkits/jetbrains/services/codewhisperer/codescan/CodeWhispererCodeScanSession.kt

Lines changed: 27 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ import java.time.Instant
4848
import java.util.Base64
4949
import java.util.UUID
5050

51-
internal class CodeWhispererCodeScanSession(private val sessionContext: CodeScanSessionContext) {
51+
internal class CodeWhispererCodeScanSession(val sessionContext: CodeScanSessionContext) {
5252
private val clientToken: UUID = UUID.randomUUID()
5353
private val urlResponse = mutableMapOf<ArtifactType, CreateUploadUrlResponse>()
5454
private val codewhispererClient: CodeWhispererClient = CodeWhispererClientManager.getInstance().getClient()
@@ -147,7 +147,7 @@ internal class CodeWhispererCodeScanSession(private val sessionContext: CodeScan
147147
"Get security scan status: ${getCodeScanResponse.status()}, " +
148148
"request id: ${getCodeScanResponse.responseMetadata().requestId()}"
149149
}
150-
sleep(CODE_SCAN_POLLING_INTERVAL_IN_SECONDS * TOTAL_MILLIS_IN_SECOND)
150+
sleepThread()
151151
if (codeScanStatus == CodeScanStatus.FAILED) {
152152
LOG.debug {
153153
"CodeWhisperer service error occurred. Something went wrong fetching results for security scan: $getCodeScanResponse " +
@@ -178,6 +178,7 @@ internal class CodeWhispererCodeScanSession(private val sessionContext: CodeScan
178178
listCodeScanFindingsResponse = listCodeScanFindings(jobId)
179179
}
180180
LOG.debug { "Successfully fetched results for the security scan." }
181+
LOG.debug { "Code scan findings: ${listCodeScanFindingsResponse.codeScanFindings()}" }
181182
LOG.debug { "Rendering response to display security scan results." }
182183
issues = mapToCodeScanIssues(documents)
183184
codeScanResponseContext = codeScanResponseContext.copy(codeScanTotalIssues = issues.count())
@@ -191,7 +192,7 @@ internal class CodeWhispererCodeScanSession(private val sessionContext: CodeScan
191192
/**
192193
* Creates an upload URL and uplaods the zip file to the presigned URL
193194
*/
194-
private fun createUploadUrlAndUpload(zipFile: File, artifactType: String): CreateUploadUrlResponse = try {
195+
fun createUploadUrlAndUpload(zipFile: File, artifactType: String): CreateUploadUrlResponse = try {
195196
val fileMd5: String = Base64.getEncoder().encodeToString(DigestUtils.md5(FileInputStream(zipFile)))
196197
LOG.debug { "Fetching presigned URL for uploading $artifactType." }
197198
val createUploadUrlResponse = createUploadUrl(fileMd5, artifactType)
@@ -205,13 +206,13 @@ internal class CodeWhispererCodeScanSession(private val sessionContext: CodeScan
205206
throw e
206207
}
207208

208-
private fun createUploadUrl(md5Content: String, artifactType: String): CreateUploadUrlResponse = codewhispererClient.createUploadUrl {
209+
fun createUploadUrl(md5Content: String, artifactType: String): CreateUploadUrlResponse = codewhispererClient.createUploadUrl {
209210
it.contentMd5(md5Content)
210211
it.artifactType(artifactType)
211212
}
212213

213214
@Throws(IOException::class)
214-
private fun uploadArtifactTOS3(url: String, fileToUpload: File, md5: String) {
215+
fun uploadArtifactTOS3(url: String, fileToUpload: File, md5: String) {
215216
HttpRequests.put(url, "application/zip").userAgent(AwsClientManager.userAgent).tuner {
216217
it.setRequestProperty(CONTENT_MD5, md5)
217218
it.setRequestProperty(SERVER_SIDE_ENCRYPTION, AES256)
@@ -222,7 +223,7 @@ internal class CodeWhispererCodeScanSession(private val sessionContext: CodeScan
222223
}
223224
}
224225

225-
private fun createCodeScan(language: String): CreateCodeScanResponse {
226+
fun createCodeScan(language: String): CreateCodeScanResponse {
226227
val artifactsMap = mapOf(
227228
ArtifactType.SOURCE_CODE to urlResponse[ArtifactType.SOURCE_CODE]?.uploadId(),
228229
ArtifactType.BUILT_JARS to urlResponse[ArtifactType.BUILT_JARS]?.uploadId()
@@ -240,14 +241,14 @@ internal class CodeWhispererCodeScanSession(private val sessionContext: CodeScan
240241
}
241242
}
242243

243-
private fun getCodeScan(jobId: String): GetCodeScanResponse = try {
244+
fun getCodeScan(jobId: String): GetCodeScanResponse = try {
244245
codewhispererClient.getCodeScan { it.jobId(jobId) }
245246
} catch (e: Exception) {
246247
LOG.debug { "Getting security scan failed: ${e.message}" }
247248
throw e
248249
}
249250

250-
private fun listCodeScanFindings(jobId: String): ListCodeScanFindingsResponse = try {
251+
fun listCodeScanFindings(jobId: String): ListCodeScanFindingsResponse = try {
251252
codewhispererClient.listCodeScanFindings {
252253
it.jobId(jobId)
253254
it.codeScanFindingsSchema(CodeScanFindingsSchema.CODESCAN_FINDINGS_1_0)
@@ -257,15 +258,21 @@ internal class CodeWhispererCodeScanSession(private val sessionContext: CodeScan
257258
throw e
258259
}
259260

260-
private fun mapToCodeScanIssues(recommendations: List<String>): List<CodeWhispererCodeScanIssue> {
261+
fun mapToCodeScanIssues(recommendations: List<String>): List<CodeWhispererCodeScanIssue> {
261262
val scanRecommendations: List<CodeScanRecommendation> = recommendations.map {
262263
val value: List<CodeScanRecommendation> = MAPPER.readValue(it)
263264
value
264265
}.flatten()
266+
LOG.debug { "Total code scan issues returned from service: ${scanRecommendations.size}" }
265267
return scanRecommendations.mapNotNull {
266-
val file = LocalFileSystem.getInstance().findFileByIoFile(
267-
Path.of(File.separator, it.filePath).toFile()
268-
)
268+
val file = try {
269+
LocalFileSystem.getInstance().findFileByIoFile(
270+
Path.of(it.filePath).toFile()
271+
)
272+
} catch (e: Exception) {
273+
LOG.debug { "Cannot find file at location ${it.filePath}" }
274+
null
275+
}
269276
when (file?.isDirectory) {
270277
false -> {
271278
runReadAction {
@@ -293,13 +300,18 @@ internal class CodeWhispererCodeScanSession(private val sessionContext: CodeScan
293300
}
294301
}
295302
}
303+
304+
fun sleepThread() {
305+
sleep(CODE_SCAN_POLLING_INTERVAL_IN_SECONDS * TOTAL_MILLIS_IN_SECOND)
306+
}
307+
296308
companion object {
297309
private val LOG = getLogger<CodeWhispererCodeScanSession>()
298310
private val MAPPER = jacksonObjectMapper()
299311
.configure(DeserializationFeature.FAIL_ON_UNKNOWN_PROPERTIES, false)
300-
private const val AES256 = "AES256"
301-
private const val CONTENT_MD5 = "Content-MD5"
302-
private const val SERVER_SIDE_ENCRYPTION = "x-amz-server-side-encryption"
312+
const val AES256 = "AES256"
313+
const val CONTENT_MD5 = "Content-MD5"
314+
const val SERVER_SIDE_ENCRYPTION = "x-amz-server-side-encryption"
303315
}
304316
}
305317

jetbrains-core/src/software/aws/toolkits/jetbrains/services/codewhisperer/codescan/sessionconfig/CodeScanSessionConfig.kt

Lines changed: 117 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -4,35 +4,113 @@
44
package software.aws.toolkits.jetbrains.services.codewhisperer.codescan.sessionconfig
55

66
import com.intellij.openapi.project.Project
7+
import com.intellij.openapi.project.guessProjectDir
78
import com.intellij.openapi.vfs.LocalFileSystem
9+
import com.intellij.openapi.vfs.VfsUtil
810
import com.intellij.openapi.vfs.VirtualFile
911
import software.aws.toolkits.core.utils.createTemporaryZipFile
1012
import software.aws.toolkits.core.utils.debug
1113
import software.aws.toolkits.core.utils.getLogger
1214
import software.aws.toolkits.core.utils.putNextEntry
1315
import software.aws.toolkits.jetbrains.services.codewhisperer.codescan.fileFormatNotSupported
16+
import software.aws.toolkits.jetbrains.services.codewhisperer.codescan.fileTooLarge
17+
import software.aws.toolkits.jetbrains.services.codewhisperer.editor.CodeWhispererEditorUtil.codeWhispererLanguage
1418
import software.aws.toolkits.jetbrains.services.codewhisperer.util.CodeWhispererConstants.CODE_SCAN_CREATE_PAYLOAD_TIMEOUT_IN_SECONDS
1519
import software.aws.toolkits.jetbrains.services.codewhisperer.util.CodeWhispererConstants.TOTAL_BYTES_IN_KB
1620
import software.aws.toolkits.jetbrains.services.codewhisperer.util.CodeWhispererConstants.TOTAL_BYTES_IN_MB
1721
import software.aws.toolkits.telemetry.CodewhispererLanguage
1822
import java.io.File
23+
import java.nio.file.Files
1924
import java.nio.file.Path
25+
import java.time.Instant
2026

21-
internal sealed class CodeScanSessionConfig {
22-
abstract fun createPayload(): Payload
27+
internal sealed class CodeScanSessionConfig(
28+
private val selectedFile: VirtualFile,
29+
private val project: Project
30+
) {
31+
protected val projectRoot = project.guessProjectDir() ?: error("Cannot guess base directory for project ${project.name}")
32+
33+
abstract val sourceExt: String
2334

2435
/**
2536
* Timeout for the overall job - "Run Security Scan".
2637
*/
2738
abstract fun overallJobTimeoutInSeconds(): Long
2839

40+
abstract fun getPayloadLimitInBytes(): Int
41+
42+
open fun getImportedFiles(file: VirtualFile, includedSourceFiles: Set<String>): List<String> = listOf()
43+
44+
open fun createPayload(): Payload {
45+
// Fail fast if the selected file size is greater than the payload limit.
46+
if (selectedFile.length > getPayloadLimitInBytes()) {
47+
fileTooLarge(getPresentablePayloadLimit())
48+
}
49+
50+
val start = Instant.now().toEpochMilli()
51+
52+
LOG.debug { "Creating payload. File selected as root for the context truncation: ${selectedFile.path}" }
53+
54+
val (includedSourceFiles, payloadSize, totalLines, _) = includeDependencies()
55+
56+
// Copy all the included source files to the source zip
57+
val srcZip = zipFiles(includedSourceFiles.map { Path.of(it) })
58+
val payloadContext = PayloadContext(
59+
selectedFile.codeWhispererLanguage,
60+
totalLines,
61+
includedSourceFiles.size,
62+
Instant.now().toEpochMilli() - start,
63+
payloadSize,
64+
srcZip.length()
65+
)
66+
67+
return Payload(payloadContext, srcZip)
68+
}
69+
70+
open fun includeDependencies(): PayloadMetadata {
71+
val includedSourceFiles = mutableSetOf<String>()
72+
var currentTotalFileSize = 0L
73+
var currentTotalLines = 0L
74+
val files = getSourceFilesUnderProjectRoot(selectedFile)
75+
val queue = ArrayDeque<String>()
76+
77+
files.forEach { pivotFile ->
78+
val filePath = pivotFile.path
79+
queue.addLast(filePath)
80+
81+
// BFS
82+
while (queue.isNotEmpty()) {
83+
if (currentTotalFileSize.equals(getPayloadLimitInBytes())) {
84+
return PayloadMetadata(includedSourceFiles, currentTotalFileSize, currentTotalLines)
85+
}
86+
87+
val currentFilePath = queue.removeFirst()
88+
val currentFile = File(currentFilePath).toVirtualFile()
89+
if (includedSourceFiles.contains(currentFilePath) || currentFile == null) continue
90+
91+
val currentFileSize = currentFile.length
92+
93+
// Ignore file if including it exceeds the payload limit.
94+
if (currentTotalFileSize > getPayloadLimitInBytes() - currentFileSize) continue
95+
96+
currentTotalFileSize += currentFileSize
97+
currentTotalLines += Files.lines(currentFile.toNioPath()).count()
98+
includedSourceFiles.add(currentFilePath)
99+
100+
getImportedFiles(currentFile, includedSourceFiles).forEach {
101+
if (!includedSourceFiles.contains(it)) queue.addLast(it)
102+
}
103+
}
104+
}
105+
106+
return PayloadMetadata(includedSourceFiles, currentTotalFileSize, currentTotalLines)
107+
}
108+
29109
/**
30110
* Timeout for creating the payload [createPayload]
31111
*/
32112
open fun createPayloadTimeoutInSeconds(): Long = CODE_SCAN_CREATE_PAYLOAD_TIMEOUT_IN_SECONDS
33113

34-
abstract fun getPayloadLimitInBytes(): Int
35-
36114
open fun getPresentablePayloadLimit(): String = when (getPayloadLimitInBytes() >= TOTAL_BYTES_IN_MB) {
37115
true -> "${getPayloadLimitInBytes() / TOTAL_BYTES_IN_MB}MB"
38116
false -> "${getPayloadLimitInBytes() / TOTAL_BYTES_IN_KB}KB"
@@ -45,13 +123,38 @@ internal sealed class CodeScanSessionConfig {
45123
}
46124
}.toFile()
47125

126+
/**
127+
* Returns all the source files for a given payload type.
128+
*/
129+
open fun getSourceFilesUnderProjectRoot(selectedFile: VirtualFile): List<VirtualFile> {
130+
// Include the current selected file
131+
val files = mutableListOf(selectedFile)
132+
// Include other files only if the current file is in the project.
133+
if (selectedFile.path.startsWith(projectRoot.path)) {
134+
files.addAll(
135+
VfsUtil.collectChildrenRecursively(projectRoot).filter {
136+
it.path.endsWith(sourceExt) && it != selectedFile
137+
}
138+
)
139+
}
140+
return files
141+
}
142+
143+
protected fun getPath(root: String, relativePath: String = ""): Path? = try {
144+
Path.of(root, relativePath).normalize()
145+
} catch (e: Exception) {
146+
LOG.debug { "Cannot find file at path $relativePath relative to the root $root" }
147+
null
148+
}
149+
48150
protected fun File.toVirtualFile() = LocalFileSystem.getInstance().findFileByIoFile(this)
49151

50152
companion object {
51153
private val LOG = getLogger<CodeScanSessionConfig>()
52-
fun create(file: VirtualFile, project: Project): CodeScanSessionConfig = when (file.extension) {
53-
"java" -> JavaCodeScanSessionConfig(file, project)
54-
"py" -> PythonCodeScanSessionConfig(file, project)
154+
const val FILE_SEPARATOR = '/'
155+
fun create(file: VirtualFile, project: Project): CodeScanSessionConfig = when (file.codeWhispererLanguage) {
156+
CodewhispererLanguage.Java -> JavaCodeScanSessionConfig(file, project)
157+
CodewhispererLanguage.Python -> PythonCodeScanSessionConfig(file, project)
55158
else -> fileFormatNotSupported(file.extension ?: "")
56159
}
57160
}
@@ -73,3 +176,10 @@ data class PayloadContext(
73176
val buildPayloadSize: Long? = null,
74177
val buildZipFileSize: Long? = null
75178
)
179+
180+
data class PayloadMetadata(
181+
val sourceFiles: Set<String>,
182+
val payloadSize: Long,
183+
val linesScanned: Long,
184+
val buildPaths: Set<String> = setOf()
185+
)

0 commit comments

Comments
 (0)