diff --git a/packages/core/src/auth/sso/ssoAccessTokenProvider.ts b/packages/core/src/auth/sso/ssoAccessTokenProvider.ts index e753fb2ef90..edf460c6656 100644 --- a/packages/core/src/auth/sso/ssoAccessTokenProvider.ts +++ b/packages/core/src/auth/sso/ssoAccessTokenProvider.ts @@ -59,6 +59,15 @@ export abstract class SsoAccessTokenProvider { private static logIfChanged = onceChanged((s: string) => getLogger().info(s)) private readonly className = 'SsoAccessTokenProvider' + /** + * Prevents concurrent token refresh operations. + * Maps tokenCacheKey to an in-flight refresh promise. + */ + private static refreshPromises = new Map< + string, + Promise<{ token: SsoToken; registration: ClientRegistration; region: string; startUrl: string }> + >() + public static set authSource(val: string) { SsoAccessTokenProvider._authSource = val } @@ -108,15 +117,43 @@ export abstract class SsoAccessTokenProvider { true ) ) + if (!data || !isExpired(data.token)) { + getLogger().debug('Auth: token is valid, returning cached token (key=%s)', this.tokenCacheKey) return data?.token } + getLogger().info( + `Auth: bearer token expired (expires at ${data.token.expiresAt}), attempting refresh (key=${this.tokenCacheKey})` + ) + if (data.registration && !isExpired(data.registration) && hasProps(data.token, 'refreshToken')) { - const refreshed = await this.refreshToken(data.token, data.registration) + getLogger().debug(`Auth: refresh token available, calling refreshToken() (key=${this.tokenCacheKey})`) + // Check if a refresh is already in progress for this token + const existingRefresh = SsoAccessTokenProvider.refreshPromises.get(this.tokenCacheKey) + if (existingRefresh) { + getLogger().debug( + 'SsoAccessTokenProvider: Token refresh already in progress, waiting for existing refresh' + ) + const refreshed = await existingRefresh + return refreshed.token + } + + // Start a new refresh and store the promise + const refreshPromise = this.refreshToken(data.token, data.registration) + SsoAccessTokenProvider.refreshPromises.set(this.tokenCacheKey, refreshPromise) - return refreshed.token + try { + const refreshed = await refreshPromise + return refreshed.token + } finally { + // Clean up the promise from the map once complete (success or failure) + SsoAccessTokenProvider.refreshPromises.delete(this.tokenCacheKey) + } } else { + getLogger().warn( + `getToken: cannot refresh - registration expired or no refresh token available (key=${this.tokenCacheKey})` + ) await this.invalidate('allCacheExpired') } } @@ -172,10 +209,18 @@ export abstract class SsoAccessTokenProvider { try { const clientInfo = selectFrom(registration, 'clientId', 'clientSecret') + getLogger().debug(`Auth refreshToken: calling OIDC createToken API (key=${this.tokenCacheKey})`) const response = await this.oidc.createToken({ ...clientInfo, ...token, grantType: refreshGrantType }) + + getLogger().debug(`Auth refreshToken: got response, now saving to cache...`) + const refreshed = this.formatToken(response, registration) + getLogger().debug(`refreshToken: saving refreshed token to cache (key=${this.tokenCacheKey})`) await this.cache.token.save(this.tokenCacheKey, refreshed) + getLogger().info( + `Auth refreshToken: token refresh successful (key=${this.tokenCacheKey}, new expiry=${response.expiresAt})` + ) telemetry.aws_refreshCredentials.emit({ result: 'Succeeded', requestId: response.requestId, @@ -184,6 +229,10 @@ export abstract class SsoAccessTokenProvider { return refreshed } catch (err) { + getLogger().error( + `Auth refreshToken: token refresh failed (key=${this.tokenCacheKey}): ${getErrorMsg(err as unknown as Error)}` + ) + if (err instanceof DiskCacheError) { /** * Background: @@ -197,6 +246,9 @@ export abstract class SsoAccessTokenProvider { * to the logs where the error was logged. Hopefully they can use this information to fix the issue, * or at least hint for them to provide the logs in a bug report. */ + getLogger().warn( + `Auth refreshToken: DiskCacheError during refresh, not invalidating session (key=${this.tokenCacheKey})` + ) void DiskCacheErrorMessage.instance.showMessageThrottled(err) } else if (!isNetworkError(err)) { const reason = getTelemetryReason(err)