@@ -11,10 +11,10 @@ const localize = nls.loadMessageBundle()
1111import * as vscode from 'vscode'
1212import * as localizedText from '../shared/localizedText'
1313import { Credentials } from '@aws-sdk/types'
14- import { SsoAccessTokenProvider } from './sso/ssoAccessTokenProvider'
14+ import { SsoAccessTokenProvider , SsoTokenProvider } from './sso/ssoAccessTokenProvider'
1515import { Timeout } from '../shared/utilities/timeoutUtils'
1616import { errorCode , isAwsError , isNetworkError , ToolkitError , UnknownError } from '../shared/errors'
17- import { getCache , getCacheFileWatcher } from './sso/cache'
17+ import { getCache , getCacheFileWatcher , SsoCache } from './sso/cache'
1818import { isNonNullable , Mutable } from '../shared/utilities/tsUtils'
1919import { SsoToken , truncateStartUrl } from './sso/model'
2020import { SsoClient } from './sso/clients'
@@ -69,6 +69,7 @@ import { withTelemetryContext } from '../shared/telemetry/util'
6969import { DiskCacheError } from '../shared/utilities/cacheUtils'
7070import { setContext } from '../shared/vscode/setContext'
7171import { builderIdStartUrl , internalStartUrl } from './sso/constants'
72+ import { SageMakerSsoTokenProvider } from './sso/sageMakerAccessTokenProvider'
7273
7374interface AuthService {
7475 /**
@@ -121,6 +122,10 @@ function keyedDebounce<T, U extends any[], K extends string = string>(
121122 }
122123}
123124
125+ export function useSageMakerSsoProfile ( ) {
126+ return isSageMaker ( ) && isAmazonQ ( )
127+ }
128+
124129export interface ConnectionStateChangeEvent {
125130 readonly id : Connection [ 'id' ]
126131 readonly state : ProfileMetadata [ 'connectionState' ]
@@ -141,17 +146,29 @@ export class Auth implements AuthService, ConnectionManager {
141146 readonly #onDidChangeConnectionState = new vscode . EventEmitter < ConnectionStateChangeEvent > ( )
142147 readonly #onDidUpdateConnection = new vscode . EventEmitter < StatefulConnection > ( )
143148 readonly #onDidDeleteConnection = new vscode . EventEmitter < DeletedConnection > ( )
149+ readonly #onDidPrecreateActiveConnection = new vscode . EventEmitter < StatefulConnection > ( )
144150 public readonly onDidChangeActiveConnection = this . #onDidChangeActiveConnection. event
145151 public readonly onDidChangeConnectionState = this . #onDidChangeConnectionState. event
146152 public readonly onDidUpdateConnection = this . #onDidUpdateConnection. event
147153 /** Fired when a connection and its metadata has been completely deleted */
148154 public readonly onDidDeleteConnection = this . #onDidDeleteConnection. event
155+ public readonly onDidPrecreateActiveConnection = this . #onDidPrecreateActiveConnection. event
149156
150157 public constructor (
151158 private readonly store : ProfileStore ,
152159 private readonly iamProfileProvider = CredentialsProviderManager . getInstance ( ) ,
153160 private readonly createSsoClient = SsoClient . create . bind ( SsoClient ) ,
154- private readonly createSsoTokenProvider = SsoAccessTokenProvider . create . bind ( SsoAccessTokenProvider )
161+ private readonly createSsoTokenProvider : (
162+ profile : {
163+ readonly startUrl : string
164+ readonly region : string
165+ readonly identifier ?: string
166+ readonly scopes : string [ ]
167+ } ,
168+ cache ?: SsoCache
169+ ) => SsoTokenProvider = useSageMakerSsoProfile ( )
170+ ? SageMakerSsoTokenProvider . create . bind ( SageMakerSsoTokenProvider )
171+ : SsoAccessTokenProvider . create . bind ( SsoAccessTokenProvider )
155172 ) { }
156173
157174 #activeConnection: Mutable < StatefulConnection > | undefined
@@ -324,6 +341,29 @@ export class Auth implements AuthService, ConnectionManager {
324341 return toCollection ( load . bind ( this ) )
325342 }
326343
344+ private async createSageMakerSsoConnection ( ) : Promise < StatefulConnection | undefined > {
345+ if ( ! useSageMakerSsoProfile ) {
346+ return undefined
347+ }
348+ const id = SageMakerSsoTokenProvider . sagemakerConectionId
349+ const { startUrl, region, scopes } = SageMakerSsoTokenProvider . getSagemakerProfile ( )
350+ const profile = createSsoProfile ( startUrl , region , scopes )
351+ const tokenProvider = this . getSsoTokenProvider ( id , {
352+ ...profile ,
353+ metadata : { connectionState : 'unauthenticated' } ,
354+ } )
355+
356+ const token = await tokenProvider . getToken ( )
357+ if ( ! token ) {
358+ return undefined
359+ }
360+
361+ const storedProfile = await this . store . addProfile ( id , profile )
362+ await this . updateConnectionState ( id , 'valid' )
363+ const connection = this . getSsoConnection ( id , storedProfile )
364+ return connection
365+ }
366+
327367 public async createConnection ( profile : SsoProfile ) : Promise < SsoConnection >
328368 @withTelemetryContext ( { name : 'createConnection' , class : authClassName } )
329369 public async createConnection ( profile : Profile ) : Promise < Connection > {
@@ -786,7 +826,7 @@ export class Auth implements AuthService, ConnectionManager {
786826 {
787827 identifier : tokenIdentifier ,
788828 startUrl : profile . startUrl ,
789- scopes : profile . scopes ,
829+ scopes : profile . scopes ?? [ ] ,
790830 region : profile . ssoRegion ,
791831 } ,
792832 this . #ssoCache
@@ -859,7 +899,7 @@ export class Auth implements AuthService, ConnectionManager {
859899
860900 private readonly getToken = keyedDebounce ( this . _getToken . bind ( this ) )
861901 @withTelemetryContext ( { name : '_getToken' , class : authClassName } )
862- private async _getToken ( id : Connection [ 'id' ] , provider : SsoAccessTokenProvider ) : Promise < SsoToken > {
902+ private async _getToken ( id : Connection [ 'id' ] , provider : SsoTokenProvider ) : Promise < SsoToken > {
863903 const token = await provider . getToken ( ) . catch ( ( err ) => {
864904 this . throwOnRecoverableError ( err )
865905
@@ -963,6 +1003,20 @@ export class Auth implements AuthService, ConnectionManager {
9631003 return this . authenticate ( id , refresh )
9641004 }
9651005
1006+ public async tryAutoConnectSageMaker ( ) : Promise < StatefulConnection | undefined > {
1007+ try {
1008+ const sagemakerConnection = await this . createSageMakerSsoConnection ( )
1009+ if ( ! sagemakerConnection ) {
1010+ return undefined
1011+ }
1012+
1013+ await this . useConnection ( { id : SageMakerSsoTokenProvider . sagemakerConectionId } )
1014+ return sagemakerConnection
1015+ } catch ( err ) {
1016+ getLogger ( ) . warn ( `auth: failed to connect using SageMaker auth token: %s` , err )
1017+ }
1018+ }
1019+
9661020 public readonly tryAutoConnect = once ( async ( ) => this . _tryAutoConnect ( ) )
9671021 @withTelemetryContext ( { name : 'tryAutoConnect' , class : authClassName } )
9681022 private async _tryAutoConnect ( ) {
0 commit comments