@@ -13,11 +13,10 @@ import org.junit.Before
13
13
import org.junit.Rule
14
14
import org.junit.Test
15
15
import org.junit.jupiter.api.assertThrows
16
- import org.mockito.Mockito
17
16
import org.mockito.kotlin.any
18
17
import org.mockito.kotlin.argThat
19
- import org.mockito.kotlin.eq
20
18
import org.mockito.kotlin.mock
19
+ import org.mockito.kotlin.reset
21
20
import org.mockito.kotlin.spy
22
21
import org.mockito.kotlin.times
23
22
import org.mockito.kotlin.verify
@@ -49,6 +48,7 @@ import software.aws.toolkits.jetbrains.core.credentials.sso.DeviceAuthorizationG
49
48
import software.aws.toolkits.jetbrains.core.credentials.sso.DeviceGrantAccessTokenCacheKey
50
49
import software.aws.toolkits.jetbrains.core.credentials.sso.DiskCache
51
50
import software.aws.toolkits.jetbrains.core.credentials.sso.PKCEAccessTokenCacheKey
51
+ import java.time.Clock
52
52
import java.time.Instant
53
53
import java.time.temporal.ChronoUnit
54
54
@@ -158,7 +158,7 @@ class InteractiveBearerTokenProviderTest {
158
158
}
159
159
160
160
@Test
161
- fun `resolveToken does 't refresh if token was retrieved recently` () {
161
+ fun `resolveToken doesn 't refresh if token was retrieved recently` () {
162
162
stubClientRegistration()
163
163
whenever(diskCache.loadAccessToken(any<DeviceGrantAccessTokenCacheKey >())).thenReturn(
164
164
DeviceAuthorizationGrantToken (
@@ -173,11 +173,56 @@ class InteractiveBearerTokenProviderTest {
173
173
sut.resolveToken()
174
174
}
175
175
176
+ @Test
177
+ fun `resolveToken attempts to refresh token on first invoke if expired` () {
178
+ stubClientRegistration()
179
+ stubAccessToken()
180
+ whenever(diskCache.loadAccessToken(any<DeviceGrantAccessTokenCacheKey >())).thenReturn(
181
+ DeviceAuthorizationGrantToken (
182
+ startUrl = startUrl,
183
+ region = region,
184
+ accessToken = " accessToken" ,
185
+ refreshToken = " refreshToken" ,
186
+ expiresAt = Instant .now()
187
+ )
188
+ )
189
+ val sut = buildSut()
190
+ sut.resolveToken()
191
+
192
+ verify(oidcClient).createToken(any<CreateTokenRequest >())
193
+ }
194
+
195
+ @Test
196
+ fun `resolveToken refreshes on subsequent invokes if expired` () {
197
+ val mockClock = mock<Clock >()
198
+ whenever(mockClock.instant()).thenReturn(Instant .now())
199
+ stubClientRegistration()
200
+ stubAccessToken()
201
+ whenever(diskCache.loadAccessToken(any<DeviceGrantAccessTokenCacheKey >())).thenReturn(
202
+ DeviceAuthorizationGrantToken (
203
+ startUrl = startUrl,
204
+ region = region,
205
+ accessToken = " accessToken" ,
206
+ refreshToken = " refreshToken" ,
207
+ expiresAt = Instant .now().plus(1 , ChronoUnit .HOURS )
208
+ )
209
+ )
210
+ val sut = buildSut(mockClock)
211
+ // current token should be valid
212
+ assertThat(sut.resolveToken().accessToken).isEqualTo(" accessToken" )
213
+ verify(oidcClient, times(0 )).createToken(any<CreateTokenRequest >())
214
+
215
+ // then if we advance the clock it should refresh
216
+ whenever(mockClock.instant()).thenReturn(Instant .now().plus(100 , ChronoUnit .DAYS ))
217
+ assertThat(sut.resolveToken().accessToken).isEqualTo(" access1" )
218
+ verify(oidcClient, times(1 )).createToken(any<CreateTokenRequest >())
219
+ }
220
+
176
221
@Test
177
222
fun `resolveToken throws if reauthentication is needed` () {
178
223
stubClientRegistration()
179
224
stubAccessToken()
180
- Mockito . reset(oidcClient)
225
+ reset(oidcClient)
181
226
whenever(oidcClient.createToken(any<CreateTokenRequest >())).thenThrow(AccessDeniedException .create(" denied" , null ))
182
227
183
228
val sut = buildSut()
@@ -206,7 +251,8 @@ class InteractiveBearerTokenProviderTest {
206
251
sut.invalidate()
207
252
208
253
// initial load
209
- verify(diskCache).loadAccessToken(any<DeviceGrantAccessTokenCacheKey >())
254
+ // invalidate attempts to reload token from disk
255
+ verify(diskCache, times(2 )).loadAccessToken(any<DeviceGrantAccessTokenCacheKey >())
210
256
verify(diskCache).invalidateClientRegistration(region)
211
257
verify(diskCache).invalidateAccessToken(startUrl)
212
258
@@ -230,22 +276,22 @@ class InteractiveBearerTokenProviderTest {
230
276
stubAccessToken()
231
277
val sut = buildSut()
232
278
233
- assertThat(sut.currentToken()? .accessToken).isEqualTo(" accessToken " )
279
+ assertThat(sut.resolveToken() .accessToken).isEqualTo(" access1 " )
234
280
235
281
// and now instead of trying to stub out the entire OIDC device flow, abuse the fact that we short-circuit and read from disk if available
236
- Mockito . reset(diskCache)
282
+ reset(diskCache)
237
283
whenever(diskCache.loadAccessToken(any<DeviceGrantAccessTokenCacheKey >())).thenReturn(
238
284
DeviceAuthorizationGrantToken (
239
285
startUrl = startUrl,
240
286
region = region,
241
- accessToken = " access1 " ,
242
- refreshToken = " refresh1 " ,
287
+ accessToken = " access1234 " ,
288
+ refreshToken = " refresh1234 " ,
243
289
expiresAt = Instant .MAX
244
290
)
245
291
)
246
292
sut.reauthenticate()
247
293
248
- assertThat(sut.currentToken()? .accessToken).isEqualTo(" access1 " )
294
+ assertThat(sut.resolveToken() .accessToken).isEqualTo(" access1234 " )
249
295
}
250
296
251
297
@Test
@@ -263,16 +309,17 @@ class InteractiveBearerTokenProviderTest {
263
309
verify(mockListener, times(2 )).onProviderChange(sut.id)
264
310
}
265
311
266
- private fun buildSut () = InteractiveBearerTokenProvider (
312
+ private fun buildSut (clock : Clock = Clock .systemUTC() ) = InteractiveBearerTokenProvider (
267
313
startUrl = startUrl,
268
314
region = region,
269
315
scopes = scopes,
270
316
cache = diskCache,
271
- id = " test"
317
+ id = " test" ,
318
+ clock = clock,
272
319
)
273
320
274
321
private fun stubClientRegistration () {
275
- whenever(diskCache.loadClientRegistration(any<DeviceAuthorizationClientRegistrationCacheKey >(), eq( " testSource " ))).thenReturn(
322
+ whenever(diskCache.loadClientRegistration(any<DeviceAuthorizationClientRegistrationCacheKey >(), any( ))).thenReturn(
276
323
DeviceAuthorizationClientRegistration (
277
324
" " ,
278
325
" " ,
@@ -288,7 +335,7 @@ class InteractiveBearerTokenProviderTest {
288
335
region = region,
289
336
accessToken = " accessToken" ,
290
337
refreshToken = " refreshToken" ,
291
- expiresAt = Instant .MIN
338
+ expiresAt = Instant .now().minus( 100 , ChronoUnit . DAYS ),
292
339
)
293
340
)
294
341
whenever(oidcClient.createToken(any<CreateTokenRequest >())).thenReturn(
0 commit comments