diff --git a/packages/amazonq/src/lsp/auth.ts b/packages/amazonq/src/lsp/auth.ts deleted file mode 100644 index 0637019a1ab..00000000000 --- a/packages/amazonq/src/lsp/auth.ts +++ /dev/null @@ -1,111 +0,0 @@ -/*! - * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. - * SPDX-License-Identifier: Apache-2.0 - */ - -import { - ConnectionMetadata, - NotificationType, - RequestType, - ResponseMessage, -} from '@aws/language-server-runtimes/protocol' -import * as jose from 'jose' -import * as crypto from 'crypto' -import { LanguageClient } from 'vscode-languageclient' -import { AuthUtil } from 'aws-core-vscode/codewhisperer' -import { Writable } from 'stream' -import { onceChanged } from 'aws-core-vscode/utils' -import { getLogger, oneMinute } from 'aws-core-vscode/shared' - -export const encryptionKey = crypto.randomBytes(32) - -/** - * Sends a json payload to the language server, who is waiting to know what the encryption key is. - * Code reference: https://github.com/aws/language-servers/blob/7da212185a5da75a72ce49a1a7982983f438651a/client/vscode/src/credentialsActivation.ts#L77 - */ -export function writeEncryptionInit(stream: Writable): void { - const request = { - version: '1.0', - mode: 'JWT', - key: encryptionKey.toString('base64'), - } - stream.write(JSON.stringify(request)) - stream.write('\n') -} - -/** - * Request for custom notifications that Update Credentials and tokens. - * See core\aws-lsp-core\src\credentials\updateCredentialsRequest.ts for details - */ -export interface UpdateCredentialsRequest { - /** - * Encrypted token (JWT or PASETO) - * The token's contents differ whether IAM or Bearer token is sent - */ - data: string - /** - * Used by the runtime based language servers. - * Signals that this client will encrypt its credentials payloads. - */ - encrypted: boolean -} - -export const notificationTypes = { - updateBearerToken: new RequestType( - 'aws/credentials/token/update' - ), - deleteBearerToken: new NotificationType('aws/credentials/token/delete'), - getConnectionMetadata: new RequestType( - 'aws/credentials/getConnectionMetadata' - ), -} - -/** - * Facade over our VSCode Auth that does crud operations on the language server auth - */ -export class AmazonQLspAuth { - constructor(private readonly client: LanguageClient) {} - - async refreshConnection() { - const activeConnection = AuthUtil.instance.auth.activeConnection - if (activeConnection?.type === 'sso') { - // send the token to the language server - const token = await AuthUtil.instance.getBearerToken() - await this.updateBearerToken(token) - } - } - - public updateBearerToken = onceChanged(this._updateBearerToken.bind(this)) - private async _updateBearerToken(token: string) { - const request = await this.createUpdateCredentialsRequest({ - token, - }) - - await this.client.sendRequest(notificationTypes.updateBearerToken.method, request) - - this.client.info(`UpdateBearerToken: ${JSON.stringify(request)}`) - } - - public startTokenRefreshInterval(pollingTime: number = oneMinute) { - const interval = setInterval(async () => { - await this.refreshConnection().catch((e) => { - getLogger('amazonqLsp').error('Unable to update bearer token: %s', (e as Error).message) - clearInterval(interval) - }) - }, pollingTime) - return interval - } - - private async createUpdateCredentialsRequest(data: any) { - const payload = new TextEncoder().encode(JSON.stringify({ data })) - - const jwt = await new jose.CompactEncrypt(payload) - .setProtectedHeader({ alg: 'dir', enc: 'A256GCM' }) - .encrypt(encryptionKey) - - return { - data: jwt, - encrypted: true, - } - } -} diff --git a/packages/amazonq/test/unit/amazonq/lsp/auth.test.ts b/packages/amazonq/test/unit/amazonq/lsp/auth.test.ts deleted file mode 100644 index d55fef85f39..00000000000 --- a/packages/amazonq/test/unit/amazonq/lsp/auth.test.ts +++ /dev/null @@ -1,33 +0,0 @@ -/*! - * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. - * SPDX-License-Identifier: Apache-2.0 - */ -import assert from 'assert' -import { AmazonQLspAuth } from '../../../../src/lsp/auth' -import { LanguageClient } from 'vscode-languageclient' - -describe('AmazonQLspAuth', function () { - describe('updateBearerToken', function () { - it('makes request to LSP when token changes', async function () { - // Note: this token will be encrypted - let lastSentToken = {} - const auth = new AmazonQLspAuth({ - sendRequest: (_method: string, param: any) => { - lastSentToken = param - }, - info: (_message: string, _data: any) => {}, - } as LanguageClient) - - await auth.updateBearerToken('firstToken') - assert.notDeepStrictEqual(lastSentToken, {}) - const encryptedFirstToken = lastSentToken - - await auth.updateBearerToken('secondToken') - assert.notDeepStrictEqual(lastSentToken, encryptedFirstToken) - const encryptedSecondToken = lastSentToken - - await auth.updateBearerToken('secondToken') - assert.deepStrictEqual(lastSentToken, encryptedSecondToken) - }) - }) -}) diff --git a/packages/core/src/auth/auth2.ts b/packages/core/src/auth/auth2.ts new file mode 100644 index 00000000000..7cb9127ee12 --- /dev/null +++ b/packages/core/src/auth/auth2.ts @@ -0,0 +1,325 @@ +/*! + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +import * as vscode from 'vscode' +import * as jose from 'jose' +import { + GetSsoTokenParams, + getSsoTokenRequestType, + GetSsoTokenResult, + IamIdentityCenterSsoTokenSource, + InvalidateSsoTokenParams, + invalidateSsoTokenRequestType, + ProfileKind, + UpdateProfileParams, + updateProfileRequestType, + SsoTokenChangedParams, + ssoTokenChangedRequestType, + AwsBuilderIdSsoTokenSource, + UpdateCredentialsParams, + AwsErrorCodes, + SsoTokenSourceKind, + listProfilesRequestType, + ListProfilesResult, + UpdateProfileResult, + InvalidateSsoTokenResult, + AuthorizationFlowKind, + CancellationToken, + CancellationTokenSource, + bearerCredentialsDeleteNotificationType, + bearerCredentialsUpdateRequestType, + SsoTokenChangedKind, +} from '@aws/language-server-runtimes/protocol' +import { LanguageClient } from 'vscode-languageclient' +import { getLogger } from '../shared/logger/logger' +import { ToolkitError } from '../shared/errors' +import { useDeviceFlow } from './sso/ssoAccessTokenProvider' + +export type AuthState = 'notConnected' | 'connected' | 'expired' + +export type AuthStateEvent = { id: string; state: AuthState | 'refreshed' } + +export const LoginTypes = { + SSO: 'sso', + IAM: 'iam', +} as const +export type LoginType = (typeof LoginTypes)[keyof typeof LoginTypes] + +interface BaseLogin { + readonly loginType: LoginType +} + +export type Login = SsoLogin // TODO: add IamLogin type when supported + +export type TokenSource = IamIdentityCenterSsoTokenSource | AwsBuilderIdSsoTokenSource + +/** + * Handles auth requests to the Identity Server in the Amazon Q LSP. + */ +export class LanguageClientAuth { + constructor( + private readonly client: LanguageClient, + private readonly clientName: string, + public readonly encryptionKey: Buffer + ) {} + + getSsoToken( + tokenSource: TokenSource, + login: boolean = false, + cancellationToken?: CancellationToken + ): Promise { + return this.client.sendRequest( + getSsoTokenRequestType.method, + { + clientName: this.clientName, + source: tokenSource, + options: { + loginOnInvalidToken: login, + authorizationFlow: useDeviceFlow() ? AuthorizationFlowKind.DeviceCode : AuthorizationFlowKind.Pkce, + }, + } satisfies GetSsoTokenParams, + cancellationToken + ) + } + + updateProfile( + profileName: string, + startUrl: string, + region: string, + scopes: string[] + ): Promise { + return this.client.sendRequest(updateProfileRequestType.method, { + profile: { + kinds: [ProfileKind.SsoTokenProfile], + name: profileName, + settings: { + region, + sso_session: profileName, + }, + }, + ssoSession: { + name: profileName, + settings: { + sso_region: region, + sso_start_url: startUrl, + sso_registration_scopes: scopes, + }, + }, + } satisfies UpdateProfileParams) + } + + listProfiles() { + return this.client.sendRequest(listProfilesRequestType.method, {}) as Promise + } + + /** + * Returns a profile by name along with its linked sso_session. + * Does not currently exist as an API in the Identity Service. + */ + async getProfile(profileName: string) { + const response = await this.listProfiles() + const profile = response.profiles.find((profile) => profile.name === profileName) + const ssoSession = profile?.settings?.sso_session + ? response.ssoSessions.find((session) => session.name === profile!.settings!.sso_session) + : undefined + + return { profile, ssoSession } + } + + updateBearerToken(request: UpdateCredentialsParams) { + return this.client.sendRequest(bearerCredentialsUpdateRequestType.method, request) + } + + deleteBearerToken() { + return this.client.sendNotification(bearerCredentialsDeleteNotificationType.method) + } + + invalidateSsoToken(tokenId: string) { + return this.client.sendRequest(invalidateSsoTokenRequestType.method, { + ssoTokenId: tokenId, + } satisfies InvalidateSsoTokenParams) as Promise + } + + registerSsoTokenChangedHandler(ssoTokenChangedHandler: (params: SsoTokenChangedParams) => any) { + this.client.onNotification(ssoTokenChangedRequestType.method, ssoTokenChangedHandler) + } +} + +/** + * Manages an SSO connection. + */ +export class SsoLogin implements BaseLogin { + readonly loginType = LoginTypes.SSO + private readonly eventEmitter = new vscode.EventEmitter() + + // Cached information from the identity server for easy reference + private ssoTokenId: string | undefined + private connectionState: AuthState = 'connected' + private _data: { startUrl: string; region: string } | undefined + + private cancellationToken: CancellationTokenSource | undefined + + constructor( + public readonly profileName: string, + private readonly lspAuth: LanguageClientAuth + ) { + lspAuth.registerSsoTokenChangedHandler((params: SsoTokenChangedParams) => this.ssoTokenChangedHandler(params)) + } + + get data() { + return this._data + } + + async login(opts: { startUrl: string; region: string; scopes: string[] }) { + await this.updateProfile(opts) + return this._getSsoToken(true) + } + + async reauthenticate() { + if (this.connectionState === 'notConnected') { + throw new ToolkitError('Cannot reauthenticate when not connected.') + } + return this._getSsoToken(true) + } + + async logout() { + if (this.ssoTokenId) { + await this.lspAuth.invalidateSsoToken(this.ssoTokenId) + } + this.updateConnectionState('notConnected') + this._data = undefined + // TODO: DeleteProfile api in Identity Service (this doesn't exist yet) + } + + async updateProfile(opts: { startUrl: string; region: string; scopes: string[] }) { + await this.lspAuth.updateProfile(this.profileName, opts.startUrl, opts.region, opts.scopes) + this._data = { + startUrl: opts.startUrl, + region: opts.region, + } + } + + /** + * Restore the connection state and connection details to memory, if they exist. + */ + async restore() { + const sessionData = await this.lspAuth.getProfile(this.profileName) + const ssoSession = sessionData?.ssoSession?.settings + if (ssoSession?.sso_region && ssoSession?.sso_start_url) { + this._data = { + startUrl: ssoSession.sso_start_url, + region: ssoSession.sso_region, + } + } + + try { + await this._getSsoToken(false) + } catch (err) { + getLogger().error('Restoring connection failed: %s', err) + } + } + + /** + * Cancels running active login flows. + */ + cancelLogin() { + this.cancellationToken?.cancel() + this.cancellationToken?.dispose() + this.cancellationToken = undefined + } + + /** + * Returns both the decrypted access token and the payload to send to the `updateCredentials` LSP API + * with encrypted token + */ + async getToken() { + const response = await this._getSsoToken(false) + const decryptedKey = await jose.compactDecrypt(response.ssoToken.accessToken, this.lspAuth.encryptionKey) + return { + token: decryptedKey.plaintext.toString().replaceAll('"', ''), + updateCredentialsParams: response.updateCredentialsParams, + } + } + + /** + * Returns the response from `getSsoToken` LSP API and sets the connection state based on the errors/result + * of the call. + */ + private async _getSsoToken(login: boolean) { + let response: GetSsoTokenResult + this.cancellationToken = new CancellationTokenSource() + + try { + response = await this.lspAuth.getSsoToken( + { + /** + * Note that we do not use SsoTokenSourceKind.AwsBuilderId here. + * This is because it does not leave any state behind on disk, so + * we cannot infer that a builder ID connection exists via the + * Identity Server alone. + */ + kind: SsoTokenSourceKind.IamIdentityCenter, + profileName: this.profileName, + } satisfies IamIdentityCenterSsoTokenSource, + login, + this.cancellationToken.token + ) + } catch (err: any) { + switch (err.data?.awsErrorCode) { + case AwsErrorCodes.E_CANCELLED: + case AwsErrorCodes.E_SSO_SESSION_NOT_FOUND: + case AwsErrorCodes.E_PROFILE_NOT_FOUND: + case AwsErrorCodes.E_INVALID_SSO_TOKEN: + this.updateConnectionState('notConnected') + break + case AwsErrorCodes.E_CANNOT_REFRESH_SSO_TOKEN: + this.updateConnectionState('expired') + break + // TODO: implement when identity server emits E_NETWORK_ERROR, E_FILESYSTEM_ERROR + // case AwsErrorCodes.E_NETWORK_ERROR: + // case AwsErrorCodes.E_FILESYSTEM_ERROR: + // // do stuff, probably nothing at all + // break + default: + getLogger().error('SsoLogin: unknown error when requesting token: %s', err) + break + } + throw err + } finally { + this.cancellationToken?.dispose() + this.cancellationToken = undefined + } + + this.ssoTokenId = response.ssoToken.id + this.updateConnectionState('connected') + return response + } + + getConnectionState() { + return this.connectionState + } + + onDidChangeConnectionState(handler: (e: AuthStateEvent) => any) { + return this.eventEmitter.event(handler) + } + + private updateConnectionState(state: AuthState) { + if (this.connectionState !== state) { + this.eventEmitter.fire({ id: this.profileName, state }) + } + this.connectionState = state + } + + private ssoTokenChangedHandler(params: SsoTokenChangedParams) { + if (params.ssoTokenId === this.ssoTokenId) { + if (params.kind === SsoTokenChangedKind.Expired) { + this.updateConnectionState('expired') + return + } else if (params.kind === SsoTokenChangedKind.Refreshed) { + this.eventEmitter.fire({ id: this.profileName, state: 'refreshed' }) + } + } + } +} diff --git a/packages/core/src/auth/sso/ssoAccessTokenProvider.ts b/packages/core/src/auth/sso/ssoAccessTokenProvider.ts index e753fb2ef90..7caf2638e1b 100644 --- a/packages/core/src/auth/sso/ssoAccessTokenProvider.ts +++ b/packages/core/src/auth/sso/ssoAccessTokenProvider.ts @@ -289,17 +289,7 @@ export abstract class SsoAccessTokenProvider { profile: Pick, cache = getCache(), oidc: OidcClient = OidcClient.create(profile.region), - reAuthState?: ReAuthState, - useDeviceFlow: () => boolean = () => { - /** - * Device code flow is neccessary when: - * 1. We are in a workspace connected through ssh (codecatalyst, etc) - * 2. We are connected to a remote backend through the web browser (code server, openshift dev spaces) - * - * Since we are unable to serve the final authorization page - */ - return getExtRuntimeContext().extensionHost === 'remote' - } + reAuthState?: ReAuthState ) { if (DevSettings.instance.get('webAuth', false) && getExtRuntimeContext().extensionHost === 'webworker') { return new WebAuthorization(profile, cache, oidc, reAuthState) @@ -400,6 +390,17 @@ function getSessionDuration(id: string) { return creationDate !== undefined ? globals.clock.Date.now() - creationDate : undefined } +export function useDeviceFlow(): boolean { + /** + * Device code flow is neccessary when: + * 1. We are in a workspace connected through ssh (codecatalyst, etc) + * 2. We are connected to a remote backend through the web browser (code server, openshift dev spaces) + * + * Since we are unable to serve the final authorization page + */ + return getExtRuntimeContext().extensionHost === 'remote' +} + /** * SSO "device code" flow (RFC: https://tools.ietf.org/html/rfc8628) * 1. Get a client id (SSO-OIDC identifier, formatted per RFC6749). diff --git a/packages/core/src/codewhisperer/util/authUtil.ts b/packages/core/src/codewhisperer/util/authUtil.ts index 0898493b6db..08be6c26e4b 100644 --- a/packages/core/src/codewhisperer/util/authUtil.ts +++ b/packages/core/src/codewhisperer/util/authUtil.ts @@ -5,358 +5,131 @@ import * as vscode from 'vscode' import * as localizedText from '../../shared/localizedText' -import { Auth } from '../../auth/auth' -import { ToolkitError, isNetworkError, tryRun } from '../../shared/errors' -import { getSecondaryAuth, setScopes } from '../../auth/secondaryAuth' -import { isSageMaker } from '../../shared/extensionUtilities' +import * as nls from 'vscode-nls' +import { ToolkitError } from '../../shared/errors' import { AmazonQPromptSettings } from '../../shared/settings' -import { - scopesCodeWhispererCore, - createBuilderIdProfile, - hasScopes, - SsoConnection, - createSsoProfile, - Connection, - isIamConnection, - isSsoConnection, - isBuilderIdConnection, - scopesCodeWhispererChat, - scopesFeatureDev, - scopesGumby, - isIdcSsoConnection, - hasExactScopes, - getTelemetryMetadataForConn, - ProfileNotFoundError, -} from '../../auth/connection' +import { scopesCodeWhispererCore, scopesCodeWhispererChat, scopesFeatureDev, scopesGumby } from '../../auth/connection' import { getLogger } from '../../shared/logger/logger' -import { Commands, placeholder } from '../../shared/vscode/commands2' +import { Commands } from '../../shared/vscode/commands2' import { vsCodeState } from '../models/model' -import { onceChanged, once } from '../../shared/utilities/functionUtils' -import { indent } from '../../shared/utilities/textUtilities' import { showReauthenticateMessage } from '../../shared/utilities/messages' import { showAmazonQWalkthroughOnce } from '../../amazonq/onboardingPage/walkthrough' import { setContext } from '../../shared/vscode/setContext' -import { isInDevEnv } from '../../shared/vscode/env' import { openUrl } from '../../shared/utilities/vsCodeUtils' -import * as nls from 'vscode-nls' -const localize = nls.loadMessageBundle() import { telemetry } from '../../shared/telemetry/telemetry' -import { asStringifiedStack } from '../../shared/telemetry/spans' -import { withTelemetryContext } from '../../shared/telemetry/util' -import { focusAmazonQPanel } from '../../codewhispererChat/commands/registerCommands' -import { throttle } from 'lodash' -import { RegionProfileManager } from '../region/regionProfileManager' +import { AuthStateEvent, AuthStates, LanguageClientAuth, LoginTypes, SsoLogin } from '../../auth/auth2' +import { builderIdStartUrl } from '../../auth/sso/constants' +import { VSCODE_EXTENSION_ID } from '../../shared/extensions' + +const localize = nls.loadMessageBundle() + /** Backwards compatibility for connections w pre-chat scopes */ export const codeWhispererCoreScopes = [...scopesCodeWhispererCore] export const codeWhispererChatScopes = [...codeWhispererCoreScopes, ...scopesCodeWhispererChat] export const amazonQScopes = [...codeWhispererChatScopes, ...scopesGumby, ...scopesFeatureDev] /** - * "Core" are the CW scopes that existed before the addition of new scopes - * for Amazon Q. + * Handles authentication within Amazon Q. + * Amazon Q only supports a single connection at a time. */ -export const isValidCodeWhispererCoreConnection = (conn?: Connection): conn is Connection => { - return ( - (isSageMaker() && isIamConnection(conn)) || (isSsoConnection(conn) && hasScopes(conn, codeWhispererCoreScopes)) - ) -} -/** Superset that includes all of CodeWhisperer + Amazon Q */ -export const isValidAmazonQConnection = (conn?: Connection): conn is Connection => { - return ( - (isSageMaker() && isIamConnection(conn)) || - ((isSsoConnection(conn) || isBuilderIdConnection(conn)) && - isValidCodeWhispererCoreConnection(conn) && - hasScopes(conn, amazonQScopes)) - ) -} - -const authClassName = 'AuthQ' - export class AuthUtil { - static #instance: AuthUtil - protected static readonly logIfChanged = onceChanged((s: string) => getLogger().info(s)) + public readonly profileName = VSCODE_EXTENSION_ID.amazonq - private reauthenticatePromptShown: boolean = false - private _isCustomizationFeatureEnabled: boolean = false - - // user should only see that screen once. - // TODO: move to memento - public hasAlreadySeenMigrationAuthScreen: boolean = false + // IAM login currently not supported + private session: SsoLogin - public get isCustomizationFeatureEnabled(): boolean { - return this._isCustomizationFeatureEnabled + static create(lspAuth: LanguageClientAuth) { + return (this.#instance ??= new this(lspAuth)) } - // This boolean controls whether the Select Customization node will be visible. A change to this value - // means that the old UX was wrong and must refresh the devTool tree. - public set isCustomizationFeatureEnabled(value: boolean) { - if (this._isCustomizationFeatureEnabled === value) { - return + static #instance: AuthUtil + public static get instance() { + if (!this.#instance) { + throw new ToolkitError('AuthUtil not ready. Was it initialized with a running LSP?') } - this._isCustomizationFeatureEnabled = value - void Commands.tryExecute('aws.amazonq.refreshStatusBar') + return this.#instance } - public readonly secondaryAuth = getSecondaryAuth( - this.auth, - 'codewhisperer', - 'Amazon Q', - isValidCodeWhispererCoreConnection - ) - public readonly restore = () => this.secondaryAuth.restoreConnection() - - public constructor( - public readonly auth = Auth.instance, - public readonly regionProfileManager = new RegionProfileManager(() => this.conn) - ) {} - - public initCodeWhispererHooks = once(() => { - this.auth.onDidChangeConnectionState(async (e) => { - getLogger().info(`codewhisperer: connection changed to ${e.state}: ${e.id}`) - if (e.state !== 'authenticating') { - await this.refreshCodeWhisperer() - } - - await this.setVscodeContextProps() - }) - - this.secondaryAuth.onDidChangeActiveConnection(async () => { - getLogger().info(`codewhisperer: active connection changed`) - if (this.isValidEnterpriseSsoInUse()) { - void vscode.commands.executeCommand('aws.amazonq.notifyNewCustomizations') - await this.regionProfileManager.restoreProfileSelection() - } - vsCodeState.isFreeTierLimitReached = false - await Promise.all([ - // onDidChangeActiveConnection may trigger before these modules are activated. - Commands.tryExecute('aws.amazonq.refreshStatusBar'), - Commands.tryExecute('aws.amazonq.updateReferenceLog'), - ]) - - await this.setVscodeContextProps() - - // To check valid connection - if (this.isValidEnterpriseSsoInUse() || (this.isBuilderIdInUse() && !this.isConnectionExpired())) { - await showAmazonQWalkthroughOnce() - } - - if (!this.isConnected()) { - await this.regionProfileManager.invalidateProfile(this.regionProfileManager.activeRegionProfile?.arn) - } - }) - - this.regionProfileManager.onDidChangeRegionProfile(async () => { - await this.setVscodeContextProps() - }) - }) - - public async setVscodeContextProps() { - // if users are "pending profile selection", they're not fully connected and require profile selection for Q usage - // requireProfileSelection() always returns false for builderID users - await setContext('aws.codewhisperer.connected', this.isConnected() && !this.requireProfileSelection()) - const doShowAmazonQLoginView = - !this.isConnected() || this.isConnectionExpired() || this.requireProfileSelection() - await setContext('aws.amazonq.showLoginView', doShowAmazonQLoginView) - await setContext('aws.codewhisperer.connectionExpired', this.isConnectionExpired()) - await setContext('aws.amazonq.connectedSsoIdc', isIdcSsoConnection(this.conn)) + private constructor(private readonly lspAuth: LanguageClientAuth) { + this.session = new SsoLogin(this.profileName, this.lspAuth) + this.onDidChangeConnectionState((e: AuthStateEvent) => this.stateChangeHandler(e)) } - public reformatStartUrl(startUrl: string | undefined) { - return !startUrl ? undefined : startUrl.replace(/[\/#]+$/g, '') + isSsoSession() { + return this.session.loginType === LoginTypes.SSO } - // current active cwspr connection - public get conn() { - return this.secondaryAuth.activeConnection + async restore() { + await this.session.restore() } - // TODO: move this to the shared auth.ts - public get startUrl(): string | undefined { - // Reformat the url to remove any trailing '/' and `#` - // e.g. https://view.awsapps.com/start/# will become https://view.awsapps.com/start - return isSsoConnection(this.conn) ? this.reformatStartUrl(this.conn?.startUrl) : undefined - } - - public get isUsingSavedConnection() { - return this.conn !== undefined && this.secondaryAuth.hasSavedConnection - } - - public isConnected(): boolean { - return this.conn !== undefined - } - - public isEnterpriseSsoInUse(): boolean { - const conn = this.conn - // we have an sso that isn't builder id, must be IdC by process of elimination - const isUsingEnterpriseSso = conn?.type === 'sso' && !isBuilderIdConnection(conn) - return conn !== undefined && isUsingEnterpriseSso - } + async login(startUrl: string, region: string) { + const response = await this.session.login({ startUrl, region, scopes: amazonQScopes }) + await showAmazonQWalkthroughOnce() - // If there is an active SSO connection - public isValidEnterpriseSsoInUse(): boolean { - return this.isEnterpriseSsoInUse() && !this.isConnectionExpired() + return response } - public isBuilderIdInUse(): boolean { - return this.conn !== undefined && isBuilderIdConnection(this.conn) - } - - @withTelemetryContext({ name: 'connectToAwsBuilderId', class: authClassName }) - public async connectToAwsBuilderId(): Promise { - let conn = (await this.auth.listConnections()).find(isBuilderIdConnection) - - if (!conn) { - conn = await this.auth.createConnection(createBuilderIdProfile(amazonQScopes)) - } else if (!isValidAmazonQConnection(conn)) { - conn = await this.secondaryAuth.addScopes(conn, amazonQScopes) - } - - if (this.auth.getConnectionState(conn) === 'invalid') { - conn = await this.auth.reauthenticate(conn) + reauthenticate() { + if (!this.isSsoSession()) { + throw new ToolkitError('Cannot reauthenticate non-SSO session.') } - return (await this.secondaryAuth.useNewConnection(conn)) as SsoConnection + return this.session.reauthenticate() } - @withTelemetryContext({ name: 'connectToEnterpriseSso', class: authClassName }) - public async connectToEnterpriseSso(startUrl: string, region: string): Promise { - let conn = (await this.auth.listConnections()).find( - (conn): conn is SsoConnection => - isSsoConnection(conn) && conn.startUrl.toLowerCase() === startUrl.toLowerCase() - ) - - if (!conn) { - conn = await this.auth.createConnection(createSsoProfile(startUrl, region, amazonQScopes)) - } else if (!isValidAmazonQConnection(conn)) { - conn = await this.secondaryAuth.addScopes(conn, amazonQScopes) - } - - if (this.auth.getConnectionState(conn) === 'invalid') { - conn = await this.auth.reauthenticate(conn) + logout() { + if (!this.isSsoSession()) { + // Only SSO requires logout + return } - - return (await this.secondaryAuth.useNewConnection(conn)) as SsoConnection + this.lspAuth.deleteBearerToken() + return this.session.logout() } - public static get instance() { - if (this.#instance !== undefined) { - return this.#instance + async getToken() { + if (this.isSsoSession()) { + return (await this.session.getToken()).token + } else { + throw new ToolkitError('Cannot get token for non-SSO session.') } - - const self = (this.#instance = new this()) - return self } - @withTelemetryContext({ name: 'getBearerToken', class: authClassName }) - public async getBearerToken(): Promise { - await this.restore() - - if (this.conn === undefined) { - throw new ToolkitError('No connection found', { code: 'NoConnection' }) - } - - if (!isSsoConnection(this.conn)) { - throw new ToolkitError('Connection is not an SSO connection', { code: 'BadConnectionType' }) - } - - try { - const bearerToken = await this.conn.getToken() - return bearerToken.accessToken - } catch (err) { - if (err instanceof ProfileNotFoundError) { - // Expected that connection would be deleted by conn.getToken() - void focusAmazonQPanel.execute(placeholder, 'profileNotFoundSignout') - } - throw err - } + get connection() { + return this.session.data } - @withTelemetryContext({ name: 'getCredentials', class: authClassName }) - public async getCredentials() { - await this.restore() - - if (this.conn === undefined) { - throw new ToolkitError('No connection found', { code: 'NoConnection' }) - } - - if (!isIamConnection(this.conn)) { - throw new ToolkitError('Connection is not an IAM connection', { code: 'BadConnectionType' }) - } - - return this.conn.getCredentials() + getAuthState() { + return this.session.getConnectionState() } - public isConnectionValid(log: boolean = true): boolean { - const connectionValid = this.conn !== undefined && !this.secondaryAuth.isConnectionExpired - - if (log) { - this.logConnection() - } - - return connectionValid + isConnected() { + return this.getAuthState() === AuthStates.CONNECTED } - public isConnectionExpired(log: boolean = true): boolean { - const connectionExpired = - this.secondaryAuth.isConnectionExpired && - this.conn !== undefined && - isValidCodeWhispererCoreConnection(this.conn) - - if (log) { - this.logConnection() - } - - return connectionExpired + isConnectionExpired() { + return this.getAuthState() === AuthStates.EXPIRED } - requireProfileSelection(): boolean { - if (isBuilderIdConnection(this.conn)) { - return false - } - return isIdcSsoConnection(this.conn) && this.regionProfileManager.activeRegionProfile === undefined + isBuilderIdConnection() { + return this.connection?.startUrl === builderIdStartUrl } - private logConnection() { - const logStr = indent( - `codewhisperer: connection states - connection isValid=${this.isConnectionValid(false)}, - connection isValidCodewhispererCoreConnection=${isValidCodeWhispererCoreConnection(this.conn)}, - connection isExpired=${this.isConnectionExpired(false)}, - secondaryAuth isExpired=${this.secondaryAuth.isConnectionExpired}, - connection isUndefined=${this.conn === undefined}`, - 4, - true - ) - - AuthUtil.logIfChanged(logStr) + isIdcConnection() { + return this.connection?.startUrl && this.connection?.startUrl !== builderIdStartUrl } - @withTelemetryContext({ name: 'reauthenticate', class: authClassName }) - public async reauthenticate() { - try { - if (this.conn?.type !== 'sso') { - return - } - - if (!hasExactScopes(this.conn, amazonQScopes)) { - const conn = await setScopes(this.conn, amazonQScopes, this.auth) - await this.secondaryAuth.useNewConnection(conn) - } - - await this.auth.reauthenticate(this.conn) - } catch (err) { - throw ToolkitError.chain(err, 'Unable to authenticate connection') - } finally { - await this.setVscodeContextProps() - } + onDidChangeConnectionState(handler: (e: AuthStateEvent) => any) { + return this.session.onDidChangeConnectionState(handler) } - public async refreshCodeWhisperer() { - vsCodeState.isFreeTierLimitReached = false - await Commands.tryExecute('aws.amazonq.refreshStatusBar') + public async setVscodeContextProps(state = this.getAuthState()) { + await setContext('aws.codewhisperer.connected', state === AuthStates.CONNECTED) + await setContext('aws.amazonq.showLoginView', state !== AuthStates.CONNECTED) // Login view also handles expired state. + await setContext('aws.codewhisperer.connectionExpired', state === AuthStates.EXPIRED) } - @withTelemetryContext({ name: 'showReauthenticatePrompt', class: authClassName }) + private reauthenticatePromptShown: boolean = false public async showReauthenticatePrompt(isAutoTrigger?: boolean) { if (isAutoTrigger && this.reauthenticatePromptShown) { return @@ -377,6 +150,26 @@ export class AuthUtil { } } + private _isCustomizationFeatureEnabled: boolean = false + public get isCustomizationFeatureEnabled(): boolean { + return this._isCustomizationFeatureEnabled + } + + // This boolean controls whether the Select Customization node will be visible. A change to this value + // means that the old UX was wrong and must refresh the devTool tree. + public set isCustomizationFeatureEnabled(value: boolean) { + if (this._isCustomizationFeatureEnabled === value) { + return + } + this._isCustomizationFeatureEnabled = value + void Commands.tryExecute('aws.amazonq.refreshStatusBar') + } + + public async notifyReauthenticate(isAutoTrigger?: boolean) { + void this.showReauthenticatePrompt(isAutoTrigger) + await this.setVscodeContextProps() + } + public async notifySessionConfiguration() { const suppressId = 'amazonQSessionConfigurationMessage' const settings = AmazonQPromptSettings.instance @@ -409,175 +202,35 @@ export class AuthUtil { }) } - @withTelemetryContext({ name: 'notifyReauthenticate', class: authClassName }) - public async notifyReauthenticate(isAutoTrigger?: boolean) { - void this.showReauthenticatePrompt(isAutoTrigger) - await this.setVscodeContextProps() - } - - public isValidCodeTransformationAuthUser(): boolean { - return (this.isEnterpriseSsoInUse() || this.isBuilderIdInUse()) && this.isConnectionValid() - } - - /** - * Asynchronously returns a snapshot of the overall auth state of CodeWhisperer + Chat features. - * It guarantees the latest state is correct at the risk of modifying connection state. - * If this guarantee is not required, use sync method getChatAuthStateSync() - * - * By default, network errors are ignored when determining auth state since they may be silently - * recoverable later. - * - * THROTTLE: This function is called in rapid succession by Amazon Q features and can lead to - * a barrage of disk access and/or token refreshes. We throttle to deal with this. - * - * Note we do an explicit cast of the return type due to Lodash types incorrectly indicating - * a FeatureAuthState or undefined can be returned. But since we set `leading: true` - * it will always return FeatureAuthState - */ - public getChatAuthState = throttle(() => this._getChatAuthState(), 2000, { - leading: true, - }) as () => Promise - /** - * IMPORTANT: Only use this if you do NOT want to swallow network errors, otherwise use {@link getChatAuthState()} - * @param ignoreNetErr swallows network errors - */ - @withTelemetryContext({ name: 'getChatAuthState', class: authClassName }) - public async _getChatAuthState(ignoreNetErr: boolean = true): Promise { - // The state of the connection may not have been properly validated - // and the current state we see may be stale, so refresh for latest state. - if (ignoreNetErr) { - await tryRun( - () => this.auth.refreshConnectionState(this.conn), - (err) => !isNetworkError(err), - 'getChatAuthState: Cannot refresh connection state due to network error: %s' - ) + private async stateChangeHandler(e: AuthStateEvent) { + if (e.state === 'refreshed') { + const params = this.isSsoSession() ? (await this.session.getToken()).updateCredentialsParams : undefined + await this.lspAuth.updateBearerToken(params!) + return } else { - await this.auth.refreshConnectionState(this.conn) + getLogger().info(`codewhisperer: connection changed to ${e.state}`) + await this.refreshState(e.state) } - - return this.getChatAuthStateSync(this.conn) } - /** - * Synchronously returns a snapshot of the overall auth state of CodeWhisperer + Chat features without - * validating or modifying the connection state. It is possible that the connection - * is invalid/valid, but the current state displays something else. To guarantee the true state, - * use async method getChatAuthState() - */ - public getChatAuthStateSync(conn = this.conn): FeatureAuthState { - if (conn === undefined) { - return buildFeatureAuthState(AuthStates.disconnected) + private async refreshState(state = this.getAuthState()) { + if (state === AuthStates.EXPIRED || state === AuthStates.NOT_CONNECTED) { + this.lspAuth.deleteBearerToken() } - - if (!isSsoConnection(conn) && !isSageMaker()) { - throw new ToolkitError(`Connection "${conn.id}" is not a valid type: ${conn.type}`) + if (state === AuthStates.CONNECTED) { + const bearerTokenParams = (await this.session.getToken()).updateCredentialsParams + await this.lspAuth.updateBearerToken(bearerTokenParams) } - // default to expired to indicate reauth is needed if unmodified - const state: FeatureAuthState = buildFeatureAuthState(AuthStates.expired) - - if (this.isConnectionExpired()) { - return state - } - - if (isBuilderIdConnection(conn) || isIdcSsoConnection(conn) || isSageMaker()) { - // TODO: refactor - if (isValidCodeWhispererCoreConnection(conn)) { - if (this.requireProfileSelection()) { - state[Features.codewhispererCore] = AuthStates.pendingProfileSelection - } else { - state[Features.codewhispererCore] = AuthStates.connected - } - } - if (isValidAmazonQConnection(conn)) { - if (this.requireProfileSelection()) { - for (const v of Object.values(Features)) { - state[v as Feature] = AuthStates.pendingProfileSelection - } - } else { - for (const v of Object.values(Features)) { - state[v as Feature] = AuthStates.connected - } - } - } - } - - return state - } + vsCodeState.isFreeTierLimitReached = false + await this.setVscodeContextProps(state) + await Promise.all([ + Commands.tryExecute('aws.amazonq.refreshStatusBar'), + Commands.tryExecute('aws.amazonq.updateReferenceLog'), + ]) - /** - * Edge Case: Due to a change in behaviour/functionality, there are potential extra - * auth connections that the Amazon Q extension has cached. We need to remove these - * as they are irrelevant to the Q extension and can cause issues. - */ - public async clearExtraConnections(): Promise { - const currentQConn = this.conn - // Q currently only maintains 1 connection at a time, so we assume everything else is extra. - // IMPORTANT: In the case Q starts to manage multiple connections, this implementation will need to be updated. - const allOtherConnections = (await this.auth.listConnections()).filter((c) => c.id !== currentQConn?.id) - for (const conn of allOtherConnections) { - getLogger().warn(`forgetting extra amazon q connection: %O`, conn) - await telemetry.auth_modifyConnection.run( - async () => { - telemetry.record({ - connectionState: Auth.instance.getConnectionState(conn) ?? 'undefined', - source: asStringifiedStack(telemetry.getFunctionStack()), - ...(await getTelemetryMetadataForConn(conn)), - }) - - if (isInDevEnv()) { - telemetry.record({ action: 'forget' }) - // in a Dev Env the connection may be used by code catalyst, so we forget instead of fully deleting - await this.auth.forgetConnection(conn) - } else { - telemetry.record({ action: 'delete' }) - await this.auth.deleteConnection(conn) - } - }, - { functionId: { name: 'clearExtraConnections', class: authClassName } } - ) + if (state === AuthStates.CONNECTED && this.isIdcConnection()) { + void vscode.commands.executeCommand('aws.amazonq.notifyNewCustomizations') } } } - -export type FeatureAuthState = { [feature in Feature]: AuthState } -export type Feature = (typeof Features)[keyof typeof Features] -export type AuthState = (typeof AuthStates)[keyof typeof AuthStates] - -export const AuthStates = { - /** The current connection is working and supports this feature. */ - connected: 'connected', - /** No connection exists, so this feature cannot be used*/ - disconnected: 'disconnected', - /** - * The current connection exists, but needs to be reauthenticated for this feature to work - * - * Look to use {@link AuthUtil.reauthenticate()} - */ - expired: 'expired', - /** - * A connection exists, but does not support this feature. - * - * Eg: We are currently using Builder ID, but must use Identity Center. - */ - unsupported: 'unsupported', - /** - * The current connection exists and isn't expired, - * but fetching/refreshing the token resulted in a network error. - */ - connectedWithNetworkError: 'connectedWithNetworkError', - pendingProfileSelection: 'pendingProfileSelection', -} as const -const Features = { - codewhispererCore: 'codewhispererCore', - codewhispererChat: 'codewhispererChat', - amazonQ: 'amazonQ', -} as const - -function buildFeatureAuthState(state: AuthState): FeatureAuthState { - return { - codewhispererCore: state, - codewhispererChat: state, - amazonQ: state, - } -} diff --git a/packages/core/src/test/credentials/auth2.test.ts b/packages/core/src/test/credentials/auth2.test.ts new file mode 100644 index 00000000000..f0b56c4ff6b --- /dev/null +++ b/packages/core/src/test/credentials/auth2.test.ts @@ -0,0 +1,530 @@ +/*! + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +import * as sinon from 'sinon' +import * as vscode from 'vscode' +import { LanguageClientAuth, SsoLogin, AuthStates } from '../../auth/auth2' +import { LanguageClient } from 'vscode-languageclient' +import { + GetSsoTokenResult, + SsoTokenSourceKind, + AuthorizationFlowKind, + ListProfilesResult, + UpdateCredentialsParams, + SsoTokenChangedParams, + bearerCredentialsUpdateRequestType, + bearerCredentialsDeleteNotificationType, + ssoTokenChangedRequestType, + SsoTokenChangedKind, + invalidateSsoTokenRequestType, + ProfileKind, + AwsErrorCodes, +} from '@aws/language-server-runtimes/protocol' +import * as ssoProvider from '../../auth/sso/ssoAccessTokenProvider' + +const profileName = 'test-profile' +const sessionName = 'test-session' +const region = 'us-east-1' +const startUrl = 'test-url' +const tokenId = 'test-token' + +describe('LanguageClientAuth', () => { + let client: sinon.SinonStubbedInstance + let auth: LanguageClientAuth + const encryptionKey = Buffer.from('test-key') + let useDeviceFlowStub: sinon.SinonStub + + beforeEach(() => { + client = sinon.createStubInstance(LanguageClient) + auth = new LanguageClientAuth(client as unknown as LanguageClient, 'testClient', encryptionKey) + useDeviceFlowStub = sinon.stub(ssoProvider, 'useDeviceFlow') + }) + + afterEach(() => { + sinon.restore() + }) + + describe('getSsoToken', () => { + async function testGetSsoToken(useDeviceFlow: boolean) { + const tokenSource = { + kind: SsoTokenSourceKind.IamIdentityCenter, + profileName, + } + useDeviceFlowStub.returns(useDeviceFlow ? true : false) + + await auth.getSsoToken(tokenSource, true) + + sinon.assert.calledOnce(client.sendRequest) + sinon.assert.calledWith( + client.sendRequest, + sinon.match.any, + sinon.match({ + clientName: 'testClient', + source: tokenSource, + options: { + loginOnInvalidToken: true, + authorizationFlow: useDeviceFlow + ? AuthorizationFlowKind.DeviceCode + : AuthorizationFlowKind.Pkce, + }, + }) + ) + } + + it('sends correct request parameters for pkce flow', async () => { + await testGetSsoToken(false) + }) + + it('sends correct request parameters for device code flow', async () => { + await testGetSsoToken(true) + }) + }) + + describe('updateProfile', () => { + it('sends correct profile update parameters', async () => { + await auth.updateProfile(profileName, startUrl, region, ['scope1']) + + sinon.assert.calledOnce(client.sendRequest) + const requestParams = client.sendRequest.firstCall.args[1] + sinon.assert.match(requestParams.profile, { + name: profileName, + }) + sinon.assert.match(requestParams.ssoSession.settings, { + sso_region: region, + }) + }) + }) + + describe('getProfile', () => { + const profile = { name: profileName, settings: { sso_session: sessionName } } + const ssoSession = { name: sessionName, settings: { sso_region: region, sso_start_url: startUrl } } + + it('returns the correct profile and sso session', async () => { + const mockListProfilesResult: ListProfilesResult = { + profiles: [ + { + ...profile, + kinds: [], + }, + ], + ssoSessions: [ssoSession], + } + client.sendRequest.resolves(mockListProfilesResult) + + const result = await auth.getProfile(profileName) + + sinon.assert.calledOnce(client.sendRequest) + sinon.assert.match(result, { + profile, + ssoSession, + }) + }) + + it('returns undefined for non-existent profile', async () => { + const mockListProfilesResult: ListProfilesResult = { + profiles: [], + ssoSessions: [], + } + client.sendRequest.resolves(mockListProfilesResult) + + const result = await auth.getProfile('non-existent-profile') + + sinon.assert.calledOnce(client.sendRequest) + sinon.assert.match(result, { profile: undefined, ssoSession: undefined }) + }) + }) + + describe('updateBearerToken', () => { + it('sends request', async () => { + const updateParams: UpdateCredentialsParams = { + data: 'token-data', + encrypted: true, + } + + await auth.updateBearerToken(updateParams) + + sinon.assert.calledOnce(client.sendRequest) + sinon.assert.calledWith(client.sendRequest, bearerCredentialsUpdateRequestType.method, updateParams) + }) + }) + + describe('deleteBearerToken', () => { + it('sends notification', async () => { + auth.deleteBearerToken() + + sinon.assert.calledOnce(client.sendNotification) + sinon.assert.calledWith(client.sendNotification, bearerCredentialsDeleteNotificationType.method) + }) + }) + + describe('invalidateSsoToken', () => { + it('sends request', async () => { + client.sendRequest.resolves({ success: true }) + const result = await auth.invalidateSsoToken(tokenId) + + sinon.assert.calledOnce(client.sendRequest) + sinon.assert.calledWith(client.sendRequest, invalidateSsoTokenRequestType.method, { ssoTokenId: tokenId }) + sinon.assert.match(result, { success: true }) + }) + }) + + describe('registerSsoTokenChangedHandler', () => { + it('registers the handler correctly', () => { + const handler = sinon.spy() + + auth.registerSsoTokenChangedHandler(handler) + + sinon.assert.calledOnce(client.onNotification) + sinon.assert.calledWith(client.onNotification, ssoTokenChangedRequestType.method, sinon.match.func) + + // Simulate a token changed notification + const tokenChangedParams: SsoTokenChangedParams = { + kind: SsoTokenChangedKind.Refreshed, + ssoTokenId: tokenId, + } + const registeredHandler = client.onNotification.firstCall.args[1] + registeredHandler(tokenChangedParams) + + sinon.assert.calledOnce(handler) + sinon.assert.calledWith(handler, tokenChangedParams) + }) + }) +}) + +describe('SsoLogin', () => { + let lspAuth: sinon.SinonStubbedInstance + let ssoLogin: SsoLogin + let eventEmitter: vscode.EventEmitter + let fireEventSpy: sinon.SinonSpy + + const loginOpts = { + startUrl, + region, + scopes: ['scope1'], + } + + const mockGetSsoTokenResponse: GetSsoTokenResult = { + ssoToken: { + id: tokenId, + accessToken: 'encrypted-token', + }, + updateCredentialsParams: { + data: '', + }, + } + + beforeEach(() => { + lspAuth = sinon.createStubInstance(LanguageClientAuth) + eventEmitter = new vscode.EventEmitter() + fireEventSpy = sinon.spy(eventEmitter, 'fire') + ssoLogin = new SsoLogin(profileName, lspAuth as any) + ;(ssoLogin as any).eventEmitter = eventEmitter + ;(ssoLogin as any).connectionState = AuthStates.NOT_CONNECTED + }) + + afterEach(() => { + sinon.restore() + eventEmitter.dispose() + }) + + describe('login', () => { + it('updates profile and returns SSO token', async () => { + lspAuth.updateProfile.resolves() + lspAuth.getSsoToken.resolves(mockGetSsoTokenResponse) + + const response = await ssoLogin.login(loginOpts) + + sinon.assert.calledOnce(lspAuth.updateProfile) + sinon.assert.calledWith( + lspAuth.updateProfile, + profileName, + loginOpts.startUrl, + loginOpts.region, + loginOpts.scopes + ) + sinon.assert.calledOnce(lspAuth.getSsoToken) + sinon.assert.match(ssoLogin.getConnectionState(), AuthStates.CONNECTED) + sinon.assert.match(ssoLogin.data, { + startUrl: loginOpts.startUrl, + region: loginOpts.region, + }) + sinon.assert.match(response.ssoToken.id, tokenId) + sinon.assert.match(response.updateCredentialsParams, mockGetSsoTokenResponse.updateCredentialsParams) + }) + }) + + describe('reauthenticate', () => { + it('throws when not connected', async () => { + ;(ssoLogin as any).connectionState = AuthStates.NOT_CONNECTED + try { + await ssoLogin.reauthenticate() + sinon.assert.fail('Should have thrown an error') + } catch (err) { + sinon.assert.match((err as Error).message, 'Cannot reauthenticate when not connected.') + } + }) + + it('returns new SSO token when connected', async () => { + ;(ssoLogin as any).connectionState = AuthStates.CONNECTED + lspAuth.getSsoToken.resolves(mockGetSsoTokenResponse) + + const response = await ssoLogin.reauthenticate() + + sinon.assert.calledOnce(lspAuth.getSsoToken) + sinon.assert.match(ssoLogin.getConnectionState(), AuthStates.CONNECTED) + sinon.assert.match(response.ssoToken.id, tokenId) + sinon.assert.match(response.updateCredentialsParams, mockGetSsoTokenResponse.updateCredentialsParams) + }) + }) + + describe('logout', () => { + it('invalidates token and updates state', async () => { + await ssoLogin.logout() + + sinon.assert.match(ssoLogin.getConnectionState(), AuthStates.NOT_CONNECTED) + sinon.assert.match(ssoLogin.data, undefined) + }) + + it('emits state change event', async () => { + ;(ssoLogin as any).connectionState = AuthStates.CONNECTED + ;(ssoLogin as any).ssoTokenId = tokenId + ;(ssoLogin as any)._data = { + startUrl: loginOpts.startUrl, + region: loginOpts.region, + } + ;(ssoLogin as any).eventEmitter = eventEmitter + + lspAuth.invalidateSsoToken.resolves({ success: true }) + + await ssoLogin.logout() + + sinon.assert.calledOnce(fireEventSpy) + sinon.assert.calledWith(fireEventSpy, { + id: profileName, + state: AuthStates.NOT_CONNECTED, + }) + }) + }) + + describe('restore', () => { + const mockProfile = { + profile: { + kinds: [ProfileKind.SsoTokenProfile], + name: profileName, + }, + ssoSession: { + name: sessionName, + settings: { + sso_region: region, + sso_start_url: startUrl, + }, + }, + } + + it('restores connection state from existing profile', async () => { + lspAuth.getProfile.resolves(mockProfile) + lspAuth.getSsoToken.resolves(mockGetSsoTokenResponse) + + await ssoLogin.restore() + + sinon.assert.calledOnce(lspAuth.getProfile) + sinon.assert.calledWith(lspAuth.getProfile, mockProfile.profile.name) + sinon.assert.calledOnce(lspAuth.getSsoToken) + sinon.assert.calledWith( + lspAuth.getSsoToken, + sinon.match({ + kind: SsoTokenSourceKind.IamIdentityCenter, + profileName: mockProfile.profile.name, + }), + false // login parameter + ) + + sinon.assert.match(ssoLogin.data, { + region: region, + startUrl: startUrl, + }) + sinon.assert.match(ssoLogin.getConnectionState(), AuthStates.CONNECTED) + sinon.assert.match((ssoLogin as any).ssoTokenId, tokenId) + }) + + it('does not connect for non-existent profile', async () => { + lspAuth.getProfile.resolves({ profile: undefined, ssoSession: undefined }) + + await ssoLogin.restore() + + sinon.assert.calledOnce(lspAuth.getProfile) + sinon.assert.calledOnce(lspAuth.getSsoToken) + sinon.assert.match(ssoLogin.data, undefined) + sinon.assert.match(ssoLogin.getConnectionState(), AuthStates.NOT_CONNECTED) + }) + + it('emits state change event on successful restore', async () => { + ;(ssoLogin as any).eventEmitter = eventEmitter + + lspAuth.getProfile.resolves(mockProfile) + lspAuth.getSsoToken.resolves(mockGetSsoTokenResponse) + + await ssoLogin.restore() + + sinon.assert.calledOnce(fireEventSpy) + sinon.assert.calledWith(fireEventSpy, { + id: profileName, + state: AuthStates.CONNECTED, + }) + }) + }) + + describe('cancelLogin', () => { + it('cancels and dispose token source', async () => { + await ssoLogin.login(loginOpts).catch(() => {}) + + ssoLogin.cancelLogin() + + const tokenSource = (ssoLogin as any).cancellationToken + sinon.assert.match(tokenSource, undefined) + }) + }) + + describe('_getSsoToken', () => { + beforeEach(() => { + ;(ssoLogin as any).connectionState = AuthStates.CONNECTED + }) + + const testErrorHandling = async (errorCode: string, expectedState: string, shouldEmitEvent: boolean = true) => { + const error = new Error('Token error') + ;(error as any).data = { awsErrorCode: errorCode } + lspAuth.getSsoToken.rejects(error) + + try { + await (ssoLogin as any)._getSsoToken(false) + sinon.assert.fail('Should have thrown an error') + } catch (err) { + sinon.assert.match(err, error) + } + + sinon.assert.match(ssoLogin.getConnectionState(), expectedState) + + if (shouldEmitEvent) { + sinon.assert.calledWith(fireEventSpy, { + id: profileName, + state: expectedState, + }) + } + + sinon.assert.match((ssoLogin as any).cancellationToken, undefined) + } + + const notConnectedErrors = [ + AwsErrorCodes.E_CANCELLED, + AwsErrorCodes.E_SSO_SESSION_NOT_FOUND, + AwsErrorCodes.E_PROFILE_NOT_FOUND, + AwsErrorCodes.E_INVALID_SSO_TOKEN, + ] + + for (const errorCode of notConnectedErrors) { + it(`handles ${errorCode} error`, async () => { + await testErrorHandling(errorCode, AuthStates.NOT_CONNECTED) + }) + } + + it('handles token refresh error', async () => { + await testErrorHandling(AwsErrorCodes.E_CANNOT_REFRESH_SSO_TOKEN, AuthStates.EXPIRED) + }) + + it('handles unknown errors', async () => { + await testErrorHandling('UNKNOWN_ERROR', ssoLogin.getConnectionState(), false) + }) + + it('returns correct response and cleans up cancellation token', async () => { + lspAuth.getSsoToken.resolves(mockGetSsoTokenResponse) + + const response = await (ssoLogin as any)._getSsoToken(true) + + sinon.assert.calledWith( + lspAuth.getSsoToken, + sinon.match({ + kind: SsoTokenSourceKind.IamIdentityCenter, + profileName, + }), + true + ) + + sinon.assert.match(response, mockGetSsoTokenResponse) + sinon.assert.match((ssoLogin as any).cancellationToken, undefined) + }) + + it('updates state when token is retrieved successfully', async () => { + ;(ssoLogin as any).connectionState = AuthStates.NOT_CONNECTED + lspAuth.getSsoToken.resolves(mockGetSsoTokenResponse) + + await (ssoLogin as any)._getSsoToken(true) + + sinon.assert.match(ssoLogin.getConnectionState(), AuthStates.CONNECTED) + sinon.assert.match((ssoLogin as any).ssoTokenId, tokenId) + sinon.assert.calledWith(fireEventSpy, { + id: profileName, + state: AuthStates.CONNECTED, + }) + }) + }) + + describe('onDidChangeConnectionState', () => { + it('should register handler for connection state changes', () => { + const handler = sinon.spy() + ssoLogin.onDidChangeConnectionState(handler) + + // Simulate state change + ;(ssoLogin as any).updateConnectionState(AuthStates.CONNECTED) + + sinon.assert.calledWith(handler, { + id: profileName, + state: AuthStates.CONNECTED, + }) + }) + }) + + describe('ssoTokenChangedHandler', () => { + beforeEach(() => { + ;(ssoLogin as any).ssoTokenId = tokenId + ;(ssoLogin as any).connectionState = AuthStates.CONNECTED + }) + + it('updates state when token expires', () => { + ;(ssoLogin as any).ssoTokenChangedHandler({ + kind: 'Expired', + ssoTokenId: tokenId, + }) + + sinon.assert.match(ssoLogin.getConnectionState(), AuthStates.EXPIRED) + sinon.assert.calledOnce(fireEventSpy) + sinon.assert.calledWith(fireEventSpy, { + id: profileName, + state: AuthStates.EXPIRED, + }) + }) + + it('emits refresh event when token is refreshed', () => { + ;(ssoLogin as any).ssoTokenChangedHandler({ + kind: 'Refreshed', + ssoTokenId: tokenId, + }) + + sinon.assert.calledOnce(fireEventSpy) + sinon.assert.calledWith(fireEventSpy, { + id: profileName, + state: 'refreshed', + }) + }) + + it('does not emit event for different token ID', () => { + ;(ssoLogin as any).ssoTokenChangedHandler({ + kind: 'Refreshed', + ssoTokenId: 'different-token-id', + }) + + sinon.assert.notCalled(fireEventSpy) + }) + }) +}) diff --git a/packages/core/src/test/credentials/sso/ssoAccessTokenProvider.test.ts b/packages/core/src/test/credentials/sso/ssoAccessTokenProvider.test.ts index 2cb98193224..bd7f264f557 100644 --- a/packages/core/src/test/credentials/sso/ssoAccessTokenProvider.test.ts +++ b/packages/core/src/test/credentials/sso/ssoAccessTokenProvider.test.ts @@ -83,7 +83,7 @@ describe('SsoAccessTokenProvider', function () { tempDir = await makeTemporaryTokenCacheFolder() cache = getCache(tempDir) reAuthState = new TestReAuthState() - sut = SsoAccessTokenProvider.create({ region, startUrl }, cache, oidcClient, reAuthState, () => true) + sut = SsoAccessTokenProvider.create({ region, startUrl }, cache, oidcClient, reAuthState) }) afterEach(async function () { @@ -271,13 +271,7 @@ describe('SsoAccessTokenProvider', function () { await sut.createToken() // Mimic when we sign out then in again with the same region+startUrl. The ID is the only thing different. - sut = SsoAccessTokenProvider.create( - { region, startUrl, identifier: 'bbb' }, - cache, - oidcClient, - reAuthState, - () => true - ) + sut = SsoAccessTokenProvider.create({ region, startUrl, identifier: 'bbb' }, cache, oidcClient, reAuthState) await sut.createToken() assertTelemetry('aws_loginWithBrowser', [