diff --git a/package-lock.json b/package-lock.json index 7ad7d2486f8..6e832688fc4 100644 --- a/package-lock.json +++ b/package-lock.json @@ -10900,7 +10900,7 @@ "license": "Apache-2.0", "dependencies": { "@aws/chat-client-ui-types": "^0.1.53", - "@aws/language-server-runtimes-types": "^0.1.47", + "@aws/language-server-runtimes-types": "^0.1.52", "@aws/mynah-ui": "^4.28.0" } }, @@ -10909,15 +10909,15 @@ "dev": true, "license": "Apache-2.0", "dependencies": { - "@aws/language-server-runtimes-types": "^0.1.47" + "@aws/language-server-runtimes-types": "^0.1.52" } }, "node_modules/@aws/language-server-runtimes": { - "version": "0.2.111", + "version": "0.2.120", "dev": true, "license": "Apache-2.0", "dependencies": { - "@aws/language-server-runtimes-types": "^0.1.47", + "@aws/language-server-runtimes-types": "^0.1.52", "@opentelemetry/api": "^1.9.0", "@opentelemetry/api-logs": "^0.200.0", "@opentelemetry/core": "^2.0.0", @@ -10941,7 +10941,7 @@ } }, "node_modules/@aws/language-server-runtimes-types": { - "version": "0.1.47", + "version": "0.1.52", "dev": true, "license": "Apache-2.0", "dependencies": { @@ -25498,8 +25498,8 @@ "@aws-sdk/types": "^3.13.1", "@aws/chat-client": "^0.1.4", "@aws/chat-client-ui-types": "^0.1.53", - "@aws/language-server-runtimes": "^0.2.111", - "@aws/language-server-runtimes-types": "^0.1.47", + "@aws/language-server-runtimes": "^0.2.120", + "@aws/language-server-runtimes-types": "^0.1.52", "@cspotcode/source-map-support": "^0.8.1", "@sinonjs/fake-timers": "^10.0.2", "@types/adm-zip": "^0.4.34", diff --git a/packages/amazonq/package.json b/packages/amazonq/package.json index f01a40c1110..e5415dc956f 100644 --- a/packages/amazonq/package.json +++ b/packages/amazonq/package.json @@ -1196,106 +1196,134 @@ "description": "AWS Contributed Icon", "default": { "fontPath": "./resources/fonts/aws-toolkit-icons.woff", - "fontCharacter": "\\f1d1" + "fontCharacter": "\\f1d3" } }, "aws-mynah-MynahIconBlack": { "description": "AWS Contributed Icon", "default": { "fontPath": "./resources/fonts/aws-toolkit-icons.woff", - "fontCharacter": "\\f1d2" + "fontCharacter": "\\f1d4" } }, "aws-mynah-MynahIconWhite": { "description": "AWS Contributed Icon", "default": { "fontPath": "./resources/fonts/aws-toolkit-icons.woff", - "fontCharacter": "\\f1d3" + "fontCharacter": "\\f1d5" } }, "aws-mynah-logo": { "description": "AWS Contributed Icon", "default": { "fontPath": "./resources/fonts/aws-toolkit-icons.woff", - "fontCharacter": "\\f1d4" + "fontCharacter": "\\f1d6" } }, "aws-redshift-cluster": { "description": "AWS Contributed Icon", "default": { "fontPath": "./resources/fonts/aws-toolkit-icons.woff", - "fontCharacter": "\\f1d5" + "fontCharacter": "\\f1d7" } }, "aws-redshift-cluster-connected": { "description": "AWS Contributed Icon", "default": { "fontPath": "./resources/fonts/aws-toolkit-icons.woff", - "fontCharacter": "\\f1d6" + "fontCharacter": "\\f1d8" } }, "aws-redshift-database": { "description": "AWS Contributed Icon", "default": { "fontPath": "./resources/fonts/aws-toolkit-icons.woff", - "fontCharacter": "\\f1d7" + "fontCharacter": "\\f1d9" } }, "aws-redshift-redshift-cluster-connected": { "description": "AWS Contributed Icon", "default": { "fontPath": "./resources/fonts/aws-toolkit-icons.woff", - "fontCharacter": "\\f1d8" + "fontCharacter": "\\f1da" } }, "aws-redshift-schema": { "description": "AWS Contributed Icon", "default": { "fontPath": "./resources/fonts/aws-toolkit-icons.woff", - "fontCharacter": "\\f1d9" + "fontCharacter": "\\f1db" } }, "aws-redshift-table": { "description": "AWS Contributed Icon", "default": { "fontPath": "./resources/fonts/aws-toolkit-icons.woff", - "fontCharacter": "\\f1da" + "fontCharacter": "\\f1dc" } }, "aws-s3-bucket": { "description": "AWS Contributed Icon", "default": { "fontPath": "./resources/fonts/aws-toolkit-icons.woff", - "fontCharacter": "\\f1db" + "fontCharacter": "\\f1dd" } }, "aws-s3-create-bucket": { "description": "AWS Contributed Icon", "default": { "fontPath": "./resources/fonts/aws-toolkit-icons.woff", - "fontCharacter": "\\f1dc" + "fontCharacter": "\\f1de" } }, "aws-schemas-registry": { "description": "AWS Contributed Icon", "default": { "fontPath": "./resources/fonts/aws-toolkit-icons.woff", - "fontCharacter": "\\f1dd" + "fontCharacter": "\\f1e1" } }, "aws-schemas-schema": { "description": "AWS Contributed Icon", "default": { "fontPath": "./resources/fonts/aws-toolkit-icons.woff", - "fontCharacter": "\\f1de" + "fontCharacter": "\\f1e2" } }, "aws-stepfunctions-preview": { + "description": "AWS Contributed Icon", + "default": { + "fontPath": "./resources/fonts/aws-toolkit-icons.woff", + "fontCharacter": "\\f1e3" + } + }, + "aws-lambda-create-stack": { + "description": "AWS Contributed Icon", + "default": { + "fontPath": "./resources/fonts/aws-toolkit-icons.woff", + "fontCharacter": "\\f1d1" + } + }, + "aws-lambda-create-stack-light": { + "description": "AWS Contributed Icon", + "default": { + "fontPath": "./resources/fonts/aws-toolkit-icons.woff", + "fontCharacter": "\\f1d2" + } + }, + "aws-sagemaker-code-editor": { "description": "AWS Contributed Icon", "default": { "fontPath": "./resources/fonts/aws-toolkit-icons.woff", "fontCharacter": "\\f1df" } + }, + "aws-sagemaker-jupyter-lab": { + "description": "AWS Contributed Icon", + "default": { + "fontPath": "./resources/fonts/aws-toolkit-icons.woff", + "fontCharacter": "\\f1e0" + } } }, "walkthroughs": [ diff --git a/packages/amazonq/src/lsp/client.ts b/packages/amazonq/src/lsp/client.ts index 4395ade9a2c..97c5e44af5d 100644 --- a/packages/amazonq/src/lsp/client.ts +++ b/packages/amazonq/src/lsp/client.ts @@ -164,6 +164,9 @@ export async function startLanguageServer( }, credentials: { providesBearerToken: true, + // Add IAM credentials support + providesIamCredentials: true, + supportsAssumeRole: true, }, }, /** @@ -211,9 +214,10 @@ export async function startLanguageServer( /** All must be setup before {@link AuthUtil.restore} otherwise they may not trigger when expected */ AuthUtil.instance.regionProfileManager.onDidChangeRegionProfile(async () => { + const activeProfile = AuthUtil.instance.regionProfileManager.activeRegionProfile void pushConfigUpdate(client, { type: 'profile', - profileArn: AuthUtil.instance.regionProfileManager.activeRegionProfile?.arn, + profileArn: activeProfile?.arn, }) }) diff --git a/packages/amazonq/test/unit/codewhisperer/util/authUtil.test.ts b/packages/amazonq/test/unit/codewhisperer/util/authUtil.test.ts index 233a7371425..9d7de2b4558 100644 --- a/packages/amazonq/test/unit/codewhisperer/util/authUtil.test.ts +++ b/packages/amazonq/test/unit/codewhisperer/util/authUtil.test.ts @@ -11,13 +11,29 @@ import { createTestAuthUtil, TestFolder } from 'aws-core-vscode/test' import { constants, cache } from 'aws-core-vscode/auth' import { auth2 } from 'aws-core-vscode/auth' import { mementoUtils, fs } from 'aws-core-vscode/shared' +import { GetIamCredentialResult } from '@aws/language-server-runtimes/protocol' describe('AuthUtil', async function () { let auth: any + let mockResponse: GetIamCredentialResult beforeEach(async function () { await createTestAuthUtil() auth = AuthUtil.instance + mockResponse = { + credential: { + id: 'test-credential-id', + kinds: [], + credentials: { + accessKeyId: 'encrypted-access-key', + secretAccessKey: 'encrypted-secret-key', + sessionToken: 'encrypted-session-token', + }, + }, + updateCredentialsParams: { + data: 'credential-data', + }, + } satisfies GetIamCredentialResult }) afterEach(async function () { @@ -412,18 +428,6 @@ describe('AuthUtil', async function () { describe('loginIam', function () { it('creates IAM session and logs in', async function () { - const mockResponse = { - id: 'test-credential-id', - credentials: { - accessKeyId: 'encrypted-access-key', - secretAccessKey: 'encrypted-secret-key', - sessionToken: 'encrypted-session-token', - }, - updateCredentialsParams: { - data: 'credential-data', - }, - } - const mockIamLogin = { login: sinon.stub().resolves(mockResponse), loginType: 'iam', @@ -431,17 +435,44 @@ describe('AuthUtil', async function () { sinon.stub(auth2, 'IamLogin').returns(mockIamLogin as any) - const response = await auth.loginIam('accessKey', 'secretKey', 'sessionToken') + const response = await auth.loginIam({ + accessKey: 'testAccessKey', + secretKey: 'testSecretKey', + sessionToken: 'testSessionToken', + }) assert.ok(mockIamLogin.login.calledOnce) assert.ok( - mockIamLogin.login.calledWith({ - accessKey: 'accessKey', - secretKey: 'secretKey', + mockIamLogin.login.calledWithMatch({ + accessKey: 'testAccessKey', + secretKey: 'testSecretKey', + sessionToken: 'testSessionToken', }) ) assert.strictEqual(response, mockResponse) }) + + it('creates IAM session with role ARN', async function () { + const mockIamLoginArn = { + login: sinon.stub().resolves(mockResponse), + loginType: 'iam', + } + + sinon.stub(auth2, 'IamLogin').returns(mockIamLoginArn as any) + + const opts: auth2.IamProfileOptions = { + accessKey: 'testAccessKey', + secretKey: 'testSecretKey', + sessionToken: 'testSessionToken', + roleArn: 'arn:aws:iam::123456789012:role/TestRole', + } + + const response = await auth.loginIam(opts) + + assert.ok(mockIamLoginArn.login.calledOnce) + assert.ok(mockIamLoginArn.login.calledWith(opts)) + assert.strictEqual(response, mockResponse) + }) }) describe('getIamCredential', function () { diff --git a/packages/core/package.json b/packages/core/package.json index 8e697bb8b52..f53e23cb136 100644 --- a/packages/core/package.json +++ b/packages/core/package.json @@ -443,8 +443,8 @@ "@aws-sdk/types": "^3.13.1", "@aws/chat-client": "^0.1.4", "@aws/chat-client-ui-types": "^0.1.53", - "@aws/language-server-runtimes": "^0.2.111", - "@aws/language-server-runtimes-types": "^0.1.47", + "@aws/language-server-runtimes": "^0.2.120", + "@aws/language-server-runtimes-types": "^0.1.52", "@cspotcode/source-map-support": "^0.8.1", "@sinonjs/fake-timers": "^10.0.2", "@types/adm-zip": "^0.4.34", diff --git a/packages/core/src/auth/auth2.ts b/packages/core/src/auth/auth2.ts index e69c6730a9a..3ee72b3dc7b 100644 --- a/packages/core/src/auth/auth2.ts +++ b/packages/core/src/auth/auth2.ts @@ -12,16 +12,20 @@ import { GetIamCredentialParams, getIamCredentialRequestType, GetIamCredentialResult, + InvalidateStsCredentialResult, IamIdentityCenterSsoTokenSource, InvalidateSsoTokenParams, + InvalidateStsCredentialParams, invalidateSsoTokenRequestType, + invalidateStsCredentialRequestType, ProfileKind, UpdateProfileParams, updateProfileRequestType, SsoTokenChangedParams, - // StsCredentialChangedParams, + StsCredentialChangedParams, + StsCredentialChangedKind, ssoTokenChangedRequestType, - // stsCredentialChangedRequestType, + stsCredentialChangedRequestType, AwsBuilderIdSsoTokenSource, UpdateCredentialsParams, AwsErrorCodes, @@ -36,6 +40,7 @@ import { iamCredentialsDeleteNotificationType, bearerCredentialsDeleteNotificationType, bearerCredentialsUpdateRequestType, + SsoTokenChangedKind, RequestType, ResponseMessage, NotificationType, @@ -44,10 +49,9 @@ import { iamCredentialsUpdateRequestType, Profile, SsoSession, - SsoTokenChangedKind, - // invalidateStsCredentialRequestType, - // InvalidateStsCredentialParams, - // InvalidateStsCredentialResult, + GetMfaCodeParams, + GetMfaCodeResult, + getMfaCodeRequestType, } from '@aws/language-server-runtimes/protocol' import { LanguageClient } from 'vscode-languageclient' import { getLogger } from '../shared/logger/logger' @@ -56,6 +60,8 @@ import { useDeviceFlow } from './sso/ssoAccessTokenProvider' import { getCacheDir, getCacheFileWatcher, getFlareCacheFileName } from './sso/cache' import { VSCODE_EXTENSION_ID } from '../shared/extensions' import { IamCredentials } from '@aws/language-server-runtimes-types' +import globals from '../shared/extensionGlobals' +import { getMfaSerialFromUser, getMfaTokenFromUser } from './credentials/utils' export const notificationTypes = { updateIamCredential: new RequestType( @@ -87,6 +93,22 @@ export type Login = SsoLogin | IamLogin export type TokenSource = IamIdentityCenterSsoTokenSource | AwsBuilderIdSsoTokenSource +export type IamProfileOptions = { + accessKey?: string + secretKey?: string + sessionToken?: string + roleArn?: string + sourceProfile?: string +} + +const IamProfileOptionsDefaults = { + accessKey: '', + secretKey: '', + sessionToken: '', + roleArn: '', + sourceProfile: '', +} satisfies IamProfileOptions + /** * Handles auth requests to the Identity Server in the Amazon Q LSP. */ @@ -155,6 +177,7 @@ export class LanguageClientAuth { sso_session: profileName, aws_access_key_id: '', aws_secret_access_key: '', + role_arn: '', }, }, ssoSession: { @@ -168,58 +191,32 @@ export class LanguageClientAuth { } satisfies UpdateProfileParams) } - updateIamProfile( - profileName: string, - accessKey: string, - secretKey: string, - sessionToken?: string, - roleArn?: string, - sourceProfile?: string - ): Promise { - // Add credentials and delete SSO settings from profile - let profile: Profile - if (roleArn && sourceProfile) { - profile = { - kinds: [ProfileKind.IamSourceProfileProfile], - name: profileName, - settings: { - sso_session: '', - aws_access_key_id: '', - aws_secret_access_key: '', - aws_session_token: '', - role_arn: roleArn, - source_profile: sourceProfile, - }, - } - } else if (accessKey && secretKey) { - profile = { - kinds: [ProfileKind.IamCredentialsProfile], - name: profileName, - settings: { - sso_session: '', - aws_access_key_id: accessKey, - aws_secret_access_key: secretKey, - aws_session_token: sessionToken, - role_arn: '', - source_profile: '', - }, - } + updateIamProfile(profileName: string, opts: IamProfileOptions): Promise { + // Substitute missing fields for defaults + const fields = { ...IamProfileOptionsDefaults, ...opts } + // Get the profile kind matching the provided fields + let kind: ProfileKind + if (fields.roleArn && fields.sourceProfile) { + kind = ProfileKind.IamSourceProfileProfile + } else if (fields.accessKey && fields.secretKey) { + kind = ProfileKind.IamCredentialsProfile } else { - profile = { - kinds: [ProfileKind.Unknown], + kind = ProfileKind.Unknown + } + + return this.client.sendRequest(updateProfileRequestType.method, { + profile: { + kinds: [kind], name: profileName, settings: { - aws_access_key_id: '', - aws_secret_access_key: '', - aws_session_token: '', - role_arn: '', - source_profile: '', + aws_access_key_id: fields.accessKey, + aws_secret_access_key: fields.secretKey, + aws_session_token: fields.sessionToken, + role_arn: fields.roleArn, + source_profile: fields.sourceProfile, }, - } - } - return this.client.sendRequest(updateProfileRequestType.method, { - profile: profile, - } satisfies UpdateProfileParams) + }, + }) } listProfiles() { @@ -262,10 +259,24 @@ export class LanguageClientAuth { } satisfies InvalidateSsoTokenParams) as Promise } + invalidateStsCredential(tokenId: string) { + return this.client.sendRequest(invalidateStsCredentialRequestType.method, { + iamCredentialId: tokenId, + } satisfies InvalidateStsCredentialParams) as Promise + } + registerSsoTokenChangedHandler(ssoTokenChangedHandler: (params: SsoTokenChangedParams) => any) { this.client.onNotification(ssoTokenChangedRequestType.method, ssoTokenChangedHandler) } + registerStsCredentialChangedHandler(stsCredentialChangedHandler: (params: StsCredentialChangedParams) => any) { + this.client.onNotification(stsCredentialChangedRequestType.method, stsCredentialChangedHandler) + } + + registerGetMfaCodeHandler(getMfaCodeHandler: (params: GetMfaCodeParams) => Promise) { + this.client.onRequest(getMfaCodeRequestType.method, getMfaCodeHandler) + } + registerCacheWatcher(cacheChangedHandler: (event: cacheChangedEvent) => any) { this.cacheWatcher.onDidCreate(() => cacheChangedHandler('create')) this.cacheWatcher.onDidDelete(() => cacheChangedHandler('delete')) @@ -279,7 +290,9 @@ export abstract class BaseLogin { protected loginType: LoginType | undefined protected connectionState: AuthState = 'notConnected' protected cancellationToken: CancellationTokenSource | undefined - protected _data: { startUrl?: string; region?: string; accessKey?: string; secretKey?: string } | undefined + protected _data: + | { startUrl?: string; region?: string; accessKey?: string; secretKey?: string; sessionToken?: string } + | undefined constructor( public readonly profileName: string, @@ -502,16 +515,17 @@ export class SsoLogin extends BaseLogin { export class IamLogin extends BaseLogin { // Cached information from the identity server for easy reference override readonly loginType = LoginTypes.IAM - // private iamCredentialId: string | undefined + private iamCredentialId: string | undefined constructor(profileName: string, lspAuth: LanguageClientAuth, eventEmitter: vscode.EventEmitter) { super(profileName, lspAuth, eventEmitter) - // lspAuth.registerStsCredentialChangedHandler((params: StsCredentialChangedParams) => - // this.stsCredentialChangedHandler(params) - // ) + lspAuth.registerStsCredentialChangedHandler((params: StsCredentialChangedParams) => + this.stsCredentialChangedHandler(params) + ) + lspAuth.registerGetMfaCodeHandler((params: GetMfaCodeParams) => this.getMfaCodeHandler(params)) } - async login(opts: { accessKey: string; secretKey: string }) { + async login(opts: IamProfileOptions) { await this.updateProfile(opts) return this._getIamCredential(true) } @@ -524,21 +538,36 @@ export class IamLogin extends BaseLogin { } async logout() { - // if (this.iamCredentialId) { - // await this.lspAuth.invalidateIamCredential(this.iamCredentialId) - // } - await this.lspAuth.updateIamProfile(this.profileName, '', '', '', '', '') - await this.lspAuth.updateIamProfile(this.profileName + '-source', '', '', '', '', '') + if (this.iamCredentialId) { + await this.lspAuth.invalidateStsCredential(this.iamCredentialId) + } + await this.lspAuth.updateIamProfile(this.profileName, {}) + await this.lspAuth.updateIamProfile(this.profileName + '-source', {}) this.updateConnectionState('notConnected') this._data = undefined // TODO: DeleteProfile api in Identity Service (this doesn't exist yet) } - async updateProfile(opts: { accessKey: string; secretKey: string }) { - await this.lspAuth.updateIamProfile(this.profileName, opts.accessKey, opts.secretKey) - this._data = { - accessKey: opts.accessKey, - secretKey: opts.secretKey, + async updateProfile(opts: IamProfileOptions) { + if (opts.roleArn) { + // Create the source and target profiles + const sourceProfile = this.profileName + '-source' + await this.lspAuth.updateIamProfile(sourceProfile, { + accessKey: opts.accessKey, + secretKey: opts.secretKey, + sessionToken: opts.sessionToken, + }) + await this.lspAuth.updateIamProfile(this.profileName, { + roleArn: opts.roleArn, + sourceProfile: sourceProfile, + }) + } else { + // Create the target profile + await this.lspAuth.updateIamProfile(this.profileName, { + accessKey: opts.accessKey, + secretKey: opts.secretKey, + sessionToken: opts.sessionToken, + }) } } @@ -568,10 +597,10 @@ export class IamLogin extends BaseLogin { async getCredential() { const response = await this._getIamCredential(false) const credentials: IamCredentials = { - accessKeyId: await this.decrypt(response.credentials.accessKeyId), - secretAccessKey: await this.decrypt(response.credentials.secretAccessKey), - sessionToken: response.credentials.sessionToken - ? await this.decrypt(response.credentials.sessionToken) + accessKeyId: await this.decrypt(response.credential.credentials.accessKeyId), + secretAccessKey: await this.decrypt(response.credential.credentials.secretAccessKey), + sessionToken: response.credential.credentials.sessionToken + ? await this.decrypt(response.credential.credentials.sessionToken) : undefined, } return { @@ -593,8 +622,10 @@ export class IamLogin extends BaseLogin { } catch (err: any) { switch (err.data?.awsErrorCode) { case AwsErrorCodes.E_CANCELLED: - case AwsErrorCodes.E_SSO_SESSION_NOT_FOUND: + case AwsErrorCodes.E_INVALID_PROFILE: case AwsErrorCodes.E_PROFILE_NOT_FOUND: + case AwsErrorCodes.E_CANNOT_CREATE_STS_CREDENTIAL: + case AwsErrorCodes.E_INVALID_STS_CREDENTIAL: this.updateConnectionState('notConnected') break default: @@ -607,19 +638,41 @@ export class IamLogin extends BaseLogin { this.cancellationToken = undefined } - // this.iamCredentialId = response.id + // Update cached credentials and credential ID + if (response.credential?.credentials?.accessKeyId && response.credential?.credentials?.secretAccessKey) { + this._data = { + accessKey: response.credential.credentials.accessKeyId, + secretKey: response.credential.credentials.secretAccessKey, + sessionToken: response.credential.credentials.sessionToken, + } + this.iamCredentialId = response.credential.id + } this.updateConnectionState('connected') return response } - // private stsCredentialChangedHandler(params: StsCredentialChangedParams) { - // if (params.stsCredentialId === this.iamCredentialId) { - // if (params.kind === CredentialChangedKind.Expired) { - // this.updateConnectionState('expired') - // return - // } else if (params.kind === CredentialChangedKind.Refreshed) { - // this.eventEmitter.fire({ id: this.profileName, state: 'refreshed' }) - // } - // } - // } + private stsCredentialChangedHandler(params: StsCredentialChangedParams) { + if (params.stsCredentialId === this.iamCredentialId) { + if (params.kind === StsCredentialChangedKind.Expired) { + this.updateConnectionState('expired') + return + } else if (params.kind === StsCredentialChangedKind.Refreshed) { + this.eventEmitter.fire({ id: this.iamCredentialId, state: 'refreshed' }) + } + } + } + + private async getMfaCodeHandler(params: GetMfaCodeParams): Promise { + if (params.mfaSerial) { + await globals.globalState.update('recentMfaSerial', { mfaSerial: params.mfaSerial }) + } + const defaultMfaSerial = globals.globalState.tryGet('recentMfaSerial', Object, { + mfaSerial: '', + }).mfaSerial + let mfaSerial = await getMfaSerialFromUser(defaultMfaSerial, params.profileName) + mfaSerial = mfaSerial.trim() + await globals.globalState.update('recentMfaSerial', { mfaSerial: mfaSerial }) + const mfaCode = await getMfaTokenFromUser(mfaSerial, params.profileName) + return { code: mfaCode ?? '', mfaSerial: mfaSerial ?? '' } + } } diff --git a/packages/core/src/auth/credentials/utils.ts b/packages/core/src/auth/credentials/utils.ts index 885a4fb1f87..4f4eda5027c 100644 --- a/packages/core/src/auth/credentials/utils.ts +++ b/packages/core/src/auth/credentials/utils.ts @@ -103,14 +103,35 @@ export class CredentialsSettings extends fromExtensionManifest('aws', { profile: const errorMessageUserCancelled = localize('AWS.error.mfa.userCancelled', 'User cancelled entering authentication code') /** - * @description Prompts user for MFA token + * @description Prompts user for MFA serial number * - * Entered token is passed to the callback. - * If user cancels out, the callback is passed an error with a fixed message string. + * @param defaultSerial Default MFA serial number to pre-fill + * @param profileName Name of Credentials profile we are asking an MFA serial for + */ +export async function getMfaSerialFromUser(defaultSerial: string, profileName: string): Promise { + const inputBox = createInputBox({ + ignoreFocusOut: true, + placeholder: localize('AWS.prompt.mfa.enterCode.placeholder', 'Enter mfaSerial Number Here'), + title: localize('AWS.prompt.mfa.enterCode.title', 'MFA Challenge for {0}', profileName), + prompt: localize('AWS.prompt.mfa.enterCode.prompt', 'Enter Serial Number for MFA device', defaultSerial), + value: defaultSerial, // Pre-fill with default value + }) + + const token = await inputBox.prompt() + + // Distinguish user cancel vs code entry issues with the error message + if (!isValidResponse(token)) { + throw new Error(errorMessageUserCancelled) + } + + return token +} + +/** + * @description Prompts user for MFA token * * @param mfaSerial Serial arn of MFA device * @param profileName Name of Credentials profile we are asking an MFA Token for - * @param callback tokens/errors are passed through here */ export async function getMfaTokenFromUser(mfaSerial: string, profileName: string): Promise { const inputBox = createInputBox({ diff --git a/packages/core/src/auth/index.ts b/packages/core/src/auth/index.ts index 2dd361f9804..727277f0fcd 100644 --- a/packages/core/src/auth/index.ts +++ b/packages/core/src/auth/index.ts @@ -22,6 +22,7 @@ export { } from './connection' export { Auth } from './auth' export { CredentialsStore } from './credentials/store' +export { getMfaTokenFromUser, getMfaSerialFromUser } from './credentials/utils' export { LoginManager } from './deprecated/loginManager' export * as constants from './sso/constants' export * as cache from './sso/cache' diff --git a/packages/core/src/codewhisperer/util/authUtil.ts b/packages/core/src/codewhisperer/util/authUtil.ts index a4a2eff051e..cee02407033 100644 --- a/packages/core/src/codewhisperer/util/authUtil.ts +++ b/packages/core/src/codewhisperer/util/authUtil.ts @@ -39,6 +39,7 @@ import { IamLogin, AuthState, LoginTypes, + IamProfileOptions, } from '../../auth/auth2' import { builderIdStartUrl, internalStartUrl } from '../../auth/sso/constants' import { VSCODE_EXTENSION_ID } from '../../shared/extensions' @@ -73,7 +74,13 @@ export interface IAuthProvider { getToken(): Promise getIamCredential(): Promise readonly profileName: string - readonly connection?: { startUrl?: string; region?: string; accessKey?: string; secretKey?: string } + readonly connection?: { + startUrl?: string + region?: string + accessKey?: string + secretKey?: string + sessionToken?: string + } } /** @@ -171,30 +178,26 @@ export class AuthUtil implements IAuthProvider { // Log in using SSO async loginSso(startUrl: string, region: string): Promise { - let response: GetSsoTokenResult | undefined // Create SSO login session if (!this.isSsoSession()) { this.session = new SsoLogin(this.profileName, this.lspAuth, this.eventEmitter) } - // eslint-disable-next-line prefer-const - response = await (this.session as SsoLogin).login({ startUrl: startUrl, region: region, scopes: amazonQScopes }) + const response = await (this.session as SsoLogin).login({ + startUrl: startUrl, + region: region, + scopes: amazonQScopes, + }) await showAmazonQWalkthroughOnce() return response } // Log in using IAM or STS credentials - async loginIam( - accessKey: string, - secretKey: string, - sessionToken?: string - ): Promise { - let response: GetIamCredentialResult | undefined + async loginIam(opts: IamProfileOptions): Promise { // Create IAM login session if (!this.isIamSession()) { this.session = new IamLogin(this.profileName, this.lspAuth, this.eventEmitter) } - // eslint-disable-next-line prefer-const - response = await (this.session as IamLogin).login({ accessKey: accessKey, secretKey: secretKey }) + const response = await (this.session as IamLogin).login(opts) await showAmazonQWalkthroughOnce() return response } @@ -531,9 +534,10 @@ export class AuthUtil implements IAuthProvider { scopes: amazonQScopes, } - if (this.session instanceof SsoLogin) { - await this.session.updateProfile(registrationKey) + if (!this.isSsoSession()) { + this.session = new SsoLogin(this.profileName, this.lspAuth, this.eventEmitter) } + await (this.session as SsoLogin).updateProfile(registrationKey) const cacheDir = getCacheDir() diff --git a/packages/core/src/login/webview/vue/amazonq/backend_amazonq.ts b/packages/core/src/login/webview/vue/amazonq/backend_amazonq.ts index 6cd3c4d1309..36338318431 100644 --- a/packages/core/src/login/webview/vue/amazonq/backend_amazonq.ts +++ b/packages/core/src/login/webview/vue/amazonq/backend_amazonq.ts @@ -3,7 +3,7 @@ * SPDX-License-Identifier: Apache-2.0 */ import * as vscode from 'vscode' -import { AwsConnection, SsoConnection } from '../../../../auth/connection' +import { AwsConnection, IamProfile, SsoConnection } from '../../../../auth/connection' import { AuthUtil } from '../../../../codewhisperer/util/authUtil' import { CommonAuthWebview } from '../backend' import { awsIdSignIn } from '../../../../codewhisperer/util/showSsoPrompt' @@ -196,16 +196,21 @@ export class AmazonQLoginWebview extends CommonAuthWebview { async startIamCredentialSetup( profileName: string, accessKey: string, - secretKey: string + secretKey: string, + sessionToken?: string, + roleArn?: string ): Promise { getLogger().debug(`called startIamCredentialSetup()`) // Defining separate auth function to emit telemetry before returning from this method + await globals.globalState.update('recentIamKeys', { accessKey: accessKey }) + await globals.globalState.update('recentRoleArn', { roleArn: roleArn }) const runAuth = async (): Promise => { try { - await AuthUtil.instance.loginIam(accessKey, secretKey) + await AuthUtil.instance.loginIam({ accessKey, secretKey, sessionToken, roleArn }) } catch (e) { getLogger().error('Failed submitting credentials %O', e) - return { id: this.id, text: e as string } + const message = e instanceof Error ? e.message : (e as string) + return { id: this.id, text: message } } // Enable code suggestions vsCodeState.isFreeTierLimitReached = false @@ -230,6 +235,12 @@ export class AmazonQLoginWebview extends CommonAuthWebview { /** If users are unauthenticated in Q/CW, we should always display the auth screen. */ async quitLoginScreen() {} + async listIamCredentialProfiles(): Promise { + // Amazon Q only supports 1 connection at a time, + // so there isn't a need to de-duplicate connections. + return [] + } + /** * The purpose of returning Error.message is to notify vue frontend that API call fails and to render corresponding error message to users * @returns ProfileList when API call succeeds, otherwise Error.message diff --git a/packages/core/src/login/webview/vue/backend.ts b/packages/core/src/login/webview/vue/backend.ts index b455c205ec8..8c93d1702a0 100644 --- a/packages/core/src/login/webview/vue/backend.ts +++ b/packages/core/src/login/webview/vue/backend.ts @@ -19,6 +19,7 @@ import { scopesCodeWhispererChat, scopesSsoAccountAccess, SsoConnection, + IamProfile, TelemetryMetadata, } from '../../../auth/connection' import { Auth } from '../../../auth/auth' @@ -174,7 +175,9 @@ export abstract class CommonAuthWebview extends VueWebview { abstract startIamCredentialSetup( profileName: string, accessKey: string, - secretKey: string + secretKey: string, + sessionToken?: string, + roleArn?: string ): Promise async showResourceExplorer(): Promise { @@ -208,6 +211,8 @@ export abstract class CommonAuthWebview extends VueWebview { abstract listRegionProfiles(): Promise + abstract listIamCredentialProfiles(): Promise + abstract selectRegionProfile(profile: RegionProfile, source: ProfileSwitchIntent): Promise /** @@ -306,6 +311,10 @@ export abstract class CommonAuthWebview extends VueWebview { return globals.globalState.tryGet('recentIamKeys', Object, { accessKey: '' }) } + getDefaultRoleArn(): { roleArn: string } { + return globals.globalState.tryGet('recentRoleArn', Object, { roleArn: '' }) + } + cancelAuthFlow() { AuthSSOServer.lastInstance?.cancelCurrentFlow() } diff --git a/packages/core/src/login/webview/vue/login.vue b/packages/core/src/login/webview/vue/login.vue index 90eabc09db4..d58ce53b69a 100644 --- a/packages/core/src/login/webview/vue/login.vue +++ b/packages/core/src/login/webview/vue/login.vue @@ -279,6 +279,26 @@ v-model="secretKey" @keydown.enter="handleContinueClick()" /> +
+
Session Token (Optional)
+ +
Role ARN (Optional)
+ +
@@ -367,6 +387,8 @@ export default defineComponent({ profileName: '', accessKey: '', secretKey: '', + sessionToken: '', + roleArn: '', } }, async created() { @@ -374,7 +396,9 @@ export default defineComponent({ this.startUrl = defaultSso.startUrl this.selectedRegion = defaultSso.region const defaultIamAccessKey = await this.getDefaultIamAccessKey() + const defaultRoleArn = await this.getDefaultRoleArn() this.accessKey = defaultIamAccessKey.accessKey + this.roleArn = defaultRoleArn.roleArn await this.emitUpdate('created') }, @@ -497,7 +521,13 @@ export default defineComponent({ return } this.stage = 'AUTHENTICATING' - const error = await client.startIamCredentialSetup(this.profileName, this.accessKey, this.secretKey) + const error = await client.startIamCredentialSetup( + this.profileName, + this.accessKey, + this.secretKey, + this.sessionToken, + this.roleArn + ) if (error) { this.stage = 'START' void client.errorNotification(error) @@ -610,6 +640,9 @@ export default defineComponent({ async getDefaultIamAccessKey() { return await client.getDefaultIamKeys() }, + async getDefaultRoleArn() { + return await client.getDefaultRoleArn() + }, handleHelpLinkClick() { void client.emitUiClick('auth_helpLink') }, diff --git a/packages/core/src/login/webview/vue/toolkit/backend_toolkit.ts b/packages/core/src/login/webview/vue/toolkit/backend_toolkit.ts index caec2c764bc..d08bc771297 100644 --- a/packages/core/src/login/webview/vue/toolkit/backend_toolkit.ts +++ b/packages/core/src/login/webview/vue/toolkit/backend_toolkit.ts @@ -9,6 +9,7 @@ import { getLogger } from '../../../../shared/logger/logger' import { CommonAuthWebview } from '../backend' import { AwsConnection, + IamProfile, SsoConnection, TelemetryMetadata, createSsoProfile, @@ -157,6 +158,10 @@ export class ToolkitLoginWebview extends CommonAuthWebview { return (await Auth.instance.listConnections()).filter((conn) => isSsoConnection(conn)) as SsoConnection[] } + async listIamCredentialProfiles(): Promise { + return [] + } + override reauthenticateConnection(): Promise { throw new Error('Method not implemented.') } diff --git a/packages/core/src/shared/globalState.ts b/packages/core/src/shared/globalState.ts index 2a11321e5d3..41712b30e87 100644 --- a/packages/core/src/shared/globalState.ts +++ b/packages/core/src/shared/globalState.ts @@ -72,6 +72,8 @@ export type globalKey = | 'recentCredentials' | 'recentSso' | 'recentIamKeys' + | 'recentRoleArn' + | 'recentMfaSerial' // List of regions enabled in AWS Explorer. | 'region' // TODO: implement this via `PromptSettings` instead of globalState. diff --git a/packages/core/src/shared/settings.ts b/packages/core/src/shared/settings.ts index 75a6b03780b..96b19e3cbf5 100644 --- a/packages/core/src/shared/settings.ts +++ b/packages/core/src/shared/settings.ts @@ -779,6 +779,7 @@ const devSettings = { amazonqLsp: Record(String, String), amazonqWorkspaceLsp: Record(String, String), ssoCacheDirectory: String, + stsCacheDirectory: String, autofillStartUrl: String, autofillAccessKey: String, webAuth: Boolean, diff --git a/packages/core/src/test/amazonqDoc/utils.ts b/packages/core/src/test/amazonqDoc/utils.ts index d6d74e7ac3c..37c2252918c 100644 --- a/packages/core/src/test/amazonqDoc/utils.ts +++ b/packages/core/src/test/amazonqDoc/utils.ts @@ -107,6 +107,8 @@ export async function sessionWriteFile(session: Session, uri: vscode.Uri, encode export function createMockAuthUtil(sandbox: sinon.SinonSandbox) { const mockLspAuth: Partial = { registerSsoTokenChangedHandler: sinon.stub().resolves(), + registerStsCredentialChangedHandler: sinon.stub().resolves(), + registerGetMfaCodeHandler: sinon.stub().resolves(), } AuthUtil.create(mockLspAuth as LanguageClientAuth) sandbox.stub(AuthUtil.instance.regionProfileManager, 'onDidChangeRegionProfile').resolves() diff --git a/packages/core/src/test/credentials/auth2.test.ts b/packages/core/src/test/credentials/auth2.test.ts index ee90295b56d..c6f9b31e868 100644 --- a/packages/core/src/test/credentials/auth2.test.ts +++ b/packages/core/src/test/credentials/auth2.test.ts @@ -15,13 +15,17 @@ import { ListProfilesResult, UpdateCredentialsParams, SsoTokenChangedParams, + StsCredentialChangedParams, bearerCredentialsUpdateRequestType, bearerCredentialsDeleteNotificationType, iamCredentialsUpdateRequestType, iamCredentialsDeleteNotificationType, ssoTokenChangedRequestType, + stsCredentialChangedRequestType, SsoTokenChangedKind, + StsCredentialChangedKind, invalidateSsoTokenRequestType, + invalidateStsCredentialRequestType, ProfileKind, AwsErrorCodes, } from '@aws/language-server-runtimes/protocol' @@ -100,7 +104,11 @@ describe('LanguageClientAuth', () => { }) it('sends correct IAM profile update parameters', async () => { - await auth.updateIamProfile(profileName, 'accessKey', 'secretKey', 'sessionToken') + await auth.updateIamProfile(profileName, { + accessKey: 'myAccessKey', + secretKey: 'mySecretKey', + sessionToken: 'mySessionToken', + }) sinon.assert.calledOnce(client.sendRequest) const requestParams = client.sendRequest.firstCall.args[1] @@ -109,9 +117,11 @@ describe('LanguageClientAuth', () => { kinds: [ProfileKind.IamCredentialsProfile], }) sinon.assert.match(requestParams.profile.settings, { - aws_access_key_id: 'accessKey', - aws_secret_access_key: 'secretKey', - aws_session_token: 'sessionToken', + aws_access_key_id: 'myAccessKey', + aws_secret_access_key: 'mySecretKey', + aws_session_token: 'mySessionToken', + role_arn: '', + source_profile: '', }) }) }) @@ -219,6 +229,40 @@ describe('LanguageClientAuth', () => { }) }) + describe('invalidateStsCredential', () => { + it('sends request', async () => { + client.sendRequest.resolves({ success: true }) + const result = await auth.invalidateStsCredential(profileName) + + sinon.assert.calledOnce(client.sendRequest) + sinon.assert.calledWith(client.sendRequest, invalidateStsCredentialRequestType.method, { + iamCredentialId: profileName, + }) + sinon.assert.match(result, { success: true }) + }) + }) + + describe('registerStsCredentialChangedHandler', () => { + it('registers the handler correctly', () => { + const handler = sinon.spy() + + auth.registerStsCredentialChangedHandler(handler) + + sinon.assert.calledOnce(client.onNotification) + sinon.assert.calledWith(client.onNotification, stsCredentialChangedRequestType.method, sinon.match.func) + + const credentialChangedParams: StsCredentialChangedParams = { + kind: StsCredentialChangedKind.Refreshed, + stsCredentialId: 'test-credential-id', + } + const registeredHandler = client.onNotification.firstCall.args[1] + registeredHandler(credentialChangedParams) + + sinon.assert.calledOnce(handler) + sinon.assert.calledWith(handler, credentialChangedParams) + }) + }) + describe('invalidateSsoToken', () => { it('sends request', async () => { client.sendRequest.resolves({ success: true }) @@ -601,11 +645,14 @@ describe('IamLogin', () => { } const mockGetIamCredentialResponse: GetIamCredentialResult = { - id: 'test-credential-id', - credentials: { - accessKeyId: 'encrypted-access-key', - secretAccessKey: 'encrypted-secret-key', - sessionToken: 'encrypted-session-token', + credential: { + id: 'test-credential-id', + kinds: [], + credentials: { + accessKeyId: 'encrypted-access-key', + secretAccessKey: 'encrypted-secret-key', + sessionToken: 'encrypted-session-token', + }, }, updateCredentialsParams: { data: 'credential-data', @@ -633,10 +680,10 @@ describe('IamLogin', () => { const response = await iamLogin.login(loginOpts) sinon.assert.calledOnce(lspAuth.updateIamProfile) - sinon.assert.calledWith(lspAuth.updateIamProfile, profileName, loginOpts.accessKey, loginOpts.secretKey) + sinon.assert.calledWith(lspAuth.updateIamProfile, profileName, loginOpts) sinon.assert.calledOnce(lspAuth.getIamCredential) sinon.assert.match(iamLogin.getConnectionState(), 'connected') - sinon.assert.match(response.id, 'test-credential-id') + sinon.assert.match(response.credential.id, 'test-credential-id') }) }) @@ -659,7 +706,22 @@ describe('IamLogin', () => { sinon.assert.calledOnce(lspAuth.getIamCredential) sinon.assert.match(iamLogin.getConnectionState(), 'connected') - sinon.assert.match(response.id, 'test-credential-id') + sinon.assert.match(response.credential.id, 'test-credential-id') + }) + }) + + describe('logout', () => { + it('invalidates credential and updates state', async () => { + ;(iamLogin as any).iamCredentialId = 'test-credential-id' + lspAuth.invalidateStsCredential.resolves({ success: true }) + lspAuth.updateIamProfile.resolves() + + await iamLogin.logout() + + sinon.assert.calledOnce(lspAuth.invalidateStsCredential) + sinon.assert.calledWith(lspAuth.invalidateStsCredential, 'test-credential-id') + sinon.assert.match(iamLogin.getConnectionState(), 'notConnected') + sinon.assert.match(iamLogin.data, undefined) }) }) @@ -717,4 +779,20 @@ describe('IamLogin', () => { // sinon.assert.match((iamLogin as any).iamCredentialId, 'test-credential-id') }) }) + + describe('stsCredentialChangedHandler', () => { + beforeEach(() => { + ;(iamLogin as any).iamCredentialId = 'test-credential-id' + ;(iamLogin as any).connectionState = 'connected' + }) + + it('updates state when credential expires', () => { + ;(iamLogin as any).stsCredentialChangedHandler({ + kind: StsCredentialChangedKind.Expired, + stsCredentialId: 'test-credential-id', + }) + + sinon.assert.match(iamLogin.getConnectionState(), 'expired') + }) + }) }) diff --git a/packages/core/src/test/testAuthUtil.ts b/packages/core/src/test/testAuthUtil.ts index 929eca99b65..9e93ebbbdff 100644 --- a/packages/core/src/test/testAuthUtil.ts +++ b/packages/core/src/test/testAuthUtil.ts @@ -27,10 +27,14 @@ export async function createTestAuthUtil() { } const fakeCredential = { - credentials: { - accessKeyId: 'fake-access-key-id', - secretAccessKey: 'fake-secret-access-key', - sessionToken: 'fake-session-token', + credential: { + id: 'fake-id', + kinds: [], + credentials: { + accessKeyId: 'fake-access-key-id', + secretAccessKey: 'fake-secret-access-key', + sessionToken: 'fake-session-token', + }, }, updateCredentialsParams: { data: 'fake-data', @@ -39,6 +43,8 @@ export async function createTestAuthUtil() { const mockLspAuth: Partial = { registerSsoTokenChangedHandler: sinon.stub().resolves(), + registerStsCredentialChangedHandler: sinon.stub().resolves(), + registerGetMfaCodeHandler: sinon.stub().resolves(), updateSsoProfile: sinon.stub().resolves(), getSsoToken: sinon.stub().resolves(fakeToken), getIamCredential: sinon.stub().resolves(fakeCredential),