Skip to content

Commit 324c47b

Browse files
committed
add sts management and mfa
1 parent 996497e commit 324c47b

File tree

6 files changed

+57
-23
lines changed

6 files changed

+57
-23
lines changed

packages/amazonq/src/lsp/client.ts

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,8 @@ import {
3030
ResponseError,
3131
LSPErrorCodes,
3232
updateConfigurationRequestType,
33+
GetMfaCodeParams,
34+
GetMfaCodeResult,
3335
} from '@aws/language-server-runtimes/protocol'
3436
import {
3537
AuthUtil,
@@ -56,7 +58,7 @@ import { processUtils } from 'aws-core-vscode/shared'
5658
import { activate as activateChat } from './chat/activation'
5759
import { activate as activeInlineChat } from '../inlineChat/activation'
5860
import { AmazonQResourcePaths } from './lspInstaller'
59-
import { auth2 } from 'aws-core-vscode/auth'
61+
import { auth2, getMfaTokenFromUser, getMfaSerialFromUser } from 'aws-core-vscode/auth'
6062
import { ConfigSection, isValidConfigSection, pushConfigUpdate, toAmazonQLSPLogLevel } from './config'
6163
import { telemetry } from 'aws-core-vscode/telemetry'
6264
import { SessionManager } from '../app/inline/sessionManager'
@@ -337,6 +339,24 @@ async function postStartLanguageServer(
337339
}
338340
)
339341

342+
// Handler for when Flare needs to assume a role with MFA code
343+
client.onRequest(
344+
auth2.notificationTypes.getMfaCode.method,
345+
async (params: GetMfaCodeParams): Promise<GetMfaCodeResult> => {
346+
if (params.mfaSerial) {
347+
globals.globalState.update('recentMfaSerial', { mfaSerial: params.mfaSerial })
348+
}
349+
const defaultMfaSerial = globals.globalState.tryGet('recentMfaSerial', Object, {
350+
mfaSerial: '',
351+
}).mfaSerial
352+
let mfaSerial = await getMfaSerialFromUser(defaultMfaSerial, params.profileName)
353+
mfaSerial = mfaSerial.trim()
354+
await globals.globalState.update('recentMfaSerial', { mfaSerial: mfaSerial })
355+
const mfaCode = await getMfaTokenFromUser(mfaSerial, params.profileName)
356+
return { code: mfaCode ?? '', mfaSerial: mfaSerial ?? '' }
357+
}
358+
)
359+
340360
const sendProfileToLsp = async () => {
341361
try {
342362
const result = await client.sendRequest(updateConfigurationRequestType.method, {

packages/core/src/auth/auth2.ts

Lines changed: 19 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@ import { LanguageClient } from 'vscode-languageclient'
5757
import { getLogger } from '../shared/logger/logger'
5858
import { ToolkitError } from '../shared/errors'
5959
import { useDeviceFlow } from './sso/ssoAccessTokenProvider'
60-
import { getCacheDir, getCacheFileWatcher, getFlareCacheFileName } from './sso/cache'
60+
import { getCacheDir, getCacheFileWatcher, getFlareCacheFileName, getStsCacheDir } from './sso/cache'
6161
import { VSCODE_EXTENSION_ID } from '../shared/extensions'
6262
import { IamCredentials } from '@aws/language-server-runtimes-types'
6363
import globals from '../shared/extensionGlobals'
@@ -75,6 +75,7 @@ export const notificationTypes = {
7575
getConnectionMetadata: new RequestType<undefined, ConnectionMetadata, Error>(
7676
getConnectionMetadataRequestType.method
7777
),
78+
getMfaCode: new RequestType<GetMfaCodeParams, ResponseMessage, Error>(getMfaCodeRequestType.method),
7879
}
7980

8081
export type AuthState = 'notConnected' | 'connected' | 'expired'
@@ -89,6 +90,8 @@ export type LoginType = (typeof LoginTypes)[keyof typeof LoginTypes]
8990

9091
export type cacheChangedEvent = 'delete' | 'create'
9192

93+
export type stsCacheChangedEvent = 'delete' | 'create'
94+
9295
export type Login = SsoLogin | IamLogin
9396

9497
export type TokenSource = IamIdentityCenterSsoTokenSource | AwsBuilderIdSsoTokenSource
@@ -114,6 +117,10 @@ const IamProfileOptionsDefaults = {
114117
*/
115118
export class LanguageClientAuth {
116119
readonly #ssoCacheWatcher = getCacheFileWatcher(getCacheDir(), getFlareCacheFileName(VSCODE_EXTENSION_ID.amazonq))
120+
readonly #stsCacheWatcher = getCacheFileWatcher(
121+
getStsCacheDir(),
122+
getFlareCacheFileName(VSCODE_EXTENSION_ID.amazonq)
123+
)
117124

118125
constructor(
119126
private readonly client: LanguageClient,
@@ -125,6 +132,10 @@ export class LanguageClientAuth {
125132
return this.#ssoCacheWatcher
126133
}
127134

135+
public get stsCacheWatcher() {
136+
return this.#stsCacheWatcher
137+
}
138+
128139
getSsoToken(
129140
tokenSource: TokenSource,
130141
login: boolean = false,
@@ -281,6 +292,11 @@ export class LanguageClientAuth {
281292
this.cacheWatcher.onDidCreate(() => cacheChangedHandler('create'))
282293
this.cacheWatcher.onDidDelete(() => cacheChangedHandler('delete'))
283294
}
295+
296+
registerStsCacheWatcher(stsCacheChangedHandler: (event: stsCacheChangedEvent) => any) {
297+
this.stsCacheWatcher.onDidCreate(() => stsCacheChangedHandler('create'))
298+
this.stsCacheWatcher.onDidDelete(() => stsCacheChangedHandler('delete'))
299+
}
284300
}
285301

286302
/**
@@ -357,13 +373,8 @@ export abstract class BaseLogin {
357373
* Decrypts an encrypted string, removes its quotes, and returns the resulting string
358374
*/
359375
protected async decrypt(encrypted: string): Promise<string> {
360-
try {
361-
const decrypted = await jose.compactDecrypt(encrypted, this.lspAuth.encryptionKey)
362-
return decrypted.plaintext.toString().replaceAll('"', '')
363-
} catch (e) {
364-
getLogger().error(`Failed to decrypt: ${encrypted}`)
365-
return encrypted
366-
}
376+
const decrypted = await jose.compactDecrypt(encrypted, this.lspAuth.encryptionKey)
377+
return decrypted.plaintext.toString().replaceAll('"', '')
367378
}
368379
}
369380

@@ -575,14 +586,6 @@ export class IamLogin extends BaseLogin {
575586
* Restore the connection state and connection details to memory, if they exist.
576587
*/
577588
async restore() {
578-
const sessionData = await this.getProfile()
579-
const credentials = sessionData?.profile?.settings
580-
if (credentials?.aws_access_key_id && credentials?.aws_secret_access_key) {
581-
this._data = {
582-
accessKey: credentials.aws_access_key_id,
583-
secretKey: credentials.aws_secret_access_key,
584-
}
585-
}
586589
try {
587590
await this._getIamCredential(false)
588591
} catch (err) {

packages/core/src/auth/sso/cache.ts

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,9 @@ export interface SsoCache {
3636
}
3737

3838
const defaultCacheDir = () => path.join(fs.getUserHomeDir(), '.aws/sso/cache')
39+
const defaultStsCacheDir = () => path.join(fs.getUserHomeDir(), '.aws/cli/cache')
3940
export const getCacheDir = () => DevSettings.instance.get('ssoCacheDirectory', defaultCacheDir())
41+
export const getStsCacheDir = () => DevSettings.instance.get('stsCacheDirectory', defaultStsCacheDir())
4042

4143
export function getCache(directory = getCacheDir()): SsoCache {
4244
return {

packages/core/src/codewhisperer/util/authUtil.ts

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ import { telemetry } from '../../shared/telemetry/telemetry'
3333
import {
3434
AuthStateEvent,
3535
cacheChangedEvent,
36+
stsCacheChangedEvent,
3637
LanguageClientAuth,
3738
Login,
3839
SsoLogin,
@@ -116,6 +117,7 @@ export class AuthUtil implements IAuthProvider {
116117
await this.setVscodeContextProps()
117118
})
118119
lspAuth.registerCacheWatcher(async (event: cacheChangedEvent) => await this.cacheChangedHandler(event))
120+
lspAuth.registerStsCacheWatcher(async (event: stsCacheChangedEvent) => await this.stsCacheChangedHandler(event))
119121
}
120122

121123
// Do NOT use this in production code, only used for testing
@@ -148,12 +150,13 @@ export class AuthUtil implements IAuthProvider {
148150
this.session = new SsoLogin(this.profileName, this.lspAuth, this.eventEmitter)
149151
await this.session.restore()
150152
if (!this.isConnected()) {
153+
await this.session?.logout()
151154
// Try to restore an IAM session
152155
this.session = new IamLogin(this.profileName, this.lspAuth, this.eventEmitter)
153156
await this.session.restore()
154157
if (!this.isConnected()) {
155158
// If both fail, reset the session
156-
this.session = undefined
159+
await this.session?.logout()
157160
}
158161
}
159162
}
@@ -262,10 +265,6 @@ export class AuthUtil implements IAuthProvider {
262265
return Boolean(this.connection?.startUrl && this.connection?.startUrl !== builderIdStartUrl)
263266
}
264267

265-
isIamConnection() {
266-
return Boolean(this.connection?.accessKey && this.connection?.secretKey)
267-
}
268-
269268
isInternalAmazonUser(): boolean {
270269
return this.isConnected() && this.connection?.startUrl === internalStartUrl
271270
}
@@ -365,6 +364,15 @@ export class AuthUtil implements IAuthProvider {
365364
}
366365
}
367366

367+
private async stsCacheChangedHandler(event: stsCacheChangedEvent) {
368+
this.logger.debug(`Sts Cache change event received: ${event}`)
369+
if (event === 'delete') {
370+
await this.logout()
371+
} else if (event === 'create') {
372+
await this.restore()
373+
}
374+
}
375+
368376
private async stateChangeHandler(e: AuthStateEvent) {
369377
if (e.state === 'refreshed') {
370378
if (this.isSsoSession()) {

packages/core/src/login/webview/vue/backend.ts

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -187,7 +187,7 @@ export abstract class CommonAuthWebview extends VueWebview {
187187
abstract fetchConnections(): Promise<AwsConnection[] | undefined>
188188

189189
async errorNotification(e: AuthError) {
190-
await showMessage('error', e.text)
190+
showMessage('error', e.text)
191191
}
192192

193193
abstract quitLoginScreen(): Promise<void>

packages/core/src/test/testAuthUtil.ts

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,7 @@ export async function createTestAuthUtil() {
5858
updateIamProfile: sinon.stub().resolves(),
5959
invalidateSsoToken: sinon.stub().resolves(),
6060
registerCacheWatcher: sinon.stub().resolves(),
61+
registerStsCacheWatcher: sinon.stub().resolves(),
6162
encryptionKey,
6263
}
6364

0 commit comments

Comments
 (0)