Skip to content

Commit e4c918b

Browse files
authored
CodeWhisperer CodeScan - Handle file outside project (#3779)
1 parent f12c7cd commit e4c918b

File tree

2 files changed

+59
-7
lines changed

2 files changed

+59
-7
lines changed

jetbrains-core/src/software/aws/toolkits/jetbrains/services/codewhisperer/codescan/sessionconfig/CodeScanSessionConfig.kt

Lines changed: 23 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,8 @@ sealed class CodeScanSessionConfig(
3333
private val selectedFile: VirtualFile,
3434
private val project: Project
3535
) {
36-
val projectRoot = project.guessProjectDir() ?: error("Cannot guess base directory for project ${project.name}")
36+
var projectRoot = project.guessProjectDir() ?: error("Cannot guess base directory for project ${project.name}")
37+
private set
3738

3839
abstract val sourceExt: String
3940

@@ -64,23 +65,38 @@ sealed class CodeScanSessionConfig(
6465

6566
LOG.debug { "Creating payload. File selected as root for the context truncation: ${selectedFile.path}" }
6667

67-
val (includedSourceFiles, payloadSize, totalLines, _) = includeDependencies()
68+
val payloadMetadata = when (selectedFile.path.startsWith(projectRoot.path)) {
69+
true -> includeDependencies()
70+
false -> {
71+
// Set project root as the parent of the selected file.
72+
projectRoot = selectedFile.parent
73+
includeFileOutsideProjectRoot()
74+
}
75+
}
6876

6977
// Copy all the included source files to the source zip
70-
val srcZip = zipFiles(includedSourceFiles.map { Path.of(it) })
78+
val srcZip = zipFiles(payloadMetadata.sourceFiles.map { Path.of(it) })
7179
val payloadContext = PayloadContext(
7280
selectedFile.programmingLanguage().toTelemetryType(),
73-
totalLines,
74-
includedSourceFiles.size,
81+
payloadMetadata.linesScanned,
82+
payloadMetadata.sourceFiles.size,
7583
Instant.now().toEpochMilli() - start,
76-
includedSourceFiles.mapNotNull { Path.of(it).toFile().toVirtualFile() },
77-
payloadSize,
84+
payloadMetadata.sourceFiles.mapNotNull { Path.of(it).toFile().toVirtualFile() },
85+
payloadMetadata.payloadSize,
7886
srcZip.length()
7987
)
8088

8189
return Payload(payloadContext, srcZip)
8290
}
8391

92+
open fun includeFileOutsideProjectRoot(): PayloadMetadata =
93+
// Handle the case where the selected file is outside the project root.
94+
PayloadMetadata(
95+
setOf(selectedFile.path),
96+
selectedFile.length,
97+
Files.lines(selectedFile.toNioPath()).count().toLong()
98+
)
99+
84100
open fun includeDependencies(): PayloadMetadata {
85101
val includedSourceFiles = mutableSetOf<String>()
86102
var currentTotalFileSize = 0L

jetbrains-core/tst/software/aws/toolkits/jetbrains/services/codewhisperer/codescan/CodeWhispererPythonCodeScanTest.kt

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -149,6 +149,42 @@ class CodeWhispererPythonCodeScanTest : CodeWhispererCodeScanTestBase(PythonCode
149149
assertThat(filesInZip).isEqualTo(2)
150150
}
151151

152+
@Test
153+
fun `test createPayload for file outside project`() {
154+
val fileOutsideProjectPy = projectRule.fixture.addFileToProject(
155+
"../fileOutsideProject.py",
156+
"""
157+
import numpy as np
158+
import util
159+
a = 1
160+
"""
161+
).virtualFile
162+
val totalSize = fileOutsideProjectPy.length
163+
val totalLines = fileOutsideProjectPy.toNioPath().toFile().readLines().size.toLong()
164+
sessionConfigSpy = spy(CodeScanSessionConfig.create(fileOutsideProjectPy, project) as PythonCodeScanSessionConfig)
165+
166+
val payload = sessionConfigSpy.createPayload()
167+
assertNotNull(payload)
168+
assertThat(payload.context.totalFiles).isEqualTo(1)
169+
170+
assertThat(payload.context.scannedFiles.size).isEqualTo(1)
171+
assertThat(payload.context.scannedFiles).containsExactly(fileOutsideProjectPy)
172+
173+
assertThat(payload.context.srcPayloadSize).isEqualTo(totalSize)
174+
assertThat(payload.context.language).isEqualTo(CodewhispererLanguage.Python)
175+
assertThat(payload.context.totalLines).isEqualTo(totalLines)
176+
assertNotNull(payload.srcZip)
177+
178+
val bufferedInputStream = BufferedInputStream(payload.srcZip.inputStream())
179+
val zis = ZipInputStream(bufferedInputStream)
180+
var filesInZip = 0
181+
while (zis.nextEntry != null) {
182+
filesInZip += 1
183+
}
184+
185+
assertThat(filesInZip).isEqualTo(1)
186+
}
187+
152188
@Test
153189
fun `e2e happy path integration test`() {
154190
assertE2ERunsSuccessfully(sessionConfigSpy, project, totalLines, 3, totalSize, 2)

0 commit comments

Comments
 (0)