Skip to content

Commit 383360e

Browse files
authored
Save refreshed PKCE token as PKCE token instead of DAG token (#4470)
1 parent 9e712c3 commit 383360e

File tree

2 files changed

+68
-3
lines changed

2 files changed

+68
-3
lines changed

plugins/core/jetbrains-community/src/software/aws/toolkits/jetbrains/core/credentials/sso/SsoAccessTokenProvider.kt

Lines changed: 20 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -124,7 +124,7 @@ class SsoAccessTokenProvider(
124124
val registerResponse = client.registerClient {
125125
it.clientType(PUBLIC_CLIENT_REGISTRATION_TYPE)
126126
it.scopes(scopes)
127-
it.clientName("AWS IDE Plugins for JetBrains")
127+
it.clientName(PKCE_CLIENT_NAME)
128128
}
129129

130130
val registeredClient = DeviceAuthorizationClientRegistration(
@@ -310,7 +310,11 @@ class SsoAccessTokenProvider(
310310
it.refreshToken(currentToken.refreshToken)
311311
}
312312

313-
val token = newToken.toDAGAccessToken(currentToken.createdAt)
313+
val token = when (currentToken) {
314+
is DeviceAuthorizationGrantToken -> newToken.toDAGAccessToken(currentToken.createdAt)
315+
is PKCEAuthorizationGrantToken -> newToken.toPKCEAccessToken(currentToken.createdAt)
316+
}
317+
314318
saveAccessToken(token)
315319

316320
return token
@@ -399,7 +403,7 @@ class SsoAccessTokenProvider(
399403
}
400404
}
401405

402-
private fun CreateTokenResponse.toDAGAccessToken(creationTime: Instant): AccessToken {
406+
private fun CreateTokenResponse.toDAGAccessToken(creationTime: Instant): DeviceAuthorizationGrantToken {
403407
val expirationTime = Instant.now(clock).plusSeconds(expiresIn().toLong())
404408

405409
return DeviceAuthorizationGrantToken(
@@ -412,6 +416,19 @@ class SsoAccessTokenProvider(
412416
)
413417
}
414418

419+
private fun CreateTokenResponse.toPKCEAccessToken(creationTime: Instant): PKCEAuthorizationGrantToken {
420+
val expirationTime = Instant.now(clock).plusSeconds(expiresIn().toLong())
421+
422+
return PKCEAuthorizationGrantToken(
423+
issuerUrl = ssoUrl,
424+
region = ssoRegion,
425+
accessToken = accessToken(),
426+
refreshToken = refreshToken(),
427+
expiresAt = expirationTime,
428+
createdAt = creationTime
429+
)
430+
}
431+
415432
private companion object {
416433
const val PUBLIC_CLIENT_REGISTRATION_TYPE = "public"
417434
const val DEVICE_GRANT_TYPE = "urn:ietf:params:oauth:grant-type:device_code"

plugins/core/jetbrains-community/tst/software/aws/toolkits/jetbrains/core/credentials/sso/SsoAccessTokenProviderTest.kt

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -293,6 +293,54 @@ class SsoAccessTokenProviderTest {
293293
verify(ssoCache).saveAccessToken(ssoUrl, refreshedToken)
294294
}
295295

296+
@Test
297+
fun `PKCE refresh access token saves PKCE token`() {
298+
val sut = SsoAccessTokenProvider(ssoUrl, "us-east-1", ssoCache, ssoOidcClient, scopes = listOf("dummy:scope"), clock = clock)
299+
300+
val expirationClientRegistration = clock.instant().plusSeconds(120)
301+
setupCacheStub(expirationClientRegistration)
302+
303+
val accessToken = PKCEAuthorizationGrantToken(ssoUrl, ssoRegion, "dummyToken", "refreshToken", clock.instant(), clock.instant())
304+
ssoCache.stub {
305+
on(
306+
ssoCache.loadAccessToken(any<PKCEAccessTokenCacheKey>())
307+
).thenReturn(
308+
accessToken
309+
)
310+
311+
on(
312+
ssoCache.loadClientRegistration(any<PKCEClientRegistrationCacheKey>())
313+
).thenReturn(
314+
PKCEClientRegistration(
315+
clientType = "public",
316+
redirectUris = listOf("uri"),
317+
grantTypes = listOf("grant"),
318+
issuerUrl = ssoUrl,
319+
region = ssoRegion,
320+
scopes = listOf("dummy:scope"),
321+
clientId = clientId,
322+
clientSecret = clientSecret,
323+
expiresAt = clock.instant()
324+
)
325+
)
326+
}
327+
328+
ssoOidcClient.stub {
329+
on(
330+
ssoOidcClient.createToken(refreshTokenRequest())
331+
).thenReturn(
332+
refreshTokenResponse()
333+
)
334+
}
335+
336+
val refreshedToken = runBlocking { sut.refreshToken(sut.accessToken()) }
337+
338+
verify(ssoCache).loadAccessToken(any<PKCEAccessTokenCacheKey>())
339+
verify(ssoCache).loadClientRegistration(any<PKCEClientRegistrationCacheKey>())
340+
verify(ssoOidcClient).createToken(any<CreateTokenRequest>())
341+
verify(ssoCache).saveAccessToken(any<PKCEAccessTokenCacheKey>(), eq(refreshedToken))
342+
}
343+
296344
@Test
297345
fun exceptionStopsPolling() {
298346
val expirationClientRegistration = clock.instant().plusSeconds(120)

0 commit comments

Comments
 (0)