Skip to content

Commit 88eb779

Browse files
committed
Move MFA handler to auth2
1 parent 6989096 commit 88eb779

File tree

2 files changed

+23
-22
lines changed

2 files changed

+23
-22
lines changed

packages/amazonq/src/lsp/client.ts

Lines changed: 1 addition & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -30,8 +30,6 @@ import {
3030
ResponseError,
3131
LSPErrorCodes,
3232
updateConfigurationRequestType,
33-
GetMfaCodeParams,
34-
GetMfaCodeResult,
3533
} from '@aws/language-server-runtimes/protocol'
3634
import {
3735
AuthUtil,
@@ -58,7 +56,7 @@ import { processUtils } from 'aws-core-vscode/shared'
5856
import { activate as activateChat } from './chat/activation'
5957
import { activate as activeInlineChat } from '../inlineChat/activation'
6058
import { AmazonQResourcePaths } from './lspInstaller'
61-
import { auth2, getMfaTokenFromUser, getMfaSerialFromUser } from 'aws-core-vscode/auth'
59+
import { auth2 } from 'aws-core-vscode/auth'
6260
import { ConfigSection, isValidConfigSection, pushConfigUpdate, toAmazonQLSPLogLevel } from './config'
6361
import { telemetry } from 'aws-core-vscode/telemetry'
6462
import { SessionManager } from '../app/inline/sessionManager'
@@ -339,24 +337,6 @@ async function postStartLanguageServer(
339337
}
340338
)
341339

342-
// Handler for when Flare needs to assume a role with MFA code
343-
client.onRequest(
344-
auth2.notificationTypes.getMfaCode.method,
345-
async (params: GetMfaCodeParams): Promise<GetMfaCodeResult> => {
346-
if (params.mfaSerial) {
347-
globals.globalState.update('recentMfaSerial', { mfaSerial: params.mfaSerial })
348-
}
349-
const defaultMfaSerial = globals.globalState.tryGet('recentMfaSerial', Object, {
350-
mfaSerial: '',
351-
}).mfaSerial
352-
let mfaSerial = await getMfaSerialFromUser(defaultMfaSerial, params.profileName)
353-
mfaSerial = mfaSerial.trim()
354-
globals.globalState.update('recentMfaSerial', { mfaSerial: mfaSerial })
355-
const mfaCode = await getMfaTokenFromUser(mfaSerial, params.profileName)
356-
return { code: mfaCode ?? '', mfaSerial: mfaSerial ?? '' }
357-
}
358-
)
359-
360340
const sendProfileToLsp = async () => {
361341
try {
362342
const result = await client.sendRequest(updateConfigurationRequestType.method, {

packages/core/src/auth/auth2.ts

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@ import {
5151
SsoSession,
5252
GetMfaCodeParams,
5353
getMfaCodeRequestType,
54+
GetMfaCodeResult,
5455
} from '@aws/language-server-runtimes/protocol'
5556
import { LanguageClient } from 'vscode-languageclient'
5657
import { getLogger } from '../shared/logger/logger'
@@ -59,6 +60,8 @@ import { useDeviceFlow } from './sso/ssoAccessTokenProvider'
5960
import { getCacheDir, getCacheFileWatcher, getFlareCacheFileName, getStsCacheDir } from './sso/cache'
6061
import { VSCODE_EXTENSION_ID } from '../shared/extensions'
6162
import { IamCredentials } from '@aws/language-server-runtimes-types'
63+
import globals from '../shared/extensionGlobals'
64+
import { getMfaSerialFromUser, getMfaTokenFromUser } from './credentials/utils'
6265

6366
export const notificationTypes = {
6467
updateIamCredential: new RequestType<UpdateCredentialsParams, ResponseMessage, Error>(
@@ -72,7 +75,6 @@ export const notificationTypes = {
7275
getConnectionMetadata: new RequestType<undefined, ConnectionMetadata, Error>(
7376
getConnectionMetadataRequestType.method
7477
),
75-
getMfaCode: new RequestType<GetMfaCodeParams, ResponseMessage, Error>(getMfaCodeRequestType.method),
7678
}
7779

7880
export type AuthState = 'notConnected' | 'connected' | 'expired'
@@ -291,6 +293,10 @@ export class LanguageClientAuth {
291293
this.client.onNotification(stsCredentialChangedRequestType.method, stsCredentialChangedHandler)
292294
}
293295

296+
registerGetMfaCodeHandler(getMfaCodeHandler: (params: GetMfaCodeParams) => Promise<GetMfaCodeResult>) {
297+
this.client.onRequest(getMfaCodeRequestType.method, getMfaCodeHandler)
298+
}
299+
294300
registerCacheWatcher(cacheChangedHandler: (event: cacheChangedEvent) => any) {
295301
this.cacheWatcher.onDidCreate(() => cacheChangedHandler('create'))
296302
this.cacheWatcher.onDidDelete(() => cacheChangedHandler('delete'))
@@ -536,6 +542,7 @@ export class IamLogin extends BaseLogin {
536542
lspAuth.registerStsCredentialChangedHandler((params: StsCredentialChangedParams) =>
537543
this.stsCredentialChangedHandler(params)
538544
)
545+
lspAuth.registerGetMfaCodeHandler((params: GetMfaCodeParams) => this.getMfaCodeHandler(params))
539546
}
540547

541548
async login(opts: { accessKey: string; secretKey: string; sessionToken?: string; roleArn?: string }) {
@@ -667,4 +674,18 @@ export class IamLogin extends BaseLogin {
667674
}
668675
}
669676
}
677+
678+
private async getMfaCodeHandler(params: GetMfaCodeParams): Promise<GetMfaCodeResult> {
679+
if (params.mfaSerial) {
680+
await globals.globalState.update('recentMfaSerial', { mfaSerial: params.mfaSerial })
681+
}
682+
const defaultMfaSerial = globals.globalState.tryGet('recentMfaSerial', Object, {
683+
mfaSerial: '',
684+
}).mfaSerial
685+
let mfaSerial = await getMfaSerialFromUser(defaultMfaSerial, params.profileName)
686+
mfaSerial = mfaSerial.trim()
687+
await globals.globalState.update('recentMfaSerial', { mfaSerial: mfaSerial })
688+
const mfaCode = await getMfaTokenFromUser(mfaSerial, params.profileName)
689+
return { code: mfaCode ?? '', mfaSerial: mfaSerial ?? '' }
690+
}
670691
}

0 commit comments

Comments
 (0)