@@ -28,11 +28,8 @@ import software.aws.toolkits.jetbrains.utils.sleepWithCancellation
2828import software.aws.toolkits.resources.AwsCoreBundle
2929import software.aws.toolkits.telemetry.AuthType
3030import software.aws.toolkits.telemetry.AwsTelemetry
31- import software.aws.toolkits.telemetry.CredentialModification
3231import software.aws.toolkits.telemetry.CredentialSourceId
3332import software.aws.toolkits.telemetry.Result
34- import java.io.FileNotFoundException
35- import java.io.IOException
3633import java.time.Clock
3734import java.time.Duration
3835import java.time.Instant
@@ -365,60 +362,59 @@ class SsoAccessTokenProvider(
365362 }
366363
367364 fun refreshToken (currentToken : AccessToken ): AccessToken {
365+ var stageName = RefreshCredentialStage .VALIDATE_REFRESH_TOKEN
368366 if (currentToken.refreshToken == null ) {
369367 val message = " Requested token refresh, but refresh token was null"
370368 sendRefreshCredentialsMetric(
371369 currentToken,
372- reason = " Null refresh token" ,
373- reasonDesc = message,
370+ reason = " Null refresh token: $stageName " ,
371+ reasonDesc = " $stageName : $ message" ,
374372 result = Result .Failed
375373 )
376374 throw InvalidRequestException .builder().message(message).build()
377375 }
378376
377+ stageName = RefreshCredentialStage .LOAD_REGISTRATION
379378 var registration: ClientRegistration ? = null
380379 try {
381380 registration = when (currentToken) {
382381 is DeviceAuthorizationGrantToken -> loadDagClientRegistration()
383382 is PKCEAuthorizationGrantToken -> loadPkceClientRegistration()
384383 }
385384 } catch (e: Exception ) {
386- val message = " Error loading client registration : ${e.message } "
385+ val message = e.message ? : " $stageName : ${e:: class .java.name } "
387386 sendRefreshCredentialsMetric(
388387 currentToken,
389- reason = " Failed to load client registration" ,
390- reasonDesc = " Step: Load Registration - $ message" ,
388+ reason = " Failed to load client registration: $stageName " ,
389+ reasonDesc = message,
391390 result = Result .Failed
392391 )
393392 throw InvalidClientException .builder().message(message).cause(e).build()
394393 }
395394
396- var isExpired = currentToken.expiresAt.isBefore(Instant .now(clock))
397- if (isExpired){
398- registration = reauthExpiredRegistration(currentToken)
399- isExpired = accessToken().expiresAt.isBefore(Instant .now(clock))
400- }
395+ stageName = RefreshCredentialStage .VALIDATE_REGISTRATION
401396 if (registration == null ) {
402- val (message, reason ) = when {
403- isExpired -> Pair (
404- " Client registration has expired and reauth failed " ,
397+ val (reason, message ) = when {
398+ currentToken.expiresAt.isBefore( Instant .now(clock)) -> Pair (
399+ " Reauth Required: $stageName " ,
405400 " Expired client registration"
406401 )
402+ // TODO: reauth
407403 else -> Pair (
408- " Unable to load client registration from cache" ,
404+ " Unable to load client registration from cache: $stageName " ,
409405 " Null client registration"
410406 )
411407 }
412408 sendRefreshCredentialsMetric(
413409 currentToken,
414410 reason = reason,
415- reasonDesc = " Step: Check Registration - $message " ,
411+ reasonDesc = " $stageName : $message " ,
416412 result = Result .Failed
417413 )
418414 throw InvalidClientException .builder().message(message).build()
419415 }
420416
421- var stageName = RefreshCredentialStage .CREATE_TOKEN
417+ stageName = RefreshCredentialStage .CREATE_TOKEN
422418 try {
423419 val newToken = client.createToken {
424420 it.clientId(registration.clientId)
@@ -457,7 +453,7 @@ class SsoAccessTokenProvider(
457453 sendRefreshCredentialsMetric(
458454 currentToken,
459455 reason = " Refresh access token request failed: $stageName " ,
460- reasonDesc = " Step: $ message" ,
456+ reasonDesc = message,
461457 requestId = requestId,
462458 result = Result .Failed
463459 )
@@ -466,115 +462,47 @@ class SsoAccessTokenProvider(
466462 }
467463
468464 private enum class RefreshCredentialStage {
465+ VALIDATE_REFRESH_TOKEN ,
466+ LOAD_REGISTRATION ,
467+ VALIDATE_REGISTRATION ,
469468 CREATE_TOKEN ,
470469 GET_TOKEN_DETAILS ,
471470 SAVE_TOKEN ,
472471 }
473472
474473 private fun loadDagClientRegistration (): ClientRegistration ? =
475- try {
476- cache.loadClientRegistration(dagClientRegistrationCacheKey)?.let {
477- return it
478- }
479- } catch (e: FileNotFoundException ) {
480- AwsTelemetry .openCredentials(
481- result = Result .Failed ,
482- reason = " Failed to load DAG client registration from cache" ,
483- reasonDesc = e.message
484- )
485- throw e
474+ cache.loadClientRegistration(dagClientRegistrationCacheKey)?.let {
475+ return it
486476 }
487477
488478 private fun loadPkceClientRegistration (): PKCEClientRegistration ? =
489- try {
490- cache.loadClientRegistration(pkceClientRegistrationCacheKey)?.let {
491- return it as PKCEClientRegistration
492- }
493- } catch (e: FileNotFoundException ) {
494- AwsTelemetry .openCredentials(
495- result = Result .Failed ,
496- reason = " Failed to load PKCE client registration from cache" ,
497- reasonDesc = e.message
498- )
499- throw e
479+ cache.loadClientRegistration(pkceClientRegistrationCacheKey)?.let {
480+ return it as PKCEClientRegistration
500481 }
501482
502483 private fun saveClientRegistration (registration : ClientRegistration ) {
503- val credentialType: String
504- try {
505- when (registration) {
506- is DeviceAuthorizationClientRegistration -> {
507- credentialType = DeviceAuthorizationClientRegistration ::class .java.name
508- cache.saveClientRegistration(dagClientRegistrationCacheKey, registration)
509- }
484+ when (registration) {
485+ is DeviceAuthorizationClientRegistration -> {
486+ cache.saveClientRegistration(dagClientRegistrationCacheKey, registration)
487+ }
510488
511- is PKCEClientRegistration -> {
512- credentialType = PKCEClientRegistration ::class .java.name
513- cache.saveClientRegistration(pkceClientRegistrationCacheKey, registration)
514- }
489+ is PKCEClientRegistration -> {
490+ cache.saveClientRegistration(pkceClientRegistrationCacheKey, registration)
515491 }
516- } catch (e: Exception ) {
517- AwsTelemetry .createCredentials(
518- result = Result .Failed ,
519- reason = " Failed to save client registration to cache" ,
520- reasonDesc = e.message
521- )
522- throw e
523492 }
524- AwsTelemetry .createCredentials(
525- result = Result .Succeeded ,
526- reason = " $credentialType successfully written to cache" ,
527- )
528493 }
529494
530495 private fun invalidateClientRegistration () {
531- try {
532- cache.invalidateClientRegistration(dagClientRegistrationCacheKey)
533- cache.invalidateClientRegistration(pkceClientRegistrationCacheKey)
534- } catch (e: IOException ) {
535- AwsTelemetry .modifyCredentials(
536- credentialModification = CredentialModification .Delete ,
537- result = Result .Failed ,
538- reason = " Failed to invalidate client registration" ,
539- reasonDesc = e.message,
540- source = " SsoAccessTokenProvider.invalidateClientRegistration"
541- )
542- }
543- }
544-
545- private fun reauthExpiredRegistration (expiredToken : AccessToken ): ClientRegistration ? {
546- when (expiredToken) {
547- is DeviceAuthorizationGrantToken -> registerDAGClient()
548- is PKCEAuthorizationGrantToken -> registerPkceClient()
549- }
550- try {
551- return when (expiredToken) {
552- is DeviceAuthorizationGrantToken -> loadDagClientRegistration()
553- is PKCEAuthorizationGrantToken -> loadPkceClientRegistration()
554- }
555- } catch (e: Exception ) {
556- val message = " Error loading client registration: ${e.message} "
557- sendRefreshCredentialsMetric(
558- expiredToken,
559- reason = " Failed to load client registration" ,
560- reasonDesc = " Step: Load Registration after reauth - $message " ,
561- result = Result .Failed
562- )
563- throw InvalidClientException .builder().message(message).cause(e).build()
564- }
496+ cache.invalidateClientRegistration(dagClientRegistrationCacheKey)
497+ cache.invalidateClientRegistration(pkceClientRegistrationCacheKey)
565498 }
566499
567500 private fun saveAccessToken (token : AccessToken ) {
568- try {
569- when (token) {
570- is DeviceAuthorizationGrantToken -> {
571- cache.saveAccessToken(dagAccessTokenCacheKey, token)
572- }
573-
574- is PKCEAuthorizationGrantToken -> cache.saveAccessToken(pkceAccessTokenCacheKey, token)
501+ when (token) {
502+ is DeviceAuthorizationGrantToken -> {
503+ cache.saveAccessToken(dagAccessTokenCacheKey, token)
575504 }
576- } catch (e: Exception ) {
577- throw e
505+ is PKCEAuthorizationGrantToken -> cache.saveAccessToken(pkceAccessTokenCacheKey, token)
578506 }
579507 }
580508
0 commit comments