diff --git a/plugins/amazonq/shared/jetbrains-community/src/software/aws/toolkits/jetbrains/services/amazonq/lsp/auth/DefaultAuthCredentialsService.kt b/plugins/amazonq/shared/jetbrains-community/src/software/aws/toolkits/jetbrains/services/amazonq/lsp/auth/DefaultAuthCredentialsService.kt index 334b674ae5e..c5be0abe56c 100644 --- a/plugins/amazonq/shared/jetbrains-community/src/software/aws/toolkits/jetbrains/services/amazonq/lsp/auth/DefaultAuthCredentialsService.kt +++ b/plugins/amazonq/shared/jetbrains-community/src/software/aws/toolkits/jetbrains/services/amazonq/lsp/auth/DefaultAuthCredentialsService.kt @@ -7,7 +7,9 @@ import com.intellij.openapi.Disposable import com.intellij.openapi.project.Project import org.eclipse.lsp4j.jsonrpc.messages.ResponseMessage import software.aws.toolkits.core.TokenConnectionSettings +import software.aws.toolkits.jetbrains.core.credentials.ToolkitConnection import software.aws.toolkits.jetbrains.core.credentials.ToolkitConnectionManager +import software.aws.toolkits.jetbrains.core.credentials.ToolkitConnectionManagerListener import software.aws.toolkits.jetbrains.core.credentials.pinning.QConnection import software.aws.toolkits.jetbrains.core.credentials.sso.bearer.BearerTokenProvider import software.aws.toolkits.jetbrains.core.credentials.sso.bearer.BearerTokenProviderListener @@ -16,6 +18,8 @@ import software.aws.toolkits.jetbrains.services.amazonq.lsp.encryption.JwtEncryp import software.aws.toolkits.jetbrains.services.amazonq.lsp.model.aws.credentials.BearerCredentials import software.aws.toolkits.jetbrains.services.amazonq.lsp.model.aws.credentials.UpdateCredentialsPayload import software.aws.toolkits.jetbrains.services.amazonq.lsp.model.aws.credentials.UpdateCredentialsPayloadData +import software.aws.toolkits.jetbrains.utils.isQConnected +import software.aws.toolkits.jetbrains.utils.isQExpired import java.util.concurrent.CompletableFuture class DefaultAuthCredentialsService( @@ -23,31 +27,18 @@ class DefaultAuthCredentialsService( private val encryptionManager: JwtEncryptionManager, serverInstance: Disposable, ) : AuthCredentialsService, - BearerTokenProviderListener { - init { - project.messageBus.connect(serverInstance).subscribe(BearerTokenProviderListener.TOPIC, this) - - onChange("init", null) - } - - override fun onChange(providerId: String, newScopes: List?) { - val connection = ToolkitConnectionManager.getInstance(project) - .activeConnectionForFeature(QConnection.getInstance()) - ?: return + BearerTokenProviderListener, + ToolkitConnectionManagerListener { - val provider = (connection.getConnectionSettings() as? TokenConnectionSettings) - ?.tokenProvider - ?.delegate as? BearerTokenProvider - ?: return - - provider.currentToken()?.accessToken?.let { token -> - // assume encryption is always on - updateTokenCredentials(token, true) + init { + project.messageBus.connect(serverInstance).apply { + subscribe(BearerTokenProviderListener.TOPIC, this@DefaultAuthCredentialsService) + subscribe(ToolkitConnectionManagerListener.TOPIC, this@DefaultAuthCredentialsService) } - } - override fun invalidate(providerId: String) { - deleteTokenCredentials() + if (isQConnected(project) && !isQExpired(project)) { + updateTokenFromActiveConnection() + } } override fun updateTokenCredentials(accessToken: String, encrypted: Boolean): CompletableFuture { @@ -66,6 +57,41 @@ class DefaultAuthCredentialsService( } ?: completableFuture.completeExceptionally(IllegalStateException("LSP Server not running")) } + override fun onChange(providerId: String, newScopes: List?) { + updateTokenFromActiveConnection() + } + + override fun activeConnectionChanged(newConnection: ToolkitConnection?) { + val qConnection = ToolkitConnectionManager.getInstance(project) + .activeConnectionForFeature(QConnection.getInstance()) + ?: return + if (newConnection?.id != qConnection.id) return + + updateTokenFromConnection(newConnection) + } + + private fun updateTokenFromActiveConnection() { + val connection = ToolkitConnectionManager.getInstance(project) + .activeConnectionForFeature(QConnection.getInstance()) + ?: return + + updateTokenFromConnection(connection) + } + + private fun updateTokenFromConnection(connection: ToolkitConnection) { + (connection.getConnectionSettings() as? TokenConnectionSettings) + ?.tokenProvider + ?.delegate + ?.let { it as? BearerTokenProvider } + ?.currentToken() + ?.accessToken + ?.let { token -> updateTokenCredentials(token, true) } + } + + override fun invalidate(providerId: String) { + deleteTokenCredentials() + } + private fun createUpdateCredentialsPayload(token: String, encrypted: Boolean): UpdateCredentialsPayload = if (encrypted) { UpdateCredentialsPayload( diff --git a/plugins/amazonq/shared/jetbrains-community/tst/software/aws/toolkits/jetbrains/services/amazonq/lsp/auth/DefaultAuthCredentialsServiceTest.kt b/plugins/amazonq/shared/jetbrains-community/tst/software/aws/toolkits/jetbrains/services/amazonq/lsp/auth/DefaultAuthCredentialsServiceTest.kt index bf7aa4cf4dd..d141268d2c3 100644 --- a/plugins/amazonq/shared/jetbrains-community/tst/software/aws/toolkits/jetbrains/services/amazonq/lsp/auth/DefaultAuthCredentialsServiceTest.kt +++ b/plugins/amazonq/shared/jetbrains-community/tst/software/aws/toolkits/jetbrains/services/amazonq/lsp/auth/DefaultAuthCredentialsServiceTest.kt @@ -13,6 +13,7 @@ import com.intellij.util.messages.MessageBusConnection import io.mockk.every import io.mockk.just import io.mockk.mockk +import io.mockk.mockkStatic import io.mockk.runs import io.mockk.spyk import io.mockk.verify @@ -30,6 +31,8 @@ import software.aws.toolkits.jetbrains.services.amazonq.lsp.AmazonQLanguageServe import software.aws.toolkits.jetbrains.services.amazonq.lsp.AmazonQLspService import software.aws.toolkits.jetbrains.services.amazonq.lsp.encryption.JwtEncryptionManager import software.aws.toolkits.jetbrains.services.amazonq.lsp.model.aws.credentials.UpdateCredentialsPayload +import software.aws.toolkits.jetbrains.utils.isQConnected +import software.aws.toolkits.jetbrains.utils.isQExpired import java.time.Instant import java.util.concurrent.CompletableFuture @@ -38,27 +41,32 @@ class DefaultAuthCredentialsServiceTest { @JvmField @RegisterExtension val projectExtension = ProjectExtension() + + private const val TEST_ACCESS_TOKEN = "test-access-token" } private lateinit var project: Project private lateinit var mockLanguageServer: AmazonQLanguageServer private lateinit var mockEncryptionManager: JwtEncryptionManager + private lateinit var mockConnectionManager: ToolkitConnectionManager + private lateinit var mockConnection: AwsBearerTokenConnection private lateinit var sut: DefaultAuthCredentialsService - // maybe better to use real project via junit extension @BeforeEach fun setUp() { project = spyk(projectExtension.project) + setupMockLspService() + setupMockMessageBus() + setupMockConnectionManager() + } + + private fun setupMockLspService() { mockLanguageServer = mockk() - mockEncryptionManager = mockk() - every { mockEncryptionManager.encrypt(any()) } returns "mock-encrypted-data" + mockEncryptionManager = mockk { + every { encrypt(any()) } returns "mock-encrypted-data" + } - // Mock the service methods on Project val mockLspService = mockk() - every { project.getService(AmazonQLspService::class.java) } returns mockLspService - every { project.serviceIfCreated() } returns mockLspService - - // Mock the LSP service's executeSync method as a suspend function every { mockLspService.executeSync>(any()) } coAnswers { @@ -66,44 +74,126 @@ class DefaultAuthCredentialsServiceTest { func.invoke(mockLspService, mockLanguageServer) } - // Mock message bus + every { + mockLanguageServer.updateTokenCredentials(any()) + } returns CompletableFuture() + + every { + mockLanguageServer.deleteTokenCredentials() + } returns CompletableFuture.completedFuture(Unit) + + every { project.getService(AmazonQLspService::class.java) } returns mockLspService + every { project.serviceIfCreated() } returns mockLspService + } + + private fun setupMockMessageBus() { val messageBus = mockk() + val mockConnection = mockk { + every { subscribe(any(), any()) } just runs + } every { project.messageBus } returns messageBus - val mockConnection = mockk() every { messageBus.connect(any()) } returns mockConnection - every { mockConnection.subscribe(any(), any()) } just runs - - // Mock ToolkitConnectionManager - val connectionManager = mockk() - val connection = mockk() - val connectionSettings = mockk() - val provider = mockk() - val tokenDelegate = mockk() + } + + private fun setupMockConnectionManager(accessToken: String = TEST_ACCESS_TOKEN) { + mockConnection = createMockConnection(accessToken) + mockConnectionManager = mockk { + every { activeConnectionForFeature(any()) } returns mockConnection + } + every { project.service() } returns mockConnectionManager + mockkStatic("software.aws.toolkits.jetbrains.utils.FunctionUtilsKt") + // these set so init doesn't always emit + every { isQConnected(any()) } returns false + every { isQExpired(any()) } returns true + } + + private fun createMockConnection( + accessToken: String, + connectionId: String = "test-connection-id", + ): AwsBearerTokenConnection = mockk { + every { id } returns connectionId + every { getConnectionSettings() } returns createMockTokenSettings(accessToken) + } + + private fun createMockTokenSettings(accessToken: String): TokenConnectionSettings { val token = PKCEAuthorizationGrantToken( issuerUrl = "https://example.com", refreshToken = "refreshToken", - accessToken = "accessToken", + accessToken = accessToken, expiresAt = Instant.MAX, createdAt = Instant.now(), region = "us-fake-1", ) - every { project.service() } returns connectionManager - every { connectionManager.activeConnectionForFeature(any()) } returns connection - every { connection.getConnectionSettings() } returns connectionSettings - every { connectionSettings.tokenProvider } returns provider - every { provider.delegate } returns tokenDelegate - every { tokenDelegate.currentToken() } returns token + val tokenDelegate = mockk { + every { currentToken() } returns token + } - every { - mockLanguageServer.updateTokenCredentials(any()) - } returns CompletableFuture.completedFuture(ResponseMessage()) + val provider = mockk { + every { delegate } returns tokenDelegate + } + + return mockk { + every { tokenProvider } returns provider + } + } + + @Test + fun `activeConnectionChanged updates token when connection ID matches Q connection`() { + sut = DefaultAuthCredentialsService(project, mockEncryptionManager, mockk()) + val newConnection = createMockConnection("new-token", "connection-id") + every { mockConnection.id } returns "connection-id" + + sut.activeConnectionChanged(newConnection) + + verify(exactly = 1) { mockLanguageServer.updateTokenCredentials(any()) } + } + + @Test + fun `activeConnectionChanged does not update token when connection ID differs`() { + sut = DefaultAuthCredentialsService(project, mockEncryptionManager, mockk()) + val newConnection = createMockConnection("new-token", "different-id") + every { mockConnection.id } returns "q-connection-id" + + sut.activeConnectionChanged(newConnection) + + verify(exactly = 0) { mockLanguageServer.updateTokenCredentials(any()) } + } + + @Test + fun `onChange updates token with new connection`() { + sut = DefaultAuthCredentialsService(project, mockEncryptionManager, mockk()) + setupMockConnectionManager("updated-token") + + sut.onChange("providerId", listOf("new-scope")) - sut = DefaultAuthCredentialsService(project, this.mockEncryptionManager, mockk()) + verify(exactly = 1) { mockLanguageServer.updateTokenCredentials(any()) } + } + + @Test + fun `init does not update token when Q is not connected`() { + every { isQConnected(project) } returns false + every { isQExpired(project) } returns false + + sut = DefaultAuthCredentialsService(project, mockEncryptionManager, mockk()) + + verify(exactly = 0) { mockLanguageServer.updateTokenCredentials(any()) } + } + + @Test + fun `init does not update token when Q is expired`() { + every { isQConnected(project) } returns true + every { isQExpired(project) } returns true + + sut = DefaultAuthCredentialsService(project, mockEncryptionManager, mockk()) + + verify(exactly = 0) { mockLanguageServer.updateTokenCredentials(any()) } } @Test fun `test updateTokenCredentials unencrypted success`() { + sut = DefaultAuthCredentialsService(project, mockEncryptionManager, mockk()) + val token = "unencryptedToken" val isEncrypted = false @@ -121,6 +211,8 @@ class DefaultAuthCredentialsServiceTest { @Test fun `test updateTokenCredentials encrypted success`() { + sut = DefaultAuthCredentialsService(project, mockEncryptionManager, mockk()) + val encryptedToken = "encryptedToken" val decryptedToken = "decryptedToken" val isEncrypted = true @@ -141,6 +233,8 @@ class DefaultAuthCredentialsServiceTest { @Test fun `test deleteTokenCredentials success`() { + sut = DefaultAuthCredentialsService(project, mockEncryptionManager, mockk()) + every { mockLanguageServer.deleteTokenCredentials() } returns CompletableFuture.completedFuture(Unit) sut.deleteTokenCredentials() @@ -150,6 +244,10 @@ class DefaultAuthCredentialsServiceTest { @Test fun `init results in token update`() { + every { isQConnected(any()) } returns true + every { isQExpired(any()) } returns false + sut = DefaultAuthCredentialsService(project, mockEncryptionManager, mockk()) + verify(exactly = 1) { mockLanguageServer.updateTokenCredentials(any()) } } }