diff --git a/packages/amazonq/test/unit/codewhisperer/util/authUtil.test.ts b/packages/amazonq/test/unit/codewhisperer/util/authUtil.test.ts index 38fa1613022..97e245fecd3 100644 --- a/packages/amazonq/test/unit/codewhisperer/util/authUtil.test.ts +++ b/packages/amazonq/test/unit/codewhisperer/util/authUtil.test.ts @@ -140,6 +140,34 @@ describe('AuthUtil', async function () { }) }) + describe('cacheChangedHandler', function () { + it('calls logout when event is delete', async function () { + const logoutSpy = sinon.spy(auth, 'logout') + + await (auth as any).cacheChangedHandler('delete') + + assert.ok(logoutSpy.calledOnce) + }) + + it('calls restore when event is create', async function () { + const restoreSpy = sinon.spy(auth, 'restore') + + await (auth as any).cacheChangedHandler('create') + + assert.ok(restoreSpy.calledOnce) + }) + + it('does nothing for other events', async function () { + const logoutSpy = sinon.spy(auth, 'logout') + const restoreSpy = sinon.spy(auth, 'restore') + + await (auth as any).cacheChangedHandler('unknown') + + assert.ok(logoutSpy.notCalled) + assert.ok(restoreSpy.notCalled) + }) + }) + describe('stateChangeHandler', function () { let mockLspAuth: any let regionProfileManager: any diff --git a/packages/core/src/auth/auth2.ts b/packages/core/src/auth/auth2.ts index 75bc7523f0c..bff664b7e7b 100644 --- a/packages/core/src/auth/auth2.ts +++ b/packages/core/src/auth/auth2.ts @@ -41,6 +41,7 @@ import { LanguageClient } from 'vscode-languageclient' import { getLogger } from '../shared/logger/logger' import { ToolkitError } from '../shared/errors' import { useDeviceFlow } from './sso/ssoAccessTokenProvider' +import { getCacheFileWatcher } from './sso/cache' export const notificationTypes = { updateBearerToken: new RequestType( @@ -66,6 +67,8 @@ interface BaseLogin { readonly loginType: LoginType } +export type cacheChangedEvent = 'delete' | 'create' + export type Login = SsoLogin // TODO: add IamLogin type when supported export type TokenSource = IamIdentityCenterSsoTokenSource | AwsBuilderIdSsoTokenSource @@ -74,12 +77,18 @@ export type TokenSource = IamIdentityCenterSsoTokenSource | AwsBuilderIdSsoToken * Handles auth requests to the Identity Server in the Amazon Q LSP. */ export class LanguageClientAuth { + readonly #ssoCacheWatcher = getCacheFileWatcher() + constructor( private readonly client: LanguageClient, private readonly clientName: string, public readonly encryptionKey: Buffer ) {} + public get cacheWatcher() { + return this.#ssoCacheWatcher + } + getSsoToken( tokenSource: TokenSource, login: boolean = false, @@ -160,6 +169,11 @@ export class LanguageClientAuth { registerSsoTokenChangedHandler(ssoTokenChangedHandler: (params: SsoTokenChangedParams) => any) { this.client.onNotification(ssoTokenChangedRequestType.method, ssoTokenChangedHandler) } + + registerCacheWatcher(cacheChangedHandler: (event: cacheChangedEvent) => any) { + this.cacheWatcher.onDidCreate(() => cacheChangedHandler('create')) + this.cacheWatcher.onDidDelete(() => cacheChangedHandler('delete')) + } } /** diff --git a/packages/core/src/codewhisperer/activation.ts b/packages/core/src/codewhisperer/activation.ts index 952edb4ac42..e11c7f5fe80 100644 --- a/packages/core/src/codewhisperer/activation.ts +++ b/packages/core/src/codewhisperer/activation.ts @@ -498,6 +498,7 @@ export async function activate(context: ExtContext): Promise { export async function shutdown() { RecommendationHandler.instance.reportUserDecisions(-1) await CodeWhispererTracker.getTracker().shutdown() + AuthUtil.instance.regionProfileManager.globalStatePoller.kill() } function toggleIssuesVisibility(visibleCondition: (issue: CodeScanIssue, filePath: string) => boolean) { diff --git a/packages/core/src/codewhisperer/region/regionProfileManager.ts b/packages/core/src/codewhisperer/region/regionProfileManager.ts index e965052064d..4ba64570521 100644 --- a/packages/core/src/codewhisperer/region/regionProfileManager.ts +++ b/packages/core/src/codewhisperer/region/regionProfileManager.ts @@ -24,6 +24,7 @@ import { localize } from '../../shared/utilities/vsCodeUtils' import { IAuthProvider } from '../util/authUtil' import { Commands } from '../../shared/vscode/commands2' import { CachedResource } from '../../shared/utilities/resourceCache' +import { GlobalStatePoller } from '../../shared/globalState' // TODO: is there a better way to manage all endpoint strings in one place? export const defaultServiceConfig: CodeWhispererConfig = { @@ -37,6 +38,9 @@ const endpoints = createConstantMap({ 'eu-central-1': 'https://q.eu-central-1.amazonaws.com/', }) +const getRegionProfile = () => + globals.globalState.tryGet<{ [label: string]: RegionProfile }>('aws.amazonq.regionProfiles', Object, {}) + /** * 'user' -> users change the profile through Q menu * 'auth' -> users change the profile through webview profile selector page @@ -79,6 +83,17 @@ export class RegionProfileManager { } })(this.listRegionProfile.bind(this)) + // This is a poller that handles synchornization of selected region profiles between different IDE windows. + // It checks for changes in global state of region profile, invoking the change handler to switch profiles + public globalStatePoller = GlobalStatePoller.create({ + getState: getRegionProfile, + changeHandler: async () => { + const profile = this.loadPersistedRegionProfle() + void this._switchRegionProfile(profile[this.authProvider.profileName], 'reload') + }, + pollIntervalInMs: 2000, + }) + get activeRegionProfile() { if (this.authProvider.isBuilderIdConnection()) { return undefined @@ -232,6 +247,10 @@ export class RegionProfileManager { } private async _switchRegionProfile(regionProfile: RegionProfile | undefined, source: ProfileSwitchIntent) { + if (this._activeRegionProfile?.arn === regionProfile?.arn) { + return + } + this._activeRegionProfile = regionProfile this._onDidChangeRegionProfile.fire({ @@ -293,13 +312,7 @@ export class RegionProfileManager { } private loadPersistedRegionProfle(): { [label: string]: RegionProfile } { - const previousPersistedState = globals.globalState.tryGet<{ [label: string]: RegionProfile }>( - 'aws.amazonq.regionProfiles', - Object, - {} - ) - - return previousPersistedState + return getRegionProfile() } async persistSelectRegionProfile() { @@ -309,11 +322,7 @@ export class RegionProfileManager { } // persist connectionId to profileArn - const previousPersistedState = globals.globalState.tryGet<{ [label: string]: RegionProfile }>( - 'aws.amazonq.regionProfiles', - Object, - {} - ) + const previousPersistedState = getRegionProfile() previousPersistedState[this.authProvider.profileName] = this.activeRegionProfile await globals.globalState.update('aws.amazonq.regionProfiles', previousPersistedState) diff --git a/packages/core/src/codewhisperer/util/authUtil.ts b/packages/core/src/codewhisperer/util/authUtil.ts index eac12d3cf80..ec9a2ff91e1 100644 --- a/packages/core/src/codewhisperer/util/authUtil.ts +++ b/packages/core/src/codewhisperer/util/authUtil.ts @@ -31,7 +31,7 @@ import { showAmazonQWalkthroughOnce } from '../../amazonq/onboardingPage/walkthr import { setContext } from '../../shared/vscode/setContext' import { openUrl } from '../../shared/utilities/vsCodeUtils' import { telemetry } from '../../shared/telemetry/telemetry' -import { AuthStateEvent, LanguageClientAuth, LoginTypes, SsoLogin } from '../../auth/auth2' +import { AuthStateEvent, cacheChangedEvent, LanguageClientAuth, LoginTypes, SsoLogin } from '../../auth/auth2' import { builderIdStartUrl, internalStartUrl } from '../../auth/sso/constants' import { VSCODE_EXTENSION_ID } from '../../shared/extensions' import { RegionProfileManager } from '../region/regionProfileManager' @@ -90,6 +90,7 @@ export class AuthUtil implements IAuthProvider { this.regionProfileManager.onDidChangeRegionProfile(async () => { await this.setVscodeContextProps() }) + lspAuth.registerCacheWatcher(async (event: cacheChangedEvent) => await this.cacheChangedHandler(event)) } // Do NOT use this in production code, only used for testing @@ -276,6 +277,14 @@ export class AuthUtil implements IAuthProvider { }) } + private async cacheChangedHandler(event: cacheChangedEvent) { + if (event === 'delete') { + await this.logout() + } else if (event === 'create') { + await this.restore() + } + } + private async stateChangeHandler(e: AuthStateEvent) { if (e.state === 'refreshed') { const params = this.isSsoSession() ? (await this.session.getToken()).updateCredentialsParams : undefined diff --git a/packages/core/src/shared/globalState.ts b/packages/core/src/shared/globalState.ts index c38344608b9..65d761412b8 100644 --- a/packages/core/src/shared/globalState.ts +++ b/packages/core/src/shared/globalState.ts @@ -318,3 +318,68 @@ export class GlobalState implements vscode.Memento { return all?.[id] } } + +export interface GlobalStatePollerProps { + getState: () => any + changeHandler: () => void + pollIntervalInMs: number +} + +/** + * Utility class that polls a state value at regular intervals and triggers a callback when the state changes. + * + * This class can be used to monitor changes in global state and react to those changes. + */ +export class GlobalStatePoller { + protected oldValue: any + protected pollIntervalInMs: number + protected getState: () => any + protected changeHandler: () => void + protected intervalId?: NodeJS.Timeout + + constructor(props: GlobalStatePollerProps) { + this.getState = props.getState + this.changeHandler = props.changeHandler + this.pollIntervalInMs = props.pollIntervalInMs + this.oldValue = this.getState() + } + + /** + * Factory method that creates and starts a GlobalStatePoller instance. + * + * @param getState - Function that returns the current state value to monitor, e.g. globals.globalState.tryGet + * @param changeHandler - Callback function that is invoked when the state changes + * @returns A new GlobalStatePoller instance that has already started polling + */ + static create(props: GlobalStatePollerProps) { + const instance = new GlobalStatePoller(props) + instance.poll() + return instance + } + + /** + * Starts polling the state value. When a change is detected, the changeHandler callback is invoked. + */ + private poll() { + if (this.intervalId) { + this.kill() + } + this.intervalId = setInterval(() => { + const newValue = this.getState() + if (this.oldValue !== newValue) { + this.oldValue = newValue + this.changeHandler() + } + }, this.pollIntervalInMs) + } + + /** + * Stops the polling interval. + */ + kill() { + if (this.intervalId) { + clearInterval(this.intervalId) + this.intervalId = undefined + } + } +} diff --git a/packages/core/src/test/testAuthUtil.ts b/packages/core/src/test/testAuthUtil.ts index 4feefec2d68..595f8bf45ef 100644 --- a/packages/core/src/test/testAuthUtil.ts +++ b/packages/core/src/test/testAuthUtil.ts @@ -36,6 +36,7 @@ export async function createTestAuthUtil() { deleteBearerToken: sinon.stub().resolves(), updateBearerToken: sinon.stub().resolves(), invalidateSsoToken: sinon.stub().resolves(), + registerCacheWatcher: sinon.stub().resolves(), encryptionKey, }