Skip to content

Commit 8e03806

Browse files
committed
tests
1 parent 77817fa commit 8e03806

File tree

2 files changed

+129
-32
lines changed

2 files changed

+129
-32
lines changed

plugins/amazonq/shared/jetbrains-community/src/software/aws/toolkits/jetbrains/services/amazonq/lsp/auth/DefaultAuthCredentialsService.kt

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,10 +9,10 @@ import org.eclipse.lsp4j.jsonrpc.messages.ResponseMessage
99
import software.aws.toolkits.core.TokenConnectionSettings
1010
import software.aws.toolkits.jetbrains.core.credentials.ToolkitConnection
1111
import software.aws.toolkits.jetbrains.core.credentials.ToolkitConnectionManager
12+
import software.aws.toolkits.jetbrains.core.credentials.ToolkitConnectionManagerListener
1213
import software.aws.toolkits.jetbrains.core.credentials.pinning.QConnection
1314
import software.aws.toolkits.jetbrains.core.credentials.sso.bearer.BearerTokenProvider
1415
import software.aws.toolkits.jetbrains.core.credentials.sso.bearer.BearerTokenProviderListener
15-
import software.aws.toolkits.jetbrains.core.credentials.ToolkitConnectionManagerListener
1616
import software.aws.toolkits.jetbrains.services.amazonq.lsp.AmazonQLspService
1717
import software.aws.toolkits.jetbrains.services.amazonq.lsp.encryption.JwtEncryptionManager
1818
import software.aws.toolkits.jetbrains.services.amazonq.lsp.model.aws.credentials.BearerCredentials
@@ -36,7 +36,7 @@ class DefaultAuthCredentialsService(
3636
subscribe(ToolkitConnectionManagerListener.TOPIC, this@DefaultAuthCredentialsService)
3737
}
3838

39-
if(isQConnected(project) && !isQExpired(project)) {
39+
if (isQConnected(project) && !isQExpired(project)) {
4040
updateTokenFromActiveConnection()
4141
}
4242
}
@@ -108,5 +108,4 @@ class DefaultAuthCredentialsService(
108108
encrypted = false
109109
)
110110
}
111-
112111
}

plugins/amazonq/shared/jetbrains-community/tst/software/aws/toolkits/jetbrains/services/amazonq/lsp/auth/DefaultAuthCredentialsServiceTest.kt

Lines changed: 127 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ import com.intellij.util.messages.MessageBusConnection
1313
import io.mockk.every
1414
import io.mockk.just
1515
import io.mockk.mockk
16+
import io.mockk.mockkStatic
1617
import io.mockk.runs
1718
import io.mockk.spyk
1819
import io.mockk.verify
@@ -30,6 +31,8 @@ import software.aws.toolkits.jetbrains.services.amazonq.lsp.AmazonQLanguageServe
3031
import software.aws.toolkits.jetbrains.services.amazonq.lsp.AmazonQLspService
3132
import software.aws.toolkits.jetbrains.services.amazonq.lsp.encryption.JwtEncryptionManager
3233
import software.aws.toolkits.jetbrains.services.amazonq.lsp.model.aws.credentials.UpdateCredentialsPayload
34+
import software.aws.toolkits.jetbrains.utils.isQConnected
35+
import software.aws.toolkits.jetbrains.utils.isQExpired
3336
import java.time.Instant
3437
import java.util.concurrent.CompletableFuture
3538

@@ -38,72 +41,159 @@ class DefaultAuthCredentialsServiceTest {
3841
@JvmField
3942
@RegisterExtension
4043
val projectExtension = ProjectExtension()
44+
45+
private const val TEST_ACCESS_TOKEN = "test-access-token"
4146
}
4247

4348
private lateinit var project: Project
4449
private lateinit var mockLanguageServer: AmazonQLanguageServer
4550
private lateinit var mockEncryptionManager: JwtEncryptionManager
51+
private lateinit var mockConnectionManager: ToolkitConnectionManager
52+
private lateinit var mockConnection: AwsBearerTokenConnection
4653
private lateinit var sut: DefaultAuthCredentialsService
4754

48-
// maybe better to use real project via junit extension
4955
@BeforeEach
5056
fun setUp() {
5157
project = spyk(projectExtension.project)
58+
setupMockLspService()
59+
setupMockMessageBus()
60+
setupMockConnectionManager()
61+
}
62+
63+
private fun setupMockLspService() {
5264
mockLanguageServer = mockk<AmazonQLanguageServer>()
53-
mockEncryptionManager = mockk<JwtEncryptionManager>()
54-
every { mockEncryptionManager.encrypt(any()) } returns "mock-encrypted-data"
65+
mockEncryptionManager = mockk {
66+
every { encrypt(any()) } returns "mock-encrypted-data"
67+
}
5568

56-
// Mock the service methods on Project
5769
val mockLspService = mockk<AmazonQLspService>()
58-
every { project.getService(AmazonQLspService::class.java) } returns mockLspService
59-
every { project.serviceIfCreated<AmazonQLspService>() } returns mockLspService
60-
61-
// Mock the LSP service's executeSync method as a suspend function
6270
every {
6371
mockLspService.executeSync<CompletableFuture<ResponseMessage>>(any())
6472
} coAnswers {
6573
val func = firstArg<suspend AmazonQLspService.(AmazonQLanguageServer) -> CompletableFuture<ResponseMessage>>()
6674
func.invoke(mockLspService, mockLanguageServer)
6775
}
6876

69-
// Mock message bus
77+
every {
78+
mockLanguageServer.updateTokenCredentials(any())
79+
} returns CompletableFuture<ResponseMessage>()
80+
81+
every {
82+
mockLanguageServer.deleteTokenCredentials()
83+
} returns CompletableFuture.completedFuture(Unit)
84+
85+
every { project.getService(AmazonQLspService::class.java) } returns mockLspService
86+
every { project.serviceIfCreated<AmazonQLspService>() } returns mockLspService
87+
}
88+
89+
private fun setupMockMessageBus() {
7090
val messageBus = mockk<MessageBus>()
91+
val mockConnection = mockk<MessageBusConnection> {
92+
every { subscribe(any(), any()) } just runs
93+
}
7194
every { project.messageBus } returns messageBus
72-
val mockConnection = mockk<MessageBusConnection>()
7395
every { messageBus.connect(any<Disposable>()) } returns mockConnection
74-
every { mockConnection.subscribe(any(), any()) } just runs
75-
76-
// Mock ToolkitConnectionManager
77-
val connectionManager = mockk<ToolkitConnectionManager>()
78-
val connection = mockk<AwsBearerTokenConnection>()
79-
val connectionSettings = mockk<TokenConnectionSettings>()
80-
val provider = mockk<ToolkitBearerTokenProvider>()
81-
val tokenDelegate = mockk<InteractiveBearerTokenProvider>()
96+
}
97+
98+
private fun setupMockConnectionManager(accessToken: String = TEST_ACCESS_TOKEN) {
99+
mockConnection = createMockConnection(accessToken)
100+
mockConnectionManager = mockk {
101+
every { activeConnectionForFeature(any()) } returns mockConnection
102+
}
103+
every { project.service<ToolkitConnectionManager>() } returns mockConnectionManager
104+
mockkStatic("software.aws.toolkits.jetbrains.utils.FunctionUtilsKt")
105+
// these set so init doesn't always emit
106+
every { isQConnected(any()) } returns false
107+
every { isQExpired(any()) } returns true
108+
}
109+
110+
private fun createMockConnection(
111+
accessToken: String,
112+
connectionId: String = "test-connection-id",
113+
): AwsBearerTokenConnection = mockk {
114+
every { id } returns connectionId
115+
every { getConnectionSettings() } returns createMockTokenSettings(accessToken)
116+
}
117+
118+
private fun createMockTokenSettings(accessToken: String): TokenConnectionSettings {
82119
val token = PKCEAuthorizationGrantToken(
83120
issuerUrl = "https://example.com",
84121
refreshToken = "refreshToken",
85-
accessToken = "accessToken",
122+
accessToken = accessToken,
86123
expiresAt = Instant.MAX,
87124
createdAt = Instant.now(),
88125
region = "us-fake-1",
89126
)
90127

91-
every { project.service<ToolkitConnectionManager>() } returns connectionManager
92-
every { connectionManager.activeConnectionForFeature(any()) } returns connection
93-
every { connection.getConnectionSettings() } returns connectionSettings
94-
every { connectionSettings.tokenProvider } returns provider
95-
every { provider.delegate } returns tokenDelegate
96-
every { tokenDelegate.currentToken() } returns token
128+
val tokenDelegate = mockk<InteractiveBearerTokenProvider> {
129+
every { currentToken() } returns token
130+
}
97131

98-
every {
99-
mockLanguageServer.updateTokenCredentials(any())
100-
} returns CompletableFuture.completedFuture(ResponseMessage())
132+
val provider = mockk<ToolkitBearerTokenProvider> {
133+
every { delegate } returns tokenDelegate
134+
}
135+
136+
return mockk {
137+
every { tokenProvider } returns provider
138+
}
139+
}
140+
141+
@Test
142+
fun `activeConnectionChanged updates token when connection ID matches Q connection`() {
143+
sut = DefaultAuthCredentialsService(project, mockEncryptionManager, mockk())
144+
val newConnection = createMockConnection("new-token", "connection-id")
145+
every { mockConnection.id } returns "connection-id"
146+
147+
sut.activeConnectionChanged(newConnection)
148+
149+
verify(exactly = 1) { mockLanguageServer.updateTokenCredentials(any()) }
150+
}
151+
152+
@Test
153+
fun `activeConnectionChanged does not update token when connection ID differs`() {
154+
sut = DefaultAuthCredentialsService(project, mockEncryptionManager, mockk())
155+
val newConnection = createMockConnection("new-token", "different-id")
156+
every { mockConnection.id } returns "q-connection-id"
157+
158+
sut.activeConnectionChanged(newConnection)
159+
160+
verify(exactly = 0) { mockLanguageServer.updateTokenCredentials(any()) }
161+
}
162+
163+
@Test
164+
fun `onChange updates token with new connection`() {
165+
sut = DefaultAuthCredentialsService(project, mockEncryptionManager, mockk())
166+
setupMockConnectionManager("updated-token")
167+
168+
sut.onChange("providerId", listOf("new-scope"))
101169

102-
sut = DefaultAuthCredentialsService(project, this.mockEncryptionManager, mockk())
170+
verify(exactly = 1) { mockLanguageServer.updateTokenCredentials(any()) }
171+
}
172+
173+
@Test
174+
fun `init does not update token when Q is not connected`() {
175+
every { isQConnected(project) } returns false
176+
every { isQExpired(project) } returns false
177+
178+
sut = DefaultAuthCredentialsService(project, mockEncryptionManager, mockk())
179+
180+
verify(exactly = 0) { mockLanguageServer.updateTokenCredentials(any()) }
181+
}
182+
183+
@Test
184+
fun `init does not update token when Q is expired`() {
185+
every { isQConnected(project) } returns true
186+
every { isQExpired(project) } returns true
187+
188+
sut = DefaultAuthCredentialsService(project, mockEncryptionManager, mockk())
189+
190+
verify(exactly = 0) { mockLanguageServer.updateTokenCredentials(any()) }
103191
}
104192

105193
@Test
106194
fun `test updateTokenCredentials unencrypted success`() {
195+
sut = DefaultAuthCredentialsService(project, mockEncryptionManager, mockk())
196+
107197
val token = "unencryptedToken"
108198
val isEncrypted = false
109199

@@ -121,6 +211,8 @@ class DefaultAuthCredentialsServiceTest {
121211

122212
@Test
123213
fun `test updateTokenCredentials encrypted success`() {
214+
sut = DefaultAuthCredentialsService(project, mockEncryptionManager, mockk())
215+
124216
val encryptedToken = "encryptedToken"
125217
val decryptedToken = "decryptedToken"
126218
val isEncrypted = true
@@ -141,6 +233,8 @@ class DefaultAuthCredentialsServiceTest {
141233

142234
@Test
143235
fun `test deleteTokenCredentials success`() {
236+
sut = DefaultAuthCredentialsService(project, mockEncryptionManager, mockk())
237+
144238
every { mockLanguageServer.deleteTokenCredentials() } returns CompletableFuture.completedFuture(Unit)
145239

146240
sut.deleteTokenCredentials()
@@ -150,6 +244,10 @@ class DefaultAuthCredentialsServiceTest {
150244

151245
@Test
152246
fun `init results in token update`() {
247+
every { isQConnected(any()) } returns true
248+
every { isQExpired(any()) } returns false
249+
sut = DefaultAuthCredentialsService(project, mockEncryptionManager, mockk())
250+
153251
verify(exactly = 1) { mockLanguageServer.updateTokenCredentials(any()) }
154252
}
155253
}

0 commit comments

Comments
 (0)