Skip to content
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 28 additions & 0 deletions packages/amazonq/test/unit/codewhisperer/util/authUtil.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,34 @@ describe('AuthUtil', async function () {
})
})

describe('cacheChangedHandler', function () {
it('calls logout when event is delete', async function () {
const logoutSpy = sinon.spy(auth, 'logout')

await (auth as any).cacheChangedHandler('delete')

assert.ok(logoutSpy.calledOnce)
})

it('calls restore when event is create', async function () {
const restoreSpy = sinon.spy(auth, 'restore')

await (auth as any).cacheChangedHandler('create')

assert.ok(restoreSpy.calledOnce)
})

it('does nothing for other events', async function () {
const logoutSpy = sinon.spy(auth, 'logout')
const restoreSpy = sinon.spy(auth, 'restore')

await (auth as any).cacheChangedHandler('unknown')

assert.ok(logoutSpy.notCalled)
assert.ok(restoreSpy.notCalled)
})
})

describe('stateChangeHandler', function () {
let mockLspAuth: any
let regionProfileManager: any
Expand Down
12 changes: 12 additions & 0 deletions packages/core/src/auth/auth2.ts
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ import { LanguageClient } from 'vscode-languageclient'
import { getLogger } from '../shared/logger/logger'
import { ToolkitError } from '../shared/errors'
import { useDeviceFlow } from './sso/ssoAccessTokenProvider'
import { getCacheFileWatcher } from './sso/cache'

export const notificationTypes = {
updateBearerToken: new RequestType<UpdateCredentialsParams, ResponseMessage, Error>(
Expand Down Expand Up @@ -74,12 +75,18 @@ export type TokenSource = IamIdentityCenterSsoTokenSource | AwsBuilderIdSsoToken
* Handles auth requests to the Identity Server in the Amazon Q LSP.
*/
export class LanguageClientAuth {
readonly #ssoCacheWatcher = getCacheFileWatcher()

constructor(
private readonly client: LanguageClient,
private readonly clientName: string,
public readonly encryptionKey: Buffer
) {}

public get cacheWatcher() {
return this.#ssoCacheWatcher
}

getSsoToken(
tokenSource: TokenSource,
login: boolean = false,
Expand Down Expand Up @@ -160,6 +167,11 @@ export class LanguageClientAuth {
registerSsoTokenChangedHandler(ssoTokenChangedHandler: (params: SsoTokenChangedParams) => any) {
this.client.onNotification(ssoTokenChangedRequestType.method, ssoTokenChangedHandler)
}

registerCacheWatcher(cacheChangedHandler: (event: string) => any) {
this.cacheWatcher.onDidCreate(() => cacheChangedHandler('create'))
this.cacheWatcher.onDidDelete(() => cacheChangedHandler('delete'))
}
}

/**
Expand Down
18 changes: 17 additions & 1 deletion packages/core/src/codewhisperer/region/regionProfileManager.ts
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ import { localize } from '../../shared/utilities/vsCodeUtils'
import { IAuthProvider } from '../util/authUtil'
import { Commands } from '../../shared/vscode/commands2'
import { CachedResource } from '../../shared/utilities/resourceCache'
import { GlobalStatePoller } from '../../shared/globalState'

// TODO: is there a better way to manage all endpoint strings in one place?
export const defaultServiceConfig: CodeWhispererConfig = {
Expand Down Expand Up @@ -120,7 +121,18 @@ export class RegionProfileManager {
return this._profiles
}

constructor(private readonly authProvider: IAuthProvider) {}
constructor(private readonly authProvider: IAuthProvider) {
const getProfileFunction = () =>
globals.globalState.tryGet<{ [label: string]: RegionProfile }>('aws.amazonq.regionProfiles', Object, {})
const profileChangedHandler = async () => {
const profile = this.loadPersistedRegionProfle()
void this._switchRegionProfile(profile[this.authProvider.profileName], 'reload')
}

// This is a poller that handles synchornization of selected region profiles between different IDE windows.
// It checks for changes in global state of region profile, invoking the change handler to switch profiles
GlobalStatePoller.create(getProfileFunction, profileChangedHandler)
}

async getProfiles(): Promise<RegionProfile[]> {
return this.cache.getResource()
Expand Down Expand Up @@ -232,6 +244,10 @@ export class RegionProfileManager {
}

private async _switchRegionProfile(regionProfile: RegionProfile | undefined, source: ProfileSwitchIntent) {
if (this._activeRegionProfile?.arn === regionProfile?.arn) {
return
}

this._activeRegionProfile = regionProfile

this._onDidChangeRegionProfile.fire({
Expand Down
9 changes: 9 additions & 0 deletions packages/core/src/codewhisperer/util/authUtil.ts
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,7 @@ export class AuthUtil implements IAuthProvider {
this.regionProfileManager.onDidChangeRegionProfile(async () => {
await this.setVscodeContextProps()
})
lspAuth.registerCacheWatcher(async (event: string) => await this.cacheChangedHandler(event))
}

// Do NOT use this in production code, only used for testing
Expand Down Expand Up @@ -276,6 +277,14 @@ export class AuthUtil implements IAuthProvider {
})
}

private async cacheChangedHandler(event: string) {
if (event === 'delete') {
await this.logout()
} else if (event === 'create') {
await this.restore()
}
}

private async stateChangeHandler(e: AuthStateEvent) {
if (e.state === 'refreshed') {
const params = this.isSsoSession() ? (await this.session.getToken()).updateCredentialsParams : undefined
Expand Down
45 changes: 45 additions & 0 deletions packages/core/src/shared/globalState.ts
Original file line number Diff line number Diff line change
Expand Up @@ -318,3 +318,48 @@ export class GlobalState implements vscode.Memento {
return all?.[id]
}
}

/**
* Utility class that polls a state value at regular intervals and triggers a callback when the state changes.
*
* This class can be used to monitor changes in global state and react to those changes.
*/
export class GlobalStatePoller {
protected oldValue: any
protected getState: () => any
protected changeHandler: () => void

constructor(getState: () => any, changeHandler: () => void) {
this.getState = getState
this.changeHandler = changeHandler
this.oldValue = getState()
}

/**
* Factory method that creates and starts a GlobalStatePoller instance.
*
* @param getState - Function that returns the current state value to monitor, e.g. globals.globalState.tryGet
* @param changeHandler - Callback function that is invoked when the state changes
* @returns A new GlobalStatePoller instance that has already started polling
*/
static create(getState: () => any, changeHandler: () => void) {
const instance = new GlobalStatePoller(getState, changeHandler)
instance.poll()
return instance
}

/**
* Starts polling the state value at 1 second intervals.
* When a change is detected, the changeHandler callback is invoked.
*/
poll() {
const interval = 1000 // ms
setInterval(() => {
const newValue = this.getState()
if (this.oldValue !== newValue) {
this.oldValue = newValue
this.changeHandler()
}
}, interval)
}
}
1 change: 1 addition & 0 deletions packages/core/src/test/testAuthUtil.ts
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ export async function createTestAuthUtil() {
deleteBearerToken: sinon.stub().resolves(),
updateBearerToken: sinon.stub().resolves(),
invalidateSsoToken: sinon.stub().resolves(),
registerCacheWatcher: sinon.stub().resolves(),
encryptionKey,
}

Expand Down
Loading