Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 54 additions & 2 deletions packages/core/src/auth/sso/ssoAccessTokenProvider.ts
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,15 @@ export abstract class SsoAccessTokenProvider {
private static logIfChanged = onceChanged((s: string) => getLogger().info(s))
private readonly className = 'SsoAccessTokenProvider'

/**
* Prevents concurrent token refresh operations.
* Maps tokenCacheKey to an in-flight refresh promise.
*/
private static refreshPromises = new Map<
string,
Promise<{ token: SsoToken; registration: ClientRegistration; region: string; startUrl: string }>
>()

public static set authSource(val: string) {
SsoAccessTokenProvider._authSource = val
}
Expand Down Expand Up @@ -108,15 +117,43 @@ export abstract class SsoAccessTokenProvider {
true
)
)

if (!data || !isExpired(data.token)) {
getLogger().debug('Auth: token is valid, returning cached token (key=%s)', this.tokenCacheKey)
return data?.token
}

getLogger().info(
`Auth: bearer token expired (expires at ${data.token.expiresAt}), attempting refresh (key=${this.tokenCacheKey})`
)

if (data.registration && !isExpired(data.registration) && hasProps(data.token, 'refreshToken')) {
const refreshed = await this.refreshToken(data.token, data.registration)
getLogger().debug(`Auth: refresh token available, calling refreshToken() (key=${this.tokenCacheKey})`)
// Check if a refresh is already in progress for this token
const existingRefresh = SsoAccessTokenProvider.refreshPromises.get(this.tokenCacheKey)
if (existingRefresh) {
getLogger().debug(
'SsoAccessTokenProvider: Token refresh already in progress, waiting for existing refresh'
)
const refreshed = await existingRefresh
return refreshed.token
}

// Start a new refresh and store the promise
const refreshPromise = this.refreshToken(data.token, data.registration)
SsoAccessTokenProvider.refreshPromises.set(this.tokenCacheKey, refreshPromise)

return refreshed.token
try {
const refreshed = await refreshPromise
return refreshed.token
} finally {
// Clean up the promise from the map once complete (success or failure)
SsoAccessTokenProvider.refreshPromises.delete(this.tokenCacheKey)
}
} else {
getLogger().warn(
`getToken: cannot refresh - registration expired or no refresh token available (key=${this.tokenCacheKey})`
)
await this.invalidate('allCacheExpired')
}
}
Expand Down Expand Up @@ -172,10 +209,18 @@ export abstract class SsoAccessTokenProvider {

try {
const clientInfo = selectFrom(registration, 'clientId', 'clientSecret')
getLogger().debug(`Auth refreshToken: calling OIDC createToken API (key=${this.tokenCacheKey})`)
const response = await this.oidc.createToken({ ...clientInfo, ...token, grantType: refreshGrantType })

getLogger().debug(`Auth refreshToken: got response, now saving to cache...`)

const refreshed = this.formatToken(response, registration)
getLogger().debug(`refreshToken: saving refreshed token to cache (key=${this.tokenCacheKey})`)
await this.cache.token.save(this.tokenCacheKey, refreshed)

getLogger().info(
`Auth refreshToken: token refresh successful (key=${this.tokenCacheKey}, new expiry=${response.expiresAt})`
)
telemetry.aws_refreshCredentials.emit({
result: 'Succeeded',
requestId: response.requestId,
Expand All @@ -184,6 +229,10 @@ export abstract class SsoAccessTokenProvider {

return refreshed
} catch (err) {
getLogger().error(
`Auth refreshToken: token refresh failed (key=${this.tokenCacheKey}): ${getErrorMsg(err as unknown as Error)}`
)

if (err instanceof DiskCacheError) {
/**
* Background:
Expand All @@ -197,6 +246,9 @@ export abstract class SsoAccessTokenProvider {
* to the logs where the error was logged. Hopefully they can use this information to fix the issue,
* or at least hint for them to provide the logs in a bug report.
*/
getLogger().warn(
`Auth refreshToken: DiskCacheError during refresh, not invalidating session (key=${this.tokenCacheKey})`
)
void DiskCacheErrorMessage.instance.showMessageThrottled(err)
} else if (!isNetworkError(err)) {
const reason = getTelemetryReason(err)
Expand Down
Loading