Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,9 @@
import com.intellij.openapi.project.Project
import org.eclipse.lsp4j.jsonrpc.messages.ResponseMessage
import software.aws.toolkits.core.TokenConnectionSettings
import software.aws.toolkits.jetbrains.core.credentials.ToolkitConnection
import software.aws.toolkits.jetbrains.core.credentials.ToolkitConnectionManager
import software.aws.toolkits.jetbrains.core.credentials.ToolkitConnectionManagerListener
import software.aws.toolkits.jetbrains.core.credentials.pinning.QConnection
import software.aws.toolkits.jetbrains.core.credentials.sso.bearer.BearerTokenProvider
import software.aws.toolkits.jetbrains.core.credentials.sso.bearer.BearerTokenProviderListener
Expand All @@ -16,38 +18,27 @@
import software.aws.toolkits.jetbrains.services.amazonq.lsp.model.aws.credentials.BearerCredentials
import software.aws.toolkits.jetbrains.services.amazonq.lsp.model.aws.credentials.UpdateCredentialsPayload
import software.aws.toolkits.jetbrains.services.amazonq.lsp.model.aws.credentials.UpdateCredentialsPayloadData
import software.aws.toolkits.jetbrains.utils.isQConnected
import software.aws.toolkits.jetbrains.utils.isQExpired
import java.util.concurrent.CompletableFuture

class DefaultAuthCredentialsService(
private val project: Project,
private val encryptionManager: JwtEncryptionManager,
serverInstance: Disposable,
) : AuthCredentialsService,
BearerTokenProviderListener {
init {
project.messageBus.connect(serverInstance).subscribe(BearerTokenProviderListener.TOPIC, this)

onChange("init", null)
}

override fun onChange(providerId: String, newScopes: List<String>?) {
val connection = ToolkitConnectionManager.getInstance(project)
.activeConnectionForFeature(QConnection.getInstance())
?: return
BearerTokenProviderListener,
ToolkitConnectionManagerListener {

val provider = (connection.getConnectionSettings() as? TokenConnectionSettings)
?.tokenProvider
?.delegate as? BearerTokenProvider
?: return

provider.currentToken()?.accessToken?.let { token ->
// assume encryption is always on
updateTokenCredentials(token, true)
init {
project.messageBus.connect(serverInstance).apply {
subscribe(BearerTokenProviderListener.TOPIC, this@DefaultAuthCredentialsService)
subscribe(ToolkitConnectionManagerListener.TOPIC, this@DefaultAuthCredentialsService)

Check warning on line 36 in plugins/amazonq/shared/jetbrains-community/src/software/aws/toolkits/jetbrains/services/amazonq/lsp/auth/DefaultAuthCredentialsService.kt

View check run for this annotation

Codecov / codecov/patch

plugins/amazonq/shared/jetbrains-community/src/software/aws/toolkits/jetbrains/services/amazonq/lsp/auth/DefaultAuthCredentialsService.kt#L33-L36

Added lines #L33 - L36 were not covered by tests
}
}

override fun invalidate(providerId: String) {
deleteTokenCredentials()
if (isQConnected(project) && !isQExpired(project)) {
updateTokenFromActiveConnection()

Check warning on line 40 in plugins/amazonq/shared/jetbrains-community/src/software/aws/toolkits/jetbrains/services/amazonq/lsp/auth/DefaultAuthCredentialsService.kt

View check run for this annotation

Codecov / codecov/patch

plugins/amazonq/shared/jetbrains-community/src/software/aws/toolkits/jetbrains/services/amazonq/lsp/auth/DefaultAuthCredentialsService.kt#L40

Added line #L40 was not covered by tests
}
}

override fun updateTokenCredentials(accessToken: String, encrypted: Boolean): CompletableFuture<ResponseMessage> {
Expand All @@ -66,6 +57,41 @@
} ?: completableFuture.completeExceptionally(IllegalStateException("LSP Server not running"))
}

override fun onChange(providerId: String, newScopes: List<String>?) {
updateTokenFromActiveConnection()
}

Check warning on line 62 in plugins/amazonq/shared/jetbrains-community/src/software/aws/toolkits/jetbrains/services/amazonq/lsp/auth/DefaultAuthCredentialsService.kt

View check run for this annotation

Codecov / codecov/patch

plugins/amazonq/shared/jetbrains-community/src/software/aws/toolkits/jetbrains/services/amazonq/lsp/auth/DefaultAuthCredentialsService.kt#L61-L62

Added lines #L61 - L62 were not covered by tests

override fun activeConnectionChanged(newConnection: ToolkitConnection?) {
val qConnection = ToolkitConnectionManager.getInstance(project)
.activeConnectionForFeature(QConnection.getInstance())
?: return

Check warning on line 67 in plugins/amazonq/shared/jetbrains-community/src/software/aws/toolkits/jetbrains/services/amazonq/lsp/auth/DefaultAuthCredentialsService.kt

View check run for this annotation

Codecov / codecov/patch

plugins/amazonq/shared/jetbrains-community/src/software/aws/toolkits/jetbrains/services/amazonq/lsp/auth/DefaultAuthCredentialsService.kt#L66-L67

Added lines #L66 - L67 were not covered by tests
if (newConnection?.id != qConnection.id) return

Comment on lines +65 to +69
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

non-blocking, but are we sure these lines are not duplicated? I think they already exist somewhere in CW

updateTokenFromConnection(newConnection)
}

Check warning on line 71 in plugins/amazonq/shared/jetbrains-community/src/software/aws/toolkits/jetbrains/services/amazonq/lsp/auth/DefaultAuthCredentialsService.kt

View check run for this annotation

Codecov / codecov/patch

plugins/amazonq/shared/jetbrains-community/src/software/aws/toolkits/jetbrains/services/amazonq/lsp/auth/DefaultAuthCredentialsService.kt#L70-L71

Added lines #L70 - L71 were not covered by tests

private fun updateTokenFromActiveConnection() {
val connection = ToolkitConnectionManager.getInstance(project)
.activeConnectionForFeature(QConnection.getInstance())
?: return

Check warning on line 76 in plugins/amazonq/shared/jetbrains-community/src/software/aws/toolkits/jetbrains/services/amazonq/lsp/auth/DefaultAuthCredentialsService.kt

View check run for this annotation

Codecov / codecov/patch

plugins/amazonq/shared/jetbrains-community/src/software/aws/toolkits/jetbrains/services/amazonq/lsp/auth/DefaultAuthCredentialsService.kt#L75-L76

Added lines #L75 - L76 were not covered by tests

updateTokenFromConnection(connection)
}

Check warning on line 79 in plugins/amazonq/shared/jetbrains-community/src/software/aws/toolkits/jetbrains/services/amazonq/lsp/auth/DefaultAuthCredentialsService.kt

View check run for this annotation

Codecov / codecov/patch

plugins/amazonq/shared/jetbrains-community/src/software/aws/toolkits/jetbrains/services/amazonq/lsp/auth/DefaultAuthCredentialsService.kt#L78-L79

Added lines #L78 - L79 were not covered by tests

private fun updateTokenFromConnection(connection: ToolkitConnection) {
(connection.getConnectionSettings() as? TokenConnectionSettings)
?.tokenProvider
?.delegate
?.let { it as? BearerTokenProvider }
Comment on lines +82 to +85
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we move this to a util

?.currentToken()
?.accessToken
?.let { token -> updateTokenCredentials(token, true) }
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you please add the var name before the boolean?

}

Check warning on line 89 in plugins/amazonq/shared/jetbrains-community/src/software/aws/toolkits/jetbrains/services/amazonq/lsp/auth/DefaultAuthCredentialsService.kt

View check run for this annotation

Codecov / codecov/patch

plugins/amazonq/shared/jetbrains-community/src/software/aws/toolkits/jetbrains/services/amazonq/lsp/auth/DefaultAuthCredentialsService.kt#L89

Added line #L89 was not covered by tests

override fun invalidate(providerId: String) {
deleteTokenCredentials()
}

Check warning on line 93 in plugins/amazonq/shared/jetbrains-community/src/software/aws/toolkits/jetbrains/services/amazonq/lsp/auth/DefaultAuthCredentialsService.kt

View check run for this annotation

Codecov / codecov/patch

plugins/amazonq/shared/jetbrains-community/src/software/aws/toolkits/jetbrains/services/amazonq/lsp/auth/DefaultAuthCredentialsService.kt#L92-L93

Added lines #L92 - L93 were not covered by tests

private fun createUpdateCredentialsPayload(token: String, encrypted: Boolean): UpdateCredentialsPayload =
if (encrypted) {
UpdateCredentialsPayload(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ import com.intellij.util.messages.MessageBusConnection
import io.mockk.every
import io.mockk.just
import io.mockk.mockk
import io.mockk.mockkStatic
import io.mockk.runs
import io.mockk.spyk
import io.mockk.verify
Expand All @@ -30,6 +31,8 @@ import software.aws.toolkits.jetbrains.services.amazonq.lsp.AmazonQLanguageServe
import software.aws.toolkits.jetbrains.services.amazonq.lsp.AmazonQLspService
import software.aws.toolkits.jetbrains.services.amazonq.lsp.encryption.JwtEncryptionManager
import software.aws.toolkits.jetbrains.services.amazonq.lsp.model.aws.credentials.UpdateCredentialsPayload
import software.aws.toolkits.jetbrains.utils.isQConnected
import software.aws.toolkits.jetbrains.utils.isQExpired
import java.time.Instant
import java.util.concurrent.CompletableFuture

Expand All @@ -38,72 +41,159 @@ class DefaultAuthCredentialsServiceTest {
@JvmField
@RegisterExtension
val projectExtension = ProjectExtension()

private const val TEST_ACCESS_TOKEN = "test-access-token"
}

private lateinit var project: Project
private lateinit var mockLanguageServer: AmazonQLanguageServer
private lateinit var mockEncryptionManager: JwtEncryptionManager
private lateinit var mockConnectionManager: ToolkitConnectionManager
private lateinit var mockConnection: AwsBearerTokenConnection
private lateinit var sut: DefaultAuthCredentialsService

// maybe better to use real project via junit extension
@BeforeEach
fun setUp() {
project = spyk(projectExtension.project)
setupMockLspService()
setupMockMessageBus()
setupMockConnectionManager()
}

private fun setupMockLspService() {
mockLanguageServer = mockk<AmazonQLanguageServer>()
mockEncryptionManager = mockk<JwtEncryptionManager>()
every { mockEncryptionManager.encrypt(any()) } returns "mock-encrypted-data"
mockEncryptionManager = mockk {
every { encrypt(any()) } returns "mock-encrypted-data"
}

// Mock the service methods on Project
val mockLspService = mockk<AmazonQLspService>()
every { project.getService(AmazonQLspService::class.java) } returns mockLspService
every { project.serviceIfCreated<AmazonQLspService>() } returns mockLspService

// Mock the LSP service's executeSync method as a suspend function
every {
mockLspService.executeSync<CompletableFuture<ResponseMessage>>(any())
} coAnswers {
val func = firstArg<suspend AmazonQLspService.(AmazonQLanguageServer) -> CompletableFuture<ResponseMessage>>()
func.invoke(mockLspService, mockLanguageServer)
}

// Mock message bus
every {
mockLanguageServer.updateTokenCredentials(any())
} returns CompletableFuture<ResponseMessage>()

every {
mockLanguageServer.deleteTokenCredentials()
} returns CompletableFuture.completedFuture(Unit)

every { project.getService(AmazonQLspService::class.java) } returns mockLspService
every { project.serviceIfCreated<AmazonQLspService>() } returns mockLspService
}

private fun setupMockMessageBus() {
val messageBus = mockk<MessageBus>()
val mockConnection = mockk<MessageBusConnection> {
every { subscribe(any(), any()) } just runs
}
every { project.messageBus } returns messageBus
val mockConnection = mockk<MessageBusConnection>()
every { messageBus.connect(any<Disposable>()) } returns mockConnection
every { mockConnection.subscribe(any(), any()) } just runs

// Mock ToolkitConnectionManager
val connectionManager = mockk<ToolkitConnectionManager>()
val connection = mockk<AwsBearerTokenConnection>()
val connectionSettings = mockk<TokenConnectionSettings>()
val provider = mockk<ToolkitBearerTokenProvider>()
val tokenDelegate = mockk<InteractiveBearerTokenProvider>()
}

private fun setupMockConnectionManager(accessToken: String = TEST_ACCESS_TOKEN) {
mockConnection = createMockConnection(accessToken)
mockConnectionManager = mockk {
every { activeConnectionForFeature(any()) } returns mockConnection
}
every { project.service<ToolkitConnectionManager>() } returns mockConnectionManager
mockkStatic("software.aws.toolkits.jetbrains.utils.FunctionUtilsKt")
// these set so init doesn't always emit
every { isQConnected(any()) } returns false
every { isQExpired(any()) } returns true
}

private fun createMockConnection(
accessToken: String,
connectionId: String = "test-connection-id",
): AwsBearerTokenConnection = mockk {
every { id } returns connectionId
every { getConnectionSettings() } returns createMockTokenSettings(accessToken)
}

private fun createMockTokenSettings(accessToken: String): TokenConnectionSettings {
val token = PKCEAuthorizationGrantToken(
issuerUrl = "https://example.com",
refreshToken = "refreshToken",
accessToken = "accessToken",
accessToken = accessToken,
expiresAt = Instant.MAX,
createdAt = Instant.now(),
region = "us-fake-1",
)

every { project.service<ToolkitConnectionManager>() } returns connectionManager
every { connectionManager.activeConnectionForFeature(any()) } returns connection
every { connection.getConnectionSettings() } returns connectionSettings
every { connectionSettings.tokenProvider } returns provider
every { provider.delegate } returns tokenDelegate
every { tokenDelegate.currentToken() } returns token
val tokenDelegate = mockk<InteractiveBearerTokenProvider> {
every { currentToken() } returns token
}

every {
mockLanguageServer.updateTokenCredentials(any())
} returns CompletableFuture.completedFuture(ResponseMessage())
val provider = mockk<ToolkitBearerTokenProvider> {
every { delegate } returns tokenDelegate
}

return mockk {
every { tokenProvider } returns provider
}
}

@Test
fun `activeConnectionChanged updates token when connection ID matches Q connection`() {
sut = DefaultAuthCredentialsService(project, mockEncryptionManager, mockk())
val newConnection = createMockConnection("new-token", "connection-id")
every { mockConnection.id } returns "connection-id"

sut.activeConnectionChanged(newConnection)

verify(exactly = 1) { mockLanguageServer.updateTokenCredentials(any()) }
}

@Test
fun `activeConnectionChanged does not update token when connection ID differs`() {
sut = DefaultAuthCredentialsService(project, mockEncryptionManager, mockk())
val newConnection = createMockConnection("new-token", "different-id")
every { mockConnection.id } returns "q-connection-id"

sut.activeConnectionChanged(newConnection)

verify(exactly = 0) { mockLanguageServer.updateTokenCredentials(any()) }
}

@Test
fun `onChange updates token with new connection`() {
sut = DefaultAuthCredentialsService(project, mockEncryptionManager, mockk())
setupMockConnectionManager("updated-token")

sut.onChange("providerId", listOf("new-scope"))

sut = DefaultAuthCredentialsService(project, this.mockEncryptionManager, mockk())
verify(exactly = 1) { mockLanguageServer.updateTokenCredentials(any()) }
}

@Test
fun `init does not update token when Q is not connected`() {
every { isQConnected(project) } returns false
every { isQExpired(project) } returns false

sut = DefaultAuthCredentialsService(project, mockEncryptionManager, mockk())

verify(exactly = 0) { mockLanguageServer.updateTokenCredentials(any()) }
}

@Test
fun `init does not update token when Q is expired`() {
every { isQConnected(project) } returns true
every { isQExpired(project) } returns true

sut = DefaultAuthCredentialsService(project, mockEncryptionManager, mockk())

verify(exactly = 0) { mockLanguageServer.updateTokenCredentials(any()) }
}

@Test
fun `test updateTokenCredentials unencrypted success`() {
sut = DefaultAuthCredentialsService(project, mockEncryptionManager, mockk())

val token = "unencryptedToken"
val isEncrypted = false

Expand All @@ -121,6 +211,8 @@ class DefaultAuthCredentialsServiceTest {

@Test
fun `test updateTokenCredentials encrypted success`() {
sut = DefaultAuthCredentialsService(project, mockEncryptionManager, mockk())

val encryptedToken = "encryptedToken"
val decryptedToken = "decryptedToken"
val isEncrypted = true
Expand All @@ -141,6 +233,8 @@ class DefaultAuthCredentialsServiceTest {

@Test
fun `test deleteTokenCredentials success`() {
sut = DefaultAuthCredentialsService(project, mockEncryptionManager, mockk())

every { mockLanguageServer.deleteTokenCredentials() } returns CompletableFuture.completedFuture(Unit)

sut.deleteTokenCredentials()
Expand All @@ -150,6 +244,10 @@ class DefaultAuthCredentialsServiceTest {

@Test
fun `init results in token update`() {
every { isQConnected(any()) } returns true
every { isQExpired(any()) } returns false
sut = DefaultAuthCredentialsService(project, mockEncryptionManager, mockk())

verify(exactly = 1) { mockLanguageServer.updateTokenCredentials(any()) }
}
}
Loading