Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ class AmazonQToolWindowFactory : ToolWindowFactory, DumbAware {
project.messageBus.connect(toolWindow.disposable).subscribe(
BearerTokenProviderListener.TOPIC,
object : BearerTokenProviderListener {
override fun onChange(providerId: String, newScopes: List<String>?) {
override fun onProviderChange(providerId: String, newScopes: List<String>?) {
if (ToolkitConnectionManager.getInstance(project).connectionStateForFeature(QConnection.getInstance()) == BearerTokenAuthState.AUTHORIZED) {
preparePanelContent(project, qPanel)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ class CodeScanChatApp(private val scope: CoroutineScope) : AmazonQApp {
ApplicationManager.getApplication().messageBus.connect(this).subscribe(
BearerTokenProviderListener.TOPIC,
object : BearerTokenProviderListener {
override fun onChange(providerId: String, newScopes: List<String>?) {
override fun onProviderChange(providerId: String, newScopes: List<String>?) {
val qProvider = getQTokenProvider(context.project)
val isQ = qProvider?.id == providerId
val isAuthorized = qProvider?.state() == BearerTokenAuthState.AUTHORIZED
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,17 +17,14 @@ import org.mockito.kotlin.spy
import org.mockito.kotlin.whenever
import software.aws.toolkits.core.TokenConnectionSettings
import software.aws.toolkits.core.credentials.ToolkitBearerTokenProvider
import software.aws.toolkits.core.utils.test.aString
import software.aws.toolkits.jetbrains.core.credentials.AwsBearerTokenConnection
import software.aws.toolkits.jetbrains.core.credentials.ToolkitConnectionManager
import software.aws.toolkits.jetbrains.core.credentials.sso.DeviceAuthorizationGrantToken
import software.aws.toolkits.jetbrains.core.credentials.sso.bearer.BearerTokenProvider
import software.aws.toolkits.jetbrains.services.amazonq.clients.AmazonQStreamingClient
import software.aws.toolkits.jetbrains.utils.rules.CodeInsightTestFixtureRule
import software.aws.toolkits.jetbrains.utils.rules.HeavyJavaCodeInsightTestFixtureRule
import software.aws.toolkits.jetbrains.utils.rules.JavaCodeInsightTestFixtureRule
import software.aws.toolkits.jetbrains.utils.rules.addModule
import java.time.Instant

open class AmazonQTestBase(
@Rule @JvmField
Expand All @@ -47,11 +44,7 @@ open class AmazonQTestBase(
project = projectRule.project
toolkitConnectionManager = spy(ToolkitConnectionManager.getInstance(project))

val accessToken = DeviceAuthorizationGrantToken(aString(), aString(), aString(), aString(), Instant.MAX, Instant.now())

val provider = mock<BearerTokenProvider> {
doReturn(accessToken).whenever(it).refresh()
}
val provider = mock<BearerTokenProvider>()

val mockBearerProvider = mock<ToolkitBearerTokenProvider> {
doReturn(provider).whenever(it).delegate
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,8 @@ import software.amazon.awssdk.services.codewhispererruntime.model.SendTelemetryE
import software.amazon.awssdk.services.codewhispererruntime.model.StartTaskAssistCodeGenerationResponse
import software.aws.toolkits.core.TokenConnectionSettings
import software.aws.toolkits.core.credentials.ToolkitBearerTokenProvider
import software.aws.toolkits.core.utils.test.aString
import software.aws.toolkits.jetbrains.core.credentials.AwsBearerTokenConnection
import software.aws.toolkits.jetbrains.core.credentials.ToolkitConnectionManager
import software.aws.toolkits.jetbrains.core.credentials.sso.DeviceAuthorizationGrantToken
import software.aws.toolkits.jetbrains.core.credentials.sso.bearer.BearerTokenProvider
import software.aws.toolkits.jetbrains.services.amazonqFeatureDev.clients.FeatureDevClient
import software.aws.toolkits.jetbrains.services.amazonqFeatureDev.session.CodeGenerationStreamResult
Expand All @@ -41,7 +39,6 @@ import software.aws.toolkits.jetbrains.utils.rules.HeavyJavaCodeInsightTestFixtu
import software.aws.toolkits.jetbrains.utils.rules.JavaCodeInsightTestFixtureRule
import software.aws.toolkits.jetbrains.utils.rules.addModule
import java.io.File
import java.time.Instant

open class FeatureDevTestBase(
@Rule @JvmField
Expand Down Expand Up @@ -164,11 +161,7 @@ open class FeatureDevTestBase(
open fun setup() {
project = projectRule.project
toolkitConnectionManager = spy(ToolkitConnectionManager.getInstance(project))
val accessToken = DeviceAuthorizationGrantToken(aString(), aString(), aString(), aString(), Instant.MAX, Instant.now())
val provider =
mock<BearerTokenProvider> {
doReturn(accessToken).whenever(it).refresh()
}
val provider = mock<BearerTokenProvider>()
val mockBearerProvider =
mock<ToolkitBearerTokenProvider> {
doReturn(provider).whenever(it).delegate
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,7 @@ class CodeTransformChatApp : AmazonQApp {
ApplicationManager.getApplication().messageBus.connect(this).subscribe(
BearerTokenProviderListener.TOPIC,
object : BearerTokenProviderListener {
override fun onChange(providerId: String, newScopes: List<String>?) {
override fun onProviderChange(providerId: String, newScopes: List<String>?) {
val qProvider = getQTokenProvider(context.project)
val isQ = qProvider?.id == providerId
val isAuthorized = qProvider?.state() == BearerTokenAuthState.AUTHORIZED
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -87,11 +87,6 @@
var transformationPlan: TransformationPlan? = null
var didSleepOnce = false
var hasSeenTransforming = false
val maxRefreshes = 10
var numRefreshes = 0

// refresh token at start of polling since local build just prior can take a long time
refreshToken(project)

try {
waitUntil(
Expand Down Expand Up @@ -138,13 +133,10 @@
onStateChange(state, newStatus, transformationPlan)
}
state = newStatus
numRefreshes = 0
return@waitUntil state
} catch (e: AccessDeniedException) {
if (numRefreshes++ > maxRefreshes) throw e
refreshToken(project)
return@waitUntil state
} catch (e: InvalidGrantException) {
} catch (e: Exception) {

Check warning on line 137 in plugins/amazonq/codetransform/jetbrains-community/src/software/aws/toolkits/jetbrains/services/codemodernizer/utils/CodeTransformApiUtils.kt

View check run for this annotation

Codecov / codecov/patch

plugins/amazonq/codetransform/jetbrains-community/src/software/aws/toolkits/jetbrains/services/codemodernizer/utils/CodeTransformApiUtils.kt#L137

Added line #L137 was not covered by tests
if (e !is AccessDeniedException && e !is InvalidGrantException) throw e

CodeTransformMessageListener.instance.onReauthStarted()
notifyStickyWarn(
message("codemodernizer.notification.warn.expired_credentials.title"),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@ import com.intellij.openapi.vfs.VfsUtilCore
import com.intellij.openapi.vfs.VirtualFileManager
import software.amazon.awssdk.services.codewhispererruntime.model.TransformationLanguage
import software.amazon.awssdk.services.codewhispererruntime.model.TransformationStatus
import software.aws.toolkits.core.TokenConnectionSettings
import software.aws.toolkits.jetbrains.core.credentials.AwsBearerTokenConnection
import software.aws.toolkits.jetbrains.core.credentials.ToolkitConnectionManager
import software.aws.toolkits.jetbrains.core.credentials.pinning.QConnection
Expand Down Expand Up @@ -43,12 +42,6 @@ val STATES_AFTER_STARTED = setOf(
*STATES_AFTER_INITIAL_BUILD.toTypedArray(),
)

fun refreshToken(project: Project) {
val connection = ToolkitConnectionManager.getInstance(project).activeConnectionForFeature(QConnection.getInstance())
val provider = (connection?.getConnectionSettings() as TokenConnectionSettings).tokenProvider.delegate as BearerTokenProvider
provider.refresh()
}

fun getAuthType(project: Project): CredentialSourceId? {
val connection = checkBearerConnectionValidity(project, BearerTokenFeatureSet.Q)
var authType: CredentialSourceId? = null
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,6 @@ import software.aws.toolkits.core.credentials.ToolkitBearerTokenProvider
import software.aws.toolkits.core.utils.test.aString
import software.aws.toolkits.jetbrains.core.credentials.AwsBearerTokenConnection
import software.aws.toolkits.jetbrains.core.credentials.ToolkitConnectionManager
import software.aws.toolkits.jetbrains.core.credentials.sso.DeviceAuthorizationGrantToken
import software.aws.toolkits.jetbrains.core.credentials.sso.PKCEAuthorizationGrantToken
import software.aws.toolkits.jetbrains.core.credentials.sso.bearer.BearerTokenAuthState
import software.aws.toolkits.jetbrains.core.credentials.sso.bearer.BearerTokenProvider
Expand Down Expand Up @@ -250,11 +249,7 @@ open class CodeWhispererCodeModernizerTestBase(
project = projectRule.project
toolkitConnectionManager = spy(ToolkitConnectionManager.getInstance(project))

val accessToken = DeviceAuthorizationGrantToken(aString(), aString(), aString(), aString(), Instant.MAX, Instant.now())
val provider =
mock<BearerTokenProvider> {
doReturn(accessToken).whenever(it).refresh()
}
val provider = mock<BearerTokenProvider> { }
val mockBearerProvider =
mock<ToolkitBearerTokenProvider> {
doReturn(provider).whenever(it).delegate
Expand Down Expand Up @@ -340,7 +335,6 @@ open class CodeWhispererCodeModernizerTestBase(
val accessToken = PKCEAuthorizationGrantToken(aString(), aString(), aString(), aString(), Instant.MAX, Instant.now())

val provider = mock<BearerTokenProvider> {
doReturn(accessToken).whenever(it).refresh()
doReturn(accessToken).whenever(it).currentToken()
doReturn(authState).whenever(it).state()
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,6 @@ import software.aws.toolkits.jetbrains.services.codemodernizer.utils.getTableMap
import software.aws.toolkits.jetbrains.services.codemodernizer.utils.isPlanComplete
import software.aws.toolkits.jetbrains.services.codemodernizer.utils.parseBuildFile
import software.aws.toolkits.jetbrains.services.codemodernizer.utils.pollTransformationStatusAndPlan
import software.aws.toolkits.jetbrains.services.codemodernizer.utils.refreshToken
import software.aws.toolkits.jetbrains.services.codemodernizer.utils.validateCustomVersionsFile
import software.aws.toolkits.jetbrains.services.codemodernizer.utils.validateSctMetadata
import software.aws.toolkits.jetbrains.utils.notifyStickyWarn
Expand Down Expand Up @@ -90,18 +89,18 @@ class CodeWhispererCodeModernizerUtilsTest : CodeWhispererCodeModernizerTestBase
}

@Test
fun `refresh on access denied`() {
fun `show re-auth notification on access denied`() {
val mockAccessDeniedException = Mockito.mock(AccessDeniedException::class.java)

mockkStatic(::refreshToken)
every { refreshToken(any()) } just runs
mockkStatic(::notifyStickyWarn)
every { notifyStickyWarn(any(), any(), any(), any(), any()) } just runs

Mockito.doThrow(
mockAccessDeniedException
).doReturn(
exampleGetCodeMigrationResponse,
exampleGetCodeMigrationResponse.replace(TransformationStatus.STARTED),
exampleGetCodeMigrationResponse.replace(TransformationStatus.COMPLETED), // Should stop before this point
exampleGetCodeMigrationResponse.replace(TransformationStatus.COMPLETED),
).whenever(clientAdaptorSpy).getCodeModernizationJob(any())

Mockito.doReturn(exampleGetCodeMigrationPlanResponse)
Expand All @@ -128,7 +127,7 @@ class CodeWhispererCodeModernizerUtilsTest : CodeWhispererCodeModernizerTestBase
TransformationStatus.STARTED,
)
assertThat(expected).isEqualTo(mutableList)
io.mockk.verify { refreshToken(any()) }
verify { notifyStickyWarn(message("codemodernizer.notification.warn.expired_credentials.title"), any(), any(), any(), any()) }
}

@Test
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ class CodeWhispererStatusBarWidget(project: Project) :
ApplicationManager.getApplication().messageBus.connect(this).subscribe(
BearerTokenProviderListener.TOPIC,
object : BearerTokenProviderListener {
override fun onChange(providerId: String, newScopes: List<String>?) {
override fun onProviderChange(providerId: String, newScopes: List<String>?) {
statusBar.updateWidget(ID)
}
}
Expand Down

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
import software.aws.toolkits.jetbrains.core.credentials.ToolkitConnectionManager
import software.aws.toolkits.jetbrains.core.credentials.ToolkitConnectionManagerListener
import software.aws.toolkits.jetbrains.core.credentials.pinning.QConnection
import software.aws.toolkits.jetbrains.core.credentials.sso.bearer.BearerTokenAuthState
import software.aws.toolkits.jetbrains.core.credentials.sso.bearer.BearerTokenProvider
import software.aws.toolkits.jetbrains.core.credentials.sso.bearer.BearerTokenProviderListener
import software.aws.toolkits.jetbrains.services.amazonq.lsp.AmazonQLspService
Expand All @@ -43,15 +42,14 @@
private val project: Project,
private val encryptionManager: JwtEncryptionManager,
private val cs: CoroutineScope,
) : AuthCredentialsService,
BearerTokenProviderListener,
) : BearerTokenProviderListener,
ToolkitConnectionManagerListener,
QRegionProfileSelectedListener,
Disposable {

private val scheduler: ScheduledExecutorService = AppExecutorUtil.getAppScheduledExecutorService()
private var tokenSyncTask: ScheduledFuture<*>? = null
private val tokenSyncIntervalMinutes = 5L
private var tokenRefreshTask: ScheduledFuture<*>? = null
private val tokenRefreshInterval = 5L

Check warning on line 52 in plugins/amazonq/shared/jetbrains-community/src/software/aws/toolkits/jetbrains/services/amazonq/lsp/auth/DefaultAuthCredentialsService.kt

View check run for this annotation

Codecov / codecov/patch

plugins/amazonq/shared/jetbrains-community/src/software/aws/toolkits/jetbrains/services/amazonq/lsp/auth/DefaultAuthCredentialsService.kt#L52

Added line #L52 was not covered by tests

init {
project.messageBus.connect(this).apply {
Expand All @@ -67,49 +65,37 @@
}
}

// Start periodic token sync
startPeriodicTokenSync()
// Start periodic token refresh
startPeriodicTokenRefresh()

Check warning on line 69 in plugins/amazonq/shared/jetbrains-community/src/software/aws/toolkits/jetbrains/services/amazonq/lsp/auth/DefaultAuthCredentialsService.kt

View check run for this annotation

Codecov / codecov/patch

plugins/amazonq/shared/jetbrains-community/src/software/aws/toolkits/jetbrains/services/amazonq/lsp/auth/DefaultAuthCredentialsService.kt#L69

Added line #L69 was not covered by tests
}

private fun startPeriodicTokenSync() {
tokenSyncTask = scheduler.scheduleWithFixedDelay(
// TODO: we really only need a single application-wide instance of this
private fun startPeriodicTokenRefresh() {
tokenRefreshTask = scheduler.scheduleWithFixedDelay(

Check warning on line 74 in plugins/amazonq/shared/jetbrains-community/src/software/aws/toolkits/jetbrains/services/amazonq/lsp/auth/DefaultAuthCredentialsService.kt

View check run for this annotation

Codecov / codecov/patch

plugins/amazonq/shared/jetbrains-community/src/software/aws/toolkits/jetbrains/services/amazonq/lsp/auth/DefaultAuthCredentialsService.kt#L74

Added line #L74 was not covered by tests
{
try {
if (isQConnected(project)) {
if (isQExpired(project)) {
val manager = ToolkitConnectionManager.getInstance(project)
val connection = manager.activeConnectionForFeature(QConnection.getInstance()) ?: return@scheduleWithFixedDelay

// Try to refresh the token if it's in NEEDS_REFRESH state
val tokenProvider = (connection.getConnectionSettings() as? TokenConnectionSettings)
?.tokenProvider
?.delegate
?.let { it as? BearerTokenProvider } ?: return@scheduleWithFixedDelay

if (tokenProvider.state() == BearerTokenAuthState.NEEDS_REFRESH) {
try {
tokenProvider.resolveToken()
// Now that the token is refreshed, update it in Flare
updateTokenFromActiveConnection()
} catch (e: Exception) {
LOG.warn(e) { "Failed to refresh bearer token" }
}
}
} else {
updateTokenFromActiveConnection()
}
val manager = ToolkitConnectionManager.getInstance(project)

Check warning on line 78 in plugins/amazonq/shared/jetbrains-community/src/software/aws/toolkits/jetbrains/services/amazonq/lsp/auth/DefaultAuthCredentialsService.kt

View check run for this annotation

Codecov / codecov/patch

plugins/amazonq/shared/jetbrains-community/src/software/aws/toolkits/jetbrains/services/amazonq/lsp/auth/DefaultAuthCredentialsService.kt#L78

Added line #L78 was not covered by tests
val connection = manager.activeConnectionForFeature(QConnection.getInstance()) ?: return@scheduleWithFixedDelay

// periodically poll token to trigger a background refresh if needed
val tokenProvider = (connection.getConnectionSettings() as? TokenConnectionSettings)
?.tokenProvider
?.delegate
?.let { it as? BearerTokenProvider } ?: return@scheduleWithFixedDelay
tokenProvider.resolveToken()

Check warning on line 86 in plugins/amazonq/shared/jetbrains-community/src/software/aws/toolkits/jetbrains/services/amazonq/lsp/auth/DefaultAuthCredentialsService.kt

View check run for this annotation

Codecov / codecov/patch

plugins/amazonq/shared/jetbrains-community/src/software/aws/toolkits/jetbrains/services/amazonq/lsp/auth/DefaultAuthCredentialsService.kt#L86

Added line #L86 was not covered by tests
Copy link
Contributor

@manodnyab manodnyab Jul 8, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

does this attempt to use the cached supplier to refresh? should the refresh called be used here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cached supplier will refresh if needed

}
} catch (e: Exception) {
LOG.warn(e) { "Failed to sync bearer token to Flare" }
LOG.warn(e) { "Failed to refresh bearer token" }

Check warning on line 89 in plugins/amazonq/shared/jetbrains-community/src/software/aws/toolkits/jetbrains/services/amazonq/lsp/auth/DefaultAuthCredentialsService.kt

View check run for this annotation

Codecov / codecov/patch

plugins/amazonq/shared/jetbrains-community/src/software/aws/toolkits/jetbrains/services/amazonq/lsp/auth/DefaultAuthCredentialsService.kt#L89

Added line #L89 was not covered by tests
}
},
tokenSyncIntervalMinutes,
tokenSyncIntervalMinutes,
tokenRefreshInterval,
tokenRefreshInterval,

Check warning on line 93 in plugins/amazonq/shared/jetbrains-community/src/software/aws/toolkits/jetbrains/services/amazonq/lsp/auth/DefaultAuthCredentialsService.kt

View check run for this annotation

Codecov / codecov/patch

plugins/amazonq/shared/jetbrains-community/src/software/aws/toolkits/jetbrains/services/amazonq/lsp/auth/DefaultAuthCredentialsService.kt#L92-L93

Added lines #L92 - L93 were not covered by tests
TimeUnit.MINUTES
)
}

override fun updateTokenCredentials(connection: ToolkitConnection, encrypted: Boolean): CompletableFuture<ResponseMessage> {
fun updateTokenCredentials(connection: ToolkitConnection, encrypted: Boolean): CompletableFuture<ResponseMessage> {
val payload = try {
createUpdateCredentialsPayload(connection, encrypted)
} catch (e: Exception) {
Expand All @@ -129,18 +115,26 @@
}.asCompletableFuture()
}

override fun deleteTokenCredentials() {
fun deleteTokenCredentials() {
cs.launch {
AmazonQLspService.executeAsyncIfRunning(project) { server ->
server.deleteTokenCredentials()
}
}
}

override fun onChange(providerId: String, newScopes: List<String>?) {
override fun onProviderChange(providerId: String, newScopes: List<String>?) {
updateTokenFromActiveConnection()
}

override fun onTokenModified(providerId: String) {
updateTokenFromActiveConnection()
}

Check warning on line 132 in plugins/amazonq/shared/jetbrains-community/src/software/aws/toolkits/jetbrains/services/amazonq/lsp/auth/DefaultAuthCredentialsService.kt

View check run for this annotation

Codecov / codecov/patch

plugins/amazonq/shared/jetbrains-community/src/software/aws/toolkits/jetbrains/services/amazonq/lsp/auth/DefaultAuthCredentialsService.kt#L131-L132

Added lines #L131 - L132 were not covered by tests

override fun invalidate(providerId: String) {
deleteTokenCredentials()
}

Check warning on line 136 in plugins/amazonq/shared/jetbrains-community/src/software/aws/toolkits/jetbrains/services/amazonq/lsp/auth/DefaultAuthCredentialsService.kt

View check run for this annotation

Codecov / codecov/patch

plugins/amazonq/shared/jetbrains-community/src/software/aws/toolkits/jetbrains/services/amazonq/lsp/auth/DefaultAuthCredentialsService.kt#L135-L136

Added lines #L135 - L136 were not covered by tests

override fun activeConnectionChanged(newConnection: ToolkitConnection?) {
val qConnection = ToolkitConnectionManager.getInstance(project)
.activeConnectionForFeature(QConnection.getInstance())
Expand All @@ -161,10 +155,6 @@
private fun updateTokenFromConnection(connection: ToolkitConnection): CompletableFuture<ResponseMessage> =
updateTokenCredentials(connection, true)

override fun invalidate(providerId: String) {
deleteTokenCredentials()
}

private fun createUpdateCredentialsPayload(connection: ToolkitConnection, encrypted: Boolean): UpdateCredentialsPayload {
val token = (connection.getConnectionSettings() as? TokenConnectionSettings)
?.tokenProvider
Expand Down Expand Up @@ -212,8 +202,8 @@
}

override fun dispose() {
tokenSyncTask?.cancel(false)
tokenSyncTask = null
tokenRefreshTask?.cancel(false)
tokenRefreshTask = null

Check warning on line 206 in plugins/amazonq/shared/jetbrains-community/src/software/aws/toolkits/jetbrains/services/amazonq/lsp/auth/DefaultAuthCredentialsService.kt

View check run for this annotation

Codecov / codecov/patch

plugins/amazonq/shared/jetbrains-community/src/software/aws/toolkits/jetbrains/services/amazonq/lsp/auth/DefaultAuthCredentialsService.kt#L206

Added line #L206 was not covered by tests
}

companion object {
Expand Down
Loading
Loading