Skip to content

Commit 97c6eea

Browse files
authored
fix(auth): automatically refresh credentials #2618
Problem: Credentials are not refreshed automatically, mostly because of caching but also because not every source of credentials is "refreshable". Solution: Hide all the auth logic behind a simple interface then implement a provider that is compatible with the SDK v2. This implementation may prompt the user on expiration.
1 parent 8ff0737 commit 97c6eea

File tree

11 files changed

+269
-39
lines changed

11 files changed

+269
-39
lines changed
Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
{
2+
"type": "Bug Fix",
3+
"description": "Credential profiles that do not require user-input now correctly refresh when expired"
4+
}

src/credentials/credentialsStore.ts

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -61,11 +61,11 @@ export class CredentialsStore {
6161
let credentials = await this.getCredentials(credentialsId)
6262

6363
if (!credentials) {
64-
credentials = await this.setCredentials(credentialsId, credentialsProvider)
64+
credentials = await this.consumeProvider(credentialsId, credentialsProvider)
6565
} else if (credentialsProvider.getHashCode() !== credentials.credentialsHashCode) {
6666
getLogger().verbose(`Using updated credentials: ${asString(credentialsId)}`)
6767
this.invalidateCredentials(credentialsId)
68-
credentials = await this.setCredentials(credentialsId, credentialsProvider)
68+
credentials = await this.consumeProvider(credentialsId, credentialsProvider)
6969
}
7070

7171
return credentials
@@ -78,7 +78,14 @@ export class CredentialsStore {
7878
delete this.credentialsCache[asString(credentialsId)]
7979
}
8080

81-
private async setCredentials(
81+
public async setCredentials(credentials: AWS.Credentials, provider: CredentialsProvider): Promise<void> {
82+
this.credentialsCache[asString(provider.getCredentialsId())] = {
83+
credentials,
84+
credentialsHashCode: provider.getHashCode(),
85+
}
86+
}
87+
88+
private async consumeProvider(
8289
credentialsId: CredentialsId,
8390
credentialsProvider: CredentialsProvider
8491
): Promise<CachedCredentials> {

src/credentials/loginManager.ts

Lines changed: 140 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -3,13 +3,24 @@
33
* SPDX-License-Identifier: Apache-2.0
44
*/
55

6-
import * as vscode from 'vscode'
76
import globals from '../shared/extensionGlobals'
7+
8+
import * as nls from 'vscode-nls'
9+
const localize = nls.loadMessageBundle()
10+
11+
import * as vscode from 'vscode'
812
import { CancellationError } from '../shared/utilities/timeoutUtils'
913
import { AwsContext } from '../shared/awsContext'
1014
import { getAccountId } from '../shared/credentials/accountId'
1115
import { getLogger } from '../shared/logger'
12-
import { recordAwsValidateCredentials, recordVscodeActiveRegions, Result } from '../shared/telemetry/telemetry'
16+
import {
17+
CredentialSourceId,
18+
CredentialType,
19+
recordAwsRefreshCredentials,
20+
recordAwsValidateCredentials,
21+
recordVscodeActiveRegions,
22+
Result,
23+
} from '../shared/telemetry/telemetry'
1324
import { CredentialsStore } from './credentialsStore'
1425
import { CredentialsSettings, notifyUserInvalidCredentials } from './credentialsUtilities'
1526
import {
@@ -24,9 +35,9 @@ import { getIdeProperties, isCloud9 } from '../shared/extensionUtilities'
2435
import { SharedCredentialsProvider } from './providers/sharedCredentialsProvider'
2536
import { showViewLogsMessage } from '../shared/utilities/messages'
2637
import { isAutomation } from '../shared/vscode/env'
27-
28-
import * as nls from 'vscode-nls'
29-
const localize = nls.loadMessageBundle()
38+
import { Credentials } from '@aws-sdk/types'
39+
import { ToolkitError } from '../shared/toolkitError'
40+
import * as localizedText from '../shared/localizedText'
3041

3142
export class LoginManager {
3243
private readonly defaultCredentialsRegion = 'us-east-1'
@@ -51,27 +62,25 @@ export class LoginManager {
5162
let telemetryResult: Result = 'Failed'
5263

5364
try {
54-
provider = await CredentialsProviderManager.getInstance().getCredentialsProvider(args.providerId)
55-
if (!provider) {
56-
throw new Error(`Could not find Credentials Provider for ${asString(args.providerId)}`)
57-
}
65+
provider = await getProvider(args.providerId)
5866

59-
const storedCredentials = await this.store.upsertCredentials(args.providerId, provider)
60-
if (!storedCredentials) {
67+
const credentials = (await this.store.upsertCredentials(args.providerId, provider))?.credentials
68+
if (!credentials) {
6169
throw new Error(`No credentials found for id ${asString(args.providerId)}`)
6270
}
6371

6472
const credentialsRegion = provider.getDefaultRegion() ?? this.defaultCredentialsRegion
65-
const accountId = await getAccountId(storedCredentials.credentials, credentialsRegion)
73+
const accountId = await getAccountId(credentials, credentialsRegion)
6674
if (!accountId) {
6775
throw new Error('Could not determine Account Id for credentials')
6876
}
6977
recordVscodeActiveRegions({ value: (await this.awsContext.getExplorerRegions()).length })
7078

79+
this.awsContext.credentialsShim = createCredentialsShim(this.store, args.providerId, credentials)
7180
await this.awsContext.setCredentials({
72-
credentials: storedCredentials.credentials,
73-
credentialsId: asString(args.providerId),
81+
credentials,
7482
accountId: accountId,
83+
credentialsId: asString(args.providerId),
7584
defaultRegion: provider.getDefaultRegion(),
7685
})
7786

@@ -91,6 +100,7 @@ export class LoginManager {
91100

92101
await this.logout()
93102
this.store.invalidateCredentials(args.providerId)
103+
this.awsContext.credentialsShim = undefined
94104
return false
95105
} finally {
96106
const credType = provider?.getTelemetryType()
@@ -233,3 +243,119 @@ function tryMakeCredentialsProviderId(credentials: string): CredentialsId | unde
233243
return undefined
234244
}
235245
}
246+
247+
async function getProvider(id: CredentialsId): Promise<CredentialsProvider> {
248+
const provider = await CredentialsProviderManager.getInstance().getCredentialsProvider(id)
249+
if (!provider) {
250+
throw new Error(`Could not find Credentials Provider for ${asString(id)}`)
251+
}
252+
return provider
253+
}
254+
255+
/**
256+
* The Toolkit implementation has a good amount of custom logic (SSO, source profiles, etc.)
257+
* that was written to fill feature gaps in both AWS SDK V2 and V3. This code works imperfectly with
258+
* pre-existing SDK refresh logic, leading to users experiencing issues with credentials expiring and
259+
* forcing them to re-select a profile to refresh.
260+
*
261+
* So, this interface sits between everything else to mediate the refreshes. Adding a thin interface
262+
* is preferred over bulking up existing ones since it allows for clean(-ish) refactors against the
263+
* credentials subsystem. The pure data structure shape that existed previously was better, but it
264+
* just wouldn't be able to support refreshes on its own.
265+
*/
266+
export interface CredentialsShim {
267+
/**
268+
* Fetches credentials, attempting a refresh if needed.
269+
*/
270+
get: () => Promise<Credentials>
271+
272+
/**
273+
* Removes the stored credentials and performs a refresh, allowing for prompts.
274+
*
275+
* Calling this function while a refresh is still pending returns the already pending promise.
276+
*/
277+
refresh: () => Promise<Credentials>
278+
}
279+
280+
/**
281+
* Collapses a single {@link CredentialsProvider} (referenced by id) into something a bit simpler.
282+
*
283+
* We don't pass in a provider directly since {@link CredentialsProviderManager} is the true
284+
* source of credential state, at least as far as the Toolkit is concerned.
285+
*/
286+
function createCredentialsShim(
287+
store: CredentialsStore,
288+
providerId: CredentialsId,
289+
creds: Credentials
290+
): CredentialsShim {
291+
interface State {
292+
credentials: Promise<Credentials>
293+
pendingRefresh: Promise<Credentials>
294+
}
295+
296+
const state: Partial<State> = { credentials: Promise.resolve(creds) }
297+
298+
async function refresh(): Promise<Credentials> {
299+
let result: Result = 'Failed'
300+
let credentialType: CredentialType | undefined
301+
let credentialSourceId: CredentialSourceId | undefined
302+
303+
try {
304+
getLogger().debug(`credentials: refreshing provider: ${asString(providerId)}`)
305+
306+
const provider = await getProvider(providerId)
307+
const formatProviderId = () => asString(provider.getCredentialsId())
308+
309+
credentialType = provider.getTelemetryType()
310+
credentialSourceId = credentialsProviderToTelemetryType(provider.getProviderType())
311+
312+
if (!provider.canAutoConnect()) {
313+
const message = localize('aws.credentials.expired', 'Credentials are expired or invalid, login again?')
314+
const resp = await vscode.window.showInformationMessage(message, localizedText.yes, localizedText.no)
315+
316+
if (resp === localizedText.no) {
317+
throw new ToolkitError('User cancelled login', { cancelled: true })
318+
}
319+
}
320+
321+
const credentials = await provider.getCredentials()
322+
store.setCredentials(credentials, provider)
323+
getLogger().debug(`credentials: refresh succeeded for: ${formatProviderId()}`)
324+
result = 'Succeeded'
325+
326+
return credentials
327+
} catch (error) {
328+
if (error instanceof ToolkitError && error.cancelled) {
329+
result = 'Cancelled'
330+
} else {
331+
showViewLogsMessage(`Failed to refresh credentials: ${(error as any)?.message}`)
332+
}
333+
334+
state.credentials = undefined
335+
store.invalidateCredentials(providerId)
336+
globals.awsContext.credentialsShim = undefined
337+
globals.awsContext.setCredentials(undefined, true)
338+
339+
throw error
340+
} finally {
341+
recordAwsRefreshCredentials({
342+
result,
343+
passive: true,
344+
credentialType,
345+
credentialSourceId,
346+
})
347+
}
348+
}
349+
350+
const shim = {
351+
get: () => (state.credentials ??= shim.refresh()),
352+
refresh: () => {
353+
const clear = () => (state.pendingRefresh = undefined)
354+
state.credentials = state.pendingRefresh ??= refresh().finally(clear)
355+
356+
return state.credentials
357+
},
358+
}
359+
360+
return shim
361+
}

src/credentials/providers/ssoCredentialProvider.ts

Lines changed: 15 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66
import * as vscode from 'vscode'
77
import { Credentials } from '@aws-sdk/types'
8-
import { SSO } from '@aws-sdk/client-sso'
8+
import { SSO, UnauthorizedException } from '@aws-sdk/client-sso'
99
import { getLogger } from '../../shared/logger'
1010
import { SsoAccessTokenProvider } from '../sso/ssoAccessTokenProvider'
1111
import * as nls from 'vscode-nls'
@@ -23,19 +23,24 @@ export class SsoCredentialProvider {
2323
public async refreshCredentials(): Promise<Credentials> {
2424
try {
2525
const accessToken = await this.ssoAccessTokenProvider.accessToken()
26-
const roleCredentials = await this.ssoClient.getRoleCredentials({
27-
accountId: this.ssoAccount,
28-
roleName: this.ssoRole,
29-
accessToken: accessToken.accessToken,
30-
})
26+
const roleCredentials = await this.ssoClient
27+
.getRoleCredentials({
28+
accountId: this.ssoAccount,
29+
roleName: this.ssoRole,
30+
accessToken: accessToken.accessToken,
31+
})
32+
.then(resp => resp.roleCredentials)
33+
34+
const expiration = roleCredentials?.expiration ? new Date(roleCredentials.expiration) : undefined
3135

3236
return {
33-
accessKeyId: roleCredentials.roleCredentials!.accessKeyId!,
34-
secretAccessKey: roleCredentials.roleCredentials!.secretAccessKey!,
35-
sessionToken: roleCredentials.roleCredentials?.sessionToken,
37+
accessKeyId: roleCredentials!.accessKeyId!,
38+
secretAccessKey: roleCredentials!.secretAccessKey!,
39+
sessionToken: roleCredentials?.sessionToken,
40+
expiration,
3641
}
3742
} catch (err) {
38-
if ((err as { code: string }).code === 'UnauthorizedException') {
43+
if (err instanceof UnauthorizedException) {
3944
this.ssoAccessTokenProvider.invalidate()
4045
}
4146
vscode.window.showErrorMessage(

src/shared/awsClientBuilder.ts

Lines changed: 43 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,8 @@
33
* SPDX-License-Identifier: Apache-2.0
44
*/
55

6-
import { Request, AWSError } from 'aws-sdk'
6+
import { Request, AWSError, Credentials } from 'aws-sdk'
7+
import { CredentialsOptions } from 'aws-sdk/lib/credentials'
78
import { ServiceConfigurationOptions } from 'aws-sdk/lib/service'
89
import { env, version } from 'vscode'
910
import { AwsContext } from './awsContext'
@@ -60,11 +61,7 @@ export interface AWSClientBuilder {
6061
}
6162

6263
export class DefaultAWSClientBuilder implements AWSClientBuilder {
63-
private readonly _awsContext: AwsContext
64-
65-
public constructor(awsContext: AwsContext) {
66-
this._awsContext = awsContext
67-
}
64+
public constructor(private readonly awsContext: AwsContext) {}
6865

6966
public async createAwsService<T extends AWS.Service>(
7067
type: new (o: ServiceConfigurationOptions) => T,
@@ -78,7 +75,46 @@ export class DefaultAWSClientBuilder implements AWSClientBuilder {
7875
delete opt.onRequestSetup
7976

8077
if (!opt.credentials) {
81-
opt.credentials = await this._awsContext.getCredentials()
78+
const shim = this.awsContext.credentialsShim
79+
80+
if (!shim) {
81+
throw new Error('Toolkit is not logged-in.')
82+
}
83+
84+
opt.credentials = new (class extends Credentials {
85+
public constructor() {
86+
// The class doesn't like being instantiated with empty creds
87+
super({ accessKeyId: '???', secretAccessKey: '???' })
88+
}
89+
90+
public override get(callback: (err?: AWSError) => void): void {
91+
// Always try to fetch the latest creds first, attempting a refresh if needed
92+
// A 'passive' refresh is attempted first, before trying an 'active' one if certain criteria are met
93+
shim.get()
94+
.then(creds => {
95+
this.loadCreds(creds)
96+
this.needsRefresh() ? this.refresh(callback) : callback()
97+
})
98+
.catch(callback)
99+
}
100+
101+
public override refresh(callback: (err?: AWSError) => void): void {
102+
shim.refresh()
103+
.then(creds => {
104+
this.loadCreds(creds)
105+
callback()
106+
})
107+
.catch(callback)
108+
}
109+
110+
private loadCreds(creds: CredentialsOptions & { expiration?: Date }) {
111+
this.expired = false
112+
this.accessKeyId = creds.accessKeyId
113+
this.secretAccessKey = creds.secretAccessKey
114+
this.sessionToken = creds.sessionToken ?? this.sessionToken
115+
this.expireTime = creds.expiration ?? this.expireTime
116+
}
117+
})()
82118
}
83119

84120
if (!opt.region && region) {

src/shared/awsContext.ts

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ import * as AWS from '@aws-sdk/types'
88
import { regionSettingKey } from './constants'
99
import { getLogger } from '../shared/logger'
1010
import { ClassToInterfaceType } from './utilities/tsUtils'
11+
import { CredentialsShim } from '../credentials/loginManager'
1112

1213
export interface AwsContextCredentials {
1314
readonly credentials: AWS.Credentials
@@ -39,6 +40,7 @@ const DEFAULT_REGION = 'us-east-1'
3940
export class DefaultAwsContext implements AwsContext {
4041
public readonly onDidChangeContext: vscode.Event<ContextChangeEventsArgs>
4142
private readonly _onDidChangeContext: vscode.EventEmitter<ContextChangeEventsArgs>
43+
private shim?: CredentialsShim
4244

4345
// the collection of regions the user has expressed an interest in working with in
4446
// the current workspace
@@ -54,6 +56,14 @@ export class DefaultAwsContext implements AwsContext {
5456
this.explorerRegions = persistedRegions || []
5557
}
5658

59+
public get credentialsShim(): CredentialsShim | undefined {
60+
return this.shim
61+
}
62+
63+
public set credentialsShim(shim: CredentialsShim | undefined) {
64+
this.shim = shim
65+
}
66+
5767
/**
5868
* Sets the credentials to be used by the Toolkit.
5969
* Passing in undefined represents that there are no active credentials.
@@ -74,7 +84,12 @@ export class DefaultAwsContext implements AwsContext {
7484
* @description Gets the Credentials currently used by the Toolkit.
7585
*/
7686
public async getCredentials(): Promise<AWS.Credentials | undefined> {
77-
return this.currentCredentials?.credentials
87+
return (
88+
this.shim?.get().catch(error => {
89+
getLogger().warn(`credentials: failed to retrieve latest credentials: ${error.message}`)
90+
return undefined
91+
}) ?? this.currentCredentials?.credentials
92+
)
7893
}
7994

8095
// returns the configured profile, if any

0 commit comments

Comments
 (0)