diff --git a/plugins/amazonq/shared/jetbrains-community/src/software/aws/toolkits/jetbrains/services/amazonq/lsp/AmazonQLspService.kt b/plugins/amazonq/shared/jetbrains-community/src/software/aws/toolkits/jetbrains/services/amazonq/lsp/AmazonQLspService.kt index 8a2e905f8b2..293ec0e3107 100644 --- a/plugins/amazonq/shared/jetbrains-community/src/software/aws/toolkits/jetbrains/services/amazonq/lsp/AmazonQLspService.kt +++ b/plugins/amazonq/shared/jetbrains-community/src/software/aws/toolkits/jetbrains/services/amazonq/lsp/AmazonQLspService.kt @@ -24,7 +24,7 @@ import kotlinx.coroutines.CoroutineScope import kotlinx.coroutines.Deferred import kotlinx.coroutines.TimeoutCancellationException import kotlinx.coroutines.async -import kotlinx.coroutines.launch +import kotlinx.coroutines.future.asCompletableFuture import kotlinx.coroutines.runBlocking import kotlinx.coroutines.sync.Mutex import kotlinx.coroutines.sync.withLock @@ -316,11 +316,18 @@ private class AmazonQServerInstance(private val project: Project, private val cs initializeResult } - DefaultAuthCredentialsService(project, encryptionManager, this) - TextDocumentServiceHandler(project, this) - WorkspaceServiceHandler(project, this) - cs.launch { - DefaultModuleDependenciesService(project, this@AmazonQServerInstance) + // invokeOnCompletion results in weird lock/timeout error + initializeResult.asCompletableFuture().handleAsync { r, ex -> + if (ex != null) { + return@handleAsync + } + + this@AmazonQServerInstance.apply { + DefaultAuthCredentialsService(project, encryptionManager, this) + TextDocumentServiceHandler(project, this) + WorkspaceServiceHandler(project, this) + DefaultModuleDependenciesService(project, this) + } } } diff --git a/plugins/amazonq/shared/jetbrains-community/src/software/aws/toolkits/jetbrains/services/amazonq/lsp/auth/DefaultAuthCredentialsService.kt b/plugins/amazonq/shared/jetbrains-community/src/software/aws/toolkits/jetbrains/services/amazonq/lsp/auth/DefaultAuthCredentialsService.kt index 5509e6cb688..334b674ae5e 100644 --- a/plugins/amazonq/shared/jetbrains-community/src/software/aws/toolkits/jetbrains/services/amazonq/lsp/auth/DefaultAuthCredentialsService.kt +++ b/plugins/amazonq/shared/jetbrains-community/src/software/aws/toolkits/jetbrains/services/amazonq/lsp/auth/DefaultAuthCredentialsService.kt @@ -26,6 +26,8 @@ class DefaultAuthCredentialsService( BearerTokenProviderListener { init { project.messageBus.connect(serverInstance).subscribe(BearerTokenProviderListener.TOPIC, this) + + onChange("init", null) } override fun onChange(providerId: String, newScopes: List?) { @@ -39,7 +41,8 @@ class DefaultAuthCredentialsService( ?: return provider.currentToken()?.accessToken?.let { token -> - updateTokenCredentials(token, false) + // assume encryption is always on + updateTokenCredentials(token, true) } } @@ -48,13 +51,7 @@ class DefaultAuthCredentialsService( } override fun updateTokenCredentials(accessToken: String, encrypted: Boolean): CompletableFuture { - val token = if (encrypted) { - encryptionManager.decrypt(accessToken) - } else { - accessToken - } - - val payload = createUpdateCredentialsPayload(token) + val payload = createUpdateCredentialsPayload(accessToken, encrypted) return AmazonQLspService.executeIfRunning(project) { server -> server.updateTokenCredentials(payload) @@ -69,13 +66,20 @@ class DefaultAuthCredentialsService( } ?: completableFuture.completeExceptionally(IllegalStateException("LSP Server not running")) } - private fun createUpdateCredentialsPayload(token: String): UpdateCredentialsPayload = - UpdateCredentialsPayload( - data = encryptionManager.encrypt( - UpdateCredentialsPayloadData( - BearerCredentials(token) - ) - ), - encrypted = true - ) + private fun createUpdateCredentialsPayload(token: String, encrypted: Boolean): UpdateCredentialsPayload = + if (encrypted) { + UpdateCredentialsPayload( + data = encryptionManager.encrypt( + UpdateCredentialsPayloadData( + BearerCredentials(token) + ) + ), + encrypted = true + ) + } else { + UpdateCredentialsPayload( + data = token, + encrypted = false + ) + } } diff --git a/plugins/amazonq/shared/jetbrains-community/tst/software/aws/toolkits/jetbrains/services/amazonq/lsp/auth/DefaultAuthCredentialsServiceTest.kt b/plugins/amazonq/shared/jetbrains-community/tst/software/aws/toolkits/jetbrains/services/amazonq/lsp/auth/DefaultAuthCredentialsServiceTest.kt index 388e2286082..bf7aa4cf4dd 100644 --- a/plugins/amazonq/shared/jetbrains-community/tst/software/aws/toolkits/jetbrains/services/amazonq/lsp/auth/DefaultAuthCredentialsServiceTest.kt +++ b/plugins/amazonq/shared/jetbrains-community/tst/software/aws/toolkits/jetbrains/services/amazonq/lsp/auth/DefaultAuthCredentialsServiceTest.kt @@ -4,32 +4,51 @@ package software.aws.toolkits.jetbrains.services.amazonq.lsp.auth import com.intellij.openapi.Disposable +import com.intellij.openapi.components.service import com.intellij.openapi.components.serviceIfCreated import com.intellij.openapi.project.Project +import com.intellij.testFramework.ProjectExtension import com.intellij.util.messages.MessageBus import com.intellij.util.messages.MessageBusConnection import io.mockk.every import io.mockk.just import io.mockk.mockk import io.mockk.runs +import io.mockk.spyk import io.mockk.verify import org.eclipse.lsp4j.jsonrpc.messages.ResponseMessage -import org.junit.Before -import org.junit.Test +import org.junit.jupiter.api.BeforeEach +import org.junit.jupiter.api.Test +import org.junit.jupiter.api.extension.RegisterExtension +import software.aws.toolkits.core.TokenConnectionSettings +import software.aws.toolkits.core.credentials.ToolkitBearerTokenProvider +import software.aws.toolkits.jetbrains.core.credentials.AwsBearerTokenConnection +import software.aws.toolkits.jetbrains.core.credentials.ToolkitConnectionManager +import software.aws.toolkits.jetbrains.core.credentials.sso.PKCEAuthorizationGrantToken +import software.aws.toolkits.jetbrains.core.credentials.sso.bearer.InteractiveBearerTokenProvider import software.aws.toolkits.jetbrains.services.amazonq.lsp.AmazonQLanguageServer import software.aws.toolkits.jetbrains.services.amazonq.lsp.AmazonQLspService import software.aws.toolkits.jetbrains.services.amazonq.lsp.encryption.JwtEncryptionManager +import software.aws.toolkits.jetbrains.services.amazonq.lsp.model.aws.credentials.UpdateCredentialsPayload +import java.time.Instant import java.util.concurrent.CompletableFuture class DefaultAuthCredentialsServiceTest { + companion object { + @JvmField + @RegisterExtension + val projectExtension = ProjectExtension() + } + private lateinit var project: Project private lateinit var mockLanguageServer: AmazonQLanguageServer private lateinit var mockEncryptionManager: JwtEncryptionManager private lateinit var sut: DefaultAuthCredentialsService - @Before + // maybe better to use real project via junit extension + @BeforeEach fun setUp() { - project = mockk() + project = spyk(projectExtension.project) mockLanguageServer = mockk() mockEncryptionManager = mockk() every { mockEncryptionManager.encrypt(any()) } returns "mock-encrypted-data" @@ -54,6 +73,32 @@ class DefaultAuthCredentialsServiceTest { every { messageBus.connect(any()) } returns mockConnection every { mockConnection.subscribe(any(), any()) } just runs + // Mock ToolkitConnectionManager + val connectionManager = mockk() + val connection = mockk() + val connectionSettings = mockk() + val provider = mockk() + val tokenDelegate = mockk() + val token = PKCEAuthorizationGrantToken( + issuerUrl = "https://example.com", + refreshToken = "refreshToken", + accessToken = "accessToken", + expiresAt = Instant.MAX, + createdAt = Instant.now(), + region = "us-fake-1", + ) + + every { project.service() } returns connectionManager + every { connectionManager.activeConnectionForFeature(any()) } returns connection + every { connection.getConnectionSettings() } returns connectionSettings + every { connectionSettings.tokenProvider } returns provider + every { provider.delegate } returns tokenDelegate + every { tokenDelegate.currentToken() } returns token + + every { + mockLanguageServer.updateTokenCredentials(any()) + } returns CompletableFuture.completedFuture(ResponseMessage()) + sut = DefaultAuthCredentialsService(project, this.mockEncryptionManager, mockk()) } @@ -62,17 +107,15 @@ class DefaultAuthCredentialsServiceTest { val token = "unencryptedToken" val isEncrypted = false - every { - mockLanguageServer.updateTokenCredentials(any()) - } returns CompletableFuture.completedFuture(ResponseMessage()) - sut.updateTokenCredentials(token, isEncrypted) - verify(exactly = 0) { - mockEncryptionManager.decrypt(any()) - } verify(exactly = 1) { - mockLanguageServer.updateTokenCredentials(any()) + mockLanguageServer.updateTokenCredentials( + UpdateCredentialsPayload( + token, + isEncrypted + ) + ) } } @@ -82,16 +125,18 @@ class DefaultAuthCredentialsServiceTest { val decryptedToken = "decryptedToken" val isEncrypted = true - every { mockEncryptionManager.decrypt(encryptedToken) } returns decryptedToken - every { mockEncryptionManager.encrypt(any()) } returns "mock-encrypted-data" - every { - mockLanguageServer.updateTokenCredentials(any()) - } returns CompletableFuture.completedFuture(ResponseMessage()) + every { mockEncryptionManager.encrypt(any()) } returns encryptedToken - sut.updateTokenCredentials(encryptedToken, isEncrypted) + sut.updateTokenCredentials(decryptedToken, isEncrypted) - verify(exactly = 1) { mockEncryptionManager.decrypt(encryptedToken) } - verify(exactly = 1) { mockLanguageServer.updateTokenCredentials(any()) } + verify(atLeast = 1) { + mockLanguageServer.updateTokenCredentials( + UpdateCredentialsPayload( + encryptedToken, + isEncrypted + ) + ) + } } @Test @@ -102,4 +147,9 @@ class DefaultAuthCredentialsServiceTest { verify(exactly = 1) { mockLanguageServer.deleteTokenCredentials() } } + + @Test + fun `init results in token update`() { + verify(exactly = 1) { mockLanguageServer.updateTokenCredentials(any()) } + } }