Skip to content

Commit 7498ed2

Browse files
authored
Merge pull request #7370 from aws/feature/flare-mega
fix(amazonq): Merge feature/flare-mega
2 parents 1c349f1 + 0ddbf4a commit 7498ed2

File tree

5 files changed

+167
-61
lines changed

5 files changed

+167
-61
lines changed

packages/amazonq/test/unit/codewhisperer/region/regionProfileManager.test.ts

Lines changed: 52 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ import { AuthUtil, RegionProfile, RegionProfileManager, defaultServiceConfig } f
99
import { globals } from 'aws-core-vscode/shared'
1010
import { constants } from 'aws-core-vscode/auth'
1111
import { createTestAuthUtil } from 'aws-core-vscode/test'
12+
import { randomUUID } from 'crypto'
1213

1314
const enterpriseSsoStartUrl = 'https://enterprise.awsapps.com/start'
1415
const region = 'us-east-1'
@@ -158,7 +159,7 @@ describe('RegionProfileManager', async function () {
158159
})
159160
})
160161

161-
describe('persistence', function () {
162+
describe('persistSelectedRegionProfile', function () {
162163
it('persistSelectedRegionProfile', async function () {
163164
await setupConnection('idc')
164165
await regionProfileManager.switchRegionProfile(profileFoo, 'user')
@@ -177,14 +178,13 @@ describe('RegionProfileManager', async function () {
177178

178179
assert.strictEqual(state[AuthUtil.instance.profileName], profileFoo)
179180
})
181+
})
180182

181-
it(`restoreRegionProfile`, async function () {
182-
sinon.stub(regionProfileManager, 'listRegionProfile').resolves([profileFoo])
183+
describe('restoreRegionProfile', function () {
184+
beforeEach(async function () {
183185
await setupConnection('idc')
184-
if (!AuthUtil.instance.isConnected()) {
185-
fail('connection should not be undefined')
186-
}
187-
186+
})
187+
it('restores region profile if profile name matches', async function () {
188188
const state = {} as any
189189
state[AuthUtil.instance.profileName] = profileFoo
190190

@@ -194,6 +194,51 @@ describe('RegionProfileManager', async function () {
194194

195195
assert.strictEqual(regionProfileManager.activeRegionProfile, profileFoo)
196196
})
197+
198+
it('returns early when no profiles exist', async function () {
199+
const state = {} as any
200+
state[AuthUtil.instance.profileName] = undefined
201+
202+
await globals.globalState.update('aws.amazonq.regionProfiles', state)
203+
204+
await regionProfileManager.restoreRegionProfile()
205+
assert.strictEqual(regionProfileManager.activeRegionProfile, undefined)
206+
})
207+
208+
it('returns early when no profile name matches, and multiple profiles exist', async function () {
209+
const state = {} as any
210+
state[AuthUtil.instance.profileName] = undefined
211+
state[randomUUID()] = profileFoo
212+
213+
await globals.globalState.update('aws.amazonq.regionProfiles', state)
214+
215+
await regionProfileManager.restoreRegionProfile()
216+
assert.strictEqual(regionProfileManager.activeRegionProfile, undefined)
217+
})
218+
219+
it('uses single profile when no profile name matches', async function () {
220+
const state = {} as any
221+
state[randomUUID()] = profileFoo
222+
223+
await globals.globalState.update('aws.amazonq.regionProfiles', state)
224+
225+
await regionProfileManager.restoreRegionProfile()
226+
227+
assert.strictEqual(regionProfileManager.activeRegionProfile, profileFoo)
228+
})
229+
230+
it('handles cross-validation failure', async function () {
231+
const state = {
232+
[AuthUtil.instance.profileName]: profileFoo,
233+
}
234+
sinon.stub(regionProfileManager, 'loadPersistedRegionProfiles').returns(state)
235+
sinon.stub(regionProfileManager, 'getProfiles').resolves([]) // No matching profile
236+
const invalidateStub = sinon.stub(regionProfileManager, 'invalidateProfile')
237+
238+
await regionProfileManager.restoreRegionProfile()
239+
240+
assert.ok(invalidateStub.calledWith(profileFoo.arn))
241+
})
197242
})
198243

199244
describe('invalidate', function () {

packages/amazonq/test/unit/codewhisperer/util/authUtil.test.ts

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -225,6 +225,7 @@ describe('AuthUtil', async function () {
225225
})
226226

227227
describe('migrateSsoConnectionToLsp', function () {
228+
let mockLspAuth: any
228229
let memento: any
229230
let cacheDir: string
230231
let fromRegistrationFile: string
@@ -250,6 +251,9 @@ describe('AuthUtil', async function () {
250251
sinon.stub(mementoUtils, 'getEnvironmentSpecificMemento').returns(memento)
251252
sinon.stub(cache, 'getCacheDir').returns(cacheDir)
252253

254+
mockLspAuth = (auth as any).lspAuth
255+
mockLspAuth.getSsoToken.resolves(undefined)
256+
253257
fromTokenFile = cache.getTokenCacheFile(cacheDir, 'profile1')
254258
const registrationKey = {
255259
startUrl: validProfile.startUrl,
@@ -269,6 +273,27 @@ describe('AuthUtil', async function () {
269273
sinon.restore()
270274
})
271275

276+
it('skips migration if LSP token exists', async function () {
277+
memento.get.returns({ profile1: validProfile })
278+
mockLspAuth.getSsoToken.resolves({ token: 'valid-token' })
279+
280+
await auth.migrateSsoConnectionToLsp('test-client')
281+
282+
assert.ok(memento.update.calledWith('auth.profiles', undefined))
283+
assert.ok(!auth.session.updateProfile?.called)
284+
})
285+
286+
it('proceeds with migration if LSP token check throws', async function () {
287+
memento.get.returns({ profile1: validProfile })
288+
mockLspAuth.getSsoToken.rejects(new Error('Token check failed'))
289+
const updateProfileStub = sinon.stub((auth as any).session, 'updateProfile').resolves()
290+
291+
await auth.migrateSsoConnectionToLsp('test-client')
292+
293+
assert.ok(updateProfileStub.calledOnce)
294+
assert.ok(memento.update.calledWith('auth.profiles', undefined))
295+
})
296+
272297
it('migrates valid SSO connection', async function () {
273298
memento.get.returns({ profile1: validProfile })
274299

packages/core/src/codewhisperer/region/regionProfileManager.ts

Lines changed: 22 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ const endpoints = createConstantMap({
3838
'eu-central-1': 'https://q.eu-central-1.amazonaws.com/',
3939
})
4040

41-
const getRegionProfile = () =>
41+
const getRegionProfiles = () =>
4242
globals.globalState.tryGet<{ [label: string]: RegionProfile }>('aws.amazonq.regionProfiles', Object, {})
4343

4444
/**
@@ -86,9 +86,9 @@ export class RegionProfileManager {
8686
// This is a poller that handles synchornization of selected region profiles between different IDE windows.
8787
// It checks for changes in global state of region profile, invoking the change handler to switch profiles
8888
public globalStatePoller = GlobalStatePoller.create({
89-
getState: getRegionProfile,
89+
getState: getRegionProfiles,
9090
changeHandler: async () => {
91-
const profile = this.loadPersistedRegionProfle()
91+
const profile = this.loadPersistedRegionProfiles()
9292
void this._switchRegionProfile(profile[this.authProvider.profileName], 'reload')
9393
},
9494
pollIntervalInMs: 2000,
@@ -285,10 +285,23 @@ export class RegionProfileManager {
285285

286286
// Note: should be called after [this.authProvider.isConnected()] returns non null
287287
async restoreRegionProfile() {
288-
const previousSelected = this.loadPersistedRegionProfle()[this.authProvider.profileName] || undefined
289-
if (!previousSelected) {
288+
const profiles = this.loadPersistedRegionProfiles()
289+
if (!profiles || Object.keys(profiles).length === 0) {
290290
return
291291
}
292+
293+
let previousSelected = profiles[this.authProvider.profileName]
294+
295+
// If no profile matches auth profileName and there are multiple profiles, return so user can select
296+
if (!previousSelected && Object.keys(profiles).length > 1) {
297+
return
298+
}
299+
300+
// If no profile matches auth profileName but there's only one profile, use that one
301+
if (!previousSelected && Object.keys(profiles).length === 1) {
302+
previousSelected = Object.values(profiles)[0]
303+
}
304+
292305
// cross-validation
293306
this.getProfiles()
294307
.then(async (profiles) => {
@@ -319,8 +332,8 @@ export class RegionProfileManager {
319332
await this.switchRegionProfile(previousSelected, 'reload')
320333
}
321334

322-
private loadPersistedRegionProfle(): { [label: string]: RegionProfile } {
323-
return getRegionProfile()
335+
public loadPersistedRegionProfiles(): { [label: string]: RegionProfile } {
336+
return getRegionProfiles()
324337
}
325338

326339
async persistSelectRegionProfile() {
@@ -330,7 +343,7 @@ export class RegionProfileManager {
330343
}
331344

332345
// persist connectionId to profileArn
333-
const previousPersistedState = getRegionProfile()
346+
const previousPersistedState = getRegionProfiles()
334347

335348
previousPersistedState[this.authProvider.profileName] = this.activeRegionProfile
336349
await globals.globalState.update('aws.amazonq.regionProfiles', previousPersistedState)
@@ -379,7 +392,7 @@ export class RegionProfileManager {
379392
this._activeRegionProfile = undefined
380393
}
381394

382-
const profiles = this.loadPersistedRegionProfle()
395+
const profiles = this.loadPersistedRegionProfiles()
383396
const updatedProfiles = Object.fromEntries(
384397
Object.entries(profiles).filter(([connId, profile]) => profile.arn !== arn)
385398
)

packages/core/src/codewhisperer/util/authUtil.ts

Lines changed: 67 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@ import { getEnvironmentSpecificMemento } from '../../shared/utilities/mementos'
3939
import { getCacheDir, getFlareCacheFileName, getRegistrationCacheFile, getTokenCacheFile } from '../../auth/sso/cache'
4040
import { notifySelectDeveloperProfile } from '../region/utils'
4141
import { once } from '../../shared/utilities/functionUtils'
42+
import { CancellationTokenSource, SsoTokenSourceKind } from '@aws/language-server-runtimes/server-interface'
4243

4344
const localize = nls.loadMessageBundle()
4445

@@ -64,6 +65,8 @@ export interface IAuthProvider {
6465
*/
6566
export class AuthUtil implements IAuthProvider {
6667
public readonly profileName = VSCODE_EXTENSION_ID.amazonq
68+
protected logger = getLogger('amazonqAuth')
69+
6770
public readonly regionProfileManager: RegionProfileManager
6871

6972
// IAM login currently not supported
@@ -277,7 +280,7 @@ export class AuthUtil implements IAuthProvider {
277280
}
278281

279282
private async cacheChangedHandler(event: cacheChangedEvent) {
280-
getLogger().debug(`Auth: Cache change event received: ${event}`)
283+
this.logger.debug(`Cache change event received: ${event}`)
281284
if (event === 'delete') {
282285
await this.logout()
283286
} else if (event === 'create') {
@@ -291,7 +294,7 @@ export class AuthUtil implements IAuthProvider {
291294
await this.lspAuth.updateBearerToken(params!)
292295
return
293296
} else {
294-
getLogger().info(`codewhisperer: connection changed to ${e.state}`)
297+
this.logger.info(`codewhisperer: connection changed to ${e.state}`)
295298
await this.refreshState(e.state)
296299
}
297300
}
@@ -402,58 +405,77 @@ export class AuthUtil implements IAuthProvider {
402405

403406
if (!profiles) {
404407
return
405-
} else {
406-
getLogger().info(`codewhisperer: checking for old SSO connections`)
407-
for (const [id, p] of Object.entries(profiles)) {
408-
if (p.type === 'sso' && hasExactScopes(p.scopes ?? [], amazonQScopes)) {
409-
toImport = p
410-
profileId = id
411-
if (p.metadata.connectionState === 'valid') {
412-
break
413-
}
414-
}
415-
}
408+
}
416409

417-
if (toImport && profileId) {
418-
getLogger().info(`codewhisperer: migrating SSO connection to LSP identity server...`)
410+
try {
411+
// Try go get token from LSP auth. If available, skip migration and delete old auth profile
412+
const token = await this.lspAuth.getSsoToken(
413+
{
414+
kind: SsoTokenSourceKind.IamIdentityCenter,
415+
profileName: this.profileName,
416+
},
417+
false,
418+
new CancellationTokenSource().token
419+
)
420+
if (token) {
421+
this.logger.info('existing LSP auth connection found. Skipping migration')
422+
await memento.update(key, undefined)
423+
return
424+
}
425+
} catch {
426+
this.logger.info('unable to get token from LSP auth, proceeding migration')
427+
}
419428

420-
const registrationKey = {
421-
startUrl: toImport.startUrl,
422-
region: toImport.ssoRegion,
423-
scopes: amazonQScopes,
429+
this.logger.info('checking for old SSO connections')
430+
for (const [id, p] of Object.entries(profiles)) {
431+
if (p.type === 'sso' && hasExactScopes(p.scopes ?? [], amazonQScopes)) {
432+
toImport = p
433+
profileId = id
434+
if (p.metadata.connectionState === 'valid') {
435+
break
424436
}
437+
}
438+
}
425439

426-
await this.session.updateProfile(registrationKey)
440+
if (toImport && profileId) {
441+
this.logger.info('migrating SSO connection to LSP identity server...')
427442

428-
const cacheDir = getCacheDir()
443+
const registrationKey = {
444+
startUrl: toImport.startUrl,
445+
region: toImport.ssoRegion,
446+
scopes: amazonQScopes,
447+
}
429448

430-
const fromRegistrationFile = getRegistrationCacheFile(cacheDir, registrationKey)
431-
const toRegistrationFile = path.join(
432-
cacheDir,
433-
getFlareCacheFileName(
434-
JSON.stringify({
435-
region: toImport.ssoRegion,
436-
startUrl: toImport.startUrl,
437-
tool: clientName,
438-
})
439-
)
440-
)
449+
await this.session.updateProfile(registrationKey)
441450

442-
const fromTokenFile = getTokenCacheFile(cacheDir, profileId)
443-
const toTokenFile = path.join(cacheDir, getFlareCacheFileName(this.profileName))
451+
const cacheDir = getCacheDir()
444452

445-
try {
446-
await fs.rename(fromRegistrationFile, toRegistrationFile)
447-
await fs.rename(fromTokenFile, toTokenFile)
448-
getLogger().debug('Successfully renamed registration and token files')
449-
} catch (err) {
450-
getLogger().error(`Failed to rename files during migration: ${err}`)
451-
throw err
452-
}
453-
454-
await memento.update(key, undefined)
455-
getLogger().info(`codewhisperer: successfully migrated SSO connection to LSP identity server`)
453+
const fromRegistrationFile = getRegistrationCacheFile(cacheDir, registrationKey)
454+
const toRegistrationFile = path.join(
455+
cacheDir,
456+
getFlareCacheFileName(
457+
JSON.stringify({
458+
region: toImport.ssoRegion,
459+
startUrl: toImport.startUrl,
460+
tool: clientName,
461+
})
462+
)
463+
)
464+
465+
const fromTokenFile = getTokenCacheFile(cacheDir, profileId)
466+
const toTokenFile = path.join(cacheDir, getFlareCacheFileName(this.profileName))
467+
468+
try {
469+
await fs.rename(fromRegistrationFile, toRegistrationFile)
470+
await fs.rename(fromTokenFile, toTokenFile)
471+
this.logger.debug('Successfully renamed registration and token files')
472+
} catch (err) {
473+
this.logger.error(`Failed to rename files during migration: ${err}`)
474+
throw err
456475
}
476+
477+
this.logger.info('successfully migrated SSO connection to LSP identity server')
478+
await memento.update(key, undefined)
457479
}
458480
}
459481
}

packages/core/src/shared/logger/logger.ts

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ export type LogTopic =
2121
| 'nextEditPrediction'
2222
| 'resourceCache'
2323
| 'telemetry'
24+
| 'amazonqAuth'
2425

2526
class ErrorLog {
2627
constructor(

0 commit comments

Comments
 (0)