3
3
* SPDX-License-Identifier: Apache-2.0
4
4
*/
5
5
6
- import * as vscode from 'vscode'
7
6
import globals from '../shared/extensionGlobals'
7
+
8
+ import * as nls from 'vscode-nls'
9
+ const localize = nls . loadMessageBundle ( )
10
+
11
+ import * as vscode from 'vscode'
8
12
import { CancellationError } from '../shared/utilities/timeoutUtils'
9
13
import { AwsContext } from '../shared/awsContext'
10
14
import { getAccountId } from '../shared/credentials/accountId'
11
15
import { getLogger } from '../shared/logger'
12
- import { recordAwsValidateCredentials , recordVscodeActiveRegions , Result } from '../shared/telemetry/telemetry'
16
+ import {
17
+ CredentialSourceId ,
18
+ CredentialType ,
19
+ recordAwsRefreshCredentials ,
20
+ recordAwsValidateCredentials ,
21
+ recordVscodeActiveRegions ,
22
+ Result ,
23
+ } from '../shared/telemetry/telemetry'
13
24
import { CredentialsStore } from './credentialsStore'
14
25
import { CredentialsSettings , notifyUserInvalidCredentials } from './credentialsUtilities'
15
26
import {
@@ -24,9 +35,9 @@ import { getIdeProperties, isCloud9 } from '../shared/extensionUtilities'
24
35
import { SharedCredentialsProvider } from './providers/sharedCredentialsProvider'
25
36
import { showViewLogsMessage } from '../shared/utilities/messages'
26
37
import { isAutomation } from '../shared/vscode/env'
27
-
28
- import * as nls from 'vscode-nls '
29
- const localize = nls . loadMessageBundle ( )
38
+ import { Credentials } from '@aws-sdk/types'
39
+ import { ToolkitError } from '../shared/toolkitError '
40
+ import * as localizedText from '../shared/localizedText'
30
41
31
42
export class LoginManager {
32
43
private readonly defaultCredentialsRegion = 'us-east-1'
@@ -51,27 +62,25 @@ export class LoginManager {
51
62
let telemetryResult : Result = 'Failed'
52
63
53
64
try {
54
- provider = await CredentialsProviderManager . getInstance ( ) . getCredentialsProvider ( args . providerId )
55
- if ( ! provider ) {
56
- throw new Error ( `Could not find Credentials Provider for ${ asString ( args . providerId ) } ` )
57
- }
65
+ provider = await getProvider ( args . providerId )
58
66
59
- const storedCredentials = await this . store . upsertCredentials ( args . providerId , provider )
60
- if ( ! storedCredentials ) {
67
+ const credentials = ( await this . store . upsertCredentials ( args . providerId , provider ) ) ?. credentials
68
+ if ( ! credentials ) {
61
69
throw new Error ( `No credentials found for id ${ asString ( args . providerId ) } ` )
62
70
}
63
71
64
72
const credentialsRegion = provider . getDefaultRegion ( ) ?? this . defaultCredentialsRegion
65
- const accountId = await getAccountId ( storedCredentials . credentials , credentialsRegion )
73
+ const accountId = await getAccountId ( credentials , credentialsRegion )
66
74
if ( ! accountId ) {
67
75
throw new Error ( 'Could not determine Account Id for credentials' )
68
76
}
69
77
recordVscodeActiveRegions ( { value : ( await this . awsContext . getExplorerRegions ( ) ) . length } )
70
78
79
+ this . awsContext . credentialsShim = createCredentialsShim ( this . store , args . providerId , credentials )
71
80
await this . awsContext . setCredentials ( {
72
- credentials : storedCredentials . credentials ,
73
- credentialsId : asString ( args . providerId ) ,
81
+ credentials,
74
82
accountId : accountId ,
83
+ credentialsId : asString ( args . providerId ) ,
75
84
defaultRegion : provider . getDefaultRegion ( ) ,
76
85
} )
77
86
@@ -91,6 +100,7 @@ export class LoginManager {
91
100
92
101
await this . logout ( )
93
102
this . store . invalidateCredentials ( args . providerId )
103
+ this . awsContext . credentialsShim = undefined
94
104
return false
95
105
} finally {
96
106
const credType = provider ?. getTelemetryType ( )
@@ -233,3 +243,119 @@ function tryMakeCredentialsProviderId(credentials: string): CredentialsId | unde
233
243
return undefined
234
244
}
235
245
}
246
+
247
+ async function getProvider ( id : CredentialsId ) : Promise < CredentialsProvider > {
248
+ const provider = await CredentialsProviderManager . getInstance ( ) . getCredentialsProvider ( id )
249
+ if ( ! provider ) {
250
+ throw new Error ( `Could not find Credentials Provider for ${ asString ( id ) } ` )
251
+ }
252
+ return provider
253
+ }
254
+
255
+ /**
256
+ * The Toolkit implementation has a good amount of custom logic (SSO, source profiles, etc.)
257
+ * that was written to fill feature gaps in both AWS SDK V2 and V3. This code works imperfectly with
258
+ * pre-existing SDK refresh logic, leading to users experiencing issues with credentials expiring and
259
+ * forcing them to re-select a profile to refresh.
260
+ *
261
+ * So, this interface sits between everything else to mediate the refreshes. Adding a thin interface
262
+ * is preferred over bulking up existing ones since it allows for clean(-ish) refactors against the
263
+ * credentials subsystem. The pure data structure shape that existed previously was better, but it
264
+ * just wouldn't be able to support refreshes on its own.
265
+ */
266
+ export interface CredentialsShim {
267
+ /**
268
+ * Fetches credentials, attempting a refresh if needed.
269
+ */
270
+ get : ( ) => Promise < Credentials >
271
+
272
+ /**
273
+ * Removes the stored credentials and performs a refresh, allowing for prompts.
274
+ *
275
+ * Calling this function while a refresh is still pending returns the already pending promise.
276
+ */
277
+ refresh : ( ) => Promise < Credentials >
278
+ }
279
+
280
+ /**
281
+ * Collapses a single {@link CredentialsProvider} (referenced by id) into something a bit simpler.
282
+ *
283
+ * We don't pass in a provider directly since {@link CredentialsProviderManager} is the true
284
+ * source of credential state, at least as far as the Toolkit is concerned.
285
+ */
286
+ function createCredentialsShim (
287
+ store : CredentialsStore ,
288
+ providerId : CredentialsId ,
289
+ creds : Credentials
290
+ ) : CredentialsShim {
291
+ interface State {
292
+ credentials : Promise < Credentials >
293
+ pendingRefresh : Promise < Credentials >
294
+ }
295
+
296
+ const state : Partial < State > = { credentials : Promise . resolve ( creds ) }
297
+
298
+ async function refresh ( ) : Promise < Credentials > {
299
+ let result : Result = 'Failed'
300
+ let credentialType : CredentialType | undefined
301
+ let credentialSourceId : CredentialSourceId | undefined
302
+
303
+ try {
304
+ getLogger ( ) . debug ( `credentials: refreshing provider: ${ asString ( providerId ) } ` )
305
+
306
+ const provider = await getProvider ( providerId )
307
+ const formatProviderId = ( ) => asString ( provider . getCredentialsId ( ) )
308
+
309
+ credentialType = provider . getTelemetryType ( )
310
+ credentialSourceId = credentialsProviderToTelemetryType ( provider . getProviderType ( ) )
311
+
312
+ if ( ! provider . canAutoConnect ( ) ) {
313
+ const message = localize ( 'aws.credentials.expired' , 'Credentials are expired or invalid, login again?' )
314
+ const resp = await vscode . window . showInformationMessage ( message , localizedText . yes , localizedText . no )
315
+
316
+ if ( resp === localizedText . no ) {
317
+ throw new ToolkitError ( 'User cancelled login' , { cancelled : true } )
318
+ }
319
+ }
320
+
321
+ const credentials = await provider . getCredentials ( )
322
+ store . setCredentials ( credentials , provider )
323
+ getLogger ( ) . debug ( `credentials: refresh succeeded for: ${ formatProviderId ( ) } ` )
324
+ result = 'Succeeded'
325
+
326
+ return credentials
327
+ } catch ( error ) {
328
+ if ( error instanceof ToolkitError && error . cancelled ) {
329
+ result = 'Cancelled'
330
+ } else {
331
+ showViewLogsMessage ( `Failed to refresh credentials: ${ ( error as any ) ?. message } ` )
332
+ }
333
+
334
+ state . credentials = undefined
335
+ store . invalidateCredentials ( providerId )
336
+ globals . awsContext . credentialsShim = undefined
337
+ globals . awsContext . setCredentials ( undefined , true )
338
+
339
+ throw error
340
+ } finally {
341
+ recordAwsRefreshCredentials ( {
342
+ result,
343
+ passive : true ,
344
+ credentialType,
345
+ credentialSourceId,
346
+ } )
347
+ }
348
+ }
349
+
350
+ const shim = {
351
+ get : ( ) => ( state . credentials ??= shim . refresh ( ) ) ,
352
+ refresh : ( ) => {
353
+ const clear = ( ) => ( state . pendingRefresh = undefined )
354
+ state . credentials = state . pendingRefresh ??= refresh ( ) . finally ( clear )
355
+
356
+ return state . credentials
357
+ } ,
358
+ }
359
+
360
+ return shim
361
+ }
0 commit comments