Skip to content
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ class AmazonQToolWindowFactory : ToolWindowFactory, DumbAware {
project.messageBus.connect(toolWindow.disposable).subscribe(
BearerTokenProviderListener.TOPIC,
object : BearerTokenProviderListener {
override fun onChange(providerId: String, newScopes: List<String>?) {
override fun onProviderChange(providerId: String, newScopes: List<String>?) {
if (ToolkitConnectionManager.getInstance(project).connectionStateForFeature(QConnection.getInstance()) == BearerTokenAuthState.AUTHORIZED) {
preparePanelContent(project, qPanel)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ class CodeScanChatApp(private val scope: CoroutineScope) : AmazonQApp {
ApplicationManager.getApplication().messageBus.connect(this).subscribe(
BearerTokenProviderListener.TOPIC,
object : BearerTokenProviderListener {
override fun onChange(providerId: String, newScopes: List<String>?) {
override fun onProviderChange(providerId: String, newScopes: List<String>?) {
val qProvider = getQTokenProvider(context.project)
val isQ = qProvider?.id == providerId
val isAuthorized = qProvider?.state() == BearerTokenAuthState.AUTHORIZED
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,17 +17,14 @@ import org.mockito.kotlin.spy
import org.mockito.kotlin.whenever
import software.aws.toolkits.core.TokenConnectionSettings
import software.aws.toolkits.core.credentials.ToolkitBearerTokenProvider
import software.aws.toolkits.core.utils.test.aString
import software.aws.toolkits.jetbrains.core.credentials.AwsBearerTokenConnection
import software.aws.toolkits.jetbrains.core.credentials.ToolkitConnectionManager
import software.aws.toolkits.jetbrains.core.credentials.sso.DeviceAuthorizationGrantToken
import software.aws.toolkits.jetbrains.core.credentials.sso.bearer.BearerTokenProvider
import software.aws.toolkits.jetbrains.services.amazonq.clients.AmazonQStreamingClient
import software.aws.toolkits.jetbrains.utils.rules.CodeInsightTestFixtureRule
import software.aws.toolkits.jetbrains.utils.rules.HeavyJavaCodeInsightTestFixtureRule
import software.aws.toolkits.jetbrains.utils.rules.JavaCodeInsightTestFixtureRule
import software.aws.toolkits.jetbrains.utils.rules.addModule
import java.time.Instant

open class AmazonQTestBase(
@Rule @JvmField
Expand All @@ -47,11 +44,7 @@ open class AmazonQTestBase(
project = projectRule.project
toolkitConnectionManager = spy(ToolkitConnectionManager.getInstance(project))

val accessToken = DeviceAuthorizationGrantToken(aString(), aString(), aString(), aString(), Instant.MAX, Instant.now())

val provider = mock<BearerTokenProvider> {
doReturn(accessToken).whenever(it).refresh()
}
val provider = mock<BearerTokenProvider>()

val mockBearerProvider = mock<ToolkitBearerTokenProvider> {
doReturn(provider).whenever(it).delegate
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -164,11 +164,7 @@ open class FeatureDevTestBase(
open fun setup() {
project = projectRule.project
toolkitConnectionManager = spy(ToolkitConnectionManager.getInstance(project))
val accessToken = DeviceAuthorizationGrantToken(aString(), aString(), aString(), aString(), Instant.MAX, Instant.now())
val provider =
mock<BearerTokenProvider> {
doReturn(accessToken).whenever(it).refresh()
}
val provider = mock<BearerTokenProvider>()
val mockBearerProvider =
mock<ToolkitBearerTokenProvider> {
doReturn(provider).whenever(it).delegate
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,7 @@ class CodeTransformChatApp : AmazonQApp {
ApplicationManager.getApplication().messageBus.connect(this).subscribe(
BearerTokenProviderListener.TOPIC,
object : BearerTokenProviderListener {
override fun onChange(providerId: String, newScopes: List<String>?) {
override fun onProviderChange(providerId: String, newScopes: List<String>?) {
val qProvider = getQTokenProvider(context.project)
val isQ = qProvider?.id == providerId
val isAuthorized = qProvider?.state() == BearerTokenAuthState.AUTHORIZED
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -87,11 +87,6 @@ suspend fun JobId.pollTransformationStatusAndPlan(
var transformationPlan: TransformationPlan? = null
var didSleepOnce = false
var hasSeenTransforming = false
val maxRefreshes = 10
var numRefreshes = 0

// refresh token at start of polling since local build just prior can take a long time
refreshToken(project)

try {
waitUntil(
Expand Down Expand Up @@ -138,13 +133,10 @@ suspend fun JobId.pollTransformationStatusAndPlan(
onStateChange(state, newStatus, transformationPlan)
}
state = newStatus
numRefreshes = 0
return@waitUntil state
} catch (e: AccessDeniedException) {
if (numRefreshes++ > maxRefreshes) throw e
refreshToken(project)
return@waitUntil state
} catch (e: InvalidGrantException) {
} catch (e: Exception) {
if (e !is AccessDeniedException && e !is InvalidGrantException) throw e

CodeTransformMessageListener.instance.onReauthStarted()
notifyStickyWarn(
message("codemodernizer.notification.warn.expired_credentials.title"),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@ import com.intellij.openapi.vfs.VfsUtilCore
import com.intellij.openapi.vfs.VirtualFileManager
import software.amazon.awssdk.services.codewhispererruntime.model.TransformationLanguage
import software.amazon.awssdk.services.codewhispererruntime.model.TransformationStatus
import software.aws.toolkits.core.TokenConnectionSettings
import software.aws.toolkits.jetbrains.core.credentials.AwsBearerTokenConnection
import software.aws.toolkits.jetbrains.core.credentials.ToolkitConnectionManager
import software.aws.toolkits.jetbrains.core.credentials.pinning.QConnection
Expand Down Expand Up @@ -43,12 +42,6 @@ val STATES_AFTER_STARTED = setOf(
*STATES_AFTER_INITIAL_BUILD.toTypedArray(),
)

fun refreshToken(project: Project) {
val connection = ToolkitConnectionManager.getInstance(project).activeConnectionForFeature(QConnection.getInstance())
val provider = (connection?.getConnectionSettings() as TokenConnectionSettings).tokenProvider.delegate as BearerTokenProvider
provider.refresh()
}

fun getAuthType(project: Project): CredentialSourceId? {
val connection = checkBearerConnectionValidity(project, BearerTokenFeatureSet.Q)
var authType: CredentialSourceId? = null
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,6 @@ import software.aws.toolkits.core.credentials.ToolkitBearerTokenProvider
import software.aws.toolkits.core.utils.test.aString
import software.aws.toolkits.jetbrains.core.credentials.AwsBearerTokenConnection
import software.aws.toolkits.jetbrains.core.credentials.ToolkitConnectionManager
import software.aws.toolkits.jetbrains.core.credentials.sso.DeviceAuthorizationGrantToken
import software.aws.toolkits.jetbrains.core.credentials.sso.PKCEAuthorizationGrantToken
import software.aws.toolkits.jetbrains.core.credentials.sso.bearer.BearerTokenAuthState
import software.aws.toolkits.jetbrains.core.credentials.sso.bearer.BearerTokenProvider
Expand Down Expand Up @@ -250,11 +249,7 @@ open class CodeWhispererCodeModernizerTestBase(
project = projectRule.project
toolkitConnectionManager = spy(ToolkitConnectionManager.getInstance(project))

val accessToken = DeviceAuthorizationGrantToken(aString(), aString(), aString(), aString(), Instant.MAX, Instant.now())
val provider =
mock<BearerTokenProvider> {
doReturn(accessToken).whenever(it).refresh()
}
val provider = mock<BearerTokenProvider> { }
val mockBearerProvider =
mock<ToolkitBearerTokenProvider> {
doReturn(provider).whenever(it).delegate
Expand Down Expand Up @@ -340,7 +335,6 @@ open class CodeWhispererCodeModernizerTestBase(
val accessToken = PKCEAuthorizationGrantToken(aString(), aString(), aString(), aString(), Instant.MAX, Instant.now())

val provider = mock<BearerTokenProvider> {
doReturn(accessToken).whenever(it).refresh()
doReturn(accessToken).whenever(it).currentToken()
doReturn(authState).whenever(it).state()
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,6 @@ import software.aws.toolkits.jetbrains.services.codemodernizer.utils.getTableMap
import software.aws.toolkits.jetbrains.services.codemodernizer.utils.isPlanComplete
import software.aws.toolkits.jetbrains.services.codemodernizer.utils.parseBuildFile
import software.aws.toolkits.jetbrains.services.codemodernizer.utils.pollTransformationStatusAndPlan
import software.aws.toolkits.jetbrains.services.codemodernizer.utils.refreshToken
import software.aws.toolkits.jetbrains.services.codemodernizer.utils.validateCustomVersionsFile
import software.aws.toolkits.jetbrains.services.codemodernizer.utils.validateSctMetadata
import software.aws.toolkits.jetbrains.utils.notifyStickyWarn
Expand Down Expand Up @@ -90,18 +89,18 @@ class CodeWhispererCodeModernizerUtilsTest : CodeWhispererCodeModernizerTestBase
}

@Test
fun `refresh on access denied`() {
fun `show re-auth notification on access denied`() {
val mockAccessDeniedException = Mockito.mock(AccessDeniedException::class.java)

mockkStatic(::refreshToken)
every { refreshToken(any()) } just runs
mockkStatic(::notifyStickyWarn)
every { notifyStickyWarn(any(), any(), any(), any(), any()) } just runs

Mockito.doThrow(
mockAccessDeniedException
).doReturn(
exampleGetCodeMigrationResponse,
exampleGetCodeMigrationResponse.replace(TransformationStatus.STARTED),
exampleGetCodeMigrationResponse.replace(TransformationStatus.COMPLETED), // Should stop before this point
exampleGetCodeMigrationResponse.replace(TransformationStatus.COMPLETED),
).whenever(clientAdaptorSpy).getCodeModernizationJob(any())

Mockito.doReturn(exampleGetCodeMigrationPlanResponse)
Expand All @@ -128,7 +127,7 @@ class CodeWhispererCodeModernizerUtilsTest : CodeWhispererCodeModernizerTestBase
TransformationStatus.STARTED,
)
assertThat(expected).isEqualTo(mutableList)
io.mockk.verify { refreshToken(any()) }
verify { notifyStickyWarn(message("codemodernizer.notification.warn.expired_credentials.title"), any(), any(), any(), any()) }
}

@Test
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ class CodeWhispererStatusBarWidget(project: Project) :
ApplicationManager.getApplication().messageBus.connect(this).subscribe(
BearerTokenProviderListener.TOPIC,
object : BearerTokenProviderListener {
override fun onChange(providerId: String, newScopes: List<String>?) {
override fun onProviderChange(providerId: String, newScopes: List<String>?) {
statusBar.updateWidget(ID)
}
}
Expand Down

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@ import software.aws.toolkits.jetbrains.core.credentials.ToolkitConnection
import software.aws.toolkits.jetbrains.core.credentials.ToolkitConnectionManager
import software.aws.toolkits.jetbrains.core.credentials.ToolkitConnectionManagerListener
import software.aws.toolkits.jetbrains.core.credentials.pinning.QConnection
import software.aws.toolkits.jetbrains.core.credentials.sso.bearer.BearerTokenAuthState
import software.aws.toolkits.jetbrains.core.credentials.sso.bearer.BearerTokenProvider
import software.aws.toolkits.jetbrains.core.credentials.sso.bearer.BearerTokenProviderListener
import software.aws.toolkits.jetbrains.services.amazonq.lsp.AmazonQLspService
Expand All @@ -43,15 +42,14 @@ class DefaultAuthCredentialsService(
private val project: Project,
private val encryptionManager: JwtEncryptionManager,
private val cs: CoroutineScope,
) : AuthCredentialsService,
BearerTokenProviderListener,
) : BearerTokenProviderListener,
ToolkitConnectionManagerListener,
QRegionProfileSelectedListener,
Disposable {

private val scheduler: ScheduledExecutorService = AppExecutorUtil.getAppScheduledExecutorService()
private var tokenSyncTask: ScheduledFuture<*>? = null
private val tokenSyncIntervalMinutes = 5L
private var tokenRefreshTask: ScheduledFuture<*>? = null
private val tokenRefreshInterval = 5L

init {
project.messageBus.connect(this).apply {
Expand All @@ -67,49 +65,37 @@ class DefaultAuthCredentialsService(
}
}

// Start periodic token sync
startPeriodicTokenSync()
// Start periodic token refresh
startPeriodicTokenRefresh()
}

private fun startPeriodicTokenSync() {
tokenSyncTask = scheduler.scheduleWithFixedDelay(
// TODO: we really only need a single application-wide instance of this
private fun startPeriodicTokenRefresh() {
tokenRefreshTask = scheduler.scheduleWithFixedDelay(
{
try {
if (isQConnected(project)) {
if (isQExpired(project)) {
val manager = ToolkitConnectionManager.getInstance(project)
val connection = manager.activeConnectionForFeature(QConnection.getInstance()) ?: return@scheduleWithFixedDelay

// Try to refresh the token if it's in NEEDS_REFRESH state
val tokenProvider = (connection.getConnectionSettings() as? TokenConnectionSettings)
?.tokenProvider
?.delegate
?.let { it as? BearerTokenProvider } ?: return@scheduleWithFixedDelay

if (tokenProvider.state() == BearerTokenAuthState.NEEDS_REFRESH) {
try {
tokenProvider.resolveToken()
// Now that the token is refreshed, update it in Flare
updateTokenFromActiveConnection()
} catch (e: Exception) {
LOG.warn(e) { "Failed to refresh bearer token" }
}
}
} else {
updateTokenFromActiveConnection()
}
val manager = ToolkitConnectionManager.getInstance(project)
val connection = manager.activeConnectionForFeature(QConnection.getInstance()) ?: return@scheduleWithFixedDelay

// periodically poll token to trigger a background refresh if needed
val tokenProvider = (connection.getConnectionSettings() as? TokenConnectionSettings)
?.tokenProvider
?.delegate
?.let { it as? BearerTokenProvider } ?: return@scheduleWithFixedDelay
tokenProvider.resolveToken()
Copy link
Contributor

@manodnyab manodnyab Jul 8, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

does this attempt to use the cached supplier to refresh? should the refresh called be used here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cached supplier will refresh if needed

}
} catch (e: Exception) {
LOG.warn(e) { "Failed to sync bearer token to Flare" }
LOG.warn(e) { "Failed to refresh bearer token" }
}
},
tokenSyncIntervalMinutes,
tokenSyncIntervalMinutes,
tokenRefreshInterval,
tokenRefreshInterval,
TimeUnit.MINUTES
)
}

override fun updateTokenCredentials(connection: ToolkitConnection, encrypted: Boolean): CompletableFuture<ResponseMessage> {
fun updateTokenCredentials(connection: ToolkitConnection, encrypted: Boolean): CompletableFuture<ResponseMessage> {
val payload = try {
createUpdateCredentialsPayload(connection, encrypted)
} catch (e: Exception) {
Expand All @@ -129,18 +115,26 @@ class DefaultAuthCredentialsService(
}.asCompletableFuture()
}

override fun deleteTokenCredentials() {
fun deleteTokenCredentials() {
cs.launch {
AmazonQLspService.executeAsyncIfRunning(project) { server ->
server.deleteTokenCredentials()
}
}
}

override fun onChange(providerId: String, newScopes: List<String>?) {
override fun onProviderChange(providerId: String, newScopes: List<String>?) {
updateTokenFromActiveConnection()
}

override fun onTokenModified(providerId: String) {
updateTokenFromActiveConnection()
}

override fun invalidate(providerId: String) {
deleteTokenCredentials()
}

override fun activeConnectionChanged(newConnection: ToolkitConnection?) {
val qConnection = ToolkitConnectionManager.getInstance(project)
.activeConnectionForFeature(QConnection.getInstance())
Expand All @@ -161,10 +155,6 @@ class DefaultAuthCredentialsService(
private fun updateTokenFromConnection(connection: ToolkitConnection): CompletableFuture<ResponseMessage> =
updateTokenCredentials(connection, true)

override fun invalidate(providerId: String) {
deleteTokenCredentials()
}

private fun createUpdateCredentialsPayload(connection: ToolkitConnection, encrypted: Boolean): UpdateCredentialsPayload {
val token = (connection.getConnectionSettings() as? TokenConnectionSettings)
?.tokenProvider
Expand Down Expand Up @@ -212,8 +202,8 @@ class DefaultAuthCredentialsService(
}

override fun dispose() {
tokenSyncTask?.cancel(false)
tokenSyncTask = null
tokenRefreshTask?.cancel(false)
tokenRefreshTask = null
}

companion object {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,12 @@ class QRegionProfileManager : PersistentStateComponent<QProfileState>, Disposabl
connectionIdToProfileCount[connection.id] = it.size
} ?: error("You don't have access to the resource")
} catch (e: Exception) {
LOG.warn(e) { "Failed to list region profiles: ${e.message}" }
if (e is AccessDeniedException) {
LOG.warn { "Failed to list region profiles: ${e.message}" }
} else {
LOG.warn(e) { "Failed to list region profiles" }
}

throw e
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,7 @@ class DefaultAuthCredentialsServiceTest {
sut = DefaultAuthCredentialsService(project, mockEncryptionManager, this)
setupMockConnectionManager("updated-token")

sut.onChange("providerId", listOf("new-scope"))
sut.onProviderChange("providerId", listOf("new-scope"))

advanceUntilIdle()
verify(exactly = 1) { mockLanguageServer.updateTokenCredentials(any()) }
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ open class AwsClientManager : ToolkitClientManager(), Disposable {
busConnection.subscribe(
BearerTokenProviderListener.TOPIC,
object : BearerTokenProviderListener {
override fun onChange(providerId: String, newScopes: List<String>?) {
override fun onProviderChange(providerId: String, newScopes: List<String>?) {
invalidateSdks(providerId)
}

Expand Down
Loading
Loading