Skip to content
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 28 additions & 0 deletions packages/amazonq/test/unit/codewhisperer/util/authUtil.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,34 @@ describe('AuthUtil', async function () {
})
})

describe('cacheChangedHandler', function () {
it('calls logout when event is delete', async function () {
const logoutSpy = sinon.spy(auth, 'logout')

await (auth as any).cacheChangedHandler('delete')

assert.ok(logoutSpy.calledOnce)
})

it('calls restore when event is create', async function () {
const restoreSpy = sinon.spy(auth, 'restore')

await (auth as any).cacheChangedHandler('create')

assert.ok(restoreSpy.calledOnce)
})

it('does nothing for other events', async function () {
const logoutSpy = sinon.spy(auth, 'logout')
const restoreSpy = sinon.spy(auth, 'restore')

await (auth as any).cacheChangedHandler('unknown')

assert.ok(logoutSpy.notCalled)
assert.ok(restoreSpy.notCalled)
})
})

describe('stateChangeHandler', function () {
let mockLspAuth: any
let regionProfileManager: any
Expand Down
12 changes: 12 additions & 0 deletions packages/core/src/auth/auth2.ts
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ import { LanguageClient } from 'vscode-languageclient'
import { getLogger } from '../shared/logger/logger'
import { ToolkitError } from '../shared/errors'
import { useDeviceFlow } from './sso/ssoAccessTokenProvider'
import { getCacheFileWatcher } from './sso/cache'

export const notificationTypes = {
updateBearerToken: new RequestType<UpdateCredentialsParams, ResponseMessage, Error>(
Expand Down Expand Up @@ -74,12 +75,18 @@ export type TokenSource = IamIdentityCenterSsoTokenSource | AwsBuilderIdSsoToken
* Handles auth requests to the Identity Server in the Amazon Q LSP.
*/
export class LanguageClientAuth {
readonly #ssoCacheWatcher = getCacheFileWatcher()

constructor(
private readonly client: LanguageClient,
private readonly clientName: string,
public readonly encryptionKey: Buffer
) {}

public get cacheWatcher() {
return this.#ssoCacheWatcher
}

getSsoToken(
tokenSource: TokenSource,
login: boolean = false,
Expand Down Expand Up @@ -160,6 +167,11 @@ export class LanguageClientAuth {
registerSsoTokenChangedHandler(ssoTokenChangedHandler: (params: SsoTokenChangedParams) => any) {
this.client.onNotification(ssoTokenChangedRequestType.method, ssoTokenChangedHandler)
}

registerCacheWatcher(cacheChangedHandler: (event: string) => any) {
this.cacheWatcher.onDidCreate(() => cacheChangedHandler('create'))
this.cacheWatcher.onDidDelete(() => cacheChangedHandler('delete'))
}
}

/**
Expand Down
1 change: 1 addition & 0 deletions packages/core/src/codewhisperer/activation.ts
Original file line number Diff line number Diff line change
Expand Up @@ -498,6 +498,7 @@ export async function activate(context: ExtContext): Promise<void> {
export async function shutdown() {
RecommendationHandler.instance.reportUserDecisions(-1)
await CodeWhispererTracker.getTracker().shutdown()
AuthUtil.instance.regionProfileManager.globalStatePoller.kill()
}

function toggleIssuesVisibility(visibleCondition: (issue: CodeScanIssue, filePath: string) => boolean) {
Expand Down
33 changes: 21 additions & 12 deletions packages/core/src/codewhisperer/region/regionProfileManager.ts
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ import { localize } from '../../shared/utilities/vsCodeUtils'
import { IAuthProvider } from '../util/authUtil'
import { Commands } from '../../shared/vscode/commands2'
import { CachedResource } from '../../shared/utilities/resourceCache'
import { GlobalStatePoller } from '../../shared/globalState'

// TODO: is there a better way to manage all endpoint strings in one place?
export const defaultServiceConfig: CodeWhispererConfig = {
Expand All @@ -37,6 +38,9 @@ const endpoints = createConstantMap({
'eu-central-1': 'https://q.eu-central-1.amazonaws.com/',
})

const getRegionProfile = () =>
globals.globalState.tryGet<{ [label: string]: RegionProfile }>('aws.amazonq.regionProfiles', Object, {})

/**
* 'user' -> users change the profile through Q menu
* 'auth' -> users change the profile through webview profile selector page
Expand Down Expand Up @@ -79,6 +83,17 @@ export class RegionProfileManager {
}
})(this.listRegionProfile.bind(this))

// This is a poller that handles synchornization of selected region profiles between different IDE windows.
// It checks for changes in global state of region profile, invoking the change handler to switch profiles
public globalStatePoller = GlobalStatePoller.create({
getState: getRegionProfile,
changeHandler: async () => {
const profile = this.loadPersistedRegionProfle()
void this._switchRegionProfile(profile[this.authProvider.profileName], 'reload')
},
pollIntervalInMs: 2000,
})

get activeRegionProfile() {
if (this.authProvider.isBuilderIdConnection()) {
return undefined
Expand Down Expand Up @@ -232,6 +247,10 @@ export class RegionProfileManager {
}

private async _switchRegionProfile(regionProfile: RegionProfile | undefined, source: ProfileSwitchIntent) {
if (this._activeRegionProfile?.arn === regionProfile?.arn) {
return
}

this._activeRegionProfile = regionProfile

this._onDidChangeRegionProfile.fire({
Expand Down Expand Up @@ -293,13 +312,7 @@ export class RegionProfileManager {
}

private loadPersistedRegionProfle(): { [label: string]: RegionProfile } {
const previousPersistedState = globals.globalState.tryGet<{ [label: string]: RegionProfile }>(
'aws.amazonq.regionProfiles',
Object,
{}
)

return previousPersistedState
return getRegionProfile()
}

async persistSelectRegionProfile() {
Expand All @@ -309,11 +322,7 @@ export class RegionProfileManager {
}

// persist connectionId to profileArn
const previousPersistedState = globals.globalState.tryGet<{ [label: string]: RegionProfile }>(
'aws.amazonq.regionProfiles',
Object,
{}
)
const previousPersistedState = getRegionProfile()

previousPersistedState[this.authProvider.profileName] = this.activeRegionProfile
await globals.globalState.update('aws.amazonq.regionProfiles', previousPersistedState)
Expand Down
9 changes: 9 additions & 0 deletions packages/core/src/codewhisperer/util/authUtil.ts
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,7 @@ export class AuthUtil implements IAuthProvider {
this.regionProfileManager.onDidChangeRegionProfile(async () => {
await this.setVscodeContextProps()
})
lspAuth.registerCacheWatcher(async (event: string) => await this.cacheChangedHandler(event))
}

// Do NOT use this in production code, only used for testing
Expand Down Expand Up @@ -276,6 +277,14 @@ export class AuthUtil implements IAuthProvider {
})
}

private async cacheChangedHandler(event: 'create' | 'delete') {
if (event === 'delete') {
await this.logout()
} else if (event === 'create') {
await this.restore()
}
}

private async stateChangeHandler(e: AuthStateEvent) {
if (e.state === 'refreshed') {
const params = this.isSsoSession() ? (await this.session.getToken()).updateCredentialsParams : undefined
Expand Down
62 changes: 62 additions & 0 deletions packages/core/src/shared/globalState.ts
Original file line number Diff line number Diff line change
Expand Up @@ -318,3 +318,65 @@ export class GlobalState implements vscode.Memento {
return all?.[id]
}
}

export interface GlobalStatePollerProps {
getState: () => any
changeHandler: () => void
pollIntervalInMs: number
}

/**
* Utility class that polls a state value at regular intervals and triggers a callback when the state changes.
*
* This class can be used to monitor changes in global state and react to those changes.
*/
export class GlobalStatePoller {
protected oldValue: any
protected pollIntervalInMs: number
protected getState: () => any
protected changeHandler: () => void
protected intervalId?: NodeJS.Timeout

constructor(props: GlobalStatePollerProps) {
this.getState = props.getState
this.changeHandler = props.changeHandler
this.pollIntervalInMs = props.pollIntervalInMs
this.oldValue = this.getState()
}

/**
* Factory method that creates and starts a GlobalStatePoller instance.
*
* @param getState - Function that returns the current state value to monitor, e.g. globals.globalState.tryGet
* @param changeHandler - Callback function that is invoked when the state changes
* @returns A new GlobalStatePoller instance that has already started polling
*/
static create(props: GlobalStatePollerProps) {
const instance = new GlobalStatePoller(props)
instance.poll()
return instance
}

/**
* Starts polling the state value. When a change is detected, the changeHandler callback is invoked.
*/
private poll() {
this.intervalId = setInterval(() => {
const newValue = this.getState()
if (this.oldValue !== newValue) {
this.oldValue = newValue
this.changeHandler()
}
}, this.pollIntervalInMs)
}

/**
* Stops the polling interval.
*/
kill() {
if (this.intervalId) {
clearInterval(this.intervalId)
this.intervalId = undefined
}
}
}
1 change: 1 addition & 0 deletions packages/core/src/test/testAuthUtil.ts
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ export async function createTestAuthUtil() {
deleteBearerToken: sinon.stub().resolves(),
updateBearerToken: sinon.stub().resolves(),
invalidateSsoToken: sinon.stub().resolves(),
registerCacheWatcher: sinon.stub().resolves(),
encryptionKey,
}

Expand Down
Loading