Skip to content

Commit 7bc23fb

Browse files
committed
tst/lint
1 parent 5ae453f commit 7bc23fb

File tree

4 files changed

+86
-29
lines changed

4 files changed

+86
-29
lines changed

plugins/amazonq/shared/jetbrains-community/src/software/aws/toolkits/jetbrains/services/amazonq/lsp/auth/DefaultAuthCredentialsService.kt

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -165,7 +165,6 @@ class DefaultAuthCredentialsService(
165165
private fun updateTokenFromConnection(connection: ToolkitConnection): CompletableFuture<ResponseMessage> =
166166
updateTokenCredentials(connection, true)
167167

168-
169168
private fun createUpdateCredentialsPayload(connection: ToolkitConnection, encrypted: Boolean): UpdateCredentialsPayload {
170169
val token = (connection.getConnectionSettings() as? TokenConnectionSettings)
171170
?.tokenProvider

plugins/core/jetbrains-community/src/software/aws/toolkits/jetbrains/core/credentials/sso/SsoAccessTokenProvider.kt

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -194,7 +194,7 @@ class SsoAccessTokenProvider(
194194

195195
@Deprecated("Device authorization grant flow is deprecated")
196196
private fun registerDAGClient(): ClientRegistration {
197-
loadDagClientRegistration(SourceOfLoadRegistration.REGISTER_CLIENT.toString())?.let {
197+
loadDagClientRegistration(SourceOfLoadRegistration.REGISTER_CLIENT)?.let {
198198
return it
199199
}
200200

@@ -235,7 +235,7 @@ class SsoAccessTokenProvider(
235235
}
236236

237237
private fun registerPkceClient(): PKCEClientRegistration {
238-
loadPkceClientRegistration(SourceOfLoadRegistration.REGISTER_CLIENT.toString())?.let {
238+
loadPkceClientRegistration(SourceOfLoadRegistration.REGISTER_CLIENT)?.let {
239239
return it
240240
}
241241

@@ -431,8 +431,8 @@ class SsoAccessTokenProvider(
431431
stageName = RefreshCredentialStage.LOAD_REGISTRATION
432432
val registration = try {
433433
when (currentToken) {
434-
is DeviceAuthorizationGrantToken -> loadDagClientRegistration(SourceOfLoadRegistration.REFRESH_TOKEN.toString())
435-
is PKCEAuthorizationGrantToken -> loadPkceClientRegistration(SourceOfLoadRegistration.REFRESH_TOKEN.toString())
434+
is DeviceAuthorizationGrantToken -> loadDagClientRegistration(SourceOfLoadRegistration.REFRESH_TOKEN)
435+
is PKCEAuthorizationGrantToken -> loadPkceClientRegistration(SourceOfLoadRegistration.REFRESH_TOKEN)
436436
}
437437
} catch (e: Exception) {
438438
val message = e.message ?: "$stageName: ${e::class.java.name}"
@@ -519,13 +519,13 @@ class SsoAccessTokenProvider(
519519
SAVE_TOKEN,
520520
}
521521

522-
private fun loadDagClientRegistration(source: String): ClientRegistration? =
523-
cache.loadClientRegistration(dagClientRegistrationCacheKey, source)?.let {
522+
private fun loadDagClientRegistration(source: SourceOfLoadRegistration): ClientRegistration? =
523+
cache.loadClientRegistration(dagClientRegistrationCacheKey, source.toString())?.let {
524524
return it
525525
}
526526

527-
private fun loadPkceClientRegistration(source: String): PKCEClientRegistration? =
528-
cache.loadClientRegistration(pkceClientRegistrationCacheKey, source)?.let {
527+
private fun loadPkceClientRegistration(source: SourceOfLoadRegistration): PKCEClientRegistration? =
528+
cache.loadClientRegistration(pkceClientRegistrationCacheKey, source.toString())?.let {
529529
return it as PKCEClientRegistration
530530
}
531531

plugins/core/jetbrains-community/src/software/aws/toolkits/jetbrains/core/credentials/sso/bearer/BearerTokenProvider.kt

Lines changed: 17 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -37,8 +37,8 @@ import software.aws.toolkits.jetbrains.core.credentials.sso.DiskCache
3737
import software.aws.toolkits.jetbrains.core.credentials.sso.PendingAuthorization
3838
import software.aws.toolkits.jetbrains.core.credentials.sso.SsoAccessTokenProvider
3939
import software.aws.toolkits.jetbrains.core.credentials.sso.bearer.BearerTokenProviderListener.Companion.TOPIC
40+
import java.time.Clock
4041
import java.time.Duration
41-
import java.time.Instant
4242
import java.util.concurrent.atomic.AtomicBoolean
4343
import java.util.concurrent.atomic.AtomicReference
4444
import java.util.function.Supplier
@@ -75,11 +75,11 @@ interface BearerTokenProvider : SdkTokenProvider, SdkAutoCloseable, ToolkitBeare
7575
}
7676

7777
companion object {
78-
internal fun tokenExpired(accessToken: AccessToken) = Instant.now().isAfter(accessToken.expiresAt)
78+
private fun tokenExpired(accessToken: AccessToken, clock: Clock) = clock.instant().isAfter(accessToken.expiresAt)
7979

80-
internal fun state(accessToken: AccessToken?) = when {
80+
internal fun state(accessToken: AccessToken?, clock: Clock = Clock.systemUTC()) = when {
8181
accessToken == null -> BearerTokenAuthState.NOT_AUTHENTICATED
82-
tokenExpired(accessToken) -> {
82+
tokenExpired(accessToken, clock) -> {
8383
if (accessToken.refreshToken != null) {
8484
BearerTokenAuthState.NEEDS_REFRESH
8585
} else {
@@ -98,6 +98,7 @@ class InteractiveBearerTokenProvider(
9898
val scopes: List<String>,
9999
override val id: String,
100100
cache: DiskCache = diskCache,
101+
private val clock: Clock = Clock.systemUTC(),
101102
) : BearerTokenProvider, BearerTokenLogoutSupport, Disposable {
102103
override val displayName = ToolkitBearerTokenProvider.ssoDisplayName(startUrl)
103104

@@ -146,7 +147,7 @@ class InteractiveBearerTokenProvider(
146147
SupplierWithInitialValue(initialValue, accessTokenProvider).let {
147148
SupplierHolder(
148149
it,
149-
CachedSupplier.builder(it).prefetchStrategy(NonBlocking("AWS SSO bearer token refresher")).build()
150+
CachedSupplier.builder(it).clock(clock).prefetchStrategy(NonBlocking("AWS SSO bearer token refresher")).build()
150151
)
151152
}
152153

@@ -163,8 +164,14 @@ class InteractiveBearerTokenProvider(
163164
val token = if (hasCalledAtLeastOnce.getAndSet(true)) {
164165
refresh()
165166
} else {
166-
initialValue ?: throw NoTokenInitializedException("Token refresh started before session initialized")
167+
// on initial call, refresh if needed
168+
if (initialValue != null && initialValue.expiresAt.minus(DEFAULT_PREFETCH_DURATION) < clock.instant()) {
169+
refresh()
170+
} else {
171+
initialValue ?: throw NoTokenInitializedException("Token provider initialized with no token")
172+
}
167173
}
174+
168175
return RefreshResult.builder(token)
169176
.staleTime(token.expiresAt.minus(DEFAULT_STALE_DURATION))
170177
.prefetchTime(token.expiresAt.minus(DEFAULT_PREFETCH_DURATION))
@@ -187,6 +194,8 @@ class InteractiveBearerTokenProvider(
187194
}
188195
}
189196

197+
override fun state() = BearerTokenProvider.state(currentToken(), clock)
198+
190199
// how we expect consumers to obtain a token
191200
override fun resolveToken() = supplier.cachedSupplier.get()
192201

@@ -209,6 +218,7 @@ class InteractiveBearerTokenProvider(
209218

210219
override fun invalidate() {
211220
accessTokenProvider.invalidate()
221+
supplier.cachedSupplier.close()
212222
supplier = supplier()
213223
BearerTokenProviderListener.notifyCredUpdate(id)
214224
}
@@ -217,6 +227,7 @@ class InteractiveBearerTokenProvider(
217227
// we probably don't need to invalidate this, but we might as well since we need to login again anyways
218228
invalidate()
219229
accessTokenProvider.accessToken().also {
230+
supplier.cachedSupplier.close()
220231
supplier = supplier(it)
221232
BearerTokenProviderListener.notifyCredUpdate(id)
222233
}

plugins/core/jetbrains-community/tst/software/aws/toolkits/jetbrains/core/credentials/sso/bearer/InteractiveBearerTokenProviderTest.kt

Lines changed: 61 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -13,11 +13,10 @@ import org.junit.Before
1313
import org.junit.Rule
1414
import org.junit.Test
1515
import org.junit.jupiter.api.assertThrows
16-
import org.mockito.Mockito
1716
import org.mockito.kotlin.any
1817
import org.mockito.kotlin.argThat
19-
import org.mockito.kotlin.eq
2018
import org.mockito.kotlin.mock
19+
import org.mockito.kotlin.reset
2120
import org.mockito.kotlin.spy
2221
import org.mockito.kotlin.times
2322
import org.mockito.kotlin.verify
@@ -49,6 +48,7 @@ import software.aws.toolkits.jetbrains.core.credentials.sso.DeviceAuthorizationG
4948
import software.aws.toolkits.jetbrains.core.credentials.sso.DeviceGrantAccessTokenCacheKey
5049
import software.aws.toolkits.jetbrains.core.credentials.sso.DiskCache
5150
import software.aws.toolkits.jetbrains.core.credentials.sso.PKCEAccessTokenCacheKey
51+
import java.time.Clock
5252
import java.time.Instant
5353
import java.time.temporal.ChronoUnit
5454

@@ -158,7 +158,7 @@ class InteractiveBearerTokenProviderTest {
158158
}
159159

160160
@Test
161-
fun `resolveToken does't refresh if token was retrieved recently`() {
161+
fun `resolveToken doesn't refresh if token was retrieved recently`() {
162162
stubClientRegistration()
163163
whenever(diskCache.loadAccessToken(any<DeviceGrantAccessTokenCacheKey>())).thenReturn(
164164
DeviceAuthorizationGrantToken(
@@ -173,11 +173,56 @@ class InteractiveBearerTokenProviderTest {
173173
sut.resolveToken()
174174
}
175175

176+
@Test
177+
fun `resolveToken attempts to refresh token on first invoke if expired`() {
178+
stubClientRegistration()
179+
stubAccessToken()
180+
whenever(diskCache.loadAccessToken(any<DeviceGrantAccessTokenCacheKey>())).thenReturn(
181+
DeviceAuthorizationGrantToken(
182+
startUrl = startUrl,
183+
region = region,
184+
accessToken = "accessToken",
185+
refreshToken = "refreshToken",
186+
expiresAt = Instant.now()
187+
)
188+
)
189+
val sut = buildSut()
190+
sut.resolveToken()
191+
192+
verify(oidcClient).createToken(any<CreateTokenRequest>())
193+
}
194+
195+
@Test
196+
fun `resolveToken refreshes on subsequent invokes if expired`() {
197+
val mockClock = mock<Clock>()
198+
whenever(mockClock.instant()).thenReturn(Instant.now())
199+
stubClientRegistration()
200+
stubAccessToken()
201+
whenever(diskCache.loadAccessToken(any<DeviceGrantAccessTokenCacheKey>())).thenReturn(
202+
DeviceAuthorizationGrantToken(
203+
startUrl = startUrl,
204+
region = region,
205+
accessToken = "accessToken",
206+
refreshToken = "refreshToken",
207+
expiresAt = Instant.now().plus(1, ChronoUnit.HOURS)
208+
)
209+
)
210+
val sut = buildSut(mockClock)
211+
// current token should be valid
212+
assertThat(sut.resolveToken().accessToken).isEqualTo("accessToken")
213+
verify(oidcClient, times(0)).createToken(any<CreateTokenRequest>())
214+
215+
// then if we advance the clock it should refresh
216+
whenever(mockClock.instant()).thenReturn(Instant.now().plus(100, ChronoUnit.DAYS))
217+
assertThat(sut.resolveToken().accessToken).isEqualTo("access1")
218+
verify(oidcClient, times(1)).createToken(any<CreateTokenRequest>())
219+
}
220+
176221
@Test
177222
fun `resolveToken throws if reauthentication is needed`() {
178223
stubClientRegistration()
179224
stubAccessToken()
180-
Mockito.reset(oidcClient)
225+
reset(oidcClient)
181226
whenever(oidcClient.createToken(any<CreateTokenRequest>())).thenThrow(AccessDeniedException.create("denied", null))
182227

183228
val sut = buildSut()
@@ -206,7 +251,8 @@ class InteractiveBearerTokenProviderTest {
206251
sut.invalidate()
207252

208253
// initial load
209-
verify(diskCache).loadAccessToken(any<DeviceGrantAccessTokenCacheKey>())
254+
// invalidate attempts to reload token from disk
255+
verify(diskCache, times(2)).loadAccessToken(any<DeviceGrantAccessTokenCacheKey>())
210256
verify(diskCache).invalidateClientRegistration(region)
211257
verify(diskCache).invalidateAccessToken(startUrl)
212258

@@ -230,22 +276,22 @@ class InteractiveBearerTokenProviderTest {
230276
stubAccessToken()
231277
val sut = buildSut()
232278

233-
assertThat(sut.currentToken()?.accessToken).isEqualTo("accessToken")
279+
assertThat(sut.resolveToken().accessToken).isEqualTo("access1")
234280

235281
// and now instead of trying to stub out the entire OIDC device flow, abuse the fact that we short-circuit and read from disk if available
236-
Mockito.reset(diskCache)
282+
reset(diskCache)
237283
whenever(diskCache.loadAccessToken(any<DeviceGrantAccessTokenCacheKey>())).thenReturn(
238284
DeviceAuthorizationGrantToken(
239285
startUrl = startUrl,
240286
region = region,
241-
accessToken = "access1",
242-
refreshToken = "refresh1",
287+
accessToken = "access1234",
288+
refreshToken = "refresh1234",
243289
expiresAt = Instant.MAX
244290
)
245291
)
246292
sut.reauthenticate()
247293

248-
assertThat(sut.currentToken()?.accessToken).isEqualTo("access1")
294+
assertThat(sut.resolveToken().accessToken).isEqualTo("access1234")
249295
}
250296

251297
@Test
@@ -263,16 +309,17 @@ class InteractiveBearerTokenProviderTest {
263309
verify(mockListener, times(2)).onProviderChange(sut.id)
264310
}
265311

266-
private fun buildSut() = InteractiveBearerTokenProvider(
312+
private fun buildSut(clock: Clock = Clock.systemUTC()) = InteractiveBearerTokenProvider(
267313
startUrl = startUrl,
268314
region = region,
269315
scopes = scopes,
270316
cache = diskCache,
271-
id = "test"
317+
id = "test",
318+
clock = clock,
272319
)
273320

274321
private fun stubClientRegistration() {
275-
whenever(diskCache.loadClientRegistration(any<DeviceAuthorizationClientRegistrationCacheKey>(), eq("testSource"))).thenReturn(
322+
whenever(diskCache.loadClientRegistration(any<DeviceAuthorizationClientRegistrationCacheKey>(), any())).thenReturn(
276323
DeviceAuthorizationClientRegistration(
277324
"",
278325
"",
@@ -288,7 +335,7 @@ class InteractiveBearerTokenProviderTest {
288335
region = region,
289336
accessToken = "accessToken",
290337
refreshToken = "refreshToken",
291-
expiresAt = Instant.MIN
338+
expiresAt = Instant.now().minus(100, ChronoUnit.DAYS),
292339
)
293340
)
294341
whenever(oidcClient.createToken(any<CreateTokenRequest>())).thenReturn(

0 commit comments

Comments
 (0)