Skip to content

Commit b39ab4e

Browse files
laileni-awsaws-toolkit-automationroger-zhanggrhamiltchungjac
authored
feat(sagemaker): Merge sagemaker to master (#7681)
## Notes - feat(sagemaker): Merging Feature/sagemaker-connect-phase-2 to master - Reference PR: #7677 --- - Treat all work as PUBLIC. Private `feature/x` branches will not be squash-merged at release time. - Your code changes must meet the guidelines in [CONTRIBUTING.md](https://github.com/aws/aws-toolkit-vscode/blob/master/CONTRIBUTING.md#guidelines). - License: I confirm that my contribution is made under the terms of the Apache 2.0 license. --------- Co-authored-by: aws-toolkit-automation <[email protected]> Co-authored-by: Roger Zhang <[email protected]> Co-authored-by: Reed Hamilton <[email protected]> Co-authored-by: Jacob Chung <[email protected]> Co-authored-by: aws-ides-bot <[email protected]> Co-authored-by: aws-asolidu <[email protected]> Co-authored-by: Newton Der <[email protected]> Co-authored-by: Newton Der <[email protected]>
1 parent da9755a commit b39ab4e

33 files changed

+734
-320
lines changed

packages/core/src/awsService/sagemaker/activation.ts

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,18 +3,27 @@
33
* SPDX-License-Identifier: Apache-2.0
44
*/
55

6+
import * as path from 'path'
7+
import * as vscode from 'vscode'
68
import { Commands } from '../../shared/vscode/commands2'
79
import { SagemakerSpaceNode } from './explorer/sagemakerSpaceNode'
810
import { SagemakerParentNode } from './explorer/sagemakerParentNode'
911
import * as uriHandlers from './uriHandlers'
1012
import { openRemoteConnect, filterSpaceAppsByDomainUserProfiles, stopSpace } from './commands'
13+
import { updateIdleFile, startMonitoringTerminalActivity, ActivityCheckInterval } from './utils'
1114
import { ExtContext } from '../../shared/extensions'
1215
import { telemetry } from '../../shared/telemetry/telemetry'
16+
import { isSageMaker, UserActivity } from '../../shared/extensionUtilities'
17+
18+
let terminalActivityInterval: NodeJS.Timeout | undefined
1319

1420
export async function activate(ctx: ExtContext): Promise<void> {
1521
ctx.extensionContext.subscriptions.push(
1622
uriHandlers.register(ctx),
1723
Commands.register('aws.sagemaker.openRemoteConnection', async (node: SagemakerSpaceNode) => {
24+
if (!validateNode(node)) {
25+
return
26+
}
1827
await telemetry.sagemaker_openRemoteConnection.run(async () => {
1928
await openRemoteConnect(node, ctx.extensionContext)
2029
})
@@ -27,9 +36,47 @@ export async function activate(ctx: ExtContext): Promise<void> {
2736
}),
2837

2938
Commands.register('aws.sagemaker.stopSpace', async (node: SagemakerSpaceNode) => {
39+
if (!validateNode(node)) {
40+
return
41+
}
3042
await telemetry.sagemaker_stopSpace.run(async () => {
3143
await stopSpace(node, ctx.extensionContext)
3244
})
3345
})
3446
)
47+
48+
// If running in SageMaker AI Space, track user activity for autoshutdown feature
49+
if (isSageMaker('SMAI')) {
50+
// Use /tmp/ directory so the file is cleared on each reboot to prevent stale timestamps.
51+
const tmpDirectory = '/tmp/'
52+
const idleFilePath = path.join(tmpDirectory, '.sagemaker-last-active-timestamp')
53+
54+
const userActivity = new UserActivity(ActivityCheckInterval)
55+
userActivity.onUserActivity(() => updateIdleFile(idleFilePath))
56+
57+
terminalActivityInterval = startMonitoringTerminalActivity(idleFilePath)
58+
59+
// Write initial timestamp
60+
await updateIdleFile(idleFilePath)
61+
62+
ctx.extensionContext.subscriptions.push(userActivity, {
63+
dispose: () => {
64+
if (terminalActivityInterval) {
65+
clearInterval(terminalActivityInterval)
66+
terminalActivityInterval = undefined
67+
}
68+
},
69+
})
70+
}
71+
}
72+
73+
/**
74+
* Checks if a node is undefined and shows a warning message if so.
75+
*/
76+
function validateNode(node: unknown): boolean {
77+
if (!node) {
78+
void vscode.window.showWarningMessage('Space information is being refreshed. Please try again shortly.')
79+
return false
80+
}
81+
return true
3582
}

packages/core/src/awsService/sagemaker/commands.ts

Lines changed: 38 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,8 @@ import { ExtContext } from '../../shared/extensions'
1818
import { SagemakerClient } from '../../shared/clients/sagemaker'
1919
import { ToolkitError } from '../../shared/errors'
2020
import { showConfirmationMessage } from '../../shared/utilities/messages'
21+
import { RemoteSessionError } from '../../shared/remoteSession'
22+
import { ConnectFromRemoteWorkspaceMessage, InstanceTypeError } from './constants'
2123

2224
const localize = nls.loadMessageBundle()
2325

@@ -90,34 +92,36 @@ export async function deeplinkConnect(
9092
)
9193

9294
if (isRemoteWorkspace()) {
93-
void vscode.window.showErrorMessage(
94-
'You are in a remote workspace, skipping deeplink connect. Please open from a local workspace.'
95-
)
95+
void vscode.window.showErrorMessage(ConnectFromRemoteWorkspaceMessage)
9696
return
9797
}
9898

99-
const remoteEnv = await prepareDevEnvConnection(
100-
connectionIdentifier,
101-
ctx.extensionContext,
102-
'sm_dl',
103-
session,
104-
wsUrl,
105-
token,
106-
domain
107-
)
108-
10999
try {
100+
const remoteEnv = await prepareDevEnvConnection(
101+
connectionIdentifier,
102+
ctx.extensionContext,
103+
'sm_dl',
104+
session,
105+
wsUrl,
106+
token,
107+
domain
108+
)
109+
110110
await startVscodeRemote(
111111
remoteEnv.SessionProcess,
112112
remoteEnv.hostname,
113113
'/home/sagemaker-user',
114114
remoteEnv.vscPath,
115115
'sagemaker-user'
116116
)
117-
} catch (err) {
117+
} catch (err: any) {
118118
getLogger().error(
119119
`sm:OpenRemoteConnect: Unable to connect to target space with arn: ${connectionIdentifier} error: ${err}`
120120
)
121+
122+
if (![RemoteSessionError.MissingExtension, RemoteSessionError.ExtensionVersionTooLow].includes(err.code)) {
123+
throw err
124+
}
121125
}
122126
}
123127

@@ -156,16 +160,29 @@ export async function stopSpace(node: SagemakerSpaceNode, ctx: vscode.ExtensionC
156160
}
157161

158162
export async function openRemoteConnect(node: SagemakerSpaceNode, ctx: vscode.ExtensionContext) {
163+
if (isRemoteWorkspace()) {
164+
void vscode.window.showErrorMessage(ConnectFromRemoteWorkspaceMessage)
165+
return
166+
}
167+
159168
if (node.getStatus() === 'Stopped') {
160169
const client = new SagemakerClient(node.regionCode)
161-
await client.startSpace(node.spaceApp.SpaceName!, node.spaceApp.DomainId!)
162-
await tryRefreshNode(node)
163-
const appType = node.spaceApp.SpaceSettingsSummary?.AppType
164-
if (!appType) {
165-
throw new ToolkitError('AppType is undefined for the selected space. Cannot start remote connection.')
170+
171+
try {
172+
await client.startSpace(node.spaceApp.SpaceName!, node.spaceApp.DomainId!)
173+
await tryRefreshNode(node)
174+
const appType = node.spaceApp.SpaceSettingsSummary?.AppType
175+
if (!appType) {
176+
throw new ToolkitError('AppType is undefined for the selected space. Cannot start remote connection.')
177+
}
178+
await client.waitForAppInService(node.spaceApp.DomainId!, node.spaceApp.SpaceName!, appType)
179+
await tryRemoteConnection(node, ctx)
180+
} catch (err: any) {
181+
// Ignore InstanceTypeError since it means the user decided not to use an instanceType with more memory
182+
if (err.code !== InstanceTypeError) {
183+
throw err
184+
}
166185
}
167-
await client.waitForAppInService(node.spaceApp.DomainId!, node.spaceApp.SpaceName!, appType)
168-
await tryRemoteConnection(node, ctx)
169186
} else if (node.getStatus() === 'Running') {
170187
await tryRemoteConnection(node, ctx)
171188
}
Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
/*!
2+
* Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
3+
* SPDX-License-Identifier: Apache-2.0
4+
*/
5+
6+
export const ConnectFromRemoteWorkspaceMessage =
7+
'Unable to establish new remote connection. Your last active VS Code window is connected to a remote workspace. To open a new SageMaker Studio connection, select your local VS Code window and try again.'
8+
9+
export const InstanceTypeError = 'InstanceTypeError'
10+
11+
export const InstanceTypeMinimum = 'ml.t3.large'
12+
13+
export const InstanceTypeInsufficientMemory: Record<string, string> = {
14+
'ml.t3.medium': 'ml.t3.large',
15+
'ml.c7i.large': 'ml.c7i.xlarge',
16+
'ml.c6i.large': 'ml.c6i.xlarge',
17+
'ml.c6id.large': 'ml.c6id.xlarge',
18+
'ml.c5.large': 'ml.c5.xlarge',
19+
}
20+
21+
export const InstanceTypeInsufficientMemoryMessage = (
22+
spaceName: string,
23+
chosenInstanceType: string,
24+
recommendedInstanceType: string
25+
) => {
26+
return `Unable to create app for [${spaceName}] because instanceType [${chosenInstanceType}] is not supported for remote access enabled spaces. Use instanceType with at least 8 GiB memory. Would you like to start your space with instanceType [${recommendedInstanceType}]?`
27+
}
28+
29+
export const InstanceTypeNotSelectedMessage = (spaceName: string) => {
30+
return `No instanceType specified for [${spaceName}]. ${InstanceTypeMinimum} is the default instance type, which meets minimum 8 GiB memory requirements for remote access. Continuing will start your space with instanceType [${InstanceTypeMinimum}] and remotely connect.`
31+
}

packages/core/src/awsService/sagemaker/credentialMapping.ts

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,9 +10,9 @@ import globals from '../../shared/extensionGlobals'
1010
import { ToolkitError } from '../../shared/errors'
1111
import { DevSettings } from '../../shared/settings'
1212
import { Auth } from '../../auth/auth'
13-
import { parseRegionFromArn } from './utils'
1413
import { SpaceMappings, SsmConnectionInfo } from './types'
1514
import { getLogger } from '../../shared/logger/logger'
15+
import { parseArn } from './detached-server/utils'
1616

1717
const mappingFileName = '.sagemaker-space-profiles'
1818
const mappingFilePath = path.join(os.homedir(), '.aws', mappingFileName)
@@ -81,9 +81,14 @@ export async function persistSSMConnection(
8181
wsUrl?: string,
8282
token?: string
8383
): Promise<void> {
84-
const region = parseRegionFromArn(appArn)
84+
const { region } = parseArn(appArn)
8585
const endpoint = DevSettings.instance.get('endpoints', {})['sagemaker'] ?? ''
8686

87+
// TODO: Hardcoded to 'jupyterlab' due to a bug in Studio that only supports refreshing
88+
// the token for both CodeEditor and JupyterLab Apps in the jupyterlab subdomain.
89+
// This will be fixed shortly after NYSummit launch to support refresh URL in CodeEditor subdomain.
90+
const appSubDomain = 'jupyterlab'
91+
8792
let envSubdomain: string
8893

8994
if (endpoint.includes('beta')) {
@@ -101,8 +106,7 @@ export async function persistSSMConnection(
101106
? `studio.${region}.sagemaker.aws`
102107
: `${envSubdomain}.studio.${region}.asfiovnxocqpcry.com`
103108

104-
const refreshUrl = `https://studio-${domain}.${baseDomain}/api/remoteaccess/token`
105-
109+
const refreshUrl = `https://studio-${domain}.${baseDomain}/${appSubDomain}`
106110
await setSpaceCredentials(appArn, refreshUrl, {
107111
sessionId: session ?? '-',
108112
url: wsUrl ?? '-',

packages/core/src/awsService/sagemaker/detached-server/routes/getSessionAsync.ts

Lines changed: 21 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
import { IncomingMessage, ServerResponse } from 'http'
99
import url from 'url'
1010
import { SessionStore } from '../sessionStore'
11+
import { open, parseArn, readServerInfo } from '../utils'
1112

1213
export async function handleGetSessionAsync(req: IncomingMessage, res: ServerResponse): Promise<void> {
1314
const parsedUrl = url.parse(req.url || '', true)
@@ -37,38 +38,30 @@ export async function handleGetSessionAsync(req: IncomingMessage, res: ServerRes
3738
})
3839
)
3940
return
40-
} else {
41-
res.writeHead(200, { 'Content-Type': 'text/plain' })
42-
res.end(
43-
`No session found for connection identifier: ${connectionIdentifier}. Reconnecting for deeplink is not supported yet.`
44-
)
45-
return
4641
}
4742

48-
// Temporarily disabling reconnect logic for the 7/3 Phase 1 launch.
49-
// Will re-enable in the next release around 7/14.
50-
51-
// const status = await store.getStatus(connectionIdentifier, requestId)
52-
// if (status === 'pending') {
53-
// res.writeHead(204)
54-
// res.end()
55-
// return
56-
// } else if (status === 'not-started') {
57-
// const serverInfo = await readServerInfo()
58-
// const refreshUrl = await store.getRefreshUrl(connectionIdentifier)
43+
const status = await store.getStatus(connectionIdentifier, requestId)
44+
if (status === 'pending') {
45+
res.writeHead(204)
46+
res.end()
47+
return
48+
} else if (status === 'not-started') {
49+
const serverInfo = await readServerInfo()
50+
const refreshUrl = await store.getRefreshUrl(connectionIdentifier)
51+
const { spaceName } = parseArn(connectionIdentifier)
5952

60-
// const url = `${refreshUrl}?connection_identifier=${encodeURIComponent(
61-
// connectionIdentifier
62-
// )}&request_id=${encodeURIComponent(requestId)}&call_back_url=${encodeURIComponent(
63-
// `http://localhost:${serverInfo.port}/refresh_token`
64-
// )}`
53+
const url = `${refreshUrl}/${encodeURIComponent(spaceName)}?reconnect_identifier=${encodeURIComponent(
54+
connectionIdentifier
55+
)}&reconnect_request_id=${encodeURIComponent(requestId)}&reconnect_callback_url=${encodeURIComponent(
56+
`http://localhost:${serverInfo.port}/refresh_token`
57+
)}`
6558

66-
// await open(url)
67-
// res.writeHead(202, { 'Content-Type': 'text/plain' })
68-
// res.end('Session is not ready yet. Please retry in a few seconds.')
69-
// await store.markPending(connectionIdentifier, requestId)
70-
// return
71-
// }
59+
await open(url)
60+
res.writeHead(202, { 'Content-Type': 'text/plain' })
61+
res.end('Session is not ready yet. Please retry in a few seconds.')
62+
await store.markPending(connectionIdentifier, requestId)
63+
return
64+
}
7265
} catch (err) {
7366
console.error('Error handling session async request:', err)
7467
res.writeHead(500, { 'Content-Type': 'text/plain' })

packages/core/src/awsService/sagemaker/detached-server/sessionStore.ts

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,8 @@ export class SessionStore {
4949

5050
const asyncEntry = requests[requestId]
5151
if (asyncEntry?.status === 'fresh') {
52-
await this.markConsumed(connectionId, requestId)
52+
delete requests[requestId]
53+
await writeMapping(mapping)
5354
return asyncEntry
5455
}
5556

packages/core/src/awsService/sagemaker/detached-server/utils.ts

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -44,18 +44,18 @@ export async function readServerInfo(): Promise<ServerInfo> {
4444
}
4545

4646
/**
47-
* Parses a SageMaker ARN to extract region and account ID.
47+
* Parses a SageMaker ARN to extract region, account ID, and space name.
4848
* Supports formats like:
49-
* arn:aws:sagemaker:<region>:<account_id>:space/<space_name>
49+
* arn:aws:sagemaker:<region>:<account_id>:space/<domain>/<space_name>
5050
* or sm_lc_arn:aws:sagemaker:<region>:<account_id>:space__d-xxxx__<name>
5151
*
5252
* If the input is prefixed with an identifier (e.g. "sagemaker-user@"), the function will strip it.
5353
*
5454
* @param arn - The full SageMaker ARN string
55-
* @returns An object containing the region and accountId
55+
* @returns An object containing the region, accountId, and spaceName
5656
* @throws If the ARN format is invalid
5757
*/
58-
export function parseArn(arn: string): { region: string; accountId: string } {
58+
export function parseArn(arn: string): { region: string; accountId: string; spaceName: string } {
5959
const cleanedArn = arn.includes('@') ? arn.split('@')[1] : arn
6060
const regex = /^arn:aws:sagemaker:(?<region>[^:]+):(?<account_id>\d+):space[/:].+$/i
6161
const match = cleanedArn.match(regex)
@@ -64,9 +64,16 @@ export function parseArn(arn: string): { region: string; accountId: string } {
6464
throw new Error(`Invalid SageMaker ARN format: "${arn}"`)
6565
}
6666

67+
// Extract space name from the end of the ARN (after the last forward slash)
68+
const spaceName = cleanedArn.split('/').pop()
69+
if (!spaceName) {
70+
throw new Error(`Could not extract space name from ARN: "${arn}"`)
71+
}
72+
6773
return {
6874
region: match.groups.region,
6975
accountId: match.groups.account_id,
76+
spaceName: spaceName,
7077
}
7178
}
7279

0 commit comments

Comments
 (0)