Skip to content

Commit df42b76

Browse files
liuzulinsjc25-test
andauthored
feat(sagemakerunifiedstudio): Handle experience in remote ssh connection for SMUS (aws#2204)
## Problem Need to handle experience in remote ssh connection for SMUS ## Solution 1. Read metadata from /opt/ml/metadata/resource-metadata.json 2. Read DER from cred profile 3. Pre-populate Root node and project node, only show data nodes under them --- - Treat all work as PUBLIC. Private `feature/x` branches will not be squash-merged at release time. - Your code changes must meet the guidelines in [CONTRIBUTING.md](https://github.com/aws/aws-toolkit-vscode/blob/master/CONTRIBUTING.md#guidelines). - License: I confirm that my contribution is made under the terms of the Apache 2.0 license. --------- Co-authored-by: Zulin Liu <[email protected]>
1 parent 918f95b commit df42b76

File tree

11 files changed

+226
-26
lines changed

11 files changed

+226
-26
lines changed

packages/core/src/sagemakerunifiedstudio/activation.ts

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ import { initializeResourceMetadata } from './shared/utils/resourceMetadataUtils
1212

1313
export async function activate(extensionContext: vscode.ExtensionContext): Promise<void> {
1414
// Only run when environment is a SageMaker Unified Studio space
15-
if (isSageMaker('SMUS')) {
15+
if (isSageMaker('SMUS') || isSageMaker('SMUS-SPACE-REMOTE-ACCESS')) {
1616
await initializeResourceMetadata()
1717
await activateConnectionMagicsSelector(extensionContext)
1818
}

packages/core/src/sagemakerunifiedstudio/auth/providers/smusAuthenticationProvider.ts

Lines changed: 53 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,8 @@ import { DomainExecRoleCredentialsProvider } from './domainExecRoleCredentialsPr
2020
import { ProjectRoleCredentialsProvider } from './projectRoleCredentialsProvider'
2121
import { ConnectionCredentialsProvider } from './connectionCredentialsProvider'
2222
import { ConnectionClientStore } from '../../shared/client/connectionClientStore'
23+
import { getResourceMetadata } from '../../shared/utils/resourceMetadataUtils'
24+
import { fromIni } from '@aws-sdk/credential-providers'
2325

2426
/**
2527
* Sets the context variable for SageMaker Unified Studio connection state
@@ -28,6 +30,14 @@ import { ConnectionClientStore } from '../../shared/client/connectionClientStore
2830
export function setSmusConnectedContext(isConnected: boolean): Promise<void> {
2931
return setContext('aws.smus.connected', isConnected)
3032
}
33+
34+
/**
35+
* Sets the context variable for SMUS space environment state
36+
* @param inSmusSpace Whether we're in SMUS space environment
37+
*/
38+
export function setSmusSpaceEnvironmentContext(inSmusSpace: boolean): Promise<void> {
39+
return setContext('aws.smus.inSmusSpaceEnvironment', inSmusSpace)
40+
}
3141
const authClassName = 'SmusAuthenticationProvider'
3242

3343
/**
@@ -63,17 +73,34 @@ export class SmusAuthenticationProvider {
6373
// Clear all clients in client store when connection changes
6474
ConnectionClientStore.getInstance().clearAll()
6575
await setSmusConnectedContext(this.isConnected())
76+
await setSmusSpaceEnvironmentContext(SmusUtils.isInSmusSpaceEnvironment())
6677
this.onDidChangeEmitter.fire()
6778
})
6879

6980
// Set initial context in case event does not trigger
7081
void setSmusConnectedContext(this.isConnectionValid())
82+
void setSmusSpaceEnvironmentContext(SmusUtils.isInSmusSpaceEnvironment())
7183
}
7284

7385
/**
7486
* Gets the active connection
7587
*/
7688
public get activeConnection() {
89+
if (SmusUtils.isInSmusSpaceEnvironment()) {
90+
const resourceMetadata = getResourceMetadata()!
91+
if (resourceMetadata.AdditionalMetadata!.DataZoneDomainRegion) {
92+
return {
93+
domainId: resourceMetadata.AdditionalMetadata!.DataZoneDomainId!,
94+
ssoRegion: resourceMetadata.AdditionalMetadata!.DataZoneDomainRegion!,
95+
// The following fields won't be needed in SMUS space environment
96+
// Set them to be empty string for type checks only
97+
domainUrl: '',
98+
id: '',
99+
}
100+
} else {
101+
throw new ToolkitError('Domain region not found in metadata file.')
102+
}
103+
}
77104
return this.secondaryAuth.activeConnection
78105
}
79106

@@ -88,13 +115,19 @@ export class SmusAuthenticationProvider {
88115
* Checks if the connection is valid
89116
*/
90117
public isConnectionValid(): boolean {
118+
if (SmusUtils.isInSmusSpaceEnvironment()) {
119+
return true
120+
}
91121
return this.activeConnection !== undefined && !this.secondaryAuth.isConnectionExpired
92122
}
93123

94124
/**
95125
* Checks if connected to SMUS
96126
*/
97127
public isConnected(): boolean {
128+
if (SmusUtils.isInSmusSpaceEnvironment()) {
129+
return true
130+
}
98131
return this.activeConnection !== undefined
99132
}
100133

@@ -314,6 +347,10 @@ export class SmusAuthenticationProvider {
314347
* @returns Domain ID
315348
*/
316349
public getDomainId(): string {
350+
if (SmusUtils.isInSmusSpaceEnvironment()) {
351+
return getResourceMetadata()!.AdditionalMetadata!.DataZoneDomainId!
352+
}
353+
317354
if (!this.activeConnection) {
318355
throw new ToolkitError('No active SMUS connection available', { code: SmusErrorCodes.NoActiveConnection })
319356
}
@@ -332,6 +369,15 @@ export class SmusAuthenticationProvider {
332369
}
333370

334371
public getDomainRegion(): string {
372+
if (SmusUtils.isInSmusSpaceEnvironment()) {
373+
const resourceMetadata = getResourceMetadata()!
374+
if (resourceMetadata.AdditionalMetadata!.DataZoneDomainRegion) {
375+
return resourceMetadata.AdditionalMetadata!.DataZoneDomainRegion
376+
} else {
377+
throw new ToolkitError('Domain region not found in metadata file.')
378+
}
379+
}
380+
335381
if (!this.activeConnection) {
336382
throw new ToolkitError('No active SMUS connection available', { code: SmusErrorCodes.NoActiveConnection })
337383
}
@@ -342,10 +388,15 @@ export class SmusAuthenticationProvider {
342388
* Gets or creates a cached credentials provider for the active connection
343389
* @returns Promise resolving to the credentials provider
344390
*/
345-
public async getDerCredentialsProvider(): Promise<DomainExecRoleCredentialsProvider> {
391+
public async getDerCredentialsProvider(): Promise<any> {
346392
const logger = getLogger()
347393

348-
// TODO : Return fromIni() credential provider here when in the SageMaker Unified Studio hosted IDE environment.
394+
if (SmusUtils.isInSmusSpaceEnvironment()) {
395+
const credentials = fromIni({ profile: 'DomainExecutionRoleCreds' })
396+
return {
397+
getCredentials: async () => await credentials(),
398+
}
399+
}
349400

350401
if (!this.activeConnection) {
351402
throw new ToolkitError('No active SMUS connection available', { code: SmusErrorCodes.NoActiveConnection })

packages/core/src/sagemakerunifiedstudio/explorer/nodes/sageMakerUnifiedStudioProjectNode.ts

Lines changed: 66 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ import { SmusAuthenticationProvider } from '../../auth/providers/smusAuthenticat
1616
import { SageMakerUnifiedStudioComputeNode } from './sageMakerUnifiedStudioComputeNode'
1717
import { getIcon } from '../../../shared/icons'
1818
import { SmusUtils } from '../../shared/smusUtils'
19+
import { getResourceMetadata } from '../../shared/utils/resourceMetadataUtils'
1920

2021
/**
2122
* Tree node representing a SageMaker Unified Studio project
@@ -34,7 +35,21 @@ export class SageMakerUnifiedStudioProjectNode implements TreeNode {
3435
private readonly parent: SageMakerUnifiedStudioRootNode,
3536
private readonly authProvider: SmusAuthenticationProvider,
3637
private readonly extensionContext: vscode.ExtensionContext
37-
) {}
38+
) {
39+
// If we're in SMUS space environment, set project from resource metadata
40+
if (SmusUtils.isInSmusSpaceEnvironment()) {
41+
const resourceMetadata = getResourceMetadata()!
42+
if (resourceMetadata.AdditionalMetadata!.DataZoneProjectId) {
43+
this.project = {
44+
id: resourceMetadata!.AdditionalMetadata!.DataZoneProjectId!,
45+
name: 'Current Project',
46+
domainId: resourceMetadata!.AdditionalMetadata!.DataZoneDomainId!,
47+
}
48+
// Fetch the actual project name asynchronously
49+
void this.fetchProjectName()
50+
}
51+
}
52+
}
3853

3954
public async getTreeItem(): Promise<vscode.TreeItem> {
4055
if (this.project) {
@@ -70,25 +85,35 @@ export class SageMakerUnifiedStudioProjectNode implements TreeNode {
7085
result: 'Succeeded',
7186
passive: false,
7287
})
73-
const hasAccess = await this.checkProjectAccess(this.project.id)
74-
if (!hasAccess) {
75-
return [
76-
{
77-
id: 'smusProjectAccessDenied',
78-
resource: {},
79-
getTreeItem: () => {
80-
const item = new vscode.TreeItem(
81-
'You are not a member of this project. Contact any of its owners to add you as a member.',
82-
vscode.TreeItemCollapsibleState.None
83-
)
84-
return item
88+
89+
// Skip access check if we're in SMUS space environment (already in project space)
90+
if (!SmusUtils.isInSmusSpaceEnvironment()) {
91+
const hasAccess = await this.checkProjectAccess(this.project.id)
92+
if (!hasAccess) {
93+
return [
94+
{
95+
id: 'smusProjectAccessDenied',
96+
resource: {},
97+
getTreeItem: () => {
98+
const item = new vscode.TreeItem(
99+
'You are not a member of this project. Contact any of its owners to add you as a member.',
100+
vscode.TreeItemCollapsibleState.None
101+
)
102+
return item
103+
},
104+
getParent: () => this,
85105
},
86-
getParent: () => this,
87-
},
88-
]
106+
]
107+
}
89108
}
90109

91110
const dataNode = new SageMakerUnifiedStudioDataNode(this)
111+
112+
// If we're in SMUS space environment, only show data node
113+
if (SmusUtils.isInSmusSpaceEnvironment()) {
114+
return [dataNode]
115+
}
116+
92117
this.sagemakerClient = await this.initializeSagemakerClient(
93118
this.authProvider.activeConnection?.ssoRegion || 'us-east-1'
94119
)
@@ -124,7 +149,10 @@ export class SageMakerUnifiedStudioProjectNode implements TreeNode {
124149

125150
public async clearProject(): Promise<void> {
126151
await this.cleanupProjectResources()
127-
this.project = undefined
152+
// Don't clear project if we're in SMUS space environment
153+
if (!SmusUtils.isInSmusSpaceEnvironment()) {
154+
this.project = undefined
155+
}
128156
await this.refreshNode()
129157
}
130158

@@ -154,6 +182,27 @@ export class SageMakerUnifiedStudioProjectNode implements TreeNode {
154182
}
155183
}
156184

185+
private async fetchProjectName(): Promise<void> {
186+
if (!this.project || !SmusUtils.isInSmusSpaceEnvironment()) {
187+
return
188+
}
189+
190+
try {
191+
const dzClient = await DataZoneClient.getInstance(this.authProvider)
192+
const projectDetails = await dzClient.getProject(this.project.id)
193+
194+
if (projectDetails && projectDetails.name) {
195+
this.project.name = projectDetails.name
196+
// Refresh the tree item to show the updated name
197+
this.onDidChangeEmitter.fire()
198+
}
199+
} catch (err) {
200+
// No need to show error, this is just to dynamically show project name
201+
// If we fail to fetch project name, we will just show the default name
202+
this.logger.debug(`Failed to fetch project name: ${(err as Error).message}`)
203+
}
204+
}
205+
157206
private async initializeSagemakerClient(regionCode: string): Promise<SagemakerClient> {
158207
if (!this.project) {
159208
throw new Error('No project selected for initializing SageMaker client')

packages/core/src/sagemakerunifiedstudio/explorer/nodes/sageMakerUnifiedStudioRootNode.ts

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -133,7 +133,7 @@ export class SageMakerUnifiedStudioRootNode implements TreeNode {
133133

134134
/**
135135
* Checks if the user has authenticated to SageMaker Unified Studio
136-
* This is validated by checking existing Connections for SMUS.
136+
* This is validated by checking existing Connections for SMUS or resource metadata.
137137
*/
138138
private isAuthenticated(): boolean {
139139
try {
@@ -154,7 +154,7 @@ export class SageMakerUnifiedStudioRootNode implements TreeNode {
154154

155155
if (hasExpiredConnection) {
156156
// Show reauthentication prompt to user
157-
void this.authProvider.showReauthenticationPrompt(this.authProvider.activeConnection!)
157+
void this.authProvider.showReauthenticationPrompt(this.authProvider.activeConnection! as any)
158158
return true
159159
}
160160
return false
@@ -197,6 +197,7 @@ export const smusLearnMoreCommand = Commands.declare('aws.smus.learnMore', () =>
197197
*/
198198
export const smusLoginCommand = Commands.declare('aws.smus.login', () => async () => {
199199
const logger = getLogger()
200+
200201
try {
201202
// Get DataZoneClient instance for URL validation
202203

packages/core/src/sagemakerunifiedstudio/shared/client/datazoneClient.ts

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -457,6 +457,39 @@ export class DataZoneClient {
457457
}
458458
}
459459

460+
/**
461+
* Gets a specific project by ID
462+
* @param projectId The project identifier
463+
* @returns Promise resolving to the project details
464+
*/
465+
public async getProject(projectId: string): Promise<DataZoneProject> {
466+
try {
467+
this.logger.info(`DataZoneClient: Getting project ${projectId} in domain ${this.domainId}`)
468+
469+
const datazoneClient = await this.getDataZoneClient()
470+
471+
const response = await datazoneClient.getProject({
472+
domainIdentifier: this.domainId,
473+
identifier: projectId,
474+
})
475+
476+
const project: DataZoneProject = {
477+
id: response.id || '',
478+
name: response.name || '',
479+
description: response.description,
480+
domainId: this.domainId,
481+
createdAt: response.createdAt ? new Date(response.createdAt) : undefined,
482+
updatedAt: response.lastUpdatedAt ? new Date(response.lastUpdatedAt) : undefined,
483+
}
484+
485+
this.logger.debug(`DataZoneClient: Retrieved project ${projectId} with name: ${project.name}`)
486+
return project
487+
} catch (err) {
488+
this.logger.error('DataZoneClient: Failed to get project: %s', err as Error)
489+
throw err
490+
}
491+
}
492+
460493
/*
461494
* Processes a connection response to add jdbcConnection if it's a Redshift connection
462495
* @param connection The connection object to process

packages/core/src/sagemakerunifiedstudio/shared/smusUtils.ts

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@
55

66
import { getLogger } from '../../shared/logger/logger'
77
import { ToolkitError } from '../../shared/errors'
8+
import { isSageMaker } from '../../shared/extensionUtilities'
9+
import { getResourceMetadata } from './utils/resourceMetadataUtils'
810
import fetch from 'node-fetch'
911

1012
/**
@@ -338,4 +340,14 @@ export class SmusUtils {
338340
}
339341
return match[1]
340342
}
343+
344+
/**
345+
* Checks if we're in SMUS space environment (should hide certain UI elements)
346+
* @returns True if in SMUS space environment with DataZone domain ID
347+
*/
348+
public static isInSmusSpaceEnvironment(): boolean {
349+
const isSMUSspace = isSageMaker('SMUS') || isSageMaker('SMUS-SPACE-REMOTE-ACCESS')
350+
const resourceMetadata = getResourceMetadata()
351+
return isSMUSspace && !!resourceMetadata?.AdditionalMetadata?.DataZoneDomainId
352+
}
341353
}

packages/core/src/sagemakerunifiedstudio/shared/utils/resourceMetadataUtils.ts

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ export function getResourceMetadata(): ResourceMetadata | undefined {
5353
export async function initializeResourceMetadata(): Promise<void> {
5454
const logger = getLogger()
5555

56-
if (!isSageMaker('SMUS')) {
56+
if (!isSageMaker('SMUS') && !isSageMaker('SMUS-SPACE-REMOTE-ACCESS')) {
5757
logger.debug(`Not in SageMaker Unified Studio space, skipping initialization of resource metadata`)
5858
return
5959
}

packages/core/src/shared/extensionUtilities.ts

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -188,7 +188,7 @@ export function isCloud9(flavor: 'classic' | 'codecatalyst' | 'any' = 'any'): bo
188188
* @param appName to identify the proper SM instance
189189
* @returns true if the current system is SageMaker(SMAI or SMUS)
190190
*/
191-
export function isSageMaker(appName: 'SMAI' | 'SMUS' = 'SMAI'): boolean {
191+
export function isSageMaker(appName: 'SMAI' | 'SMUS' | 'SMUS-SPACE-REMOTE-ACCESS' = 'SMAI'): boolean {
192192
// Check for SageMaker-specific environment variables first
193193
let hasSMEnvVars: boolean = false
194194
if (hasSageMakerEnvVars()) {
@@ -201,6 +201,9 @@ export function isSageMaker(appName: 'SMAI' | 'SMUS' = 'SMAI'): boolean {
201201
return vscode.env.appName === sageMakerAppname && hasSMEnvVars
202202
case 'SMUS':
203203
return vscode.env.appName === sageMakerAppname && isSageMakerUnifiedStudio() && hasSMEnvVars
204+
case 'SMUS-SPACE-REMOTE-ACCESS':
205+
// When is true, the AWS toolkit is running in remote SSH conenction to SageMaker Unified Studio space
206+
return vscode.env.appName !== sageMakerAppname && isSageMakerUnifiedStudio() && hasSMEnvVars
204207
default:
205208
return false
206209
}

packages/core/src/shared/vscode/setContext.ts

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ export type contextKey =
3131
| 'aws.toolkit.notifications.show'
3232
| 'aws.amazonq.editSuggestionActive'
3333
| 'aws.smus.connected'
34+
| 'aws.smus.inSmusSpaceEnvironment'
3435
// Deprecated/legacy names. New keys should start with "aws.".
3536
| 'codewhisperer.activeLine'
3637
| 'gumby.isPlanAvailable'

0 commit comments

Comments
 (0)