Skip to content

Commit 00378ee

Browse files
committed
feat(amazonq): Enable SageMaker SSO user to automatically their SSO credentials & Pro Tier profile ARN for AmazonQ features
1 parent e6c3101 commit 00378ee

File tree

17 files changed

+39876
-60
lines changed

17 files changed

+39876
-60
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ src.gen/*
3131
**/src/shared/telemetry/clienttelemetry.d.ts
3232
**/src/codewhisperer/client/codewhispererclient.d.ts
3333
**/src/codewhisperer/client/codewhispereruserclient.d.ts
34+
**/src/shared/sagemaker/client/sagemakerclient.d.ts
3435
**/src/amazonqFeatureDev/client/featuredevproxyclient.d.ts
3536
**/src/auth/sso/oidcclientpkce.d.ts
3637

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
{
2+
"type": "Feature",
3+
"description": "Enable SageMaker SSO user access token & Profile ARN to be used for accessing AmazonQ & CodeWhisperer features"
4+
}

packages/core/scripts/build/generateServiceClient.ts

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -245,6 +245,10 @@ void (async () => {
245245
serviceJsonPath: 'src/amazonqFeatureDev/client/codewhispererruntime-2022-11-11.json',
246246
serviceName: 'FeatureDevProxyClient',
247247
},
248+
{
249+
serviceJsonPath: 'src/shared/sagemaker/client/service-2.json',
250+
serviceName: 'SageMakerClient',
251+
},
248252
]
249253
await generateServiceClients(serviceClientDefinitions)
250254
})()

packages/core/src/auth/activation.ts

Lines changed: 4 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -3,31 +3,21 @@
33
* SPDX-License-Identifier: Apache-2.0
44
*/
55

6-
import * as vscode from 'vscode'
76
import { Auth } from './auth'
87
import { LoginManager } from './deprecated/loginManager'
98
import { fromString } from './providers/credentials'
109
import { getLogger } from '../shared/logger'
1110
import { ExtensionUse, initializeCredentialsProviderManager } from './utils'
12-
import { isAmazonQ, isCloud9, isSageMaker } from '../shared/extensionUtilities'
11+
import { isCloud9, isSageMaker } from '../shared/extensionUtilities'
1312
import { isInDevEnv } from '../shared/vscode/env'
1413
import { isWeb } from '../shared/extensionGlobals'
1514

16-
interface SagemakerCookie {
17-
authMode?: 'Sso' | 'Iam'
18-
}
19-
2015
export async function initialize(loginManager: LoginManager): Promise<void> {
21-
if (isAmazonQ() && isSageMaker()) {
22-
// The command `sagemaker.parseCookies` is registered in VS Code Sagemaker environment.
23-
const result = (await vscode.commands.executeCommand('sagemaker.parseCookies')) as SagemakerCookie
24-
if (result.authMode !== 'Sso') {
25-
initializeCredentialsProviderManager()
26-
}
27-
}
16+
await initializeCredentialsProviderManager()
17+
2818
Auth.instance.onDidChangeActiveConnection(async (conn) => {
2919
// This logic needs to be moved to `Auth.useConnection` to correctly record `passive`
30-
if (conn?.type === 'iam' && conn.state === 'valid') {
20+
if (conn?.state === 'valid' && (isSageMaker() || conn?.type === 'iam')) {
3121
await loginManager.login({ passive: true, providerId: fromString(conn.id) })
3222
} else {
3323
await loginManager.logout()

packages/core/src/auth/auth.ts

Lines changed: 59 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -11,10 +11,10 @@ const localize = nls.loadMessageBundle()
1111
import * as vscode from 'vscode'
1212
import * as localizedText from '../shared/localizedText'
1313
import { Credentials } from '@aws-sdk/types'
14-
import { SsoAccessTokenProvider } from './sso/ssoAccessTokenProvider'
14+
import { SsoAccessTokenProvider, SsoTokenProvider } from './sso/ssoAccessTokenProvider'
1515
import { Timeout } from '../shared/utilities/timeoutUtils'
1616
import { errorCode, isAwsError, isNetworkError, ToolkitError, UnknownError } from '../shared/errors'
17-
import { getCache, getCacheFileWatcher } from './sso/cache'
17+
import { getCache, getCacheFileWatcher, SsoCache } from './sso/cache'
1818
import { isNonNullable, Mutable } from '../shared/utilities/tsUtils'
1919
import { SsoToken, truncateStartUrl } from './sso/model'
2020
import { SsoClient } from './sso/clients'
@@ -69,6 +69,7 @@ import { withTelemetryContext } from '../shared/telemetry/util'
6969
import { DiskCacheError } from '../shared/utilities/cacheUtils'
7070
import { setContext } from '../shared/vscode/setContext'
7171
import { builderIdStartUrl, internalStartUrl } from './sso/constants'
72+
import { SageMakerSsoTokenProvider } from './sso/sageMakerAccessTokenProvider'
7273

7374
interface AuthService {
7475
/**
@@ -121,6 +122,10 @@ function keyedDebounce<T, U extends any[], K extends string = string>(
121122
}
122123
}
123124

125+
export function useSageMakerSsoProfile() {
126+
return isSageMaker() && isAmazonQ()
127+
}
128+
124129
export interface ConnectionStateChangeEvent {
125130
readonly id: Connection['id']
126131
readonly state: ProfileMetadata['connectionState']
@@ -141,17 +146,29 @@ export class Auth implements AuthService, ConnectionManager {
141146
readonly #onDidChangeConnectionState = new vscode.EventEmitter<ConnectionStateChangeEvent>()
142147
readonly #onDidUpdateConnection = new vscode.EventEmitter<StatefulConnection>()
143148
readonly #onDidDeleteConnection = new vscode.EventEmitter<DeletedConnection>()
149+
readonly #onDidPrecreateActiveConnection = new vscode.EventEmitter<StatefulConnection>()
144150
public readonly onDidChangeActiveConnection = this.#onDidChangeActiveConnection.event
145151
public readonly onDidChangeConnectionState = this.#onDidChangeConnectionState.event
146152
public readonly onDidUpdateConnection = this.#onDidUpdateConnection.event
147153
/** Fired when a connection and its metadata has been completely deleted */
148154
public readonly onDidDeleteConnection = this.#onDidDeleteConnection.event
155+
public readonly onDidPrecreateActiveConnection = this.#onDidPrecreateActiveConnection.event
149156

150157
public constructor(
151158
private readonly store: ProfileStore,
152159
private readonly iamProfileProvider = CredentialsProviderManager.getInstance(),
153160
private readonly createSsoClient = SsoClient.create.bind(SsoClient),
154-
private readonly createSsoTokenProvider = SsoAccessTokenProvider.create.bind(SsoAccessTokenProvider)
161+
private readonly createSsoTokenProvider: (
162+
profile: {
163+
readonly startUrl: string
164+
readonly region: string
165+
readonly identifier?: string
166+
readonly scopes: string[]
167+
},
168+
cache?: SsoCache
169+
) => SsoTokenProvider = useSageMakerSsoProfile()
170+
? SageMakerSsoTokenProvider.create.bind(SageMakerSsoTokenProvider)
171+
: SsoAccessTokenProvider.create.bind(SsoAccessTokenProvider)
155172
) {}
156173

157174
#activeConnection: Mutable<StatefulConnection> | undefined
@@ -324,6 +341,29 @@ export class Auth implements AuthService, ConnectionManager {
324341
return toCollection(load.bind(this))
325342
}
326343

344+
private async createSageMakerSsoConnection(): Promise<StatefulConnection | undefined> {
345+
if (!useSageMakerSsoProfile) {
346+
return undefined
347+
}
348+
const id = SageMakerSsoTokenProvider.sagemakerConectionId
349+
const { startUrl, region, scopes } = SageMakerSsoTokenProvider.getSagemakerProfile()
350+
const profile = createSsoProfile(startUrl, region, scopes)
351+
const tokenProvider = this.getSsoTokenProvider(id, {
352+
...profile,
353+
metadata: { connectionState: 'unauthenticated' },
354+
})
355+
356+
const token = await tokenProvider.getToken()
357+
if (!token) {
358+
return undefined
359+
}
360+
361+
const storedProfile = await this.store.addProfile(id, profile)
362+
await this.updateConnectionState(id, 'valid')
363+
const connection = this.getSsoConnection(id, storedProfile)
364+
return connection
365+
}
366+
327367
public async createConnection(profile: SsoProfile): Promise<SsoConnection>
328368
@withTelemetryContext({ name: 'createConnection', class: authClassName })
329369
public async createConnection(profile: Profile): Promise<Connection> {
@@ -786,7 +826,7 @@ export class Auth implements AuthService, ConnectionManager {
786826
{
787827
identifier: tokenIdentifier,
788828
startUrl: profile.startUrl,
789-
scopes: profile.scopes,
829+
scopes: profile.scopes ?? [],
790830
region: profile.ssoRegion,
791831
},
792832
this.#ssoCache
@@ -859,7 +899,7 @@ export class Auth implements AuthService, ConnectionManager {
859899

860900
private readonly getToken = keyedDebounce(this._getToken.bind(this))
861901
@withTelemetryContext({ name: '_getToken', class: authClassName })
862-
private async _getToken(id: Connection['id'], provider: SsoAccessTokenProvider): Promise<SsoToken> {
902+
private async _getToken(id: Connection['id'], provider: SsoTokenProvider): Promise<SsoToken> {
863903
const token = await provider.getToken().catch((err) => {
864904
this.throwOnRecoverableError(err)
865905

@@ -963,6 +1003,20 @@ export class Auth implements AuthService, ConnectionManager {
9631003
return this.authenticate(id, refresh)
9641004
}
9651005

1006+
public async tryAutoConnectSageMaker(): Promise<StatefulConnection | undefined> {
1007+
try {
1008+
const sagemakerConnection = await this.createSageMakerSsoConnection()
1009+
if (!sagemakerConnection) {
1010+
return undefined
1011+
}
1012+
1013+
await this.useConnection({ id: SageMakerSsoTokenProvider.sagemakerConectionId })
1014+
return sagemakerConnection
1015+
} catch (err) {
1016+
getLogger().warn(`auth: failed to connect using SageMaker auth token: %s`, err)
1017+
}
1018+
}
1019+
9661020
public readonly tryAutoConnect = once(async () => this._tryAutoConnect())
9671021
@withTelemetryContext({ name: 'tryAutoConnect', class: authClassName })
9681022
private async _tryAutoConnect() {

packages/core/src/auth/providers/ssoCredentialsProvider.ts

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,13 +8,13 @@ import { CredentialType } from '../../shared/telemetry/telemetry.gen'
88
import { getStringHash } from '../../shared/utilities/textUtilities'
99
import { CredentialsId, CredentialsProvider, CredentialsProviderType } from './credentials'
1010
import { SsoClient } from '../sso/clients'
11-
import { SsoAccessTokenProvider } from '../sso/ssoAccessTokenProvider'
11+
import { SsoTokenProvider } from '../sso/ssoAccessTokenProvider'
1212

1313
export class SsoCredentialsProvider implements CredentialsProvider {
1414
public constructor(
1515
private readonly id: CredentialsId,
1616
private readonly client: SsoClient,
17-
private readonly tokenProvider: SsoAccessTokenProvider,
17+
private readonly tokenProvider: SsoTokenProvider,
1818
private readonly accountId: string,
1919
private readonly roleName: string
2020
) {}

packages/core/src/auth/secondaryAuth.ts

Lines changed: 23 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
import * as vscode from 'vscode'
77
import { getLogger } from '../shared/logger'
88
import { cast, Optional } from '../shared/utilities/typeConstructors'
9-
import { Auth } from './auth'
9+
import { Auth, useSageMakerSsoProfile } from './auth'
1010
import { onceChanged } from '../shared/utilities/functionUtils'
1111
import { isNonNullable } from '../shared/utilities/tsUtils'
1212
import { ToolIdStateKey } from '../shared/globalState'
@@ -24,7 +24,7 @@ let currentConn: Auth['activeConnection']
2424
const auths = new Map<string, SecondaryAuth>()
2525
const multiConnectionListeners = new WeakMap<Auth, vscode.Disposable>()
2626
const registerAuthListener = (auth: Auth) => {
27-
return auth.onDidChangeActiveConnection(async (newConn) => {
27+
const activeConnectionChangeListener = auth.onDidChangeActiveConnection(async (newConn) => {
2828
// When we change the active connection, there may be
2929
// secondary auths that were dependent on the previous active connection.
3030
// To ensure secondary auths still work, when we change to a new active connection,
@@ -38,6 +38,21 @@ const registerAuthListener = (auth: Auth) => {
3838
}
3939
currentConn = newConn
4040
})
41+
42+
const precreatedConnectionCreatedListener = auth.onDidPrecreateActiveConnection(async (newConn) => {
43+
await Promise.all(
44+
Array.from(auths.values())
45+
.filter((a) => !a.hasSavedConnection && a.isUsable(newConn))
46+
.map((a) => a.saveConnection(newConn))
47+
)
48+
})
49+
50+
return {
51+
dispose: () => {
52+
activeConnectionChangeListener.dispose()
53+
precreatedConnectionCreatedListener.dispose()
54+
},
55+
}
4156
}
4257

4358
export function getSecondaryAuth<T extends Connection>(
@@ -306,6 +321,12 @@ export class SecondaryAuth<T extends Connection = Connection> {
306321
id: 'undefined',
307322
connectionState: 'undefined',
308323
})
324+
if (useSageMakerSsoProfile()) {
325+
const connection = await this.auth.tryAutoConnectSageMaker()
326+
if (connection) {
327+
this.saveConnection(connection as unknown as T)
328+
}
329+
}
309330
await this.auth.tryAutoConnect()
310331
this.#savedConnection = await this._loadSavedConnection(span)
311332
this.#onDidChangeActiveConnection.fire(this.activeConnection)

packages/core/src/auth/sso/clients.ts

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ import { AsyncCollection } from '../../shared/utilities/asyncCollection'
2727
import { pageableToCollection, partialClone } from '../../shared/utilities/collectionUtils'
2828
import { assertHasProps, isNonNullable, RequiredProps, selectFrom } from '../../shared/utilities/tsUtils'
2929
import { getLogger } from '../../shared/logger'
30-
import { SsoAccessTokenProvider } from './ssoAccessTokenProvider'
30+
import { SsoTokenProvider } from './ssoAccessTokenProvider'
3131
import { AwsClientResponseError, isClientFault } from '../../shared/errors'
3232
import { DevSettings } from '../../shared/settings'
3333
import { SdkError } from '@aws-sdk/types'
@@ -158,7 +158,7 @@ export class SsoClient {
158158

159159
public constructor(
160160
private readonly client: PromisifyClient<SSO>,
161-
private readonly provider: SsoAccessTokenProvider
161+
private readonly provider: SsoTokenProvider
162162
) {}
163163

164164
@withTelemetryContext({ name: 'listAccounts', class: ssoClientClassName })
@@ -236,7 +236,7 @@ export class SsoClient {
236236
throw error
237237
}
238238

239-
public static create(region: string, provider: SsoAccessTokenProvider) {
239+
public static create(region: string, provider: SsoTokenProvider) {
240240
return new this(
241241
new SSO({
242242
region,

0 commit comments

Comments
 (0)