diff --git a/packages/core/src/awsService/sagemaker/commands.ts b/packages/core/src/awsService/sagemaker/commands.ts index 0c92a31c6f8..0a2f4a8c671 100644 --- a/packages/core/src/awsService/sagemaker/commands.ts +++ b/packages/core/src/awsService/sagemaker/commands.ts @@ -30,9 +30,28 @@ import { SpaceStatus, } from './constants' import { SagemakerUnifiedStudioSpaceNode } from '../../sagemakerunifiedstudio/explorer/nodes/sageMakerUnifiedStudioSpaceNode' +import { SageMakerSshConfig } from './sshConfig' +import { findSshPath } from '../../shared/utilities/pathFind' const localize = nls.loadMessageBundle() +/** + * Validates SSH configuration before starting connection. + */ +async function validateSshConfig(): Promise { + const sshPath = await findSshPath() + if (!sshPath) { + throw new ToolkitError( + 'SSH is required to connect to SageMaker spaces, but was not found.Install SSH to connect to spaces.' + ) + } + const sshConfig = new SageMakerSshConfig(sshPath, 'sm_', 'sagemaker_connect') + const result = await sshConfig.ensureValid() + if (result.isErr()) { + throw result.err() + } +} + export async function filterSpaceAppsByDomainUserProfiles(parentNode: SagemakerParentNode): Promise { if (parentNode.domainUserProfiles.size === 0) { // if parentNode has not been expanded, domainUserProfiles will be empty @@ -113,6 +132,9 @@ export async function deeplinkConnect( return } + // Validate SSH config before attempting connection + await validateSshConfig() + try { const remoteEnv = await prepareDevEnvConnection( connectionIdentifier, @@ -210,6 +232,9 @@ export async function openRemoteConnect( return } + // Validate SSH config before attempting connection + await validateSshConfig() + const spaceName = node.spaceApp.SpaceName! await tryRefreshNode(node) diff --git a/packages/core/src/awsService/sagemaker/constants.ts b/packages/core/src/awsService/sagemaker/constants.ts index 6e5f33195a0..4f9de1b1831 100644 --- a/packages/core/src/awsService/sagemaker/constants.ts +++ b/packages/core/src/awsService/sagemaker/constants.ts @@ -46,6 +46,25 @@ export const InstanceTypeNotSelectedMessage = (spaceName: string) => { export const RemoteAccessRequiredMessage = 'This space requires remote access to be enabled.\nWould you like to restart the space and connect?\nAny unsaved work will be lost.' +// SSH Configuration Error Messages +export const SshConfigUpdateDeclinedMessage = (configHostName: string, configPath: string) => + `SSH configuration has an outdated ${configHostName} section. Fix your ${configPath} file manually to enable remote connection.` + +export const SshConfigOpenedForEditMessage = () => + `SSH configuration file opened for editing. Fix the issue and try connecting again.` + +export const SshConfigSyntaxErrorMessage = (configPath: string) => + `SSH configuration has syntax errors in your ${configPath} file. Fix the configuration manually to enable remote connection.` + +export const SshConfigRemovalFailedMessage = (configHostName: string) => + `Failed to remove SSH config section for ${configHostName}` + +export const SshConfigUpdateFailedMessage = (configPath: string, configHostName: string) => + `Failed to update SSH config section. Fix your ${configPath} file manually or remove the outdated ${configHostName} section.` + +export const SshConfigModifiedMessage = (configHostName: string) => + `SSH config section for ${configHostName} has been modified. Manually remove the section and try again.` + export const SmusDeeplinkSessionExpiredError = { title: 'Session Disconnected', message: diff --git a/packages/core/src/awsService/sagemaker/model.ts b/packages/core/src/awsService/sagemaker/model.ts index a9ab87647bf..678b785e6af 100644 --- a/packages/core/src/awsService/sagemaker/model.ts +++ b/packages/core/src/awsService/sagemaker/model.ts @@ -8,7 +8,6 @@ import * as vscode from 'vscode' import { sshAgentSocketVariable, startSshAgent, startVscodeRemote } from '../../shared/extensions/ssh' import { createBoundProcess, ensureDependencies } from '../../shared/remoteSession' -import { SshConfig } from '../../shared/sshConfig' import * as path from 'path' import { persistLocalCredentials, persistSmusProjectCreds, persistSSMConnection } from './credentialMapping' import * as os from 'os' @@ -91,14 +90,6 @@ export async function prepareDevEnvConnection( await startLocalServer(ctx) await removeKnownHost(hostname) - const sshConfig = new SshConfig(ssh, 'sm_', 'sagemaker_connect') - const config = await sshConfig.ensureValid() - if (config.isErr()) { - const err = config.err() - logger.error(`sagemaker: failed to add ssh config section: ${err.message}`) - throw err - } - // set envirionment variables const vars = getSmSsmEnv(ssm, path.join(ctx.globalStorageUri.fsPath, 'sagemaker-local-server-info.json')) logger.info(`connect script logs at ${vars.LOG_FILE_LOCATION}`) diff --git a/packages/core/src/awsService/sagemaker/sshConfig.ts b/packages/core/src/awsService/sagemaker/sshConfig.ts new file mode 100644 index 00000000000..db5b3b23ae4 --- /dev/null +++ b/packages/core/src/awsService/sagemaker/sshConfig.ts @@ -0,0 +1,472 @@ +/*! + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +import * as vscode from 'vscode' +import { SshConfig } from '../../shared/sshConfig' +import { Result } from '../../shared/utilities/result' +import { ToolkitError } from '../../shared/errors' +import { getLogger } from '../../shared/logger/logger' +import { getIdeProperties } from '../../shared/extensionUtilities' +import { showConfirmationMessage } from '../../shared/utilities/messages' +import { CancellationError } from '../../shared/utilities/timeoutUtils' +import { getSshConfigPath } from '../../shared/extensions/ssh' +import { fileExists, readFileAsString } from '../../shared/filesystemUtilities' +import fs from '../../shared/fs/fs' +import { + SshConfigUpdateDeclinedMessage, + SshConfigOpenedForEditMessage, + SshConfigSyntaxErrorMessage, + SshConfigRemovalFailedMessage, + SshConfigUpdateFailedMessage, + SshConfigModifiedMessage, +} from './constants' + +const logger = getLogger('sagemaker') + +/** + * SageMaker-specific SSH configuration that handles outdated config detection and updates. + * Extends the base SshConfig with SageMaker-specific validation logic. + */ +export class SageMakerSshConfig extends SshConfig { + public override async verifySSHHost(proxyCommand: string) { + // Read the current state of SSH config + const configStateResult = await this.readSshConfigState(proxyCommand) + + // If reading config state failed, return the error result + if (configStateResult.isErr()) { + return configStateResult + } + + // Extract the state if section exists and if it's outdated + const configState = configStateResult.ok() + + // Check if section exists and is outdated + if (configState.hasSshSection && configState.isOutdated) { + const updateResult = await this.updateOutdatedSection(proxyCommand) + if (updateResult.isErr()) { + return updateResult + } + } + + // Run validation + const matchResult = await this.matchSshSection() + if (matchResult.isErr()) { + const sshError = matchResult.err() + + // Check if SM section either existed before or just created) + const hasSshSection = configState.hasSshSection || !configState.isOutdated + + if (hasSshSection) { + // Section exists and should be up-to-date, but validation still failed + // This means the error is elsewhere in the SSH config + try { + await this.promptOtherSshConfigError(sshError) + const configOpenedError = new ToolkitError(SshConfigOpenedForEditMessage(), { + code: 'SshConfigOpenedForEdit', + details: { configPath: getSshConfigPath() }, + }) + return Result.err(configOpenedError) + } catch (e) { + // User cancelled the "Open SSH Config" prompt (from promptOtherSshConfigError) + if (e instanceof CancellationError) { + const configPath = getSshConfigPath() + const externalConfigError = new ToolkitError(SshConfigSyntaxErrorMessage(configPath), { + code: 'SshConfigExternalError', + details: { configPath }, + }) + return Result.err(externalConfigError) + } + return Result.err( + ToolkitError.chain(e, 'Unexpected error while handling SSH config error', { + code: 'SshConfigErrorHandlingFailed', + }) + ) + } + } + return matchResult + } + + const configSection = matchResult.ok() + const hasProxyCommand = configSection?.includes(proxyCommand) + + if (!hasProxyCommand) { + try { + await this.promptUserToConfigureSshConfig(configSection, proxyCommand) + } catch (e) { + return Result.err( + ToolkitError.chain(e, 'Failed to configure SSH config', { + code: 'SshConfigPromptFailed', + }) + ) + } + } + + return Result.ok() + } + + /** + * Reads SSH config file once and determines its current state. + * + * State represents the current condition of the SSH config: + * - hasSshSection: Does the sm_* section exist in the file? + * - isOutdated: Is the section in an old/incorrect format? + * - existingSection: The actual content of the section (if it exists) + * + * @returns Result containing the state object or an error if file read fails + */ + public async readSshConfigState(proxyCommand: string): Promise< + Result< + { + hasSshSection: boolean // True if sm_* section exists + isOutdated: boolean // True if section needs updating + existingSection?: string // Current section content (optional) + }, + ToolkitError + > + > { + const sshConfigPath = getSshConfigPath() + + // File not existing + if (!(await fileExists(sshConfigPath))) { + return Result.ok({ hasSshSection: false, isOutdated: false }) + } + + try { + const configContent = await readFileAsString(sshConfigPath) + + // Extract the toolkit section + const existingSection = this.extractToolkitSection(configContent) + + if (!existingSection) { + return Result.ok({ hasSshSection: false, isOutdated: false }) + } + + // Generate the expected current version + const expectedSection = this.createSSHConfigSection(proxyCommand).trim() + + // Compare existing vs expected to check if outdated + const normalizeWhitespace = (str: string) => str.replace(/\s+/g, ' ').trim() + const isOutdated = normalizeWhitespace(existingSection) !== normalizeWhitespace(expectedSection) + + return Result.ok({ hasSshSection: true, isOutdated, existingSection }) + } catch (e) { + return Result.err( + ToolkitError.chain(e, 'Failed to read SSH config file', { + code: 'SshConfigReadFailed', + details: { configPath: sshConfigPath }, + }) + ) + } + } + + /** + * Handles updating an outdated SSH config section. + * Prompts user, removes old section, writes new section. + * + * @returns Result.ok() if updated successfully, Result.err() if user declined or update failed + */ + private async updateOutdatedSection(proxyCommand: string): Promise> { + const shouldUpdate = await this.promptToUpdateSshConfig() + + if (!shouldUpdate) { + // User declined the auto-update + const configPath = getSshConfigPath() + return Result.err( + new ToolkitError(SshConfigUpdateDeclinedMessage(this.configHostName, configPath), { + code: 'SshConfigUpdateDeclined', + details: { configHostName: this.configHostName, configPath }, + }) + ) + } + + try { + // Remove the outdated section + await this.removeSshConfigSection() + // Write the new section + await this.writeSectionToConfig(proxyCommand) + logger.info('Successfully updated sm_* section') + return Result.ok() + } catch (e) { + // Failed to update, handle the failure + return await this.handleSshConfigUpdateFailure(e) + } + } + + /** + * Prompts user to update the outdated SSH config section. + * This is shown when the host section exists but is outdated. + */ + private async promptToUpdateSshConfig(): Promise { + logger.warn(`Section is outdated for ${this.configHostName}`) + + const configPath = getSshConfigPath() + const confirmTitle = `${getIdeProperties().company} Toolkit will update the ${this.configHostName} section in ${configPath}` + const confirmText = 'Update SSH config' + + const response = await showConfirmationMessage({ prompt: confirmTitle, confirm: confirmText }) + + return response === true + } + + /** + * Prompts user when automatic SSH config update fails. + * @throws CancellationError if user cancels + */ + public async promptToFixUpdateFailure(updateError?: Error): Promise { + const sshConfigPath = getSshConfigPath() + + // Include error details if available + let errorDetails = '' + if (updateError?.message) { + errorDetails = `\n\nError: ${updateError.message}` + } + + const message = `Failed to update your ${sshConfigPath} file automatically.${errorDetails}\n\nOpen the file to fix the issue manually.` + + const openButton = 'Open SSH Config' + const cancelButton = 'Cancel' + + const response = await vscode.window.showErrorMessage(message, openButton, cancelButton) + + // User clicked Cancel or closed the dialog + if (response !== openButton) { + throw new CancellationError('user') + } + + await vscode.window.showTextDocument(vscode.Uri.file(sshConfigPath)) + } + + /** + * Prompts user when SSH config has errors elsewhere (not in toolkit's section). + * @throws CancellationError if user cancels + */ + public async promptOtherSshConfigError(sshError?: Error): Promise { + const sshConfigPath = getSshConfigPath() + + // Extract line number from SSH error message (best-effort). + // Note: SSH error formats are not standardized and may vary across implementations. + let errorDetails = '' + if (sshError?.message) { + const lineMatch = sshError.message.match(/line (\d+)/i) + if (lineMatch) { + errorDetails = `\n\nError at line ${lineMatch[1]}` + } + } + + const message = `There is an error in your ${sshConfigPath} file.${errorDetails}\n\nFix the error and try again.` + + const openButton = 'Open SSH Config' + const cancelButton = 'Cancel' + + const response = await vscode.window.showErrorMessage(message, openButton, cancelButton) + + // User clicked Cancel or closed the dialog + if (response !== openButton) { + throw new CancellationError('user') + } + + await vscode.window.showTextDocument(vscode.Uri.file(sshConfigPath)) + } + + /** + * Extracts the toolkit-managed SSH config section from the config content. + * returns Object with fullSection((comment + Host + directives)) and hostSection(Host + directives), or null if not found + */ + private extractToolkitSection(configContent: string): string | undefined { + const lines = configContent.split('\n') + let startIndex = -1 + let endIndex = -1 + + // Find the toolkit comment marker + for (let i = 0; i < lines.length; i++) { + if (lines[i].includes('# Created by AWS Toolkit')) { + startIndex = i + break + } + } + + if (startIndex === -1) { + return undefined + } + + // Check if next line is our Host directive + if (startIndex + 1 >= lines.length) { + return undefined + } + + const hostLine = lines[startIndex + 1] + if (!hostLine.trim().startsWith(`Host ${this.configHostName}`)) { + return undefined + } + + // Extract all indented lines (directives) after the Host line + // Stop at: blank line, non-indented line, or another Host directive + endIndex = startIndex + 2 // Start after Host line + + for (let i = startIndex + 2; i < lines.length; i++) { + const line = lines[i] + const trimmed = line.trim() + + // Stop at blank line + if (trimmed === '') { + endIndex = i + break + } + + // Stop at another Host directive + if (trimmed.startsWith('Host ')) { + endIndex = i + break + } + + // Stop at non-indented line (comment or another section) + if (line.length > 0 && line[0] !== ' ' && line[0] !== '\t') { + endIndex = i + break + } + + endIndex = i + 1 + } + + // Extract the full section (comment + Host + directives) + return lines.slice(startIndex, endIndex).join('\n') + } + + /** + * Removes the toolkit-managed SSH config section using version matching. + * + * This method checks for exact matches against known toolkit-generated configs + * to ensure we only remove content we created, not user-defined content. + */ + public async removeSshConfigSection(): Promise { + const sshConfigPath = getSshConfigPath() + + if (!(await fileExists(sshConfigPath))) { + logger.info('Config file does not exist, nothing to remove') + return + } + + try { + const configContent = await readFileAsString(sshConfigPath) + const extractedSection = this.extractToolkitSection(configContent) + + if (!extractedSection) { + logger.warn(`No ${this.configHostName} section found to remove`) + return + } + + // Get the proxy command from the extracted section + const proxyCommandMatch = extractedSection.match(/ProxyCommand\s+(.+)/) + if (!proxyCommandMatch) { + logger.warn('Could not extract ProxyCommand from section, skipping removal') + return + } + const proxyCommand = proxyCommandMatch[1].trim() + + // Check against known versions + const knownVersions = [ + this.createSSHConfigSection(proxyCommand).trim(), // Current version + this.createSSHConfigV1(proxyCommand).trim(), // Old version with User '%r' + ] + + const normalizeWhitespace = (str: string) => str.replace(/\s+/g, ' ').trim() + const extractedNormalized = normalizeWhitespace(extractedSection) + + let matchedVersion: string | undefined + for (const knownVersion of knownVersions) { + if (normalizeWhitespace(knownVersion) === extractedNormalized) { + matchedVersion = extractedSection + break + } + } + + if (!matchedVersion) { + // Section doesn't match any known version - likely user-modified + // Throw error so handleSshConfigUpdateFailure() prompts user to fix manually + throw new ToolkitError(SshConfigModifiedMessage(this.configHostName), { + code: 'SshConfigModified', + }) + } + + const updatedContent = configContent.replace(matchedVersion, '') + + await fs.writeFile(sshConfigPath, updatedContent, { atomic: true }) + + logger.info(`Removed ${this.configHostName} section`) + } catch (e) { + throw ToolkitError.chain(e, SshConfigRemovalFailedMessage(this.configHostName), { + code: 'SshConfigRemovalFailed', + }) + } + } + + /** + * Handles SSH config update failure by prompting user to fix manually. + */ + private async handleSshConfigUpdateFailure(updateError: unknown): Promise> { + try { + // Prompt user to open SSH config file to fix manually + await this.promptToFixUpdateFailure(updateError instanceof Error ? updateError : undefined) + + // User opened the file + const configOpenedError = new ToolkitError(SshConfigOpenedForEditMessage(), { + code: 'SshConfigOpenedForEdit', + details: { configPath: getSshConfigPath() }, + }) + return Result.err(configOpenedError) + } catch (promptError) { + // User cancelled the "Open SSH Config" prompt (from promptToFixUpdateFailure) + if (promptError instanceof CancellationError) { + const configPath = getSshConfigPath() + return Result.err( + ToolkitError.chain(updateError, SshConfigUpdateFailedMessage(configPath, this.configHostName), { + code: 'SshConfigUpdateFailed', + details: { + configHostName: this.configHostName, + configPath, + }, + }) + ) + } + + // Unexpected error during prompt + return Result.err( + ToolkitError.chain(promptError, 'Unexpected error while handling SSH config update failure', { + code: 'SshConfigErrorHandlingFailed', + }) + ) + } + } + + /** + * Generates old version 1 SSH config (with User '%r' directive). + * This was the format used in earlier versions of the toolkit. + */ + private createSSHConfigV1(proxyCommand: string): string { + return ` +# Created by AWS Toolkit for VSCode. https://github.com/aws/aws-toolkit-vscode +Host ${this.configHostName} + ForwardAgent yes + AddKeysToAgent yes + StrictHostKeyChecking accept-new + ProxyCommand ${proxyCommand} + User '%r' + ` + } + + /** + * Creates SageMaker-specific SSH config section (current version). + */ + protected override createSSHConfigSection(proxyCommand: string): string { + return ` +# Created by AWS Toolkit for VSCode. https://github.com/aws/aws-toolkit-vscode +Host ${this.configHostName} + ForwardAgent yes + AddKeysToAgent yes + StrictHostKeyChecking accept-new + ProxyCommand ${proxyCommand} + ` + } +} diff --git a/packages/core/src/shared/sshConfig.ts b/packages/core/src/shared/sshConfig.ts index 92b32666b06..e214eac3e1d 100644 --- a/packages/core/src/shared/sshConfig.ts +++ b/packages/core/src/shared/sshConfig.ts @@ -85,7 +85,15 @@ export class SshConfig { protected async matchSshSection() { const result = await this.checkSshOnHost() if (result.exitCode !== 0) { - return Result.err(result.error ?? new Error(`ssh check against host failed: ${result.exitCode}`)) + const errorMessage = result.stderr || `ssh check against host failed: ${result.exitCode}` + + if (result.error) { + // System level error + return Result.err(ToolkitError.chain(result.error, errorMessage)) + } + + // SSH ran but returned error exit code (config error, validation failed) + return Result.err(new ToolkitError(errorMessage)) } const matches = result.stdout.match(this.proxyCommandRegExp) return Result.ok(matches?.[0]) @@ -195,21 +203,8 @@ Host ${this.configHostName} ` } - private getSageMakerSSHConfig(proxyCommand: string): string { - return ` -# Created by AWS Toolkit for VSCode. https://github.com/aws/aws-toolkit-vscode -Host ${this.configHostName} - ForwardAgent yes - AddKeysToAgent yes - StrictHostKeyChecking accept-new - ProxyCommand ${proxyCommand} - ` - } - protected createSSHConfigSection(proxyCommand: string): string { - if (this.scriptPrefix === 'sagemaker_connect') { - return `${this.getSageMakerSSHConfig(proxyCommand)}` - } else if (this.keyPath) { + if (this.keyPath) { return `${this.getBaseSSHConfig(proxyCommand)}IdentityFile '${this.keyPath}'\n User '%r'\n` } return this.getBaseSSHConfig(proxyCommand) diff --git a/packages/core/src/test/awsService/sagemaker/commands.test.ts b/packages/core/src/test/awsService/sagemaker/commands.test.ts index fd835cfe79e..03f4b54dea1 100644 --- a/packages/core/src/test/awsService/sagemaker/commands.test.ts +++ b/packages/core/src/test/awsService/sagemaker/commands.test.ts @@ -46,6 +46,7 @@ describe('SageMaker Commands', () => { let mockTryRefreshNode: sinon.SinonStub let mockTryRemoteConnection: sinon.SinonStub let mockIsRemoteWorkspace: sinon.SinonStub + let mockValidateSshConfig: sinon.SinonStub let openRemoteConnect: typeof openRemoteConnectStatic beforeEach(() => { @@ -83,10 +84,70 @@ describe('SageMaker Commands', () => { ) sandbox.replace(require('../../../shared/vscode/env'), 'isRemoteWorkspace', mockIsRemoteWorkspace) + // Mock SSH validation components BEFORE loading commands module + const mockFindSshPath = sandbox.stub().resolves('/usr/bin/ssh') + sandbox.replace(require('../../../shared/utilities/pathFind'), 'findSshPath', mockFindSshPath) + + // Mock file system operations to prevent actual SSH config reading + const mockFileExists = sandbox.stub().resolves(false) // No SSH config file exists + sandbox.replace(require('../../../shared/filesystemUtilities'), 'fileExists', mockFileExists) + + // Mock SageMakerSshConfig.ensureValid BEFORE loading commands module + const sshConfigModule = require('../../../awsService/sagemaker/sshConfig') + mockValidateSshConfig = sandbox.stub(sshConfigModule.SageMakerSshConfig.prototype, 'ensureValid') + mockValidateSshConfig.resolves({ + isOk: () => true, + isErr: () => false, + ok: () => undefined, + unwrap: () => undefined, + }) + + // load the commands module const freshModule = require('../../../awsService/sagemaker/commands') openRemoteConnect = freshModule.openRemoteConnect }) + describe('SSH validation', () => { + it('calls ensureValid before attempting connection', async () => { + mockNode.getStatus.returns('Running') + mockNode.spaceApp.SpaceSettingsSummary.RemoteAccess = 'ENABLED' + + await openRemoteConnect(mockNode, {} as any, mockClient) + + // Verify ensureValid was called + assert(mockValidateSshConfig.called, 'ensureValid should be called') + + // Verify it was called before tryRemoteConnection + assert( + mockValidateSshConfig.calledBefore(mockTryRemoteConnection), + 'ensureValid should be called before connection attempt' + ) + }) + + it('does not attempt connection if ensureValid fails', async () => { + mockNode.getStatus.returns('Running') + mockNode.spaceApp.SpaceSettingsSummary.RemoteAccess = 'ENABLED' + + // Make ensureValid fail + mockValidateSshConfig.resetBehavior() + mockValidateSshConfig.resolves({ + isOk: () => false, + isErr: () => true, + err: () => new Error('SSH validation failed'), + }) + + try { + await openRemoteConnect(mockNode, {} as any, mockClient) + assert.fail('Should have thrown error') + } catch (error) { + // Expected to throw + } + + // Verify connection was NOT attempted + assert(mockTryRemoteConnection.notCalled, 'Connection should not be attempted when ensureValid fails') + }) + }) + describe('handleRunningSpaceWithDisabledAccess', () => { beforeEach(() => { mockNode.getStatus.returns('Running') diff --git a/packages/core/src/test/awsService/sagemaker/sshConfig.test.ts b/packages/core/src/test/awsService/sagemaker/sshConfig.test.ts new file mode 100644 index 00000000000..dd2339bb47d --- /dev/null +++ b/packages/core/src/test/awsService/sagemaker/sshConfig.test.ts @@ -0,0 +1,601 @@ +/*! + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ +import assert from 'assert' +import * as sinon from 'sinon' +import * as path from 'path' +import { ToolkitError } from '../../../shared/errors' +import { Result } from '../../../shared/utilities/result' +import { SageMakerSshConfig } from '../../../awsService/sagemaker/sshConfig' +import { makeTemporaryToolkitFolder } from '../../../shared/filesystemUtilities' +import { CancellationError } from '../../../shared/utilities/timeoutUtils' +import fs from '../../../shared/fs/fs' +import { getTestWindow } from '../../shared/vscode/window' +import { SshConfigOpenedForEditMessage } from '../../../awsService/sagemaker/constants' + +describe('SageMakerSshConfig', function () { + let sandbox: sinon.SinonSandbox + let config: SageMakerSshConfig + let tempDir: string + let sshConfigPath: string + const testProxyCommand = "'sagemaker_connect' '%n'" + + beforeEach(async function () { + sandbox = sinon.createSandbox() + tempDir = await makeTemporaryToolkitFolder() + sshConfigPath = path.join(tempDir, 'config') + + // Mock getSshConfigPath to use temp directory + sandbox.stub(require('../../../shared/extensions/ssh'), 'getSshConfigPath').returns(sshConfigPath) + + config = new SageMakerSshConfig('/usr/bin/ssh', 'sm_', 'sagemaker_connect') + }) + + afterEach(async function () { + sandbox.restore() + getTestWindow().dispose() + if (tempDir) { + await fs.delete(tempDir, { recursive: true }) + } + }) + + async function writeTestSshConfig(content: string): Promise { + await fs.writeFile(sshConfigPath, content) + } + + describe('readSshConfigState', function () { + /** + * Test: No SSH config file exists + * Expected: Returns Ok with hasSshSection=false, isOutdated=false + */ + it('returns Ok with false state when file does not exist', async function () { + const result = await config.readSshConfigState(testProxyCommand) + + assert.ok(result.isOk()) + const state = result.ok() + assert.strictEqual(state.hasSshSection, false) + assert.strictEqual(state.isOutdated, false) + }) + + /** + * Test: SSH config has sm_* section with old "User '%r'" directive + * Expected: Returns Ok with hasSshSection=true, isOutdated=true + */ + it('returns Ok with outdated=true when section has old User field', async function () { + await writeTestSshConfig(`# Created by AWS Toolkit +Host sm_* + ProxyCommand 'sagemaker_connect' '%n' + User '%r' +`) + + const result = await config.readSshConfigState(testProxyCommand) + + assert.ok(result.isOk()) + const state = result.ok() + assert.strictEqual(state.hasSshSection, true) + assert.strictEqual(state.isOutdated, true) + }) + + /** + * Test: SSH config has sm_* section with correct/updated format + * Expected: Returns Ok with hasSshSection=true, isOutdated=false + */ + it('returns Ok with outdated=false when section is current', async function () { + await writeTestSshConfig(`# Created by AWS Toolkit for VSCode. https://github.com/aws/aws-toolkit-vscode +Host sm_* + ForwardAgent yes + AddKeysToAgent yes + StrictHostKeyChecking accept-new + ProxyCommand 'sagemaker_connect' '%n' + `) + + const result = await config.readSshConfigState(testProxyCommand) + + assert.ok(result.isOk()) + const state = result.ok() + assert.strictEqual(state.hasSshSection, true) + assert.strictEqual(state.isOutdated, false) + }) + + /** + * Test: SSH config file doesn't end with newline + * Expected: Correctly detects section even without trailing newline + */ + it('handles files without trailing newline (EOF issue)', async function () { + const configContent = `# Created by AWS Toolkit for VSCode. https://github.com/aws/aws-toolkit-vscode +Host sm_* + ForwardAgent yes + AddKeysToAgent yes + StrictHostKeyChecking accept-new + ProxyCommand 'sagemaker_connect' '%n' + ` + await fs.writeFile(sshConfigPath, configContent.trimEnd(), { flag: 'w' }) + + const result = await config.readSshConfigState(testProxyCommand) + + assert.ok(result.isOk()) + const state = result.ok() + assert.strictEqual(state.hasSshSection, true) + assert.strictEqual(state.isOutdated, false) + }) + + /** + * Test: SSH config has sm_* section but missing some directives + * Expected: Detects as outdated because it doesn't match expected format + */ + it('detects outdated config with different whitespace', async function () { + await writeTestSshConfig(`# Created by AWS Toolkit +Host sm_* + ProxyCommand 'sagemaker_connect' '%n' + ForwardAgent yes +`) + + const result = await config.readSshConfigState(testProxyCommand) + + assert.ok(result.isOk()) + const state = result.ok() + assert.strictEqual(state.hasSshSection, true) + assert.strictEqual(state.isOutdated, true) + }) + }) + + describe('verifySSHHost', function () { + let removeStub: sinon.SinonStub + let writeStub: sinon.SinonStub + let matchStub: sinon.SinonStub + + beforeEach(function () { + removeStub = sandbox.stub(config, 'removeSshConfigSection') + writeStub = sandbox.stub(config as any, 'writeSectionToConfig') + matchStub = sandbox.stub(config as any, 'matchSshSection') + }) + + /** + * Test: checks for outdated config BEFORE SSH validation + * Expected: Section is updated before validation runs + */ + it('checks for outdated config BEFORE running SSH validation', async function () { + await writeTestSshConfig(`# Created by AWS Toolkit for VSCode. https://github.com/aws/aws-toolkit-vscode +Host sm_* + ForwardAgent yes + AddKeysToAgent yes + StrictHostKeyChecking accept-new + ProxyCommand 'sagemaker_connect' '%n' + User '%r' + `) + + getTestWindow().onDidShowMessage((message) => { + if (message.items.some((item) => item.title === 'Update SSH config')) { + message.selectItem('Update SSH config') + } + }) + removeStub.resolves() + writeStub.resolves() + matchStub.resolves(Result.ok(`Host sm_*\n ProxyCommand ${testProxyCommand}`)) + + await config.verifySSHHost(testProxyCommand) + + assert(removeStub.calledBefore(matchStub), 'Section should be updated before validation runs') + }) + + /** + * Test: User accepts the update prompt for outdated config + * Expected: Removes old section, writes new section, returns Ok + */ + it('prompts user to update when config is outdated', async function () { + await writeTestSshConfig(`# Created by AWS Toolkit for VSCode. https://github.com/aws/aws-toolkit-vscode +Host sm_* + ForwardAgent yes + AddKeysToAgent yes + StrictHostKeyChecking accept-new + ProxyCommand 'sagemaker_connect' '%n' + User '%r' + `) + + getTestWindow().onDidShowMessage((message) => { + if (message.items.some((item) => item.title === 'Update SSH config')) { + message.selectItem('Update SSH config') + } + }) + + removeStub.resolves() + writeStub.resolves() + // After update, matchSshSection should return section with the proxy command + matchStub.resolves(Result.ok(`Host sm_*\n ProxyCommand ${testProxyCommand}`)) + + const result = await config.verifySSHHost(testProxyCommand) + + assert.ok(result.isOk(), `Expected Ok but got: ${result.isErr() ? result.err().message : 'unknown'}`) + assert(removeStub.calledOnce, 'Should remove old section once') + assert(writeStub.calledOnce, 'Should write new section once') + }) + + /** + * Test: User clicks "Cancel" when prompted to update outdated config + * Expected: Returns error with code 'SshConfigUpdateDeclined' + */ + it('returns error when user declines update', async function () { + await writeTestSshConfig(`# Created by AWS Toolkit for VSCode. https://github.com/aws/aws-toolkit-vscode +Host sm_* + ForwardAgent yes + AddKeysToAdmin yes + StrictHostKeyChecking accept-new + ProxyCommand 'sagemaker_connect' '%n' + User '%r' + `) + + // User clicks Cancel + getTestWindow().onDidShowMessage((message) => { + message.selectItem('Cancel') + }) + + const result = await config.verifySSHHost(testProxyCommand) + assert.ok(result.isErr()) + + const error = result.err() + assert.ok(error instanceof ToolkitError) + + assert.strictEqual(error.code, 'SshConfigUpdateDeclined') + assert(removeStub.notCalled, 'Should not remove section when user declines') + assert(writeStub.notCalled, 'Should not write section when user declines') + }) + + /** + * Test: SSH validation fails due to error elsewhere in config (not in sm_* section) + * Expected: Extracts line number from error, prompts user to fix external error + */ + it('shows helpful error with line number when external error exists', async function () { + await writeTestSshConfig(`# Created by AWS Toolkit for VSCode. https://github.com/aws/aws-toolkit-vscode +Host sm_* + ForwardAgent yes + AddKeysToAgent yes + StrictHostKeyChecking accept-new + ProxyCommand 'sagemaker_connect' '%n' + +Host github.com + InvalidDirective bad-value +`) + + // Mock SSH validation to fail with line number + matchStub.resolves( + Result.err(new Error('~/.ssh/config: line 9: Bad configuration option: InvalidDirective')) + ) + + const promptErrorStub = sandbox.stub(config as any, 'promptOtherSshConfigError') + promptErrorStub.rejects(new CancellationError('user')) + + const result = await config.verifySSHHost(testProxyCommand) + + assert.ok(result.isErr()) + assert(promptErrorStub.calledOnce, 'Should prompt about external error') + + // Verify error message was passed with line number + const errorArg = promptErrorStub.firstCall.args[0] + assert.ok(errorArg.message.includes('line 9'), 'Error should include line number') + }) + + /** + * Test: Happy path - config is up-to-date and SSH validation succeeds + * Expected: No prompts shown, validation runs successfully, returns Ok + */ + it('handles successful validation when config is up-to-date', async function () { + await writeTestSshConfig(`# Created by AWS Toolkit for VSCode. https://github.com/aws/aws-toolkit-vscode +Host sm_* + ForwardAgent yes + AddKeysToAgent yes + StrictHostKeyChecking accept-new + ProxyCommand 'sagemaker_connect' '%n' + `) + + matchStub.resolves(Result.ok("Host sm_*\n ProxyCommand 'sagemaker_connect' '%n'")) + + const result = await config.verifySSHHost(testProxyCommand) + + assert.ok(result.isOk()) + assert(removeStub.notCalled, 'Should not update when config is up-to-date') + assert(matchStub.calledOnce, 'Should run SSH validation') + }) + }) + + describe('removeSshConfigSection', function () { + /** + * Test: Removes only the sm_* section, preserves other sections + * Expected: sm_* section and toolkit comment removed, other sections intact + */ + it('removes the sm_* section from config', async function () { + await writeTestSshConfig(`# Some other config +Host github.com + User git + +# Created by AWS Toolkit for VSCode. https://github.com/aws/aws-toolkit-vscode +Host sm_* + ForwardAgent yes + AddKeysToAgent yes + StrictHostKeyChecking accept-new + ProxyCommand 'sagemaker_connect' '%n' + User '%r' + +Host another.com + User test +`) + + await config.removeSshConfigSection() + + const content = await fs.readFileText(sshConfigPath) + assert.ok(!content.includes('Host sm_*'), 'Should remove sm_* section') + assert.ok(!content.includes('Created by AWS Toolkit'), 'Should remove toolkit comment') + assert.ok(content.includes('Host github.com'), 'Should keep other sections') + assert.ok(content.includes('Host another.com'), 'Should keep other sections') + }) + + /** + * Test: Attempts to remove section when it doesn't exist in config + * Expected: No error thrown, existing content preserved + */ + it('handles missing section gracefully', async function () { + await writeTestSshConfig(`Host github.com + User git +`) + + await config.removeSshConfigSection() + + const content = await fs.readFileText(sshConfigPath) + assert.ok(content.includes('Host github.com'), 'Should keep existing content') + }) + + /** + * Test: Removes section from file without trailing newline (EOF edge case) + * Expected: Section removed correctly even without trailing newline + */ + it('handles files without trailing newline', async function () { + const configContent = `# Created by AWS Toolkit for VSCode. https://github.com/aws/aws-toolkit-vscode +Host sm_* + ForwardAgent yes + AddKeysToAgent yes + StrictHostKeyChecking accept-new + ProxyCommand 'sagemaker_connect' '%n' + User '%r' + ` + await fs.writeFile(sshConfigPath, configContent.trimEnd(), { flag: 'w' }) + + await config.removeSshConfigSection() + + const content = await fs.readFileText(sshConfigPath) + assert.strictEqual(content.trim(), '', 'Should remove section even without trailing newline') + }) + + /** + * Test: Versioned matching - removes old v1 format (with User '%r') + * Expected: Old format section is recognized and removed + */ + it('removes old v1 format with User directive', async function () { + await writeTestSshConfig(`# Created by AWS Toolkit for VSCode. https://github.com/aws/aws-toolkit-vscode +Host sm_* + ForwardAgent yes + AddKeysToAgent yes + StrictHostKeyChecking accept-new + ProxyCommand 'sagemaker_connect' '%n' + User '%r' + +Host github.com + User git +`) + + await config.removeSshConfigSection() + + const content = await fs.readFileText(sshConfigPath) + assert.ok(!content.includes('Host sm_*'), 'Should remove old v1 section') + assert.ok(!content.includes("User '%r'"), 'Should remove User directive') + assert.ok(content.includes('Host github.com'), 'Should keep other sections') + }) + + /** + * Test: Versioned matching - does NOT remove user-modified section + * Expected: Throws error, section is not removed + */ + it('throws error for user-modified section', async function () { + await writeTestSshConfig(`# Created by AWS Toolkit for VSCode. https://github.com/aws/aws-toolkit-vscode +Host sm_* + ForwardAgent yes + ServerAliveInterval 60 + ProxyCommand 'sagemaker_connect' '%n' + +Host github.com + User git +`) + + await assert.rejects( + async () => await config.removeSshConfigSection(), + (error: Error) => { + assert.ok(error instanceof ToolkitError) + // Error is wrapped in SshConfigRemovalFailed, check the cause + const toolkitError = error as ToolkitError + assert.strictEqual(toolkitError.code, 'SshConfigRemovalFailed') + assert.ok(toolkitError.cause instanceof ToolkitError) + assert.strictEqual((toolkitError.cause as ToolkitError).code, 'SshConfigModified') + return true + }, + 'Should throw SshConfigRemovalFailed with SshConfigModified cause' + ) + + // Verify section was NOT removed + const content = await fs.readFileText(sshConfigPath) + assert.ok(content.includes('Host sm_*'), 'Should NOT remove modified section') + assert.ok(content.includes('ServerAliveInterval 60'), 'User customization should remain') + }) + }) + + describe('promptOtherSshConfigError', function () { + /** + * Test: SSH error message contains line number (e.g., "line 42") + * Expected: Extracts line number and includes it in error message (but doesn't navigate cursor) + */ + it('extracts and displays line number from SSH error', async function () { + await writeTestSshConfig(`# Some SSH config +Host github.com + User git +`) + + const sshError = new Error('~/.ssh/config: line 42: Bad configuration option: InvalidDirective') + + getTestWindow().onDidShowMessage((message) => { + assert.ok(message.message.includes('line 42'), 'Should include line number in error message') + message.selectItem('Open SSH Config') + }) + + await config.promptOtherSshConfigError(sshError) + + const messages = getTestWindow().shownMessages + assert(messages.length > 0, 'Should show error message') + }) + + /** + * Test: User clicks "Cancel" when prompted about external SSH error + * Expected: Throws CancellationError to signal user cancellation + */ + it('throws CancellationError when user cancels', async function () { + const sshError = new Error('SSH error') + + getTestWindow().onDidShowMessage((message) => { + message.selectItem('Cancel') + }) + + try { + await config.promptOtherSshConfigError(sshError) + assert.fail('Should have thrown CancellationError') + } catch (e) { + assert.ok(e instanceof CancellationError) + } + }) + + /** + * Test: User clicks "Open SSH Config" to fix external error + * Expected: Opens SSH config file in editor, no error thrown + */ + it('opens SSH config file when user clicks Open', async function () { + // Create the SSH config file + await writeTestSshConfig(`# Some SSH config +Host github.com + User git +`) + + const sshError = new Error('SSH error') + + getTestWindow().onDidShowMessage((message) => { + message.selectItem('Open SSH Config') + }) + + await config.promptOtherSshConfigError(sshError) + + const messages = getTestWindow().shownMessages + assert(messages.length > 0, 'Should show error message') + }) + }) + + describe('error handling', function () { + /** + * Test: Error occurs during config update (e.g., write fails) and user cancels + * Expected: Returns ToolkitError with code 'SshConfigUpdateFailed' + */ + it('returns proper error when update fails and user cancels', async function () { + await writeTestSshConfig(`# Created by AWS Toolkit +Host sm_* + User '%r' +`) + + // removal failure during update + sandbox.stub(config, 'removeSshConfigSection').rejects(new Error('Write failed')) + + // User accepts update prompt (handled by updateOutdatedSection internally) + getTestWindow().onDidShowMessage((message) => { + if (message.items.some((item) => item.title === 'Update SSH config')) { + message.selectItem('Update SSH config') + } + // User cancels the "Open SSH Config" prompt after failure + if (message.items.some((item) => item.title === 'Open SSH Config')) { + message.selectItem('Cancel') + } + }) + + const result = await config.verifySSHHost(testProxyCommand) + + assert.ok(result.isErr()) + const error = result.err() + assert.ok(error instanceof ToolkitError) + assert.strictEqual(error.code, 'SshConfigUpdateFailed') + }) + + /** + * Test: Error occurs during config update and user opens file to fix + * Expected: Returns ToolkitError with code 'SshConfigOpenedForEdit' + */ + it('opens config file when update fails and user accepts', async function () { + await writeTestSshConfig(`# Created by AWS Toolkit for VSCode. https://github.com/aws/aws-toolkit-vscode +Host sm_* + ForwardAgent yes + AddKeysToAgent yes + StrictHostKeyChecking accept-new + ProxyCommand 'sagemaker_connect' '%n' + User '%r' + `) + + // Simulate removal failure during update + sandbox.stub(config, 'removeSshConfigSection').rejects(new Error('Write failed')) + + // User accepts update prompt, then clicks "Open SSH Config" after failure + getTestWindow().onDidShowMessage((message) => { + if (message.items.some((item) => item.title === 'Update SSH config')) { + message.selectItem('Update SSH config') + } + if (message.items.some((item) => item.title === 'Open SSH Config')) { + message.selectItem('Open SSH Config') + } + }) + + const result = await config.verifySSHHost(testProxyCommand) + + assert.ok(result.isErr()) + const error = result.err() + assert.ok(error instanceof ToolkitError) + assert.strictEqual(error.code, 'SshConfigOpenedForEdit') + assert.strictEqual(error.message, SshConfigOpenedForEditMessage()) + }) + }) + + describe('createSSHConfigSection', function () { + /** + * Test: SageMaker SSH config format + * Expected: Contains SageMaker-specific directives (ForwardAgent, AddKeysToAgent, StrictHostKeyChecking) + * Does NOT contain User '%r' + */ + it('creates SageMaker-specific SSH config section', function () { + // Access the protected method through type casting + const section = (config as any).createSSHConfigSection(testProxyCommand) + + // Verify SageMaker-specific directives + assert.ok(section.includes('Host sm_'), 'Should include Host sm_*') + assert.ok(section.includes('ForwardAgent yes'), 'Should include ForwardAgent yes') + assert.ok(section.includes('AddKeysToAgent yes'), 'Should include AddKeysToAgent yes') + assert.ok(section.includes('StrictHostKeyChecking accept-new'), 'Should include StrictHostKeyChecking') + assert.ok(section.includes(`ProxyCommand ${testProxyCommand}`), 'Should include ProxyCommand') + + // Verify it does NOT include CodeCatalyst-specific directives + assert.ok(!section.includes("User '%r'"), 'Should NOT include User directive (SageMaker-specific)') + assert.ok(!section.includes('IdentityFile'), 'Should NOT include IdentityFile (SageMaker-specific)') + }) + + /** + * Test: SSH config includes AWS Toolkit comment + * Expected: Section starts with AWS Toolkit comment for identification + */ + it('includes AWS Toolkit comment in config section', function () { + const section = (config as any).createSSHConfigSection(testProxyCommand) + + assert.ok( + section.includes('# Created by AWS Toolkit'), + 'Should include AWS Toolkit comment for identification' + ) + }) + }) +})