Skip to content

Commit 436fc38

Browse files
vpbhargavBhargava Varadharajanchungjac
authored
feat(smsus): Add connection handling for SMUS (aws#2179)
**Description** Added a mechanism to create Connection profiles and track it across session for SMUS. UX after this PR : - User login flow where they provide Domain URL -> Authenticate and have an active session for however long SSO session is valid (default - 8 hours) - User connection info is persisted and tracked across VSCode sessions. User can close/reopen/quit and the session information is retained. - User can have multiple connections - 1 in explorer and other in SMUS. They can choose to use the SMUS connection in explorer also if they want. For now, SMUS required login and no connection selection. - SSO tokens are stored in the SSO token cache at `.aws/sso/cache` - Tokens are refreshed reactively if it it close to expiry. - When connection expires, user gets a notification and an option to reauthenticate. - User has ability to Sign Out from connection. This will also remove the connection metadata. (No past connections history etc for now to keep things simple). Next steps : - Provide credential provider to go from tokens to DER/Project role credentials. **Motivation** Support auth flow in SMUS. **Testing Done** Unit tests and also tested all flows locally. ## Problem ## Solution --- - Treat all work as PUBLIC. Private `feature/x` branches will not be squash-merged at release time. - Your code changes must meet the guidelines in [CONTRIBUTING.md](https://github.com/aws/aws-toolkit-vscode/blob/master/CONTRIBUTING.md#guidelines). - License: I confirm that my contribution is made under the terms of the Apache 2.0 license. Co-authored-by: Bhargava Varadharajan <[email protected]> Co-authored-by: chungjac <[email protected]>
1 parent 84df901 commit 436fc38

File tree

15 files changed

+1712
-370
lines changed

15 files changed

+1712
-370
lines changed

packages/core/package.nls.json

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -231,6 +231,7 @@
231231
"AWS.command.s3.uploadFile": "Upload Files...",
232232
"AWS.command.s3.uploadFileToParent": "Upload to Parent...",
233233
"AWS.command.smus.switchProject": "Switch Project",
234+
"AWS.command.smus.signOut": "Sign Out",
234235
"AWS.command.sagemaker.filterSpaces": "Filter Sagemaker Spaces",
235236
"AWS.command.stepFunctions.createStateMachineFromTemplate": "Create a new Step Functions state machine",
236237
"AWS.command.stepFunctions.publishStateMachine": "Publish state machine to Step Functions",
@@ -482,5 +483,6 @@
482483
"AWS.toolkit.lambda.walkthrough.step1.description": "Locally test and debug your code.",
483484
"AWS.toolkit.lambda.walkthrough.step2.title": "Deploy to the cloud",
484485
"AWS.toolkit.lambda.walkthrough.step2.description": "Test your application in the cloud from within VS Code. \n\nNote: The AWS CLI and the SAM CLI require AWS Credentials to interact with the cloud. For information on setting up your credentials, see [Authentication and access credentials](https://docs.aws.amazon.com/cli/latest/userguide/cli-chap-configure.html). \n\n[Configure credentials](command:aws.toolkit.lambda.walkthrough.credential)",
485-
"AWS.toolkit.lambda.serverlessLand.quickpickTitle": "Create application with Serverless template"
486+
"AWS.toolkit.lambda.serverlessLand.quickpickTitle": "Create application with Serverless template",
487+
"AWS.command.smus.signout": "Sign Out"
486488
}

packages/core/src/auth/secondaryAuth.ts

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ import { withTelemetryContext } from '../shared/telemetry/util'
1818
import { isNetworkError } from '../shared/errors'
1919
import globals from '../shared/extensionGlobals'
2020

21-
export type ToolId = 'codecatalyst' | 'codewhisperer' | 'testId'
21+
export type ToolId = 'codecatalyst' | 'codewhisperer' | 'testId' | 'smus'
2222

2323
let currentConn: Auth['activeConnection']
2424
const auths = new Map<string, SecondaryAuth>()
Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
/*!
2+
* Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
3+
* SPDX-License-Identifier: Apache-2.0
4+
*/
5+
6+
import { SsoProfile, SsoConnection } from '../../auth/connection'
7+
8+
/**
9+
* Scope for SageMaker Unified Studio authentication
10+
*/
11+
export const scopeSmus = 'datazone:domain:access'
12+
13+
/**
14+
* SageMaker Unified Studio profile extending the base SSO profile
15+
*/
16+
export interface SmusProfile extends SsoProfile {
17+
readonly domainUrl: string
18+
readonly domainId: string
19+
}
20+
21+
/**
22+
* SageMaker Unified Studio connection extending the base SSO connection
23+
*/
24+
export interface SmusConnection extends SmusProfile, SsoConnection {
25+
readonly id: string
26+
readonly label: string
27+
}
28+
29+
/**
30+
* Creates a SageMaker Unified Studio profile
31+
* @param domainUrl The SageMaker Unified Studio domain URL
32+
* @param domainId The SageMaker Unified Studio domain ID
33+
* @param startUrl The SSO start URL (issuer URL)
34+
* @param region The AWS region
35+
* @returns A SageMaker Unified Studio profile
36+
*/
37+
export function createSmusProfile(
38+
domainUrl: string,
39+
domainId: string,
40+
startUrl: string,
41+
region: string,
42+
scopes = [scopeSmus]
43+
): SmusProfile & { readonly scopes: string[] } {
44+
return {
45+
scopes,
46+
type: 'sso',
47+
startUrl,
48+
ssoRegion: region,
49+
domainUrl,
50+
domainId,
51+
}
52+
}
53+
54+
/**
55+
* Checks if a connection is a valid SageMaker Unified Studio connection
56+
* @param conn Connection to check
57+
* @returns True if the connection is a valid SMUS connection
58+
*/
59+
export function isValidSmusConnection(conn?: any): conn is SmusConnection {
60+
if (!conn || conn.type !== 'sso') {
61+
return false
62+
}
63+
// Check if the connection has the required SMUS scope
64+
const hasScope = Array.isArray(conn.scopes) && conn.scopes.includes(scopeSmus)
65+
// Check if the connection has the required SMUS properties
66+
const hasSmusProps = 'domainUrl' in conn && 'domainId' in conn
67+
return !!hasScope && !!hasSmusProps
68+
}
Lines changed: 218 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,218 @@
1+
/*!
2+
* Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
3+
* SPDX-License-Identifier: Apache-2.0
4+
*/
5+
6+
import * as vscode from 'vscode'
7+
import { Auth } from '../../auth/auth'
8+
import { getSecondaryAuth } from '../../auth/secondaryAuth'
9+
import { ToolkitError } from '../../shared/errors'
10+
import { withTelemetryContext } from '../../shared/telemetry/util'
11+
import { SsoConnection } from '../../auth/connection'
12+
import { showReauthenticateMessage } from '../../shared/utilities/messages'
13+
import * as localizedText from '../../shared/localizedText'
14+
import { ToolkitPromptSettings } from '../../shared/settings'
15+
import { setContext } from '../../shared/vscode/setContext'
16+
import { DataZoneClient } from '../shared/client/datazoneClient'
17+
import { createSmusProfile, isValidSmusConnection, SmusConnection } from './model'
18+
import { getLogger } from '../../shared/logger/logger'
19+
20+
/**
21+
* Sets the context variable for SageMaker Unified Studio connection state
22+
* @param isConnected Whether SMUS is connected
23+
*/
24+
export function setSmusConnectedContext(isConnected: boolean): Promise<void> {
25+
return setContext('aws.smus.connected', isConnected)
26+
}
27+
const authClassName = 'SmusAuthenticationProvider'
28+
29+
/**
30+
* Authentication provider for SageMaker Unified Studio
31+
* Manages authentication state and credentials for SMUS
32+
*/
33+
export class SmusAuthenticationProvider {
34+
public readonly onDidChangeActiveConnection = this.secondaryAuth.onDidChangeActiveConnection
35+
private readonly onDidChangeEmitter = new vscode.EventEmitter<void>()
36+
public readonly onDidChange = this.onDidChangeEmitter.event
37+
38+
public constructor(
39+
public readonly auth = Auth.instance,
40+
public readonly secondaryAuth = getSecondaryAuth(
41+
auth,
42+
'smus',
43+
'SageMaker Unified Studio',
44+
isValidSmusConnection
45+
)
46+
) {
47+
this.onDidChangeActiveConnection(async () => {
48+
await setSmusConnectedContext(this.isConnected())
49+
this.onDidChangeEmitter.fire()
50+
})
51+
52+
// Set initial context in case event does not trigger
53+
void setSmusConnectedContext(this.isConnectionValid())
54+
}
55+
56+
/**
57+
* Gets the active connection
58+
*/
59+
public get activeConnection() {
60+
return this.secondaryAuth.activeConnection
61+
}
62+
63+
/**
64+
* Checks if using a saved connection
65+
*/
66+
public get isUsingSavedConnection() {
67+
return this.secondaryAuth.hasSavedConnection
68+
}
69+
70+
/**
71+
* Checks if the connection is valid
72+
*/
73+
public isConnectionValid(): boolean {
74+
return this.activeConnection !== undefined && !this.secondaryAuth.isConnectionExpired
75+
}
76+
77+
/**
78+
* Checks if connected to SMUS
79+
*/
80+
public isConnected(): boolean {
81+
return this.activeConnection !== undefined
82+
}
83+
84+
/**
85+
* Restores the previous connection
86+
* Uses a promise to prevent multiple simultaneous restore calls
87+
*/
88+
public async restore() {
89+
await this.secondaryAuth.restoreConnection()
90+
}
91+
92+
/**
93+
* Authenticates with SageMaker Unified Studio using a domain URL
94+
* @param domainUrl The SageMaker Unified Studio domain URL
95+
* @returns Promise resolving to the connection
96+
*/
97+
@withTelemetryContext({ name: 'connectToSmus', class: authClassName })
98+
public async connectToSmus(domainUrl: string): Promise<SmusConnection> {
99+
const logger = getLogger()
100+
101+
try {
102+
// Create DataZoneClient instance and extract domain info
103+
const dataZoneClient = DataZoneClient.getInstance()
104+
const { domainId, region } = dataZoneClient.extractDomainInfoFromUrl(domainUrl)
105+
106+
// Validate domain ID
107+
if (!domainId) {
108+
throw new ToolkitError('Invalid domain URL format', { code: 'InvalidDomainUrl' })
109+
}
110+
111+
logger.info(`SMUS: Connecting to domain ${domainId} in region ${region}`)
112+
113+
// Check if we already have a connection for this domain
114+
const existingConn = (await this.auth.listConnections()).find(
115+
(c): c is SmusConnection =>
116+
isValidSmusConnection(c) && (c as any).domainUrl?.toLowerCase() === domainUrl.toLowerCase()
117+
)
118+
119+
if (existingConn) {
120+
const connectionState = this.auth.getConnectionState(existingConn)
121+
logger.info(`SMUS: Found existing connection ${existingConn.id} with state: ${connectionState}`)
122+
123+
// If connection is valid, use it directly without triggering new auth flow
124+
if (connectionState === 'valid') {
125+
logger.info('SMUS: Using existing valid connection')
126+
127+
// Use the existing connection
128+
const result = await this.secondaryAuth.useNewConnection(existingConn)
129+
logger.debug(`SMUS: Reused existing connection successfully, id=${result.id}`)
130+
return result
131+
}
132+
133+
// If connection is invalid or expired, reauthenticate
134+
if (connectionState === 'invalid') {
135+
logger.info('SMUS: Existing connection is invalid, reauthenticating')
136+
const reauthenticatedConn = await this.reauthenticate(existingConn)
137+
138+
// Create the SMUS connection wrapper
139+
const smusConn: SmusConnection = {
140+
...reauthenticatedConn,
141+
domainUrl,
142+
domainId,
143+
}
144+
145+
const result = await this.secondaryAuth.useNewConnection(smusConn)
146+
logger.debug(`SMUS: Reauthenticated connection successfully, id=${result.id}`)
147+
return result
148+
}
149+
}
150+
151+
// No existing connection found, create a new one
152+
logger.info('SMUS: No existing connection found, creating new connection')
153+
154+
// Get SSO instance info from DataZone
155+
const ssoInstanceInfo = await dataZoneClient.getSsoInstanceInfo(domainUrl)
156+
157+
// Create a new connection
158+
const profile = createSmusProfile(domainUrl, domainId, ssoInstanceInfo.issuerUrl, ssoInstanceInfo.region)
159+
const newConn = await this.auth.createConnection(profile)
160+
logger.debug(`SMUS: Created new connection ${newConn.id}`)
161+
162+
const smusConn: SmusConnection = {
163+
...newConn,
164+
domainUrl,
165+
domainId,
166+
}
167+
168+
const result = await this.secondaryAuth.useNewConnection(smusConn)
169+
return result
170+
} catch (e) {
171+
throw ToolkitError.chain(e, 'Failed to connect to SageMaker Unified Studio', {
172+
code: 'FailedToConnect',
173+
})
174+
}
175+
}
176+
177+
/**
178+
* Reauthenticates an existing connection
179+
* @param conn Connection to reauthenticate
180+
* @returns Promise resolving to the reauthenticated connection
181+
*/
182+
@withTelemetryContext({ name: 'reauthenticate', class: authClassName })
183+
public async reauthenticate(conn: SsoConnection) {
184+
try {
185+
return await this.auth.reauthenticate(conn)
186+
} catch (err) {
187+
throw ToolkitError.chain(err, 'Unable to reauthenticate SageMaker Unified Studio connection.')
188+
}
189+
}
190+
191+
/**
192+
* Shows a reauthentication prompt to the user
193+
* @param conn Connection to reauthenticate
194+
*/
195+
public async showReauthenticationPrompt(conn: SsoConnection): Promise<void> {
196+
await showReauthenticateMessage({
197+
message: localizedText.connectionExpired('SageMaker Unified Studio'),
198+
connect: localizedText.reauthenticate,
199+
suppressId: 'smusConnectionExpired',
200+
settings: ToolkitPromptSettings.instance,
201+
reauthFunc: async () => {
202+
await this.reauthenticate(conn)
203+
},
204+
})
205+
}
206+
207+
// URL extraction functions have been moved to DataZoneClient
208+
209+
static #instance: SmusAuthenticationProvider | undefined
210+
211+
public static get instance(): SmusAuthenticationProvider | undefined {
212+
return SmusAuthenticationProvider.#instance
213+
}
214+
215+
public static fromContext() {
216+
return (this.#instance ??= new this())
217+
}
218+
}

packages/core/src/sagemakerunifiedstudio/explorer/activation.ts

Lines changed: 32 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,14 +9,27 @@ import {
99
retrySmusProjectsCommand,
1010
smusLoginCommand,
1111
smusLearnMoreCommand,
12+
smusSignOutCommand,
1213
SageMakerUnifiedStudioRootNode,
1314
selectSMUSProject,
1415
} from './nodes/sageMakerUnifiedStudioRootNode'
1516
import { DataZoneClient } from '../shared/client/datazoneClient'
17+
import { getLogger } from '../../shared/logger/logger'
18+
import { setSmusConnectedContext, SmusAuthenticationProvider } from '../auth/smusAuthenticationProvider'
1619

1720
export async function activate(extensionContext: vscode.ExtensionContext): Promise<void> {
21+
// Initialize the SMUS authentication provider
22+
const logger = getLogger()
23+
logger.debug('SMUS: Initializing authentication provider')
24+
// Create the auth provider instance (this will trigger restore() in the constructor)
25+
const smusAuthProvider = SmusAuthenticationProvider.fromContext()
26+
await smusAuthProvider.restore()
27+
// Set initial auth context after restore
28+
void setSmusConnectedContext(smusAuthProvider.isConnected())
29+
logger.debug('SMUS: Authentication provider initialized')
30+
1831
// Create the SMUS projects tree view
19-
const smusRootNode = new SageMakerUnifiedStudioRootNode()
32+
const smusRootNode = new SageMakerUnifiedStudioRootNode(smusAuthProvider)
2033
const treeDataProvider = new ResourceTreeDataProvider({ getChildren: () => smusRootNode.getChildren() })
2134

2235
// Register the tree view
@@ -27,6 +40,7 @@ export async function activate(extensionContext: vscode.ExtensionContext): Promi
2740
extensionContext.subscriptions.push(
2841
smusLoginCommand.register(),
2942
smusLearnMoreCommand.register(),
43+
smusSignOutCommand.register(),
3044
retrySmusProjectsCommand.register(),
3145
treeView,
3246
vscode.commands.registerCommand('aws.smus.rootView.refresh', () => {
@@ -43,6 +57,23 @@ export async function activate(extensionContext: vscode.ExtensionContext): Promi
4357
return await selectSMUSProject(projectNode)
4458
}),
4559

60+
vscode.commands.registerCommand('aws.smus.reauthenticate', async (connection?: any) => {
61+
if (connection) {
62+
try {
63+
await smusAuthProvider.reauthenticate(connection)
64+
// Refresh the tree view after successful reauthentication
65+
treeDataProvider.refresh()
66+
// Show success message
67+
void vscode.window.showInformationMessage(
68+
'Successfully reauthenticated with SageMaker Unified Studio'
69+
)
70+
} catch (error) {
71+
// Show error message if reauthentication fails
72+
void vscode.window.showErrorMessage(`Failed to reauthenticate: ${error}`)
73+
logger.error('SMUS: Reauthentication failed: %O', error)
74+
}
75+
}
76+
}),
4677
// Dispose DataZoneClient when extension is deactivated
4778
{ dispose: () => DataZoneClient.dispose() }
4879
)

0 commit comments

Comments
 (0)