Skip to content

Commit e6d8159

Browse files
fix(amazonq): lsp token/update emits for init, logins, and token refresh (#5477)
* emit on login * reorder * tests
1 parent 38fe91e commit e6d8159

File tree

2 files changed

+175
-51
lines changed

2 files changed

+175
-51
lines changed

plugins/amazonq/shared/jetbrains-community/src/software/aws/toolkits/jetbrains/services/amazonq/lsp/auth/DefaultAuthCredentialsService.kt

Lines changed: 48 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,9 @@ import com.intellij.openapi.Disposable
77
import com.intellij.openapi.project.Project
88
import org.eclipse.lsp4j.jsonrpc.messages.ResponseMessage
99
import software.aws.toolkits.core.TokenConnectionSettings
10+
import software.aws.toolkits.jetbrains.core.credentials.ToolkitConnection
1011
import software.aws.toolkits.jetbrains.core.credentials.ToolkitConnectionManager
12+
import software.aws.toolkits.jetbrains.core.credentials.ToolkitConnectionManagerListener
1113
import software.aws.toolkits.jetbrains.core.credentials.pinning.QConnection
1214
import software.aws.toolkits.jetbrains.core.credentials.sso.bearer.BearerTokenProvider
1315
import software.aws.toolkits.jetbrains.core.credentials.sso.bearer.BearerTokenProviderListener
@@ -16,38 +18,27 @@ import software.aws.toolkits.jetbrains.services.amazonq.lsp.encryption.JwtEncryp
1618
import software.aws.toolkits.jetbrains.services.amazonq.lsp.model.aws.credentials.BearerCredentials
1719
import software.aws.toolkits.jetbrains.services.amazonq.lsp.model.aws.credentials.UpdateCredentialsPayload
1820
import software.aws.toolkits.jetbrains.services.amazonq.lsp.model.aws.credentials.UpdateCredentialsPayloadData
21+
import software.aws.toolkits.jetbrains.utils.isQConnected
22+
import software.aws.toolkits.jetbrains.utils.isQExpired
1923
import java.util.concurrent.CompletableFuture
2024

2125
class DefaultAuthCredentialsService(
2226
private val project: Project,
2327
private val encryptionManager: JwtEncryptionManager,
2428
serverInstance: Disposable,
2529
) : AuthCredentialsService,
26-
BearerTokenProviderListener {
27-
init {
28-
project.messageBus.connect(serverInstance).subscribe(BearerTokenProviderListener.TOPIC, this)
29-
30-
onChange("init", null)
31-
}
32-
33-
override fun onChange(providerId: String, newScopes: List<String>?) {
34-
val connection = ToolkitConnectionManager.getInstance(project)
35-
.activeConnectionForFeature(QConnection.getInstance())
36-
?: return
30+
BearerTokenProviderListener,
31+
ToolkitConnectionManagerListener {
3732

38-
val provider = (connection.getConnectionSettings() as? TokenConnectionSettings)
39-
?.tokenProvider
40-
?.delegate as? BearerTokenProvider
41-
?: return
42-
43-
provider.currentToken()?.accessToken?.let { token ->
44-
// assume encryption is always on
45-
updateTokenCredentials(token, true)
33+
init {
34+
project.messageBus.connect(serverInstance).apply {
35+
subscribe(BearerTokenProviderListener.TOPIC, this@DefaultAuthCredentialsService)
36+
subscribe(ToolkitConnectionManagerListener.TOPIC, this@DefaultAuthCredentialsService)
4637
}
47-
}
4838

49-
override fun invalidate(providerId: String) {
50-
deleteTokenCredentials()
39+
if (isQConnected(project) && !isQExpired(project)) {
40+
updateTokenFromActiveConnection()
41+
}
5142
}
5243

5344
override fun updateTokenCredentials(accessToken: String, encrypted: Boolean): CompletableFuture<ResponseMessage> {
@@ -66,6 +57,41 @@ class DefaultAuthCredentialsService(
6657
} ?: completableFuture.completeExceptionally(IllegalStateException("LSP Server not running"))
6758
}
6859

60+
override fun onChange(providerId: String, newScopes: List<String>?) {
61+
updateTokenFromActiveConnection()
62+
}
63+
64+
override fun activeConnectionChanged(newConnection: ToolkitConnection?) {
65+
val qConnection = ToolkitConnectionManager.getInstance(project)
66+
.activeConnectionForFeature(QConnection.getInstance())
67+
?: return
68+
if (newConnection?.id != qConnection.id) return
69+
70+
updateTokenFromConnection(newConnection)
71+
}
72+
73+
private fun updateTokenFromActiveConnection() {
74+
val connection = ToolkitConnectionManager.getInstance(project)
75+
.activeConnectionForFeature(QConnection.getInstance())
76+
?: return
77+
78+
updateTokenFromConnection(connection)
79+
}
80+
81+
private fun updateTokenFromConnection(connection: ToolkitConnection) {
82+
(connection.getConnectionSettings() as? TokenConnectionSettings)
83+
?.tokenProvider
84+
?.delegate
85+
?.let { it as? BearerTokenProvider }
86+
?.currentToken()
87+
?.accessToken
88+
?.let { token -> updateTokenCredentials(token, true) }
89+
}
90+
91+
override fun invalidate(providerId: String) {
92+
deleteTokenCredentials()
93+
}
94+
6995
private fun createUpdateCredentialsPayload(token: String, encrypted: Boolean): UpdateCredentialsPayload =
7096
if (encrypted) {
7197
UpdateCredentialsPayload(

plugins/amazonq/shared/jetbrains-community/tst/software/aws/toolkits/jetbrains/services/amazonq/lsp/auth/DefaultAuthCredentialsServiceTest.kt

Lines changed: 127 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ import com.intellij.util.messages.MessageBusConnection
1313
import io.mockk.every
1414
import io.mockk.just
1515
import io.mockk.mockk
16+
import io.mockk.mockkStatic
1617
import io.mockk.runs
1718
import io.mockk.spyk
1819
import io.mockk.verify
@@ -30,6 +31,8 @@ import software.aws.toolkits.jetbrains.services.amazonq.lsp.AmazonQLanguageServe
3031
import software.aws.toolkits.jetbrains.services.amazonq.lsp.AmazonQLspService
3132
import software.aws.toolkits.jetbrains.services.amazonq.lsp.encryption.JwtEncryptionManager
3233
import software.aws.toolkits.jetbrains.services.amazonq.lsp.model.aws.credentials.UpdateCredentialsPayload
34+
import software.aws.toolkits.jetbrains.utils.isQConnected
35+
import software.aws.toolkits.jetbrains.utils.isQExpired
3336
import java.time.Instant
3437
import java.util.concurrent.CompletableFuture
3538

@@ -38,72 +41,159 @@ class DefaultAuthCredentialsServiceTest {
3841
@JvmField
3942
@RegisterExtension
4043
val projectExtension = ProjectExtension()
44+
45+
private const val TEST_ACCESS_TOKEN = "test-access-token"
4146
}
4247

4348
private lateinit var project: Project
4449
private lateinit var mockLanguageServer: AmazonQLanguageServer
4550
private lateinit var mockEncryptionManager: JwtEncryptionManager
51+
private lateinit var mockConnectionManager: ToolkitConnectionManager
52+
private lateinit var mockConnection: AwsBearerTokenConnection
4653
private lateinit var sut: DefaultAuthCredentialsService
4754

48-
// maybe better to use real project via junit extension
4955
@BeforeEach
5056
fun setUp() {
5157
project = spyk(projectExtension.project)
58+
setupMockLspService()
59+
setupMockMessageBus()
60+
setupMockConnectionManager()
61+
}
62+
63+
private fun setupMockLspService() {
5264
mockLanguageServer = mockk<AmazonQLanguageServer>()
53-
mockEncryptionManager = mockk<JwtEncryptionManager>()
54-
every { mockEncryptionManager.encrypt(any()) } returns "mock-encrypted-data"
65+
mockEncryptionManager = mockk {
66+
every { encrypt(any()) } returns "mock-encrypted-data"
67+
}
5568

56-
// Mock the service methods on Project
5769
val mockLspService = mockk<AmazonQLspService>()
58-
every { project.getService(AmazonQLspService::class.java) } returns mockLspService
59-
every { project.serviceIfCreated<AmazonQLspService>() } returns mockLspService
60-
61-
// Mock the LSP service's executeSync method as a suspend function
6270
every {
6371
mockLspService.executeSync<CompletableFuture<ResponseMessage>>(any())
6472
} coAnswers {
6573
val func = firstArg<suspend AmazonQLspService.(AmazonQLanguageServer) -> CompletableFuture<ResponseMessage>>()
6674
func.invoke(mockLspService, mockLanguageServer)
6775
}
6876

69-
// Mock message bus
77+
every {
78+
mockLanguageServer.updateTokenCredentials(any())
79+
} returns CompletableFuture<ResponseMessage>()
80+
81+
every {
82+
mockLanguageServer.deleteTokenCredentials()
83+
} returns CompletableFuture.completedFuture(Unit)
84+
85+
every { project.getService(AmazonQLspService::class.java) } returns mockLspService
86+
every { project.serviceIfCreated<AmazonQLspService>() } returns mockLspService
87+
}
88+
89+
private fun setupMockMessageBus() {
7090
val messageBus = mockk<MessageBus>()
91+
val mockConnection = mockk<MessageBusConnection> {
92+
every { subscribe(any(), any()) } just runs
93+
}
7194
every { project.messageBus } returns messageBus
72-
val mockConnection = mockk<MessageBusConnection>()
7395
every { messageBus.connect(any<Disposable>()) } returns mockConnection
74-
every { mockConnection.subscribe(any(), any()) } just runs
75-
76-
// Mock ToolkitConnectionManager
77-
val connectionManager = mockk<ToolkitConnectionManager>()
78-
val connection = mockk<AwsBearerTokenConnection>()
79-
val connectionSettings = mockk<TokenConnectionSettings>()
80-
val provider = mockk<ToolkitBearerTokenProvider>()
81-
val tokenDelegate = mockk<InteractiveBearerTokenProvider>()
96+
}
97+
98+
private fun setupMockConnectionManager(accessToken: String = TEST_ACCESS_TOKEN) {
99+
mockConnection = createMockConnection(accessToken)
100+
mockConnectionManager = mockk {
101+
every { activeConnectionForFeature(any()) } returns mockConnection
102+
}
103+
every { project.service<ToolkitConnectionManager>() } returns mockConnectionManager
104+
mockkStatic("software.aws.toolkits.jetbrains.utils.FunctionUtilsKt")
105+
// these set so init doesn't always emit
106+
every { isQConnected(any()) } returns false
107+
every { isQExpired(any()) } returns true
108+
}
109+
110+
private fun createMockConnection(
111+
accessToken: String,
112+
connectionId: String = "test-connection-id",
113+
): AwsBearerTokenConnection = mockk {
114+
every { id } returns connectionId
115+
every { getConnectionSettings() } returns createMockTokenSettings(accessToken)
116+
}
117+
118+
private fun createMockTokenSettings(accessToken: String): TokenConnectionSettings {
82119
val token = PKCEAuthorizationGrantToken(
83120
issuerUrl = "https://example.com",
84121
refreshToken = "refreshToken",
85-
accessToken = "accessToken",
122+
accessToken = accessToken,
86123
expiresAt = Instant.MAX,
87124
createdAt = Instant.now(),
88125
region = "us-fake-1",
89126
)
90127

91-
every { project.service<ToolkitConnectionManager>() } returns connectionManager
92-
every { connectionManager.activeConnectionForFeature(any()) } returns connection
93-
every { connection.getConnectionSettings() } returns connectionSettings
94-
every { connectionSettings.tokenProvider } returns provider
95-
every { provider.delegate } returns tokenDelegate
96-
every { tokenDelegate.currentToken() } returns token
128+
val tokenDelegate = mockk<InteractiveBearerTokenProvider> {
129+
every { currentToken() } returns token
130+
}
97131

98-
every {
99-
mockLanguageServer.updateTokenCredentials(any())
100-
} returns CompletableFuture.completedFuture(ResponseMessage())
132+
val provider = mockk<ToolkitBearerTokenProvider> {
133+
every { delegate } returns tokenDelegate
134+
}
135+
136+
return mockk {
137+
every { tokenProvider } returns provider
138+
}
139+
}
140+
141+
@Test
142+
fun `activeConnectionChanged updates token when connection ID matches Q connection`() {
143+
sut = DefaultAuthCredentialsService(project, mockEncryptionManager, mockk())
144+
val newConnection = createMockConnection("new-token", "connection-id")
145+
every { mockConnection.id } returns "connection-id"
146+
147+
sut.activeConnectionChanged(newConnection)
148+
149+
verify(exactly = 1) { mockLanguageServer.updateTokenCredentials(any()) }
150+
}
151+
152+
@Test
153+
fun `activeConnectionChanged does not update token when connection ID differs`() {
154+
sut = DefaultAuthCredentialsService(project, mockEncryptionManager, mockk())
155+
val newConnection = createMockConnection("new-token", "different-id")
156+
every { mockConnection.id } returns "q-connection-id"
157+
158+
sut.activeConnectionChanged(newConnection)
159+
160+
verify(exactly = 0) { mockLanguageServer.updateTokenCredentials(any()) }
161+
}
162+
163+
@Test
164+
fun `onChange updates token with new connection`() {
165+
sut = DefaultAuthCredentialsService(project, mockEncryptionManager, mockk())
166+
setupMockConnectionManager("updated-token")
167+
168+
sut.onChange("providerId", listOf("new-scope"))
101169

102-
sut = DefaultAuthCredentialsService(project, this.mockEncryptionManager, mockk())
170+
verify(exactly = 1) { mockLanguageServer.updateTokenCredentials(any()) }
171+
}
172+
173+
@Test
174+
fun `init does not update token when Q is not connected`() {
175+
every { isQConnected(project) } returns false
176+
every { isQExpired(project) } returns false
177+
178+
sut = DefaultAuthCredentialsService(project, mockEncryptionManager, mockk())
179+
180+
verify(exactly = 0) { mockLanguageServer.updateTokenCredentials(any()) }
181+
}
182+
183+
@Test
184+
fun `init does not update token when Q is expired`() {
185+
every { isQConnected(project) } returns true
186+
every { isQExpired(project) } returns true
187+
188+
sut = DefaultAuthCredentialsService(project, mockEncryptionManager, mockk())
189+
190+
verify(exactly = 0) { mockLanguageServer.updateTokenCredentials(any()) }
103191
}
104192

105193
@Test
106194
fun `test updateTokenCredentials unencrypted success`() {
195+
sut = DefaultAuthCredentialsService(project, mockEncryptionManager, mockk())
196+
107197
val token = "unencryptedToken"
108198
val isEncrypted = false
109199

@@ -121,6 +211,8 @@ class DefaultAuthCredentialsServiceTest {
121211

122212
@Test
123213
fun `test updateTokenCredentials encrypted success`() {
214+
sut = DefaultAuthCredentialsService(project, mockEncryptionManager, mockk())
215+
124216
val encryptedToken = "encryptedToken"
125217
val decryptedToken = "decryptedToken"
126218
val isEncrypted = true
@@ -141,6 +233,8 @@ class DefaultAuthCredentialsServiceTest {
141233

142234
@Test
143235
fun `test deleteTokenCredentials success`() {
236+
sut = DefaultAuthCredentialsService(project, mockEncryptionManager, mockk())
237+
144238
every { mockLanguageServer.deleteTokenCredentials() } returns CompletableFuture.completedFuture(Unit)
145239

146240
sut.deleteTokenCredentials()
@@ -150,6 +244,10 @@ class DefaultAuthCredentialsServiceTest {
150244

151245
@Test
152246
fun `init results in token update`() {
247+
every { isQConnected(any()) } returns true
248+
every { isQExpired(any()) } returns false
249+
sut = DefaultAuthCredentialsService(project, mockEncryptionManager, mockk())
250+
153251
verify(exactly = 1) { mockLanguageServer.updateTokenCredentials(any()) }
154252
}
155253
}

0 commit comments

Comments
 (0)