Skip to content

Commit 7d92d3a

Browse files
authored
fix(amazonq): always send creds on lsp init if available (#5459)
Previously lsp server gets token by luck depending on timing of credential events. Instead, client should explicitly send token after server finishes init
1 parent 526eb4a commit 7d92d3a

File tree

3 files changed

+104
-43
lines changed

3 files changed

+104
-43
lines changed

plugins/amazonq/shared/jetbrains-community/src/software/aws/toolkits/jetbrains/services/amazonq/lsp/AmazonQLspService.kt

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ import kotlinx.coroutines.CoroutineScope
2424
import kotlinx.coroutines.Deferred
2525
import kotlinx.coroutines.TimeoutCancellationException
2626
import kotlinx.coroutines.async
27-
import kotlinx.coroutines.launch
27+
import kotlinx.coroutines.future.asCompletableFuture
2828
import kotlinx.coroutines.runBlocking
2929
import kotlinx.coroutines.sync.Mutex
3030
import kotlinx.coroutines.sync.withLock
@@ -316,11 +316,18 @@ private class AmazonQServerInstance(private val project: Project, private val cs
316316
initializeResult
317317
}
318318

319-
DefaultAuthCredentialsService(project, encryptionManager, this)
320-
TextDocumentServiceHandler(project, this)
321-
WorkspaceServiceHandler(project, this)
322-
cs.launch {
323-
DefaultModuleDependenciesService(project, this@AmazonQServerInstance)
319+
// invokeOnCompletion results in weird lock/timeout error
320+
initializeResult.asCompletableFuture().handleAsync { r, ex ->
321+
if (ex != null) {
322+
return@handleAsync
323+
}
324+
325+
this@AmazonQServerInstance.apply {
326+
DefaultAuthCredentialsService(project, encryptionManager, this)
327+
TextDocumentServiceHandler(project, this)
328+
WorkspaceServiceHandler(project, this)
329+
DefaultModuleDependenciesService(project, this)
330+
}
324331
}
325332
}
326333

plugins/amazonq/shared/jetbrains-community/src/software/aws/toolkits/jetbrains/services/amazonq/lsp/auth/DefaultAuthCredentialsService.kt

Lines changed: 21 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,8 @@ class DefaultAuthCredentialsService(
2626
BearerTokenProviderListener {
2727
init {
2828
project.messageBus.connect(serverInstance).subscribe(BearerTokenProviderListener.TOPIC, this)
29+
30+
onChange("init", null)
2931
}
3032

3133
override fun onChange(providerId: String, newScopes: List<String>?) {
@@ -39,7 +41,8 @@ class DefaultAuthCredentialsService(
3941
?: return
4042

4143
provider.currentToken()?.accessToken?.let { token ->
42-
updateTokenCredentials(token, false)
44+
// assume encryption is always on
45+
updateTokenCredentials(token, true)
4346
}
4447
}
4548

@@ -48,13 +51,7 @@ class DefaultAuthCredentialsService(
4851
}
4952

5053
override fun updateTokenCredentials(accessToken: String, encrypted: Boolean): CompletableFuture<ResponseMessage> {
51-
val token = if (encrypted) {
52-
encryptionManager.decrypt(accessToken)
53-
} else {
54-
accessToken
55-
}
56-
57-
val payload = createUpdateCredentialsPayload(token)
54+
val payload = createUpdateCredentialsPayload(accessToken, encrypted)
5855

5956
return AmazonQLspService.executeIfRunning(project) { server ->
6057
server.updateTokenCredentials(payload)
@@ -69,13 +66,20 @@ class DefaultAuthCredentialsService(
6966
} ?: completableFuture.completeExceptionally(IllegalStateException("LSP Server not running"))
7067
}
7168

72-
private fun createUpdateCredentialsPayload(token: String): UpdateCredentialsPayload =
73-
UpdateCredentialsPayload(
74-
data = encryptionManager.encrypt(
75-
UpdateCredentialsPayloadData(
76-
BearerCredentials(token)
77-
)
78-
),
79-
encrypted = true
80-
)
69+
private fun createUpdateCredentialsPayload(token: String, encrypted: Boolean): UpdateCredentialsPayload =
70+
if (encrypted) {
71+
UpdateCredentialsPayload(
72+
data = encryptionManager.encrypt(
73+
UpdateCredentialsPayloadData(
74+
BearerCredentials(token)
75+
)
76+
),
77+
encrypted = true
78+
)
79+
} else {
80+
UpdateCredentialsPayload(
81+
data = token,
82+
encrypted = false
83+
)
84+
}
8185
}

plugins/amazonq/shared/jetbrains-community/tst/software/aws/toolkits/jetbrains/services/amazonq/lsp/auth/DefaultAuthCredentialsServiceTest.kt

Lines changed: 70 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -4,32 +4,51 @@
44
package software.aws.toolkits.jetbrains.services.amazonq.lsp.auth
55

66
import com.intellij.openapi.Disposable
7+
import com.intellij.openapi.components.service
78
import com.intellij.openapi.components.serviceIfCreated
89
import com.intellij.openapi.project.Project
10+
import com.intellij.testFramework.ProjectExtension
911
import com.intellij.util.messages.MessageBus
1012
import com.intellij.util.messages.MessageBusConnection
1113
import io.mockk.every
1214
import io.mockk.just
1315
import io.mockk.mockk
1416
import io.mockk.runs
17+
import io.mockk.spyk
1518
import io.mockk.verify
1619
import org.eclipse.lsp4j.jsonrpc.messages.ResponseMessage
17-
import org.junit.Before
18-
import org.junit.Test
20+
import org.junit.jupiter.api.BeforeEach
21+
import org.junit.jupiter.api.Test
22+
import org.junit.jupiter.api.extension.RegisterExtension
23+
import software.aws.toolkits.core.TokenConnectionSettings
24+
import software.aws.toolkits.core.credentials.ToolkitBearerTokenProvider
25+
import software.aws.toolkits.jetbrains.core.credentials.AwsBearerTokenConnection
26+
import software.aws.toolkits.jetbrains.core.credentials.ToolkitConnectionManager
27+
import software.aws.toolkits.jetbrains.core.credentials.sso.PKCEAuthorizationGrantToken
28+
import software.aws.toolkits.jetbrains.core.credentials.sso.bearer.InteractiveBearerTokenProvider
1929
import software.aws.toolkits.jetbrains.services.amazonq.lsp.AmazonQLanguageServer
2030
import software.aws.toolkits.jetbrains.services.amazonq.lsp.AmazonQLspService
2131
import software.aws.toolkits.jetbrains.services.amazonq.lsp.encryption.JwtEncryptionManager
32+
import software.aws.toolkits.jetbrains.services.amazonq.lsp.model.aws.credentials.UpdateCredentialsPayload
33+
import java.time.Instant
2234
import java.util.concurrent.CompletableFuture
2335

2436
class DefaultAuthCredentialsServiceTest {
37+
companion object {
38+
@JvmField
39+
@RegisterExtension
40+
val projectExtension = ProjectExtension()
41+
}
42+
2543
private lateinit var project: Project
2644
private lateinit var mockLanguageServer: AmazonQLanguageServer
2745
private lateinit var mockEncryptionManager: JwtEncryptionManager
2846
private lateinit var sut: DefaultAuthCredentialsService
2947

30-
@Before
48+
// maybe better to use real project via junit extension
49+
@BeforeEach
3150
fun setUp() {
32-
project = mockk<Project>()
51+
project = spyk(projectExtension.project)
3352
mockLanguageServer = mockk<AmazonQLanguageServer>()
3453
mockEncryptionManager = mockk<JwtEncryptionManager>()
3554
every { mockEncryptionManager.encrypt(any()) } returns "mock-encrypted-data"
@@ -54,6 +73,32 @@ class DefaultAuthCredentialsServiceTest {
5473
every { messageBus.connect(any<Disposable>()) } returns mockConnection
5574
every { mockConnection.subscribe(any(), any()) } just runs
5675

76+
// Mock ToolkitConnectionManager
77+
val connectionManager = mockk<ToolkitConnectionManager>()
78+
val connection = mockk<AwsBearerTokenConnection>()
79+
val connectionSettings = mockk<TokenConnectionSettings>()
80+
val provider = mockk<ToolkitBearerTokenProvider>()
81+
val tokenDelegate = mockk<InteractiveBearerTokenProvider>()
82+
val token = PKCEAuthorizationGrantToken(
83+
issuerUrl = "https://example.com",
84+
refreshToken = "refreshToken",
85+
accessToken = "accessToken",
86+
expiresAt = Instant.MAX,
87+
createdAt = Instant.now(),
88+
region = "us-fake-1",
89+
)
90+
91+
every { project.service<ToolkitConnectionManager>() } returns connectionManager
92+
every { connectionManager.activeConnectionForFeature(any()) } returns connection
93+
every { connection.getConnectionSettings() } returns connectionSettings
94+
every { connectionSettings.tokenProvider } returns provider
95+
every { provider.delegate } returns tokenDelegate
96+
every { tokenDelegate.currentToken() } returns token
97+
98+
every {
99+
mockLanguageServer.updateTokenCredentials(any())
100+
} returns CompletableFuture.completedFuture(ResponseMessage())
101+
57102
sut = DefaultAuthCredentialsService(project, this.mockEncryptionManager, mockk())
58103
}
59104

@@ -62,17 +107,15 @@ class DefaultAuthCredentialsServiceTest {
62107
val token = "unencryptedToken"
63108
val isEncrypted = false
64109

65-
every {
66-
mockLanguageServer.updateTokenCredentials(any())
67-
} returns CompletableFuture.completedFuture(ResponseMessage())
68-
69110
sut.updateTokenCredentials(token, isEncrypted)
70111

71-
verify(exactly = 0) {
72-
mockEncryptionManager.decrypt(any())
73-
}
74112
verify(exactly = 1) {
75-
mockLanguageServer.updateTokenCredentials(any())
113+
mockLanguageServer.updateTokenCredentials(
114+
UpdateCredentialsPayload(
115+
token,
116+
isEncrypted
117+
)
118+
)
76119
}
77120
}
78121

@@ -82,16 +125,18 @@ class DefaultAuthCredentialsServiceTest {
82125
val decryptedToken = "decryptedToken"
83126
val isEncrypted = true
84127

85-
every { mockEncryptionManager.decrypt(encryptedToken) } returns decryptedToken
86-
every { mockEncryptionManager.encrypt(any()) } returns "mock-encrypted-data"
87-
every {
88-
mockLanguageServer.updateTokenCredentials(any())
89-
} returns CompletableFuture.completedFuture(ResponseMessage())
128+
every { mockEncryptionManager.encrypt(any()) } returns encryptedToken
90129

91-
sut.updateTokenCredentials(encryptedToken, isEncrypted)
130+
sut.updateTokenCredentials(decryptedToken, isEncrypted)
92131

93-
verify(exactly = 1) { mockEncryptionManager.decrypt(encryptedToken) }
94-
verify(exactly = 1) { mockLanguageServer.updateTokenCredentials(any()) }
132+
verify(atLeast = 1) {
133+
mockLanguageServer.updateTokenCredentials(
134+
UpdateCredentialsPayload(
135+
encryptedToken,
136+
isEncrypted
137+
)
138+
)
139+
}
95140
}
96141

97142
@Test
@@ -102,4 +147,9 @@ class DefaultAuthCredentialsServiceTest {
102147

103148
verify(exactly = 1) { mockLanguageServer.deleteTokenCredentials() }
104149
}
150+
151+
@Test
152+
fun `init results in token update`() {
153+
verify(exactly = 1) { mockLanguageServer.updateTokenCredentials(any()) }
154+
}
105155
}

0 commit comments

Comments
 (0)