Skip to content

Commit b58d156

Browse files
rlileigaol
authored andcommitted
feat(amazonq): hook up lsp payload encryption (aws#5370)
1 parent d51ae1a commit b58d156

File tree

5 files changed

+174
-2
lines changed

5 files changed

+174
-2
lines changed

plugins/amazonq/shared/jetbrains-community/src/software/aws/toolkits/jetbrains/services/amazonq/lsp/AmazonQLspService.kt

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ import org.slf4j.event.Level
3737
import software.aws.toolkits.core.utils.getLogger
3838
import software.aws.toolkits.core.utils.warn
3939
import software.aws.toolkits.jetbrains.isDeveloperMode
40+
import software.aws.toolkits.jetbrains.services.amazonq.lsp.encryption.JwtEncryptionManager
4041
import software.aws.toolkits.jetbrains.services.amazonq.lsp.model.createExtendedClientMetadata
4142
import software.aws.toolkits.jetbrains.services.telemetry.ClientMetadata
4243
import java.io.IOException
@@ -111,6 +112,8 @@ class AmazonQLspService(private val project: Project, private val cs: CoroutineS
111112
}
112113

113114
private class AmazonQServerInstance(private val project: Project, private val cs: CoroutineScope) : Disposable {
115+
private val encryptionManager = JwtEncryptionManager()
116+
114117
private val launcher: Launcher<AmazonQLanguageServer>
115118

116119
private val languageServer: AmazonQLanguageServer
@@ -172,7 +175,11 @@ private class AmazonQServerInstance(private val project: Project, private val cs
172175
}
173176

174177
init {
175-
val cmd = GeneralCommandLine("amazon-q-lsp")
178+
val cmd = GeneralCommandLine(
179+
"amazon-q-lsp",
180+
"--stdio",
181+
"--set-credentials-encryption-key",
182+
)
176183

177184
launcherHandler = KillableColoredProcessHandler.Silent(cmd)
178185
val inputWrapper = LSPProcessListener()
@@ -207,6 +214,9 @@ private class AmazonQServerInstance(private val project: Project, private val cs
207214
launcherFuture = launcher.startListening()
208215

209216
cs.launch {
217+
// encryption info must be sent within 5s or Flare process will exit
218+
encryptionManager.writeInitializationPayload(launcherHandler.process.outputStream)
219+
210220
val initializeResult = try {
211221
withTimeout(Duration.ofSeconds(10)) {
212222
languageServer.initialize(createInitializeParams()).await()
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
// Copyright 2025 Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
// SPDX-License-Identifier: Apache-2.0
3+
4+
package software.aws.toolkits.jetbrains.services.amazonq.lsp.encryption
5+
6+
import com.fasterxml.jackson.module.kotlin.jacksonObjectMapper
7+
import com.nimbusds.jose.EncryptionMethod
8+
import com.nimbusds.jose.JWEAlgorithm
9+
import com.nimbusds.jose.JWEHeader
10+
import com.nimbusds.jose.JWEObject
11+
import com.nimbusds.jose.Payload
12+
import com.nimbusds.jose.crypto.DirectDecrypter
13+
import com.nimbusds.jose.crypto.DirectEncrypter
14+
import software.aws.toolkits.jetbrains.services.amazonq.lsp.model.EncryptionInitializationRequest
15+
import java.io.OutputStream
16+
import java.security.SecureRandom
17+
import java.util.Base64
18+
import javax.crypto.SecretKey
19+
import javax.crypto.spec.SecretKeySpec
20+
21+
class JwtEncryptionManager(private val key: SecretKey) {
22+
constructor() : this(generateHmacKey())
23+
24+
private val mapper = jacksonObjectMapper()
25+
26+
fun writeInitializationPayload(os: OutputStream) {
27+
val payload = EncryptionInitializationRequest(
28+
EncryptionInitializationRequest.Version.V1_0,
29+
EncryptionInitializationRequest.Mode.JWT,
30+
Base64.getUrlEncoder().withoutPadding().encodeToString(key.encoded)
31+
)
32+
33+
// write directly to stream because utils are closing the underlying stream
34+
os.write("${mapper.writeValueAsString(payload)}\n".toByteArray())
35+
}
36+
37+
fun encrypt(data: Any): String {
38+
val header = JWEHeader(JWEAlgorithm.DIR, EncryptionMethod.A256GCM)
39+
val payload = if (data is String) {
40+
Payload(data)
41+
} else {
42+
Payload(mapper.writeValueAsBytes(data))
43+
}
44+
45+
val jweObject = JWEObject(header, payload)
46+
jweObject.encrypt(DirectEncrypter(key))
47+
48+
return jweObject.serialize()
49+
}
50+
51+
fun decrypt(jwt: String): String {
52+
val jweObject = JWEObject.parse(jwt)
53+
jweObject.decrypt(DirectDecrypter(key))
54+
55+
return jweObject.payload.toString()
56+
}
57+
58+
private companion object {
59+
private fun generateHmacKey(): SecretKey {
60+
val keyBytes = ByteArray(32)
61+
SecureRandom().nextBytes(keyBytes)
62+
return SecretKeySpec(keyBytes, "HmacSHA256")
63+
}
64+
}
65+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
// Copyright 2025 Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
// SPDX-License-Identifier: Apache-2.0
3+
4+
package software.aws.toolkits.jetbrains.services.amazonq.lsp.model
5+
6+
import com.fasterxml.jackson.annotation.JsonValue
7+
8+
data class EncryptionInitializationRequest(
9+
val version: Version,
10+
val mode: Mode,
11+
val key: String,
12+
) {
13+
enum class Version(@JsonValue val value: String) {
14+
V1_0("1.0"),
15+
}
16+
17+
enum class Mode(@JsonValue val value: String) {
18+
JWT("JWT"),
19+
}
20+
}

plugins/amazonq/shared/jetbrains-community/src/software/aws/toolkits/jetbrains/services/amazonq/lsp/model/aws/credentials/UpdateCredentialsPayload.kt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ package software.aws.toolkits.jetbrains.services.amazonq.lsp.model.aws.credentia
55

66
data class UpdateCredentialsPayload(
77
val data: String,
8-
val encrypted: String,
8+
val encrypted: Boolean,
99
)
1010

1111
data class UpdateCredentialsPayloadData(
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
// Copyright 2025 Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
// SPDX-License-Identifier: Apache-2.0
3+
4+
package software.aws.toolkits.jetbrains.services.amazonq.lsp.encryption
5+
6+
import com.nimbusds.jose.JOSEException
7+
import org.assertj.core.api.Assertions.assertThat
8+
import org.junit.jupiter.api.Test
9+
import org.junit.jupiter.api.assertThrows
10+
import java.io.ByteArrayOutputStream
11+
import java.util.concurrent.atomic.AtomicBoolean
12+
import javax.crypto.spec.SecretKeySpec
13+
import kotlin.random.Random
14+
15+
class JwtEncryptionManagerTest {
16+
@Test
17+
fun `uses a different encryption key for each instance`() {
18+
val blob = Random.Default.nextBytes(256)
19+
val sut1 = JwtEncryptionManager()
20+
val encrypted = sut1.encrypt(blob)
21+
22+
assertThrows<JOSEException> {
23+
assertThat(sut1.decrypt(encrypted))
24+
.isNotEqualTo(JwtEncryptionManager().decrypt(encrypted))
25+
}
26+
}
27+
28+
@Test
29+
@OptIn(ExperimentalStdlibApi::class)
30+
fun `encryption is stable with static key`() {
31+
val blob = Random.Default.nextBytes(256)
32+
val bytes = "DEADBEEF".repeat(8).hexToByteArray() // 32 bytes
33+
val key = SecretKeySpec(bytes, "HmacSHA256")
34+
val sut1 = JwtEncryptionManager(key)
35+
val encrypted = sut1.encrypt(blob)
36+
37+
// each encrypt() call will use a different IV so we can't just directly compare
38+
assertThat(sut1.decrypt(encrypted))
39+
.isEqualTo(JwtEncryptionManager(key).decrypt(encrypted))
40+
}
41+
42+
@Test
43+
fun `encryption can be round-tripped`() {
44+
val sut = JwtEncryptionManager()
45+
val blob = "DEADBEEF".repeat(8)
46+
assertThat(sut.decrypt(sut.encrypt(blob))).isEqualTo(blob)
47+
}
48+
49+
@Test
50+
@OptIn(ExperimentalStdlibApi::class)
51+
fun writeInitializationPayload() {
52+
val bytes = "DEADBEEF".repeat(8).hexToByteArray() // 32 bytes
53+
val key = SecretKeySpec(bytes, "HmacSHA256")
54+
55+
val closed = AtomicBoolean(false)
56+
val os = object : ByteArrayOutputStream() {
57+
override fun close() {
58+
closed.set(true)
59+
}
60+
}
61+
JwtEncryptionManager(key).writeInitializationPayload(os)
62+
assertThat(os.toString())
63+
// Flare requires encryption ends with new line
64+
// https://github.com/aws/language-server-runtimes/blob/4d7f81295dc12b59ed2e1c0ebaedb85ccb86cf76/runtimes/README.md#encryption
65+
.endsWith("\n")
66+
.isEqualTo(
67+
// language=JSON
68+
"""
69+
|{"version":"1.0","mode":"JWT","key":"3q2-796tvu_erb7v3q2-796tvu_erb7v3q2-796tvu8"}
70+
|
71+
""".trimMargin()
72+
)
73+
74+
// writeInitializationPayload should not close the stream
75+
assertThat(closed.get()).isFalse
76+
}
77+
}

0 commit comments

Comments
 (0)