Skip to content

Commit e1cfdc3

Browse files
committed
Add support for IAM credentials to region profile manager
1 parent 60ffc92 commit e1cfdc3

File tree

6 files changed

+127
-71
lines changed

6 files changed

+127
-71
lines changed

packages/amazonq/test/unit/codewhisperer/region/regionProfileManager.test.ts

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ describe('RegionProfileManager', async function () {
6262
const mockClient = {
6363
listAvailableProfiles: listProfilesStub,
6464
}
65-
const createClientStub = sinon.stub(regionProfileManager, '_createQClient').resolves(mockClient)
65+
const createClientStub = sinon.stub(regionProfileManager, '_createQUserClient').resolves(mockClient)
6666

6767
const profileList = await regionProfileManager.listRegionProfile()
6868

@@ -272,11 +272,11 @@ describe('RegionProfileManager', async function () {
272272
})
273273
})
274274

275-
describe('createQClient', function () {
275+
describe('createQUserClient', function () {
276276
it(`should configure the endpoint and region from a profile`, async function () {
277277
await setupConnection('idc')
278278

279-
const iadClient = await regionProfileManager.createQClient({
279+
const iadClient = await regionProfileManager.createQUserClient({
280280
name: 'foo',
281281
region: 'us-east-1',
282282
arn: 'arn',
@@ -286,7 +286,7 @@ describe('RegionProfileManager', async function () {
286286
assert.deepStrictEqual(iadClient.config.region, 'us-east-1')
287287
assert.deepStrictEqual(iadClient.endpoint.href, 'https://q.us-east-1.amazonaws.com/')
288288

289-
const fraClient = await regionProfileManager.createQClient({
289+
const fraClient = await regionProfileManager.createQUserClient({
290290
name: 'bar',
291291
region: 'eu-central-1',
292292
arn: 'arn',
@@ -302,7 +302,7 @@ describe('RegionProfileManager', async function () {
302302

303303
await assert.rejects(
304304
async () => {
305-
await regionProfileManager.createQClient({
305+
await regionProfileManager.createQUserClient({
306306
name: 'foo',
307307
region: 'ap-east-1',
308308
arn: 'arn',
@@ -314,7 +314,7 @@ describe('RegionProfileManager', async function () {
314314

315315
await assert.rejects(
316316
async () => {
317-
await regionProfileManager.createQClient({
317+
await regionProfileManager.createQUserClient({
318318
name: 'foo',
319319
region: 'unknown-somewhere',
320320
arn: 'arn',
@@ -330,7 +330,7 @@ describe('RegionProfileManager', async function () {
330330
await regionProfileManager.switchRegionProfile(profileFoo, 'user')
331331
assert.deepStrictEqual(regionProfileManager.activeRegionProfile, profileFoo)
332332

333-
const client = await regionProfileManager._createQClient('eu-central-1', 'https://amazon.com/')
333+
const client = await regionProfileManager._createQUserClient('eu-central-1', 'https://amazon.com/')
334334

335335
assert.deepStrictEqual(client.config.region, 'eu-central-1')
336336
assert.deepStrictEqual(client.endpoint.href, 'https://amazon.com/')

packages/core/src/codewhisperer/region/regionProfileManager.ts

Lines changed: 107 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ import { showConfirmationMessage } from '../../shared/utilities/messages'
1111
import globals from '../../shared/extensionGlobals'
1212
import { once } from '../../shared/utilities/functionUtils'
1313
import CodeWhispererUserClient from '../client/codewhispereruserclient'
14+
import CodeWhispererClient from '../client/codewhispererclient'
1415
import { Credentials, Service } from 'aws-sdk'
1516
import { ServiceOptions } from '../../shared/awsClientBuilder'
1617
import userApiConfig = require('../client/user-service-2.json')
@@ -152,30 +153,63 @@ export class RegionProfileManager {
152153
const failedRegions: string[] = []
153154

154155
for (const [region, endpoint] of endpoints.entries()) {
155-
const client = await this._createQClient(region, endpoint)
156-
const requester = async (request: CodeWhispererUserClient.ListAvailableProfilesRequest) =>
157-
client.listAvailableProfiles(request).promise()
158-
const request: CodeWhispererUserClient.ListAvailableProfilesRequest = {}
159156
try {
160-
const profiles = await pageableToCollection(requester, request, 'nextToken', 'profiles')
161-
.flatten()
162-
.promise()
163-
const mappedPfs = profiles.map((it) => {
164-
let accntId = ''
165-
try {
166-
accntId = parse(it.arn).accountId
167-
} catch (e) {}
168-
169-
return {
170-
name: it.profileName,
171-
region: region,
172-
arn: it.arn,
173-
description: accntId,
157+
// Get region profiles (Q developer profiles) from Q client and authenticate with SSO token
158+
if (this.authProvider.isIdcConnection()) {
159+
const client = await this._createQUserClient(region, endpoint)
160+
const requester = async (request: CodeWhispererUserClient.ListAvailableProfilesRequest) => {
161+
return client.listAvailableProfiles(request).promise()
174162
}
175-
})
163+
const request: CodeWhispererUserClient.ListAvailableProfilesRequest = {}
164+
const profiles = await pageableToCollection(requester, request, 'nextToken', 'profiles')
165+
.flatten()
166+
.promise()
167+
const mappedPfs = profiles.map((it) => {
168+
let accntId = ''
169+
try {
170+
accntId = parse(it.arn).accountId
171+
} catch (e) {}
172+
173+
return {
174+
name: it.profileName,
175+
region: region,
176+
arn: it.arn,
177+
description: accntId,
178+
}
179+
})
176180

177-
availableProfiles.push(...mappedPfs)
178-
RegionProfileManager.logger.debug(`Found ${mappedPfs.length} profiles in region ${region}`)
181+
availableProfiles.push(...mappedPfs)
182+
RegionProfileManager.logger.debug(`Found ${mappedPfs.length} profiles in region ${region}`)
183+
}
184+
// Get region profiles (Q developer profiles) from Q client and authenticate with IAM credentials
185+
else if (this.authProvider.isIamSession()) {
186+
const client = await this._createQServiceClient(region, endpoint)
187+
const requester = async (request: CodeWhispererClient.ListProfilesRequest) => {
188+
return client.listProfiles(request).promise()
189+
}
190+
const request: CodeWhispererClient.ListProfilesRequest = {}
191+
const profiles = await pageableToCollection(requester, request, 'nextToken', 'profiles')
192+
.flatten()
193+
.promise()
194+
const mappedPfs = profiles.map((it) => {
195+
let accntId = ''
196+
try {
197+
accntId = parse(it.arn).accountId
198+
} catch (e) {}
199+
200+
return {
201+
name: it.profileName,
202+
region: region,
203+
arn: it.arn,
204+
description: accntId,
205+
}
206+
})
207+
208+
availableProfiles.push(...mappedPfs)
209+
RegionProfileManager.logger.debug(`Found ${mappedPfs.length} profiles in region ${region}`)
210+
} else {
211+
throw new ToolkitError('Failed to list profiles when signed out of identity center and IAM credentials')
212+
}
179213
} catch (e) {
180214
const logMsg = isAwsError(e) ? `requestId=${e.requestId}; message=${e.message}` : (e as Error).message
181215
RegionProfileManager.logger.error(`Failed to list profiles for region ${region}: ${logMsg}`)
@@ -201,7 +235,7 @@ export class RegionProfileManager {
201235
}
202236

203237
async switchRegionProfile(regionProfile: RegionProfile | undefined, source: ProfileSwitchIntent) {
204-
if (!this.authProvider.isConnected() || !this.authProvider.isIdcConnection()) {
238+
if (!this.authProvider.isConnected()) {
205239
return
206240
}
207241

@@ -413,55 +447,72 @@ export class RegionProfileManager {
413447
}
414448

415449
// TODO: Should maintain sdk client in a better way
416-
async createQClient(profile: RegionProfile): Promise<CodeWhispererUserClient> {
450+
// Create a Q user client compatible with SSO tokens
451+
async createQUserClient(profile: RegionProfile): Promise<CodeWhispererUserClient> {
417452
if (!this.authProvider.isConnected()) {
418453
throw new Error('No valid connection')
419454
}
420455
const endpoint = endpoints.get(profile.region)
421456
if (!endpoint) {
422457
throw new Error(`trying to initiatize Q client with unrecognizable region ${profile.region}`)
423458
}
424-
return this._createQClient(profile.region, endpoint)
459+
return this._createQUserClient(profile.region, endpoint)
425460
}
426461

427-
// Visible for testing only, do not use this directly, please use createQClient(profile)
428-
async _createQClient(region: string, endpoint: string): Promise<CodeWhispererUserClient> {
429-
let serviceOption: ServiceOptions = {}
430-
if (this.authProvider.isSsoSession()) {
431-
const token = await this.authProvider.getToken()
432-
serviceOption = {
433-
apiConfig: userApiConfig,
434-
region: region,
435-
endpoint: endpoint,
436-
credentials: new Credentials({ accessKeyId: 'xxx', secretAccessKey: 'xxx' }),
437-
onRequestSetup: [
438-
(req) => {
439-
req.on('build', ({ httpRequest }) => {
440-
httpRequest.headers['Authorization'] = `Bearer ${token}`
441-
})
442-
},
443-
],
444-
} as ServiceOptions
445-
} else if (this.authProvider.isIamSession()) {
446-
const credential = await this.authProvider.getIamCredential()
447-
serviceOption = {
448-
apiConfig: apiConfig,
449-
region: region,
450-
endpoint: endpoint,
451-
credentials: new Credentials({
452-
accessKeyId: credential.accessKeyId,
453-
secretAccessKey: credential.secretAccessKey,
454-
sessionToken: credential.sessionToken,
455-
}),
456-
} as ServiceOptions
462+
// Create a Q service client compatible with IAM credentials
463+
async createQServiceClient(profile: RegionProfile): Promise<CodeWhispererClient> {
464+
if (!this.authProvider.isConnected()) {
465+
throw new Error('No valid connection')
457466
}
467+
const endpoint = endpoints.get(profile.region)
468+
if (!endpoint) {
469+
throw new Error(`trying to initiatize Q client with unrecognizable region ${profile.region}`)
470+
}
471+
return this._createQServiceClient(profile.region, endpoint)
472+
}
458473

459-
const c = (await globals.sdkClientBuilder.createAwsService(
474+
// Visible for testing only, do not use this directly, please use createQUserClient(profile)
475+
async _createQUserClient(region: string, endpoint: string): Promise<CodeWhispererUserClient> {
476+
const token = await this.authProvider.getToken()
477+
const serviceOption: ServiceOptions = {
478+
apiConfig: userApiConfig,
479+
region: region,
480+
endpoint: endpoint,
481+
credentials: new Credentials({ accessKeyId: 'xxx', secretAccessKey: 'xxx' }),
482+
onRequestSetup: [
483+
(req) => {
484+
req.on('build', ({ httpRequest }) => {
485+
httpRequest.headers['Authorization'] = `Bearer ${token}`
486+
})
487+
},
488+
],
489+
} as ServiceOptions
490+
491+
return (await globals.sdkClientBuilder.createAwsService(
460492
Service,
461493
serviceOption,
462494
undefined
463495
)) as CodeWhispererUserClient
496+
}
464497

465-
return c
498+
// Visible for testing only, do not use this directly, please use createQServiceClient(profile)
499+
async _createQServiceClient(region: string, endpoint: string): Promise<CodeWhispererClient> {
500+
const credential = await this.authProvider.getIamCredential()
501+
const serviceOption: ServiceOptions = {
502+
apiConfig: apiConfig,
503+
region: region,
504+
endpoint: endpoint,
505+
credentials: new Credentials({
506+
accessKeyId: credential.accessKeyId,
507+
secretAccessKey: credential.secretAccessKey,
508+
sessionToken: credential.sessionToken,
509+
}),
510+
} as ServiceOptions
511+
512+
return (await globals.sdkClientBuilder.createAwsService(
513+
Service,
514+
serviceOption,
515+
undefined
516+
)) as CodeWhispererClient
466517
}
467518
}

packages/core/src/codewhisperer/util/authUtil.ts

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -192,20 +192,29 @@ export class AuthUtil implements IAuthProvider {
192192
}
193193

194194
logout() {
195+
// session will be nullified the next time refreshState() is called
195196
return this.session?.logout()
196197
}
197198

198199
async getToken() {
199200
if (this.isSsoSession()) {
200-
return (await (this.session as SsoLogin).getCredential()).credential
201+
const token = (await this.session!.getCredential()).credential
202+
if (typeof token !== 'string') {
203+
throw new ToolkitError('Cannot get token with IAM session')
204+
}
205+
return token
201206
} else {
202207
throw new ToolkitError('Cannot get credential without logging in.')
203208
}
204209
}
205210

206211
async getIamCredential() {
207212
if (this.session) {
208-
return (await (this.session as IamLogin).getCredential()).credential
213+
const credential = (await this.session.getCredential()).credential
214+
if (typeof credential !== 'object') {
215+
throw new ToolkitError('Cannot get token with SSO session')
216+
}
217+
return credential
209218
} else {
210219
throw new ToolkitError('Cannot get credential without logging in.')
211220
}

packages/core/src/codewhisperer/util/customizationUtil.ts

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ export class CustomizationProvider {
4949
}
5050

5151
static async init(profile: RegionProfile): Promise<CustomizationProvider> {
52-
const client = await AuthUtil.instance.regionProfileManager.createQClient(profile)
52+
const client = await AuthUtil.instance.regionProfileManager.createQUserClient(profile)
5353
return new CustomizationProvider(client, profile)
5454
}
5555
}

packages/core/src/login/webview/vue/amazonq/backend_amazonq.ts

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -174,10 +174,6 @@ export class AmazonQLoginWebview extends CommonAuthWebview {
174174

175175
@withTelemetryContext({ name: 'signout', class: className })
176176
override async signout(): Promise<void> {
177-
if (!AuthUtil.instance.isSsoSession()) {
178-
throw new ToolkitError(`Cannot signout non-SSO connection`)
179-
}
180-
181177
this.storeMetricMetadata({
182178
authEnabledFeatures: 'codewhisperer',
183179
isReAuth: true,

packages/core/src/test/amazonq/customizationUtil.test.ts

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ describe('customizationProvider', function () {
4343
regionProfileManager: regionProfileManager,
4444
}
4545
sinon.stub(AuthUtil, 'instance').get(() => mockAuthUtil)
46-
const createClientStub = sinon.stub(regionProfileManager, 'createQClient')
46+
const createClientStub = sinon.stub(regionProfileManager, 'createQUserClient')
4747
const mockProfile = {
4848
name: 'foo',
4949
region: 'us-east-1',

0 commit comments

Comments
 (0)