Skip to content

Commit b381d0d

Browse files
authored
feat(amazonq): Add region profile manager functionality (aws#7036)
## Problem The [first version](aws#6958) of migration of all auth for vscode to Flare using the identity server did not support the recently introduced RegionProfileManager ## Solution These code changes bring back the RegionProfileManager functionality, in the new auth setup. Integration tests and unit tests to be fixed after all references are updated in a follow-up PR to keep the changes manageable. **CI is expected to fail**. --- - Treat all work as PUBLIC. Private `feature/x` branches will not be squash-merged at release time. - Your code changes must meet the guidelines in [CONTRIBUTING.md](https://github.com/aws/aws-toolkit-vscode/blob/master/CONTRIBUTING.md#guidelines). - License: I confirm that my contribution is made under the terms of the Apache 2.0 license.
1 parent 5c37ceb commit b381d0d

File tree

2 files changed

+57
-50
lines changed

2 files changed

+57
-50
lines changed

packages/core/src/codewhisperer/region/regionProfileManager.ts

Lines changed: 31 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -8,13 +8,6 @@ import { getIcon } from '../../shared/icons'
88
import { DataQuickPickItem } from '../../shared/ui/pickerPrompter'
99
import { CodeWhispererConfig, RegionProfile } from '../models/model'
1010
import { showConfirmationMessage } from '../../shared/utilities/messages'
11-
import {
12-
Connection,
13-
isBuilderIdConnection,
14-
isIdcSsoConnection,
15-
isSsoConnection,
16-
SsoConnection,
17-
} from '../../auth/connection'
1811
import globals from '../../shared/extensionGlobals'
1912
import { once } from '../../shared/utilities/functionUtils'
2013
import CodeWhispererUserClient from '../client/codewhispereruserclient'
@@ -28,6 +21,7 @@ import { parse } from '@aws-sdk/util-arn-parser'
2821
import { isAwsError, ToolkitError } from '../../shared/errors'
2922
import { telemetry } from '../../shared/telemetry/telemetry'
3023
import { localize } from '../../shared/utilities/vsCodeUtils'
24+
import { AuthUtil } from '../util/authUtil'
3125

3226
// TODO: is there a better way to manage all endpoint strings in one place?
3327
export const defaultServiceConfig: CodeWhispererConfig = {
@@ -60,21 +54,19 @@ export class RegionProfileManager {
6054
private _profiles: RegionProfile[] = []
6155

6256
get activeRegionProfile() {
63-
const conn = this.connectionProvider()
64-
if (isBuilderIdConnection(conn)) {
57+
if (AuthUtil.instance.isBuilderIdConnection()) {
6558
return undefined
6659
}
6760
return this._activeRegionProfile
6861
}
6962

7063
get clientConfig(): CodeWhispererConfig {
71-
const conn = this.connectionProvider()
72-
if (!conn) {
64+
if (!AuthUtil.instance.isConnected()) {
7365
throw new ToolkitError('trying to get client configuration without credential')
7466
}
7567

7668
// builder id should simply use default IAD
77-
if (isBuilderIdConnection(conn)) {
69+
if (AuthUtil.instance.isBuilderIdConnection()) {
7870
return defaultServiceConfig
7971
}
8072

@@ -102,18 +94,17 @@ export class RegionProfileManager {
10294
return this._profiles
10395
}
10496

105-
constructor(private readonly connectionProvider: () => Connection | undefined) {}
97+
constructor() {}
10698

107-
async listRegionProfile(): Promise<RegionProfile[]> {
99+
async listRegionProfiles(): Promise<RegionProfile[]> {
108100
this._profiles = []
109101

110-
const conn = this.connectionProvider()
111-
if (conn === undefined || !isSsoConnection(conn)) {
102+
if (!AuthUtil.instance.isConnected() || !AuthUtil.instance.isSsoSession()) {
112103
return []
113104
}
114105
const availableProfiles: RegionProfile[] = []
115106
for (const [region, endpoint] of endpoints.entries()) {
116-
const client = await this.createQClient(region, endpoint, conn as SsoConnection)
107+
const client = await this.createQClient(region, endpoint)
117108
const requester = async (request: CodeWhispererUserClient.ListAvailableProfilesRequest) =>
118109
client.listAvailableProfiles(request).promise()
119110
const request: CodeWhispererUserClient.ListAvailableProfilesRequest = {}
@@ -138,7 +129,7 @@ export class RegionProfileManager {
138129
availableProfiles.push(...mappedPfs)
139130
} catch (e) {
140131
const logMsg = isAwsError(e) ? `requestId=${e.requestId}; message=${e.message}` : (e as Error).message
141-
RegionProfileManager.logger.error(`failed to listRegionProfile: ${logMsg}`)
132+
RegionProfileManager.logger.error(`failed to listRegionProfiles: ${logMsg}`)
142133
throw e
143134
}
144135

@@ -150,18 +141,14 @@ export class RegionProfileManager {
150141
}
151142

152143
async switchRegionProfile(regionProfile: RegionProfile | undefined, source: ProfileSwitchIntent) {
153-
const conn = this.connectionProvider()
154-
if (conn === undefined || !isIdcSsoConnection(conn)) {
144+
if (!AuthUtil.instance.isConnected() || !AuthUtil.instance.isIdcConnection()) {
155145
return
156146
}
157147

158148
if (regionProfile && this.activeRegionProfile && regionProfile.arn === this.activeRegionProfile.arn) {
159149
return
160150
}
161151

162-
// TODO: make it typesafe
163-
const ssoConn = this.connectionProvider() as SsoConnection
164-
165152
// only prompt to users when users switch from A profile to B profile
166153
if (this.activeRegionProfile !== undefined && regionProfile !== undefined) {
167154
const response = await showConfirmationMessage({
@@ -179,9 +166,9 @@ export class RegionProfileManager {
179166
telemetry.amazonq_didSelectProfile.emit({
180167
source: source,
181168
amazonQProfileRegion: this.activeRegionProfile?.region ?? 'not-set',
182-
ssoRegion: ssoConn.ssoRegion,
169+
ssoRegion: AuthUtil.instance.connection?.region,
183170
result: 'Cancelled',
184-
credentialStartUrl: ssoConn.startUrl,
171+
credentialStartUrl: AuthUtil.instance.connection?.startUrl,
185172
profileCount: this.profiles.length,
186173
})
187174
return
@@ -198,9 +185,9 @@ export class RegionProfileManager {
198185
telemetry.amazonq_didSelectProfile.emit({
199186
source: source,
200187
amazonQProfileRegion: regionProfile?.region ?? 'not-set',
201-
ssoRegion: ssoConn.ssoRegion,
188+
ssoRegion: AuthUtil.instance.connection?.region,
202189
result: 'Succeeded',
203-
credentialStartUrl: ssoConn.startUrl,
190+
credentialStartUrl: AuthUtil.instance.connection?.startUrl,
204191
profileCount: this.profiles.length,
205192
})
206193
}
@@ -222,20 +209,19 @@ export class RegionProfileManager {
222209
}
223210

224211
restoreProfileSelection = once(async () => {
225-
const conn = this.connectionProvider()
226-
if (conn) {
227-
await this.restoreRegionProfile(conn)
212+
if (AuthUtil.instance.isConnected()) {
213+
await this.restoreRegionProfile()
228214
}
229215
})
230216

231217
// Note: should be called after [AuthUtil.instance.conn] returns non null
232-
async restoreRegionProfile(conn: Connection) {
233-
const previousSelected = this.loadPersistedRegionProfle()[conn.id] || undefined
218+
async restoreRegionProfile() {
219+
const previousSelected = this.loadPersistedRegionProfle()[AuthUtil.instance.profileName] || undefined
234220
if (!previousSelected) {
235221
return
236222
}
237223
// cross-validation
238-
this.listRegionProfile()
224+
this.listRegionProfiles()
239225
.then(async (profiles) => {
240226
const r = profiles.find((it) => it.arn === previousSelected.arn)
241227
if (!r) {
@@ -275,10 +261,8 @@ export class RegionProfileManager {
275261
}
276262

277263
async persistSelectRegionProfile() {
278-
const conn = this.connectionProvider()
279-
280264
// default has empty arn and shouldn't be persisted because it's just a fallback
281-
if (!conn || this.activeRegionProfile === undefined) {
265+
if (!AuthUtil.instance.isConnected() || this.activeRegionProfile === undefined) {
282266
return
283267
}
284268

@@ -289,15 +273,15 @@ export class RegionProfileManager {
289273
{}
290274
)
291275

292-
previousPersistedState[conn.id] = this.activeRegionProfile
276+
previousPersistedState[AuthUtil.instance.profileName] = this.activeRegionProfile
293277
await globals.globalState.update('aws.amazonq.regionProfiles', previousPersistedState)
294278
}
295279

296280
async generateQuickPickItem(): Promise<DataQuickPickItem<string>[]> {
297281
const selected = this.activeRegionProfile
298282
let profiles: RegionProfile[] = []
299283
try {
300-
profiles = await this.listRegionProfile()
284+
profiles = await this.listRegionProfiles()
301285
} catch (e) {
302286
return [
303287
{
@@ -344,8 +328,15 @@ export class RegionProfileManager {
344328
}
345329
}
346330

347-
async createQClient(region: string, endpoint: string, conn: SsoConnection): Promise<CodeWhispererUserClient> {
348-
const token = (await conn.getToken()).accessToken
331+
requireProfileSelection() {
332+
if (AuthUtil.instance.isBuilderIdConnection()) {
333+
return false
334+
}
335+
return AuthUtil.instance.isIdcConnection() && this.activeRegionProfile === undefined
336+
}
337+
338+
async createQClient(region: string, endpoint: string): Promise<CodeWhispererUserClient> {
339+
const token = await AuthUtil.instance.getToken()
349340
const serviceOption: ServiceOptions = {
350341
apiConfig: userApiConfig,
351342
region: region,

packages/core/src/codewhisperer/util/authUtil.ts

Lines changed: 26 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,10 @@ import { showAmazonQWalkthroughOnce } from '../../amazonq/onboardingPage/walkthr
1717
import { setContext } from '../../shared/vscode/setContext'
1818
import { openUrl } from '../../shared/utilities/vsCodeUtils'
1919
import { telemetry } from '../../shared/telemetry/telemetry'
20-
import { AuthStateEvent, AuthStates, LanguageClientAuth, LoginTypes, SsoLogin } from '../../auth/auth2'
20+
import { AuthStateEvent, LanguageClientAuth, LoginTypes, SsoLogin } from '../../auth/auth2'
2121
import { builderIdStartUrl } from '../../auth/sso/constants'
2222
import { VSCODE_EXTENSION_ID } from '../../shared/extensions'
23+
import { RegionProfileManager } from '../region/regionProfileManager'
2324

2425
const localize = nls.loadMessageBundle()
2526

@@ -34,6 +35,7 @@ export const amazonQScopes = [...codeWhispererChatScopes, ...scopesGumby, ...sco
3435
*/
3536
export class AuthUtil {
3637
public readonly profileName = VSCODE_EXTENSION_ID.amazonq
38+
public readonly regionProfileManager: RegionProfileManager
3739

3840
// IAM login currently not supported
3941
private session: SsoLogin
@@ -53,6 +55,11 @@ export class AuthUtil {
5355
private constructor(private readonly lspAuth: LanguageClientAuth) {
5456
this.session = new SsoLogin(this.profileName, this.lspAuth)
5557
this.onDidChangeConnectionState((e: AuthStateEvent) => this.stateChangeHandler(e))
58+
59+
this.regionProfileManager = new RegionProfileManager()
60+
this.regionProfileManager.onDidChangeRegionProfile(async () => {
61+
await this.setVscodeContextProps()
62+
})
5663
}
5764

5865
isSsoSession() {
@@ -104,29 +111,32 @@ export class AuthUtil {
104111
}
105112

106113
isConnected() {
107-
return this.getAuthState() === AuthStates.CONNECTED
114+
return this.getAuthState() === 'connected'
108115
}
109116

110117
isConnectionExpired() {
111-
return this.getAuthState() === AuthStates.EXPIRED
118+
return this.getAuthState() === 'expired'
112119
}
113120

114121
isBuilderIdConnection() {
115122
return this.connection?.startUrl === builderIdStartUrl
116123
}
117124

118125
isIdcConnection() {
119-
return this.connection?.startUrl && this.connection?.startUrl !== builderIdStartUrl
126+
return Boolean(this.connection?.startUrl && this.connection?.startUrl !== builderIdStartUrl)
120127
}
121128

122129
onDidChangeConnectionState(handler: (e: AuthStateEvent) => any) {
123130
return this.session.onDidChangeConnectionState(handler)
124131
}
125132

126133
public async setVscodeContextProps(state = this.getAuthState()) {
127-
await setContext('aws.codewhisperer.connected', state === AuthStates.CONNECTED)
128-
await setContext('aws.amazonq.showLoginView', state !== AuthStates.CONNECTED) // Login view also handles expired state.
129-
await setContext('aws.codewhisperer.connectionExpired', state === AuthStates.EXPIRED)
134+
await setContext('aws.codewhisperer.connected', state === 'connected')
135+
const showAmazonQLoginView =
136+
!this.isConnected() || this.isConnectionExpired() || this.regionProfileManager.requireProfileSelection()
137+
await setContext('aws.amazonq.showLoginView', showAmazonQLoginView)
138+
await setContext('aws.amazonq.connectedSsoIdc', this.isIdcConnection())
139+
await setContext('aws.codewhisperer.connectionExpired', state === 'expired')
130140
}
131141

132142
private reauthenticatePromptShown: boolean = false
@@ -214,10 +224,16 @@ export class AuthUtil {
214224
}
215225

216226
private async refreshState(state = this.getAuthState()) {
217-
if (state === AuthStates.EXPIRED || state === AuthStates.NOT_CONNECTED) {
227+
if (state === 'expired' || state === 'notConnected') {
228+
if (this.isIdcConnection()) {
229+
await this.regionProfileManager.invalidateProfile(this.regionProfileManager.activeRegionProfile?.arn)
230+
}
218231
this.lspAuth.deleteBearerToken()
219232
}
220-
if (state === AuthStates.CONNECTED) {
233+
if (state === 'connected') {
234+
if (this.isIdcConnection()) {
235+
await this.regionProfileManager.restoreProfileSelection()
236+
}
221237
const bearerTokenParams = (await this.session.getToken()).updateCredentialsParams
222238
await this.lspAuth.updateBearerToken(bearerTokenParams)
223239
}
@@ -229,7 +245,7 @@ export class AuthUtil {
229245
Commands.tryExecute('aws.amazonq.updateReferenceLog'),
230246
])
231247

232-
if (state === AuthStates.CONNECTED && this.isIdcConnection()) {
248+
if (state === 'connected' && this.isIdcConnection()) {
233249
void vscode.commands.executeCommand('aws.amazonq.notifyNewCustomizations')
234250
}
235251
}

0 commit comments

Comments
 (0)