Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import { AuthUtil, RegionProfile, RegionProfileManager, defaultServiceConfig } f
import { globals } from 'aws-core-vscode/shared'
import { constants } from 'aws-core-vscode/auth'
import { createTestAuthUtil } from 'aws-core-vscode/test'
import { randomUUID } from 'crypto'

const enterpriseSsoStartUrl = 'https://enterprise.awsapps.com/start'
const region = 'us-east-1'
Expand Down Expand Up @@ -158,7 +159,7 @@ describe('RegionProfileManager', async function () {
})
})

describe('persistence', function () {
describe('persistSelectedRegionProfile', function () {
it('persistSelectedRegionProfile', async function () {
await setupConnection('idc')
await regionProfileManager.switchRegionProfile(profileFoo, 'user')
Expand All @@ -177,14 +178,13 @@ describe('RegionProfileManager', async function () {

assert.strictEqual(state[AuthUtil.instance.profileName], profileFoo)
})
})

it(`restoreRegionProfile`, async function () {
sinon.stub(regionProfileManager, 'listRegionProfile').resolves([profileFoo])
describe('restoreRegionProfile', function () {
beforeEach(async function () {
await setupConnection('idc')
if (!AuthUtil.instance.isConnected()) {
fail('connection should not be undefined')
}

})
it('restores region profile if profile name matches', async function () {
const state = {} as any
state[AuthUtil.instance.profileName] = profileFoo

Expand All @@ -194,6 +194,51 @@ describe('RegionProfileManager', async function () {

assert.strictEqual(regionProfileManager.activeRegionProfile, profileFoo)
})

it('returns early when no profiles exist', async function () {
const state = {} as any
state[AuthUtil.instance.profileName] = undefined

await globals.globalState.update('aws.amazonq.regionProfiles', state)

await regionProfileManager.restoreRegionProfile()
assert.strictEqual(regionProfileManager.activeRegionProfile, undefined)
})

it('returns early when no profile name matches, and multiple profiles exist', async function () {
const state = {} as any
state[AuthUtil.instance.profileName] = undefined
state[randomUUID()] = profileFoo

await globals.globalState.update('aws.amazonq.regionProfiles', state)

await regionProfileManager.restoreRegionProfile()
assert.strictEqual(regionProfileManager.activeRegionProfile, undefined)
})

it('uses single profile when no profile name matches', async function () {
const state = {} as any
state[randomUUID()] = profileFoo

await globals.globalState.update('aws.amazonq.regionProfiles', state)

await regionProfileManager.restoreRegionProfile()

assert.strictEqual(regionProfileManager.activeRegionProfile, profileFoo)
})

it('handles cross-validation failure', async function () {
const state = {
[AuthUtil.instance.profileName]: profileFoo,
}
sinon.stub(regionProfileManager, 'loadPersistedRegionProfiles').returns(state)
sinon.stub(regionProfileManager, 'getProfiles').resolves([]) // No matching profile
const invalidateStub = sinon.stub(regionProfileManager, 'invalidateProfile')

await regionProfileManager.restoreRegionProfile()

assert.ok(invalidateStub.calledWith(profileFoo.arn))
})
})

describe('invalidate', function () {
Expand Down
25 changes: 25 additions & 0 deletions packages/amazonq/test/unit/codewhisperer/util/authUtil.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -225,6 +225,7 @@ describe('AuthUtil', async function () {
})

describe('migrateSsoConnectionToLsp', function () {
let mockLspAuth: any
let memento: any
let cacheDir: string
let fromRegistrationFile: string
Expand All @@ -250,6 +251,9 @@ describe('AuthUtil', async function () {
sinon.stub(mementoUtils, 'getEnvironmentSpecificMemento').returns(memento)
sinon.stub(cache, 'getCacheDir').returns(cacheDir)

mockLspAuth = (auth as any).lspAuth
mockLspAuth.getSsoToken.resolves(undefined)

fromTokenFile = cache.getTokenCacheFile(cacheDir, 'profile1')
const registrationKey = {
startUrl: validProfile.startUrl,
Expand All @@ -269,6 +273,27 @@ describe('AuthUtil', async function () {
sinon.restore()
})

it('skips migration if LSP token exists', async function () {
memento.get.returns({ profile1: validProfile })
mockLspAuth.getSsoToken.resolves({ token: 'valid-token' })

await auth.migrateSsoConnectionToLsp('test-client')

assert.ok(memento.update.calledWith('auth.profiles', undefined))
assert.ok(!auth.session.updateProfile?.called)
})

it('proceeds with migration if LSP token check throws', async function () {
memento.get.returns({ profile1: validProfile })
mockLspAuth.getSsoToken.rejects(new Error('Token check failed'))
const updateProfileStub = sinon.stub((auth as any).session, 'updateProfile').resolves()

await auth.migrateSsoConnectionToLsp('test-client')

assert.ok(updateProfileStub.calledOnce)
assert.ok(memento.update.calledWith('auth.profiles', undefined))
})

it('migrates valid SSO connection', async function () {
memento.get.returns({ profile1: validProfile })

Expand Down
5 changes: 4 additions & 1 deletion packages/core/src/auth/sso/clients.ts
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ import { StandardRetryStrategy, defaultRetryDecider } from '@smithy/middleware-r
import { AuthenticationFlow } from './model'
import { toSnakeCase } from '../../shared/utilities/textUtilities'
import { getUserAgent, withTelemetryContext } from '../../shared/telemetry/util'
import { oneSecond } from '../../shared/datetime'

export class OidcClient {
public constructor(
Expand Down Expand Up @@ -124,7 +125,9 @@ export class OidcClient {
requestHandler: {
// This field may have a bug: https://github.com/aws/aws-sdk-js-v3/issues/6271
// If the bug is real but is fixed, then we can probably remove this field and just have no timeout by default
requestTimeout: 5000,
//
// Also, we bump this higher due to ticket V1761315147, so that SSO does not timeout
requestTimeout: oneSecond * 12,
},
})

Expand Down
31 changes: 22 additions & 9 deletions packages/core/src/codewhisperer/region/regionProfileManager.ts
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ const endpoints = createConstantMap({
'eu-central-1': 'https://q.eu-central-1.amazonaws.com/',
})

const getRegionProfile = () =>
const getRegionProfiles = () =>
globals.globalState.tryGet<{ [label: string]: RegionProfile }>('aws.amazonq.regionProfiles', Object, {})

/**
Expand Down Expand Up @@ -86,9 +86,9 @@ export class RegionProfileManager {
// This is a poller that handles synchornization of selected region profiles between different IDE windows.
// It checks for changes in global state of region profile, invoking the change handler to switch profiles
public globalStatePoller = GlobalStatePoller.create({
getState: getRegionProfile,
getState: getRegionProfiles,
changeHandler: async () => {
const profile = this.loadPersistedRegionProfle()
const profile = this.loadPersistedRegionProfiles()
void this._switchRegionProfile(profile[this.authProvider.profileName], 'reload')
},
pollIntervalInMs: 2000,
Expand Down Expand Up @@ -285,10 +285,23 @@ export class RegionProfileManager {

// Note: should be called after [this.authProvider.isConnected()] returns non null
async restoreRegionProfile() {
const previousSelected = this.loadPersistedRegionProfle()[this.authProvider.profileName] || undefined
if (!previousSelected) {
const profiles = this.loadPersistedRegionProfiles()
if (!profiles || Object.keys(profiles).length === 0) {
return
}

let previousSelected = profiles[this.authProvider.profileName]

// If no profile matches auth profileName and there are multiple profiles, return so user can select
if (!previousSelected && Object.keys(profiles).length > 1) {
return
}

// If no profile matches auth profileName but there's only one profile, use that one
if (!previousSelected && Object.keys(profiles).length === 1) {
previousSelected = Object.values(profiles)[0]
}

// cross-validation
this.getProfiles()
.then(async (profiles) => {
Expand Down Expand Up @@ -319,8 +332,8 @@ export class RegionProfileManager {
await this.switchRegionProfile(previousSelected, 'reload')
}

private loadPersistedRegionProfle(): { [label: string]: RegionProfile } {
return getRegionProfile()
public loadPersistedRegionProfiles(): { [label: string]: RegionProfile } {
return getRegionProfiles()
}

async persistSelectRegionProfile() {
Expand All @@ -330,7 +343,7 @@ export class RegionProfileManager {
}

// persist connectionId to profileArn
const previousPersistedState = getRegionProfile()
const previousPersistedState = getRegionProfiles()

previousPersistedState[this.authProvider.profileName] = this.activeRegionProfile
await globals.globalState.update('aws.amazonq.regionProfiles', previousPersistedState)
Expand Down Expand Up @@ -379,7 +392,7 @@ export class RegionProfileManager {
this._activeRegionProfile = undefined
}

const profiles = this.loadPersistedRegionProfle()
const profiles = this.loadPersistedRegionProfiles()
const updatedProfiles = Object.fromEntries(
Object.entries(profiles).filter(([connId, profile]) => profile.arn !== arn)
)
Expand Down
112 changes: 67 additions & 45 deletions packages/core/src/codewhisperer/util/authUtil.ts
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ import { getEnvironmentSpecificMemento } from '../../shared/utilities/mementos'
import { getCacheDir, getFlareCacheFileName, getRegistrationCacheFile, getTokenCacheFile } from '../../auth/sso/cache'
import { notifySelectDeveloperProfile } from '../region/utils'
import { once } from '../../shared/utilities/functionUtils'
import { CancellationTokenSource, SsoTokenSourceKind } from '@aws/language-server-runtimes/server-interface'

const localize = nls.loadMessageBundle()

Expand All @@ -64,6 +65,8 @@ export interface IAuthProvider {
*/
export class AuthUtil implements IAuthProvider {
public readonly profileName = VSCODE_EXTENSION_ID.amazonq
protected logger = getLogger('amazonqAuth')

public readonly regionProfileManager: RegionProfileManager

// IAM login currently not supported
Expand Down Expand Up @@ -277,7 +280,7 @@ export class AuthUtil implements IAuthProvider {
}

private async cacheChangedHandler(event: cacheChangedEvent) {
getLogger().debug(`Auth: Cache change event received: ${event}`)
this.logger.debug(`Cache change event received: ${event}`)
if (event === 'delete') {
await this.logout()
} else if (event === 'create') {
Expand All @@ -291,7 +294,7 @@ export class AuthUtil implements IAuthProvider {
await this.lspAuth.updateBearerToken(params!)
return
} else {
getLogger().info(`codewhisperer: connection changed to ${e.state}`)
this.logger.info(`codewhisperer: connection changed to ${e.state}`)
await this.refreshState(e.state)
}
}
Expand Down Expand Up @@ -402,58 +405,77 @@ export class AuthUtil implements IAuthProvider {

if (!profiles) {
return
} else {
getLogger().info(`codewhisperer: checking for old SSO connections`)
for (const [id, p] of Object.entries(profiles)) {
if (p.type === 'sso' && hasExactScopes(p.scopes ?? [], amazonQScopes)) {
toImport = p
profileId = id
if (p.metadata.connectionState === 'valid') {
break
}
}
}
}

if (toImport && profileId) {
getLogger().info(`codewhisperer: migrating SSO connection to LSP identity server...`)
try {
// Try go get token from LSP auth. If available, skip migration and delete old auth profile
const token = await this.lspAuth.getSsoToken(
{
kind: SsoTokenSourceKind.IamIdentityCenter,
profileName: this.profileName,
},
false,
new CancellationTokenSource().token
)
if (token) {
this.logger.info('existing LSP auth connection found. Skipping migration')
await memento.update(key, undefined)
return
}
} catch {
this.logger.info('unable to get token from LSP auth, proceeding migration')
}

const registrationKey = {
startUrl: toImport.startUrl,
region: toImport.ssoRegion,
scopes: amazonQScopes,
this.logger.info('checking for old SSO connections')
for (const [id, p] of Object.entries(profiles)) {
if (p.type === 'sso' && hasExactScopes(p.scopes ?? [], amazonQScopes)) {
toImport = p
profileId = id
if (p.metadata.connectionState === 'valid') {
break
}
}
}

await this.session.updateProfile(registrationKey)
if (toImport && profileId) {
this.logger.info('migrating SSO connection to LSP identity server...')

const cacheDir = getCacheDir()
const registrationKey = {
startUrl: toImport.startUrl,
region: toImport.ssoRegion,
scopes: amazonQScopes,
}

const fromRegistrationFile = getRegistrationCacheFile(cacheDir, registrationKey)
const toRegistrationFile = path.join(
cacheDir,
getFlareCacheFileName(
JSON.stringify({
region: toImport.ssoRegion,
startUrl: toImport.startUrl,
tool: clientName,
})
)
)
await this.session.updateProfile(registrationKey)

const fromTokenFile = getTokenCacheFile(cacheDir, profileId)
const toTokenFile = path.join(cacheDir, getFlareCacheFileName(this.profileName))
const cacheDir = getCacheDir()

try {
await fs.rename(fromRegistrationFile, toRegistrationFile)
await fs.rename(fromTokenFile, toTokenFile)
getLogger().debug('Successfully renamed registration and token files')
} catch (err) {
getLogger().error(`Failed to rename files during migration: ${err}`)
throw err
}

await memento.update(key, undefined)
getLogger().info(`codewhisperer: successfully migrated SSO connection to LSP identity server`)
const fromRegistrationFile = getRegistrationCacheFile(cacheDir, registrationKey)
const toRegistrationFile = path.join(
cacheDir,
getFlareCacheFileName(
JSON.stringify({
region: toImport.ssoRegion,
startUrl: toImport.startUrl,
tool: clientName,
})
)
)

const fromTokenFile = getTokenCacheFile(cacheDir, profileId)
const toTokenFile = path.join(cacheDir, getFlareCacheFileName(this.profileName))

try {
await fs.rename(fromRegistrationFile, toRegistrationFile)
await fs.rename(fromTokenFile, toTokenFile)
this.logger.debug('Successfully renamed registration and token files')
} catch (err) {
this.logger.error(`Failed to rename files during migration: ${err}`)
throw err
}

this.logger.info('successfully migrated SSO connection to LSP identity server')
await memento.update(key, undefined)
}
}
}
1 change: 1 addition & 0 deletions packages/core/src/shared/logger/logger.ts
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ export type LogTopic =
| 'nextEditPrediction'
| 'resourceCache'
| 'telemetry'
| 'amazonqAuth'

class ErrorLog {
constructor(
Expand Down
Loading