Skip to content

Commit fddca6f

Browse files
fix(Sagemaker) : Improve error handling when the Space SSH host check fails
1 parent 1b7c4c4 commit fddca6f

File tree

6 files changed

+942
-9
lines changed

6 files changed

+942
-9
lines changed

packages/core/src/awsService/sagemaker/commands.ts

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,9 +30,28 @@ import {
3030
SpaceStatus,
3131
} from './constants'
3232
import { SagemakerUnifiedStudioSpaceNode } from '../../sagemakerunifiedstudio/explorer/nodes/sageMakerUnifiedStudioSpaceNode'
33+
import { SageMakerSshConfig } from './sshConfig'
34+
import { findSshPath } from '../../shared/utilities/pathFind'
3335

3436
const localize = nls.loadMessageBundle()
3537

38+
/**
39+
* Validates SSH configuration before starting connection.
40+
*/
41+
async function validateSshConfig(): Promise<void> {
42+
const sshPath = await findSshPath()
43+
if (!sshPath) {
44+
throw new ToolkitError(
45+
'SSH is required to connect to SageMaker spaces, but was not found.Install SSH to connect to spaces.'
46+
)
47+
}
48+
const sshConfig = new SageMakerSshConfig(sshPath, 'sm_', 'sagemaker_connect')
49+
const result = await sshConfig.ensureValid()
50+
if (result.isErr()) {
51+
throw result.err()
52+
}
53+
}
54+
3655
export async function filterSpaceAppsByDomainUserProfiles(parentNode: SagemakerParentNode): Promise<void> {
3756
if (parentNode.domainUserProfiles.size === 0) {
3857
// if parentNode has not been expanded, domainUserProfiles will be empty
@@ -107,6 +126,9 @@ export async function deeplinkConnect(
107126
return
108127
}
109128

129+
// Validate SSH config before attempting connection
130+
await validateSshConfig()
131+
110132
try {
111133
const remoteEnv = await prepareDevEnvConnection(
112134
connectionIdentifier,
@@ -301,6 +323,9 @@ async function handleRunningSpaceWithDisabledAccess(
301323
return
302324
}
303325

326+
// Validate SSH config before showing progress
327+
await validateSshConfig()
328+
304329
// Enable remote access and connect
305330
const client = sageMakerClient || new SagemakerClient(node.regionCode)
306331

@@ -359,6 +384,9 @@ async function handleStoppedSpace(
359384
spaceName: string,
360385
sageMakerClient?: SagemakerClient
361386
) {
387+
// Validate SSH config before showing progress
388+
await validateSshConfig()
389+
362390
const client = sageMakerClient || new SagemakerClient(node.regionCode)
363391

364392
try {
@@ -403,6 +431,9 @@ async function handleRunningSpaceWithEnabledAccess(
403431
spaceName: string,
404432
sageMakerClient?: SagemakerClient
405433
) {
434+
// Validate SSH config before showing progress
435+
await validateSshConfig()
436+
406437
return await vscode.window.withProgress(
407438
{
408439
location: vscode.ProgressLocation.Notification,

packages/core/src/awsService/sagemaker/model.ts

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88
import * as vscode from 'vscode'
99
import { sshAgentSocketVariable, startSshAgent, startVscodeRemote } from '../../shared/extensions/ssh'
1010
import { createBoundProcess, ensureDependencies } from '../../shared/remoteSession'
11-
import { SshConfig } from '../../shared/sshConfig'
1211
import * as path from 'path'
1312
import { persistLocalCredentials, persistSmusProjectCreds, persistSSMConnection } from './credentialMapping'
1413
import * as os from 'os'
@@ -91,13 +90,8 @@ export async function prepareDevEnvConnection(
9190
await startLocalServer(ctx)
9291
await removeKnownHost(hostname)
9392

94-
const sshConfig = new SshConfig(ssh, 'sm_', 'sagemaker_connect')
95-
const config = await sshConfig.ensureValid()
96-
if (config.isErr()) {
97-
const err = config.err()
98-
logger.error(`sagemaker: failed to add ssh config section: ${err.message}`)
99-
throw err
100-
}
93+
// Note: SSH config validation is done in commands.ts before calling this function
94+
// to ensure users can fix config issues before the progress dialog appears
10195

10296
// set envirionment variables
10397
const vars = getSmSsmEnv(ssm, path.join(ctx.globalStorageUri.fsPath, 'sagemaker-local-server-info.json'))

0 commit comments

Comments
 (0)