@@ -13,6 +13,7 @@ import com.intellij.util.messages.MessageBusConnection
1313import io.mockk.every
1414import io.mockk.just
1515import io.mockk.mockk
16+ import io.mockk.mockkStatic
1617import io.mockk.runs
1718import io.mockk.spyk
1819import io.mockk.verify
@@ -30,6 +31,8 @@ import software.aws.toolkits.jetbrains.services.amazonq.lsp.AmazonQLanguageServe
3031import software.aws.toolkits.jetbrains.services.amazonq.lsp.AmazonQLspService
3132import software.aws.toolkits.jetbrains.services.amazonq.lsp.encryption.JwtEncryptionManager
3233import software.aws.toolkits.jetbrains.services.amazonq.lsp.model.aws.credentials.UpdateCredentialsPayload
34+ import software.aws.toolkits.jetbrains.utils.isQConnected
35+ import software.aws.toolkits.jetbrains.utils.isQExpired
3336import java.time.Instant
3437import java.util.concurrent.CompletableFuture
3538
@@ -38,72 +41,159 @@ class DefaultAuthCredentialsServiceTest {
3841 @JvmField
3942 @RegisterExtension
4043 val projectExtension = ProjectExtension ()
44+
45+ private const val TEST_ACCESS_TOKEN = " test-access-token"
4146 }
4247
4348 private lateinit var project: Project
4449 private lateinit var mockLanguageServer: AmazonQLanguageServer
4550 private lateinit var mockEncryptionManager: JwtEncryptionManager
51+ private lateinit var mockConnectionManager: ToolkitConnectionManager
52+ private lateinit var mockConnection: AwsBearerTokenConnection
4653 private lateinit var sut: DefaultAuthCredentialsService
4754
48- // maybe better to use real project via junit extension
4955 @BeforeEach
5056 fun setUp () {
5157 project = spyk(projectExtension.project)
58+ setupMockLspService()
59+ setupMockMessageBus()
60+ setupMockConnectionManager()
61+ }
62+
63+ private fun setupMockLspService () {
5264 mockLanguageServer = mockk<AmazonQLanguageServer >()
53- mockEncryptionManager = mockk<JwtEncryptionManager >()
54- every { mockEncryptionManager.encrypt(any()) } returns " mock-encrypted-data"
65+ mockEncryptionManager = mockk {
66+ every { encrypt(any()) } returns " mock-encrypted-data"
67+ }
5568
56- // Mock the service methods on Project
5769 val mockLspService = mockk<AmazonQLspService >()
58- every { project.getService(AmazonQLspService ::class .java) } returns mockLspService
59- every { project.serviceIfCreated<AmazonQLspService >() } returns mockLspService
60-
61- // Mock the LSP service's executeSync method as a suspend function
6270 every {
6371 mockLspService.executeSync<CompletableFuture <ResponseMessage >>(any())
6472 } coAnswers {
6573 val func = firstArg< suspend AmazonQLspService .(AmazonQLanguageServer ) -> CompletableFuture <ResponseMessage >>()
6674 func.invoke(mockLspService, mockLanguageServer)
6775 }
6876
69- // Mock message bus
77+ every {
78+ mockLanguageServer.updateTokenCredentials(any())
79+ } returns CompletableFuture <ResponseMessage >()
80+
81+ every {
82+ mockLanguageServer.deleteTokenCredentials()
83+ } returns CompletableFuture .completedFuture(Unit )
84+
85+ every { project.getService(AmazonQLspService ::class .java) } returns mockLspService
86+ every { project.serviceIfCreated<AmazonQLspService >() } returns mockLspService
87+ }
88+
89+ private fun setupMockMessageBus () {
7090 val messageBus = mockk<MessageBus >()
91+ val mockConnection = mockk<MessageBusConnection > {
92+ every { subscribe(any(), any()) } just runs
93+ }
7194 every { project.messageBus } returns messageBus
72- val mockConnection = mockk<MessageBusConnection >()
7395 every { messageBus.connect(any<Disposable >()) } returns mockConnection
74- every { mockConnection.subscribe(any(), any()) } just runs
75-
76- // Mock ToolkitConnectionManager
77- val connectionManager = mockk<ToolkitConnectionManager >()
78- val connection = mockk<AwsBearerTokenConnection >()
79- val connectionSettings = mockk<TokenConnectionSettings >()
80- val provider = mockk<ToolkitBearerTokenProvider >()
81- val tokenDelegate = mockk<InteractiveBearerTokenProvider >()
96+ }
97+
98+ private fun setupMockConnectionManager (accessToken : String = TEST_ACCESS_TOKEN ) {
99+ mockConnection = createMockConnection(accessToken)
100+ mockConnectionManager = mockk {
101+ every { activeConnectionForFeature(any()) } returns mockConnection
102+ }
103+ every { project.service<ToolkitConnectionManager >() } returns mockConnectionManager
104+ mockkStatic(" software.aws.toolkits.jetbrains.utils.FunctionUtilsKt" )
105+ // these set so init doesn't always emit
106+ every { isQConnected(any()) } returns false
107+ every { isQExpired(any()) } returns true
108+ }
109+
110+ private fun createMockConnection (
111+ accessToken : String ,
112+ connectionId : String = "test-connection-id",
113+ ): AwsBearerTokenConnection = mockk {
114+ every { id } returns connectionId
115+ every { getConnectionSettings() } returns createMockTokenSettings(accessToken)
116+ }
117+
118+ private fun createMockTokenSettings (accessToken : String ): TokenConnectionSettings {
82119 val token = PKCEAuthorizationGrantToken (
83120 issuerUrl = " https://example.com" ,
84121 refreshToken = " refreshToken" ,
85- accessToken = " accessToken" ,
122+ accessToken = accessToken,
86123 expiresAt = Instant .MAX ,
87124 createdAt = Instant .now(),
88125 region = " us-fake-1" ,
89126 )
90127
91- every { project.service<ToolkitConnectionManager >() } returns connectionManager
92- every { connectionManager.activeConnectionForFeature(any()) } returns connection
93- every { connection.getConnectionSettings() } returns connectionSettings
94- every { connectionSettings.tokenProvider } returns provider
95- every { provider.delegate } returns tokenDelegate
96- every { tokenDelegate.currentToken() } returns token
128+ val tokenDelegate = mockk<InteractiveBearerTokenProvider > {
129+ every { currentToken() } returns token
130+ }
97131
98- every {
99- mockLanguageServer.updateTokenCredentials(any())
100- } returns CompletableFuture .completedFuture(ResponseMessage ())
132+ val provider = mockk<ToolkitBearerTokenProvider > {
133+ every { delegate } returns tokenDelegate
134+ }
135+
136+ return mockk {
137+ every { tokenProvider } returns provider
138+ }
139+ }
140+
141+ @Test
142+ fun `activeConnectionChanged updates token when connection ID matches Q connection` () {
143+ sut = DefaultAuthCredentialsService (project, mockEncryptionManager, mockk())
144+ val newConnection = createMockConnection(" new-token" , " connection-id" )
145+ every { mockConnection.id } returns " connection-id"
146+
147+ sut.activeConnectionChanged(newConnection)
148+
149+ verify(exactly = 1 ) { mockLanguageServer.updateTokenCredentials(any()) }
150+ }
151+
152+ @Test
153+ fun `activeConnectionChanged does not update token when connection ID differs` () {
154+ sut = DefaultAuthCredentialsService (project, mockEncryptionManager, mockk())
155+ val newConnection = createMockConnection(" new-token" , " different-id" )
156+ every { mockConnection.id } returns " q-connection-id"
157+
158+ sut.activeConnectionChanged(newConnection)
159+
160+ verify(exactly = 0 ) { mockLanguageServer.updateTokenCredentials(any()) }
161+ }
162+
163+ @Test
164+ fun `onChange updates token with new connection` () {
165+ sut = DefaultAuthCredentialsService (project, mockEncryptionManager, mockk())
166+ setupMockConnectionManager(" updated-token" )
167+
168+ sut.onChange(" providerId" , listOf (" new-scope" ))
101169
102- sut = DefaultAuthCredentialsService (project, this .mockEncryptionManager, mockk())
170+ verify(exactly = 1 ) { mockLanguageServer.updateTokenCredentials(any()) }
171+ }
172+
173+ @Test
174+ fun `init does not update token when Q is not connected` () {
175+ every { isQConnected(project) } returns false
176+ every { isQExpired(project) } returns false
177+
178+ sut = DefaultAuthCredentialsService (project, mockEncryptionManager, mockk())
179+
180+ verify(exactly = 0 ) { mockLanguageServer.updateTokenCredentials(any()) }
181+ }
182+
183+ @Test
184+ fun `init does not update token when Q is expired` () {
185+ every { isQConnected(project) } returns true
186+ every { isQExpired(project) } returns true
187+
188+ sut = DefaultAuthCredentialsService (project, mockEncryptionManager, mockk())
189+
190+ verify(exactly = 0 ) { mockLanguageServer.updateTokenCredentials(any()) }
103191 }
104192
105193 @Test
106194 fun `test updateTokenCredentials unencrypted success` () {
195+ sut = DefaultAuthCredentialsService (project, mockEncryptionManager, mockk())
196+
107197 val token = " unencryptedToken"
108198 val isEncrypted = false
109199
@@ -121,6 +211,8 @@ class DefaultAuthCredentialsServiceTest {
121211
122212 @Test
123213 fun `test updateTokenCredentials encrypted success` () {
214+ sut = DefaultAuthCredentialsService (project, mockEncryptionManager, mockk())
215+
124216 val encryptedToken = " encryptedToken"
125217 val decryptedToken = " decryptedToken"
126218 val isEncrypted = true
@@ -141,6 +233,8 @@ class DefaultAuthCredentialsServiceTest {
141233
142234 @Test
143235 fun `test deleteTokenCredentials success` () {
236+ sut = DefaultAuthCredentialsService (project, mockEncryptionManager, mockk())
237+
144238 every { mockLanguageServer.deleteTokenCredentials() } returns CompletableFuture .completedFuture(Unit )
145239
146240 sut.deleteTokenCredentials()
@@ -150,6 +244,10 @@ class DefaultAuthCredentialsServiceTest {
150244
151245 @Test
152246 fun `init results in token update` () {
247+ every { isQConnected(any()) } returns true
248+ every { isQExpired(any()) } returns false
249+ sut = DefaultAuthCredentialsService (project, mockEncryptionManager, mockk())
250+
153251 verify(exactly = 1 ) { mockLanguageServer.updateTokenCredentials(any()) }
154252 }
155253}
0 commit comments