Skip to content

Commit 9d7191f

Browse files
authored
codewhisperer: golang support (#4037)
* feat(codewhisperer): security scans for golang * fix styling issue * add changelog * address comments * address comments
1 parent 6be09e0 commit 9d7191f

File tree

5 files changed

+321
-0
lines changed

5 files changed

+321
-0
lines changed
Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
{
2+
"type" : "feature",
3+
"description" : "CodeWhisperer security scans now support Go files."
4+
}

jetbrains-core/src/software/aws/toolkits/jetbrains/services/codewhisperer/codescan/sessionconfig/CodeScanSessionConfig.kt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -223,6 +223,7 @@ sealed class CodeScanSessionConfig(
223223
CodewhispererLanguage.Json -> CloudFormationJsonCodeScanSessionConfig(file, project)
224224
CodewhispererLanguage.Tf,
225225
CodewhispererLanguage.Hcl -> TerraformCodeScanSessionConfig(file, project)
226+
CodewhispererLanguage.Go -> GoCodeScanSessionConfig(file, project)
226227
CodewhispererLanguage.Ruby -> RubyCodeScanSessionConfig(file, project)
227228
else -> fileFormatNotSupported(file.extension ?: "")
228229
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,114 @@
1+
// Copyright 2023 Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
// SPDX-License-Identifier: Apache-2.0
3+
4+
package software.aws.toolkits.jetbrains.services.codewhisperer.codescan.sessionconfig
5+
6+
import com.intellij.openapi.project.Project
7+
import com.intellij.openapi.roots.ProjectRootManager
8+
import com.intellij.openapi.vfs.VirtualFile
9+
import com.intellij.util.containers.addIfNotNull
10+
import software.aws.toolkits.core.utils.exists
11+
import software.aws.toolkits.jetbrains.services.codewhisperer.util.CodeWhispererConstants
12+
import software.aws.toolkits.resources.message
13+
import java.io.IOException
14+
import java.nio.file.Path
15+
import java.util.stream.Collectors
16+
import kotlin.io.path.isDirectory
17+
import kotlin.io.path.listDirectoryEntries
18+
19+
internal class GoCodeScanSessionConfig(
20+
private val selectedFile: VirtualFile,
21+
private val project: Project
22+
) : CodeScanSessionConfig(selectedFile, project) {
23+
private val importRegex = Regex("^\\s*import\\s+([^(]+?\$|\\([^)]+\\))", RegexOption.MULTILINE)
24+
private val moduleRegex = Regex("\"[^\"\\r\\n]+\"", RegexOption.MULTILINE)
25+
26+
private val projectContentRoots = ProjectRootManager.getInstance(project).contentRoots
27+
28+
override val sourceExt: List<String> = listOf(".go")
29+
30+
override fun overallJobTimeoutInSeconds(): Long = CodeWhispererConstants.GO_CODE_SCAN_TIMEOUT_IN_SECONDS
31+
32+
override fun getPayloadLimitInBytes(): Int = CodeWhispererConstants.GO_PAYLOAD_LIMIT_IN_BYTES
33+
34+
private fun extractModulePaths(importGroup: String): Set<String> {
35+
val modulePaths = mutableSetOf<String>()
36+
val moduleMatcher = moduleRegex.toPattern().matcher(importGroup)
37+
while (moduleMatcher.find()) {
38+
val match = moduleMatcher.group()
39+
modulePaths.add(match.substring(1, match.length - 1))
40+
}
41+
return modulePaths.toSet()
42+
}
43+
44+
fun parseImports(file: VirtualFile): List<String> {
45+
val imports = mutableSetOf<String>()
46+
try {
47+
file.inputStream.use {
48+
val lines = it.bufferedReader().lines().collect(Collectors.joining("\n"))
49+
val importMatcher = importRegex.toPattern().matcher(lines)
50+
while (importMatcher.find()) {
51+
val goalImports = extractModulePaths(importMatcher.group())
52+
imports.addAll(goalImports)
53+
}
54+
}
55+
} catch (e: IOException) {
56+
error(message("codewhisperer.codescan.cannot_read_file", file.path))
57+
}
58+
return imports.toList()
59+
}
60+
61+
private fun generateSourceFilePath(modulePath: String, dirPath: String): Path? {
62+
if (modulePath.isEmpty()) {
63+
return null
64+
}
65+
val packageDir = getPath(dirPath, modulePath)
66+
val slashPos = modulePath.indexOf("/")
67+
val newModulePath = if (slashPos != -1) modulePath.substring(slashPos + 1) else ""
68+
return if (packageDir?.exists() == true) packageDir else generateSourceFilePath(newModulePath, dirPath)
69+
}
70+
71+
private fun getImportedPackages(file: VirtualFile): List<Path> {
72+
val importedPackages = mutableListOf<Path>()
73+
val imports = parseImports(file)
74+
projectContentRoots.forEach { root ->
75+
imports.forEach { importPath ->
76+
val importedFilePath = generateSourceFilePath(importPath, root.path)
77+
importedPackages.addIfNotNull(importedFilePath)
78+
}
79+
}
80+
return importedPackages
81+
}
82+
83+
private fun getSiblingFiles(file: VirtualFile): List<Path> = listGoFilesInDir(file.parent.toNioPath()).filter {
84+
it.fileName.toString() != file.name
85+
}
86+
87+
private fun listGoFilesInDir(path: Path): List<Path> = path.listDirectoryEntries().filter {
88+
!it.isDirectory() && it.fileName.toString().endsWith(sourceExt[0])
89+
}
90+
91+
override fun getImportedFiles(file: VirtualFile, includedSourceFiles: Set<String>): List<String> {
92+
val importedFiles = mutableListOf<String>()
93+
val importedFilePaths = mutableListOf<String>()
94+
95+
val siblingFiles = getSiblingFiles(file)
96+
siblingFiles.forEach { sibling ->
97+
importedFilePaths.addIfNotNull(sibling.toFile().toVirtualFile()?.path)
98+
}
99+
100+
val importedPackages = getImportedPackages(file)
101+
importedPackages.forEach { pkg ->
102+
val files = listGoFilesInDir(pkg)
103+
.mapNotNull { it.toFile().toVirtualFile()?.path }
104+
importedFilePaths.addAll(files)
105+
}
106+
107+
val validSourceFiles = importedFilePaths.filter { !includedSourceFiles.contains(it) }
108+
validSourceFiles.forEach { validFile ->
109+
importedFiles.add(validFile)
110+
}
111+
112+
return importedFiles
113+
}
114+
}

jetbrains-core/src/software/aws/toolkits/jetbrains/services/codewhisperer/util/CodeWhispererConstants.kt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,8 @@ object CodeWhispererConstants {
7979
const val PYTHON_PAYLOAD_LIMIT_IN_BYTES = 1024 * 200 // 200KB
8080
const val JS_CODE_SCAN_TIMEOUT_IN_SECONDS: Long = 60
8181
const val JS_PAYLOAD_LIMIT_IN_BYTES = 1024 * 200 // 200KB
82+
const val GO_CODE_SCAN_TIMEOUT_IN_SECONDS: Long = 60
83+
const val GO_PAYLOAD_LIMIT_IN_BYTES = 1024 * 1024 // 1MB
8284
const val CODE_SCAN_POLLING_INTERVAL_IN_SECONDS: Long = 5
8385
const val CODE_SCAN_CREATE_PAYLOAD_TIMEOUT_IN_SECONDS: Long = 10
8486
const val TOTAL_BYTES_IN_KB = 1024
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,200 @@
1+
// Copyright 2023 Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
// SPDX-License-Identifier: Apache-2.0
3+
4+
package software.aws.toolkits.jetbrains.services.codewhisperer.codescan
5+
6+
import com.intellij.openapi.vfs.VirtualFile
7+
import org.assertj.core.api.Assertions.assertThat
8+
import org.junit.Before
9+
import org.junit.Test
10+
import org.junit.jupiter.api.assertThrows
11+
import org.mockito.kotlin.any
12+
import org.mockito.kotlin.spy
13+
import org.mockito.kotlin.stub
14+
import software.aws.toolkits.jetbrains.services.codewhisperer.codescan.sessionconfig.CodeScanSessionConfig
15+
import software.aws.toolkits.jetbrains.services.codewhisperer.codescan.sessionconfig.GoCodeScanSessionConfig
16+
import software.aws.toolkits.jetbrains.utils.rules.PythonCodeInsightTestFixtureRule
17+
import software.aws.toolkits.telemetry.CodewhispererLanguage
18+
import java.io.BufferedInputStream
19+
import java.util.zip.ZipInputStream
20+
import kotlin.io.path.relativeTo
21+
import kotlin.test.assertNotNull
22+
23+
class CodeWhispererGoCodeScanTest : CodeWhispererCodeScanTestBase(PythonCodeInsightTestFixtureRule()) {
24+
internal lateinit var mainGo: VirtualFile
25+
internal lateinit var helpGo: VirtualFile
26+
internal lateinit var numberGo: VirtualFile
27+
internal lateinit var sessionConfigSpy: GoCodeScanSessionConfig
28+
29+
private var totalSize: Long = 0
30+
private var totalLines: Long = 0
31+
32+
@Before
33+
override fun setup() {
34+
super.setup()
35+
setupGoProject()
36+
sessionConfigSpy = spy(CodeScanSessionConfig.create(mainGo, project) as GoCodeScanSessionConfig)
37+
setupResponse(mainGo.toNioPath().relativeTo(sessionConfigSpy.projectRoot.toNioPath()))
38+
39+
mockClient.stub {
40+
onGeneric { createUploadUrl(any()) }.thenReturn(fakeCreateUploadUrlResponse)
41+
onGeneric { createCodeScan(any(), any()) }.thenReturn(fakeCreateCodeScanResponse)
42+
onGeneric { getCodeScan(any(), any()) }.thenReturn(fakeGetCodeScanResponse)
43+
onGeneric { listCodeScanFindings(any(), any()) }.thenReturn(fakeListCodeScanFindingsResponse)
44+
}
45+
}
46+
47+
@Test
48+
fun `test createPayload`() {
49+
val payload = sessionConfigSpy.createPayload()
50+
assertNotNull(payload)
51+
assertThat(payload.context.totalFiles).isEqualTo(3)
52+
53+
assertThat(payload.context.scannedFiles.size).isEqualTo(3)
54+
assertThat(payload.context.scannedFiles).containsExactly(mainGo, helpGo, numberGo)
55+
56+
assertThat(payload.context.srcPayloadSize).isEqualTo(totalSize)
57+
assertThat(payload.context.language).isEqualTo(CodewhispererLanguage.Go)
58+
assertThat(payload.context.totalLines).isEqualTo(totalLines)
59+
assertNotNull(payload.srcZip)
60+
61+
val bufferedInputStream = BufferedInputStream(payload.srcZip.inputStream())
62+
val zis = ZipInputStream(bufferedInputStream)
63+
var filesInZip = 0
64+
while (zis.nextEntry != null) {
65+
filesInZip += 1
66+
}
67+
68+
assertThat(filesInZip).isEqualTo(3)
69+
}
70+
71+
@Test
72+
fun `test getSourceFilesUnderProjectRoot`() {
73+
assertThat(sessionConfigSpy.getSourceFilesUnderProjectRoot(mainGo).size).isEqualTo(3)
74+
}
75+
76+
@Test
77+
fun `test parseImport()`() {
78+
val mainGoImports = sessionConfigSpy.parseImports(mainGo)
79+
assertThat(mainGoImports.size).isEqualTo(2)
80+
81+
val helpGoImports = sessionConfigSpy.parseImports(helpGo)
82+
assertThat(helpGoImports.size).isEqualTo(1)
83+
84+
val numberGoImports = sessionConfigSpy.parseImports(numberGo)
85+
assertThat(numberGoImports.size).isEqualTo(1)
86+
}
87+
88+
@Test
89+
fun `test getImportedFiles()`() {
90+
val files = sessionConfigSpy.getImportedFiles(mainGo, setOf())
91+
assertNotNull(files)
92+
assertThat(files).hasSize(2)
93+
assertThat(files).contains(helpGo.path)
94+
assertThat(files).contains(numberGo.path)
95+
}
96+
97+
@Test
98+
fun `test includeDependencies()`() {
99+
val payloadMetadata = sessionConfigSpy.includeDependencies()
100+
assertNotNull(payloadMetadata)
101+
assertThat(sessionConfigSpy.isProjectTruncated()).isFalse
102+
assertThat(payloadMetadata.sourceFiles.size).isEqualTo(3)
103+
assertThat(payloadMetadata.payloadSize).isEqualTo(totalSize)
104+
assertThat(payloadMetadata.linesScanned).isEqualTo(totalLines)
105+
assertThat(payloadMetadata.buildPaths).hasSize(0)
106+
}
107+
108+
@Test
109+
fun `selected file larger than payload limit throws exception`() {
110+
sessionConfigSpy.stub {
111+
onGeneric { getPayloadLimitInBytes() }.thenReturn(100)
112+
}
113+
assertThrows<CodeWhispererCodeScanException> {
114+
sessionConfigSpy.createPayload()
115+
}
116+
}
117+
118+
@Test
119+
fun `test createPayload with custom payload limit`() {
120+
sessionConfigSpy.stub {
121+
onGeneric { getPayloadLimitInBytes() }.thenReturn(300)
122+
}
123+
val payload = sessionConfigSpy.createPayload()
124+
assertNotNull(payload)
125+
assertThat(sessionConfigSpy.isProjectTruncated()).isTrue
126+
assertThat(payload.context.totalFiles).isEqualTo(2)
127+
128+
assertThat(payload.context.scannedFiles.size).isEqualTo(2)
129+
assertThat(payload.context.scannedFiles).containsExactly(mainGo, helpGo)
130+
131+
assertThat(payload.context.srcPayloadSize).isEqualTo(220)
132+
assertThat(payload.context.language).isEqualTo(CodewhispererLanguage.Go)
133+
assertThat(payload.context.totalLines).isEqualTo(17)
134+
assertNotNull(payload.srcZip)
135+
136+
val bufferedInputStream = BufferedInputStream(payload.srcZip.inputStream())
137+
val zis = ZipInputStream(bufferedInputStream)
138+
var filesInZip = 0
139+
while (zis.nextEntry != null) {
140+
filesInZip += 1
141+
}
142+
143+
assertThat(filesInZip).isEqualTo(2)
144+
}
145+
146+
@Test
147+
fun `test e2e with session run() function`() {
148+
assertE2ERunsSuccessfully(sessionConfigSpy, projectRule.project, totalLines, 3, totalSize, 2)
149+
}
150+
151+
private fun setupGoProject() {
152+
mainGo = projectRule.fixture.addFileToProject(
153+
"/main.go",
154+
"""
155+
package main
156+
157+
import (
158+
"example/random-number/util"
159+
"fmt"
160+
)
161+
162+
func main() {
163+
fmt.Printf("Number: %d\n", util.RandomNumber())
164+
}
165+
""".trimIndent()
166+
).virtualFile
167+
totalSize += mainGo.length
168+
totalLines += mainGo.toNioPath().toFile().readLines().size
169+
170+
helpGo = projectRule.fixture.addFileToProject(
171+
"/help.go",
172+
"""
173+
package main
174+
175+
import "fmt"
176+
177+
func Help() {
178+
fmt.Printf("./main")
179+
}
180+
""".trimIndent()
181+
).virtualFile
182+
totalSize += helpGo.length
183+
totalLines += helpGo.toNioPath().toFile().readLines().size
184+
185+
numberGo = projectRule.fixture.addFileToProject(
186+
"/util/number.go",
187+
"""
188+
package util
189+
190+
import "math/rand"
191+
192+
func RandomNumber() int {
193+
return rand.Intn(100)
194+
}
195+
""".trimIndent()
196+
).virtualFile
197+
totalSize += numberGo.length
198+
totalLines += numberGo.toNioPath().toFile().readLines().size
199+
}
200+
}

0 commit comments

Comments
 (0)