Skip to content

Commit 1a5a7cf

Browse files
feat(smus): UX improvements for connecting running spaces from toolkit
1 parent ef992ae commit 1a5a7cf

File tree

4 files changed

+265
-65
lines changed

4 files changed

+265
-65
lines changed

packages/core/src/awsService/sagemaker/commands.ts

Lines changed: 213 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,12 @@ import { SagemakerClient } from '../../shared/clients/sagemaker'
1919
import { ToolkitError } from '../../shared/errors'
2020
import { showConfirmationMessage } from '../../shared/utilities/messages'
2121
import { RemoteSessionError } from '../../shared/remoteSession'
22-
import { ConnectFromRemoteWorkspaceMessage, InstanceTypeError } from './constants'
22+
import {
23+
ConnectFromRemoteWorkspaceMessage,
24+
InstanceTypeError,
25+
InstanceTypeInsufficientMemory,
26+
InstanceTypeInsufficientMemoryMessage,
27+
} from './constants'
2328
import { SagemakerUnifiedStudioSpaceNode } from '../../sagemakerunifiedstudio/explorer/nodes/sageMakerUnifiedStudioSpaceNode'
2429

2530
const localize = nls.loadMessageBundle()
@@ -195,56 +200,228 @@ export async function openRemoteConnect(
195200
const spaceName = node.spaceApp.SpaceName!
196201
await tryRefreshNode(node)
197202

198-
// for Stopped SM spaces - check instance type before showing progress
199-
if (node.getStatus() === 'Stopped') {
200-
// In case of SMUS, we pass in a SM Client and for SM AI, it creates a new SM Client.
201-
const client = sageMakerClient ? sageMakerClient : new SagemakerClient(node.regionCode)
202-
203-
try {
204-
await client.startSpace(spaceName, node.spaceApp.DomainId!)
205-
await tryRefreshNode(node)
206-
const appType = node.spaceApp.SpaceSettingsSummary?.AppType
207-
if (!appType) {
208-
throw new ToolkitError('AppType is undefined for the selected space. Cannot start remote connection.', {
209-
code: 'undefinedAppType',
210-
})
203+
const remoteAccess = node.spaceApp.SpaceSettingsSummary?.RemoteAccess
204+
const nodeStatus = node.getStatus()
205+
206+
// Route to appropriate handler based on space state
207+
if (nodeStatus === 'Running' && remoteAccess === 'DISABLED') {
208+
return handleRunningSpaceWithDisabledAccess(node, ctx, spaceName, sageMakerClient)
209+
} else if (nodeStatus === 'Stopped') {
210+
return handleStoppedSpace(node, ctx, spaceName, sageMakerClient)
211+
} else if (nodeStatus === 'Running') {
212+
return handleRunningSpaceWithEnabledAccess(node, ctx, spaceName)
213+
}
214+
}
215+
216+
/**
217+
* Checks if an instance type upgrade will be needed for remote access
218+
*/
219+
async function checkInstanceTypeUpgradeNeeded(
220+
node: SagemakerSpaceNode | SagemakerUnifiedStudioSpaceNode,
221+
sageMakerClient?: SagemakerClient
222+
): Promise<{ upgradeNeeded: boolean; currentType?: string; recommendedType?: string }> {
223+
const client = sageMakerClient || new SagemakerClient(node.regionCode)
224+
225+
try {
226+
const spaceDetails = await client.describeSpace({
227+
DomainId: node.spaceApp.DomainId!,
228+
SpaceName: node.spaceApp.SpaceName!,
229+
})
230+
231+
const appType = spaceDetails.SpaceSettings?.AppType
232+
if (!appType) {
233+
return { upgradeNeeded: false }
234+
}
235+
236+
// Get current instance type
237+
const requestedResourceSpec =
238+
appType === 'JupyterLab'
239+
? spaceDetails.SpaceSettings?.JupyterLabAppSettings?.DefaultResourceSpec
240+
: spaceDetails.SpaceSettings?.CodeEditorAppSettings?.DefaultResourceSpec
241+
242+
const currentInstanceType = requestedResourceSpec?.InstanceType
243+
244+
// Check if upgrade is needed
245+
if (currentInstanceType && currentInstanceType in InstanceTypeInsufficientMemory) {
246+
// Current type has insufficient memory
247+
return {
248+
upgradeNeeded: true,
249+
currentType: currentInstanceType,
250+
recommendedType: InstanceTypeInsufficientMemory[currentInstanceType],
211251
}
252+
}
253+
254+
return { upgradeNeeded: false, currentType: currentInstanceType }
255+
} catch (err) {
256+
const error = err as Error
257+
if (error.name === 'AccessDeniedException') {
258+
throw new ToolkitError('You do not have permission to describe spaces. Please contact your administrator', {
259+
cause: error,
260+
code: error.name,
261+
})
262+
}
263+
// For other errors, assume no upgrade needed to avoid blocking the flow
264+
return { upgradeNeeded: false }
265+
}
266+
}
267+
268+
/**
269+
* Handles connecting to a running space with disabled remote access
270+
* Requires stopping the space, enabling remote access, and restarting
271+
*/
272+
async function handleRunningSpaceWithDisabledAccess(
273+
node: SagemakerSpaceNode | SagemakerUnifiedStudioSpaceNode,
274+
ctx: vscode.ExtensionContext,
275+
spaceName: string,
276+
sageMakerClient?: SagemakerClient
277+
) {
278+
// Check if instance type upgrade will be needed
279+
const instanceTypeInfo = await checkInstanceTypeUpgradeNeeded(node, sageMakerClient)
280+
281+
let prompt: string
282+
if (instanceTypeInfo.upgradeNeeded) {
283+
prompt = InstanceTypeInsufficientMemoryMessage(
284+
spaceName,
285+
instanceTypeInfo.currentType!,
286+
instanceTypeInfo.recommendedType!
287+
)
288+
} else {
289+
// Only remote access needs to be enabled
290+
prompt = `This space requires remote access to be enabled.\nWould you like to restart the space and connect?\nAny unsaved work will be lost.`
291+
}
292+
293+
const confirmed = await showConfirmationMessage({
294+
prompt,
295+
confirm: 'Restart and Connect',
296+
cancel: 'Cancel',
297+
type: 'warning',
298+
})
299+
300+
if (!confirmed) {
301+
return
302+
}
303+
304+
// Enable remote access and connect
305+
const client = sageMakerClient || new SagemakerClient(node.regionCode)
306+
307+
return await vscode.window.withProgress(
308+
{
309+
location: vscode.ProgressLocation.Notification,
310+
cancellable: false,
311+
title: `Connecting to ${spaceName}`,
312+
},
313+
async (progress) => {
314+
try {
315+
// Show initial progress message
316+
progress.report({ message: 'Stopping the space' })
317+
318+
// Stop the running space
319+
await client.deleteApp({
320+
DomainId: node.spaceApp.DomainId!,
321+
SpaceName: spaceName,
322+
AppType: node.spaceApp.App!.AppType!,
323+
AppName: node.spaceApp.App?.AppName,
324+
})
325+
326+
// Update progress message
327+
progress.report({ message: 'Starting the space' })
328+
329+
// Start the space with remote access enabled (skip prompts since user already consented)
330+
await client.startSpace(spaceName, node.spaceApp.DomainId!, true)
331+
await tryRefreshNode(node)
212332

213-
// Only start showing progress after instance type validation
214-
return await vscode.window.withProgress(
215-
{
216-
location: vscode.ProgressLocation.Notification,
217-
cancellable: false,
218-
title: `Connecting to ${spaceName}`,
219-
},
220-
async (progress) => {
221-
progress.report({ message: 'Starting the space.' })
222-
await client.waitForAppInService(node.spaceApp.DomainId!, spaceName, appType)
223-
await tryRemoteConnection(node, ctx, progress)
333+
const appType = node.spaceApp.SpaceSettingsSummary?.AppType
334+
if (!appType) {
335+
throw new ToolkitError(
336+
'AppType is undefined for the selected space. Cannot start remote connection.',
337+
{
338+
code: 'undefinedAppType',
339+
}
340+
)
224341
}
225-
)
226-
} catch (err: any) {
227-
// Ignore InstanceTypeError since it means the user decided not to use an instanceType with more memory
228-
// just return without showing progress
229-
if (err.code === InstanceTypeError) {
230-
return
342+
343+
progress.report({ message: 'Starting the space' })
344+
await client.waitForAppInService(node.spaceApp.DomainId!, spaceName, appType)
345+
await tryRemoteConnection(node, ctx, progress)
346+
} catch (err: any) {
347+
// Handle user declining instance type upgrade
348+
if (err.code === InstanceTypeError) {
349+
return
350+
}
351+
throw new ToolkitError(`Remote connection failed: ${err.message}`, {
352+
cause: err,
353+
code: err.code,
354+
})
231355
}
232-
throw new ToolkitError(`Remote connection failed: ${(err as Error).message}`, {
233-
cause: err as Error,
234-
code: err.code,
356+
}
357+
)
358+
}
359+
360+
/**
361+
* Handles connecting to a stopped space
362+
* Starts the space and connects (remote access enabled automatically if needed)
363+
*/
364+
async function handleStoppedSpace(
365+
node: SagemakerSpaceNode | SagemakerUnifiedStudioSpaceNode,
366+
ctx: vscode.ExtensionContext,
367+
spaceName: string,
368+
sageMakerClient?: SagemakerClient
369+
) {
370+
const client = sageMakerClient || new SagemakerClient(node.regionCode)
371+
372+
try {
373+
await client.startSpace(spaceName, node.spaceApp.DomainId!)
374+
await tryRefreshNode(node)
375+
376+
const appType = node.spaceApp.SpaceSettingsSummary?.AppType
377+
if (!appType) {
378+
throw new ToolkitError('AppType is undefined for the selected space. Cannot start remote connection.', {
379+
code: 'undefinedAppType',
235380
})
236381
}
237-
} else if (node.getStatus() === 'Running') {
238-
// For running spaces, show progress
382+
239383
return await vscode.window.withProgress(
240384
{
241385
location: vscode.ProgressLocation.Notification,
242386
cancellable: false,
243387
title: `Connecting to ${spaceName}`,
244388
},
245389
async (progress) => {
390+
progress.report({ message: 'Starting the space' })
391+
await client.waitForAppInService(node.spaceApp.DomainId!, spaceName, appType)
246392
await tryRemoteConnection(node, ctx, progress)
247393
}
248394
)
395+
} catch (err: any) {
396+
// Handle user declining instance type upgrade
397+
if (err.code === InstanceTypeError) {
398+
return
399+
}
400+
throw new ToolkitError(`Remote connection failed: ${(err as Error).message}`, {
401+
cause: err as Error,
402+
code: err.code,
403+
})
249404
}
250405
}
406+
407+
/**
408+
* Handles connecting to a running space with enabled remote access
409+
* Direct connection without any space modifications
410+
*/
411+
async function handleRunningSpaceWithEnabledAccess(
412+
node: SagemakerSpaceNode | SagemakerUnifiedStudioSpaceNode,
413+
ctx: vscode.ExtensionContext,
414+
spaceName: string,
415+
sageMakerClient?: SagemakerClient
416+
) {
417+
return await vscode.window.withProgress(
418+
{
419+
location: vscode.ProgressLocation.Notification,
420+
cancellable: false,
421+
title: `Connecting to ${spaceName}`,
422+
},
423+
async (progress) => {
424+
await tryRemoteConnection(node, ctx, progress)
425+
}
426+
)
427+
}

packages/core/src/awsService/sagemaker/sagemakerSpace.ts

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -148,8 +148,18 @@ export class SagemakerSpace {
148148
const domainId = this.spaceApp?.DomainId ?? '-'
149149
const owner = this.spaceApp?.OwnershipSettingsSummary?.OwnerUserProfileName || '-'
150150
const instanceType = this.spaceApp?.App?.ResourceSpec?.InstanceType ?? '-'
151+
const remoteAccess = this.spaceApp?.SpaceSettingsSummary?.RemoteAccess
152+
153+
let baseTooltip = ''
151154
if (this.isSMUSSpace) {
152-
return `**Space:** ${spaceName} \n\n**Application:** ${appType} \n\n**Instance Type:** ${instanceType}`
155+
baseTooltip = `**Space:** ${spaceName} \n\n**Application:** ${appType} \n\n**Instance Type:** ${instanceType}`
156+
if (remoteAccess === 'ENABLED') {
157+
baseTooltip += `\n\n**Remote Access:** Enabled`
158+
} else if (remoteAccess === 'DISABLED') {
159+
baseTooltip += `\n\n**Remote Access:** Disabled`
160+
}
161+
162+
return baseTooltip
153163
}
154164
return `**Space:** ${spaceName} \n\n**Application:** ${appType} \n\n**Domain ID:** ${domainId} \n\n**User Profile:** ${owner}`
155165
}

packages/core/src/shared/clients/sagemaker.ts

Lines changed: 39 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,8 @@ import {
4848
import { getDomainSpaceKey } from '../../awsService/sagemaker/utils'
4949
import { getLogger } from '../logger/logger'
5050
import { ToolkitError } from '../errors'
51-
import { yes, no, continueText, cancel } from '../localizedText'
51+
import { continueText, cancel } from '../localizedText'
52+
import { showConfirmationMessage } from '../utilities/messages'
5253
import { AwsCredentialIdentity } from '@aws-sdk/types'
5354
import globals from '../extensionGlobals'
5455

@@ -126,7 +127,7 @@ export class SagemakerClient extends ClientWrapper<SageMakerClient> {
126127
return this.makeRequest(DeleteAppCommand, request)
127128
}
128129

129-
public async startSpace(spaceName: string, domainId: string) {
130+
public async startSpace(spaceName: string, domainId: string, skipInstanceTypePrompts: boolean = false) {
130131
let spaceDetails: DescribeSpaceCommandOutput
131132

132133
// Get existing space details
@@ -155,35 +156,47 @@ export class SagemakerClient extends ClientWrapper<SageMakerClient> {
155156

156157
// Is InstanceType defined and has enough memory?
157158
if (instanceType && instanceType in InstanceTypeInsufficientMemory) {
158-
// Prompt user to select one with sufficient memory (1 level up from their chosen one)
159-
const response = await vscode.window.showErrorMessage(
160-
InstanceTypeInsufficientMemoryMessage(
161-
spaceDetails.SpaceName || '',
162-
instanceType,
163-
InstanceTypeInsufficientMemory[instanceType]
164-
),
165-
yes,
166-
no
167-
)
159+
if (skipInstanceTypePrompts) {
160+
// User already consented, upgrade automatically
161+
instanceType = InstanceTypeInsufficientMemory[instanceType]
162+
} else {
163+
// Prompt user to select one with sufficient memory (1 level up from their chosen one)
164+
const confirmed = await showConfirmationMessage({
165+
prompt: InstanceTypeInsufficientMemoryMessage(
166+
spaceDetails.SpaceName || '',
167+
instanceType,
168+
InstanceTypeInsufficientMemory[instanceType]
169+
),
170+
confirm: 'Restart and Connect',
171+
cancel: 'Cancel',
172+
type: 'warning',
173+
})
168174

169-
if (response === no) {
170-
throw new ToolkitError('InstanceType has insufficient memory.', { code: InstanceTypeError })
171-
}
175+
if (!confirmed) {
176+
throw new ToolkitError('InstanceType has insufficient memory.', { code: InstanceTypeError })
177+
}
172178

173-
instanceType = InstanceTypeInsufficientMemory[instanceType]
179+
instanceType = InstanceTypeInsufficientMemory[instanceType]
180+
}
174181
} else if (!instanceType) {
175-
// Prompt user to select the minimum supported instance type
176-
const response = await vscode.window.showErrorMessage(
177-
InstanceTypeNotSelectedMessage(spaceDetails.SpaceName || ''),
178-
continueText,
179-
cancel
180-
)
182+
if (skipInstanceTypePrompts) {
183+
// User already consented, use minimum
184+
instanceType = InstanceTypeMinimum
185+
} else {
186+
// Prompt user to select the minimum supported instance type
187+
const confirmed = await showConfirmationMessage({
188+
prompt: InstanceTypeNotSelectedMessage(spaceDetails.SpaceName || ''),
189+
confirm: continueText,
190+
cancel: cancel,
191+
type: 'warning',
192+
})
181193

182-
if (response === cancel) {
183-
throw new ToolkitError('InstanceType not defined.', { code: InstanceTypeError })
184-
}
194+
if (!confirmed) {
195+
throw new ToolkitError('InstanceType not defined.', { code: InstanceTypeError })
196+
}
185197

186-
instanceType = InstanceTypeMinimum
198+
instanceType = InstanceTypeMinimum
199+
}
187200
}
188201

189202
// First, update the space if needed

0 commit comments

Comments
 (0)