diff --git a/packages/core/src/awsService/ec2/model.ts b/packages/core/src/awsService/ec2/model.ts index 021ad623dd2..f47386ef47d 100644 --- a/packages/core/src/awsService/ec2/model.ts +++ b/packages/core/src/awsService/ec2/model.ts @@ -3,7 +3,6 @@ * SPDX-License-Identifier: Apache-2.0 */ import * as vscode from 'vscode' -import * as path from 'path' import { Session } from 'aws-sdk/clients/ssm' import { IAM, SSM } from 'aws-sdk' import { Ec2Selection } from './prompter' @@ -28,7 +27,6 @@ import { CancellationError, Timeout } from '../../shared/utilities/timeoutUtils' import { showMessageWithCancel } from '../../shared/utilities/messages' import { SshConfig, sshLogFileLocation } from '../../shared/sshConfig' import { SshKeyPair } from './sshKeyPair' -import globals from '../../shared/extensionGlobals' export type Ec2ConnectErrorCode = 'EC2SSMStatus' | 'EC2SSMPermission' | 'EC2SSMConnect' | 'EC2SSMAgentStatus' @@ -235,8 +233,7 @@ export class Ec2ConnectionManager { } public async configureSshKeys(selection: Ec2Selection, remoteUser: string): Promise { - const keyPath = path.join(globals.context.globalStorageUri.fsPath, `aws-ec2-key`) - const keyPair = await SshKeyPair.getSshKeyPair(keyPath, 30000) + const keyPair = await SshKeyPair.getSshKeyPair(`aws-ec2-key`, 30000) await this.sendSshKeyToInstance(selection, keyPair, remoteUser) return keyPair } diff --git a/packages/core/src/awsService/ec2/sshKeyPair.ts b/packages/core/src/awsService/ec2/sshKeyPair.ts index 3f0f1e23de8..45565aa6e41 100644 --- a/packages/core/src/awsService/ec2/sshKeyPair.ts +++ b/packages/core/src/awsService/ec2/sshKeyPair.ts @@ -9,6 +9,7 @@ import { tryRun } from '../../shared/utilities/pathFind' import { Timeout } from '../../shared/utilities/timeoutUtils' import { findAsync } from '../../shared/utilities/collectionUtils' import { RunParameterContext } from '../../shared/utilities/processUtils' +import path from 'path' type sshKeyType = 'rsa' | 'ed25519' @@ -20,7 +21,7 @@ export class SshKeyPair { private readonly keyPath: string, lifetime: number ) { - this.publicKeyPath = `${keyPath}.pub` + this.publicKeyPath = `${this.keyPath}.pub` this.lifeTimeout = new Timeout(lifetime) this.lifeTimeout.onCompletion(async () => { @@ -28,11 +29,27 @@ export class SshKeyPair { }) } - public static async getSshKeyPair(keyPath: string, lifetime: number) { + private static getKeypath(keyName: string): string { + return path.join(globals.context.globalStorageUri.fsPath, keyName) + } + + public static async getSshKeyPair(keyName: string, lifetime: number) { + const keyPath = SshKeyPair.getKeypath(keyName) await SshKeyPair.generateSshKeyPair(keyPath) return new SshKeyPair(keyPath, lifetime) } + private static isValidKeyPath(keyPath: string): boolean { + const relative = path.relative(globals.context.globalStorageUri.fsPath, keyPath) + return relative !== undefined && !relative.startsWith('..') && !path.isAbsolute(relative) && keyPath.length > 4 + } + + private static assertValidKeypath(keyPath: string, message: string): void | never { + if (!SshKeyPair.isValidKeyPath(keyPath)) { + throw new ToolkitError(message) + } + } + private static async assertGenerated(keyPath: string, keyGenerated: boolean): Promise { if (!keyGenerated) { throw new ToolkitError('ec2: Unable to generate ssh key pair with either ed25519 or rsa') @@ -85,6 +102,10 @@ export class SshKeyPair { } public async delete(): Promise { + SshKeyPair.assertValidKeypath( + this.keyPath, + `ec2: keyPath became invalid after creation, not deleting key at ${this.keyPath}` + ) await fs.delete(this.keyPath, { force: true }) await fs.delete(this.publicKeyPath, { force: true }) diff --git a/packages/core/src/test/awsService/ec2/model.test.ts b/packages/core/src/test/awsService/ec2/model.test.ts index 6e750f6ebfe..f2d5c8728c0 100644 --- a/packages/core/src/test/awsService/ec2/model.test.ts +++ b/packages/core/src/test/awsService/ec2/model.test.ts @@ -15,7 +15,6 @@ import { SshKeyPair } from '../../../awsService/ec2/sshKeyPair' import { DefaultIamClient } from '../../../shared/clients/iamClient' import { assertNoTelemetryMatch, createTestWorkspaceFolder } from '../../testUtil' import { fs } from '../../../shared' -import path from 'path' describe('Ec2ConnectClient', function () { let client: Ec2ConnectionManager @@ -135,15 +134,13 @@ describe('Ec2ConnectClient', function () { it('calls the sdk with the proper parameters', async function () { const sendCommandStub = sinon.stub(SsmClient.prototype, 'sendCommandAndWait') - sinon.stub(SshKeyPair, 'generateSshKeyPair') - sinon.stub(SshKeyPair.prototype, 'getPublicKey').resolves('test-key') - const testSelection = { instanceId: 'test-id', region: 'test-region', } - const mockKeys = await SshKeyPair.getSshKeyPair('fakeDir', 30000) - await client.sendSshKeyToInstance(testSelection, mockKeys, 'test-user') + + const keys = await SshKeyPair.getSshKeyPair('key', 30000) + await client.sendSshKeyToInstance(testSelection, keys, 'test-user') sinon.assert.calledWith(sendCommandStub, testSelection.instanceId, 'AWS-RunShellScript') sinon.restore() }) @@ -156,10 +153,9 @@ describe('Ec2ConnectClient', function () { region: 'test-region', } const testWorkspaceFolder = await createTestWorkspaceFolder() - const keyPath = path.join(testWorkspaceFolder.uri.fsPath, 'key') - const keys = await SshKeyPair.getSshKeyPair(keyPath, 60000) + const keys = await SshKeyPair.getSshKeyPair('key', 60000) await client.sendSshKeyToInstance(testSelection, keys, 'test-user') - const privKey = await fs.readFileText(keyPath) + const privKey = await fs.readFileText(keys.getPrivateKeyPath()) assertNoTelemetryMatch(privKey) sinon.restore() diff --git a/packages/core/src/test/awsService/ec2/sshKeyPair.test.ts b/packages/core/src/test/awsService/ec2/sshKeyPair.test.ts index 25b0578694c..144f436e319 100644 --- a/packages/core/src/test/awsService/ec2/sshKeyPair.test.ts +++ b/packages/core/src/test/awsService/ec2/sshKeyPair.test.ts @@ -5,28 +5,25 @@ import assert from 'assert' import nodefs from 'fs' // eslint-disable-line no-restricted-imports import * as sinon from 'sinon' -import * as path from 'path' import * as os from 'os' import { SshKeyPair } from '../../../awsService/ec2/sshKeyPair' -import { createTestWorkspaceFolder, installFakeClock } from '../../testUtil' +import { installFakeClock } from '../../testUtil' import { InstalledClock } from '@sinonjs/fake-timers' import { ChildProcess } from '../../../shared/utilities/processUtils' import { fs, globals } from '../../../shared' describe('SshKeyPair', async function () { - let temporaryDirectory: string - let keyPath: string - let keyPair: SshKeyPair let clock: InstalledClock + let keyPair: SshKeyPair + let keyName: string before(async function () { - temporaryDirectory = (await createTestWorkspaceFolder()).uri.fsPath - keyPath = path.join(temporaryDirectory, 'testKeyPair') clock = installFakeClock() }) beforeEach(async function () { - keyPair = await SshKeyPair.getSshKeyPair(keyPath, 30000) + keyName = 'testKeyPair' + keyPair = await SshKeyPair.getSshKeyPair(keyName, 30000) }) afterEach(async function () { @@ -34,21 +31,21 @@ describe('SshKeyPair', async function () { }) after(async function () { - await fs.delete(temporaryDirectory, { recursive: true }) clock.uninstall() sinon.restore() }) it('generates key in target file', async function () { - const contents = await fs.readFileBytes(keyPath) + const contents = await fs.readFileBytes(keyPair.getPrivateKeyPath()) assert.notStrictEqual(contents.length, 0) }) it('generates unique key each time', async function () { - const beforeContent = await fs.readFileBytes(keyPath) - keyPair = await SshKeyPair.getSshKeyPair(keyPath, 30000) - const afterContent = await fs.readFileBytes(keyPath) - assert.notStrictEqual(beforeContent, afterContent) + const keyPair2 = await SshKeyPair.getSshKeyPair(`${keyName}2`, 30000) + const content1 = await fs.readFileBytes(keyPair2.getPrivateKeyPath()) + const content2 = await fs.readFileBytes(keyPair.getPrivateKeyPath()) + assert.notStrictEqual(content1, content2) + await keyPair2.delete() }) it('sets permission of the file to read/write owner', async function () { @@ -59,7 +56,7 @@ describe('SshKeyPair', async function () { }) it('defaults to ed25519 key type', async function () { - const process = new ChildProcess(`ssh-keygen`, ['-vvv', '-l', '-f', keyPath]) + const process = new ChildProcess(`ssh-keygen`, ['-vvv', '-l', '-f', keyPair.getPrivateKeyPath()]) const result = await process.run() // Check private key header for algorithm name assert.strictEqual(result.stdout.includes('[ED25519 256]'), true) @@ -70,18 +67,14 @@ describe('SshKeyPair', async function () { const stub = sinon.stub(SshKeyPair, 'tryKeyGen') stub.onFirstCall().resolves(false) stub.callThrough() - keyPair = await SshKeyPair.getSshKeyPair(keyPath, 30000) - const process = new ChildProcess(`ssh-keygen`, ['-vvv', '-l', '-f', keyPath]) + const rsaKey = await SshKeyPair.getSshKeyPair('rsa', 30000) + const process = new ChildProcess(`ssh-keygen`, ['-vvv', '-l', '-f', rsaKey.getPrivateKeyPath()]) const result = await process.run() // Check private key header for algorithm name assert.strictEqual(result.stdout.includes('[RSA'), true) stub.restore() }) - it('properly names the public key', function () { - assert.strictEqual(keyPair.getPublicKeyPath(), `${keyPath}.pub`) - }) - it('reads in public ssh key that is non-empty', async function () { const key = await keyPair.getPublicKey() assert.notStrictEqual(key.length, 0) @@ -89,10 +82,10 @@ describe('SshKeyPair', async function () { it('does overwrite existing keys on get call', async function () { const generateStub = sinon.spy(SshKeyPair, 'generateSshKeyPair') - const keyBefore = await fs.readFileBytes(keyPath) - keyPair = await SshKeyPair.getSshKeyPair(keyPath, 30000) + const keyBefore = await fs.readFileBytes(keyPair.getPrivateKeyPath()) + keyPair = await SshKeyPair.getSshKeyPair(keyName, 30000) - const keyAfter = await fs.readFileBytes(keyPath) + const keyAfter = await fs.readFileBytes(keyPair.getPrivateKeyPath()) sinon.assert.calledOnce(generateStub) assert.notStrictEqual(keyBefore, keyAfter) @@ -118,7 +111,7 @@ describe('SshKeyPair', async function () { sinon.stub(SshKeyPair, 'generateSshKeyPair') const deleteStub = sinon.stub(SshKeyPair.prototype, 'delete') - keyPair = await SshKeyPair.getSshKeyPair(keyPath, 50) + keyPair = await SshKeyPair.getSshKeyPair(keyName, 50) await clock.tickAsync(10) sinon.assert.notCalled(deleteStub) await clock.tickAsync(100) diff --git a/packages/core/src/test/globalSetup.test.ts b/packages/core/src/test/globalSetup.test.ts index c7bc947c482..92e2db349ee 100644 --- a/packages/core/src/test/globalSetup.test.ts +++ b/packages/core/src/test/globalSetup.test.ts @@ -101,10 +101,13 @@ export const mochaHooks = { globals.telemetry.clearRecords() globals.telemetry.logger.clear() TelemetryDebounceInfo.instance.clear() + // mochaGlobalSetup() set this to a fake, so it's safe to clear it here. await globals.globalState.clear() await testUtil.closeAllEditors() + await fs.delete(globals.context.globalStorageUri.fsPath, { recursive: true, force: true }) + await fs.mkdir(globals.context.globalStorageUri.fsPath) }, async afterEach(this: Mocha.Context) { if (openExternalStub.called && openExternalStub.returned(sinon.match.typeOf('undefined'))) { diff --git a/packages/core/src/test/shared/sshConfig.test.ts b/packages/core/src/test/shared/sshConfig.test.ts index 2400e7f0e44..96ca450ae14 100644 --- a/packages/core/src/test/shared/sshConfig.test.ts +++ b/packages/core/src/test/shared/sshConfig.test.ts @@ -5,13 +5,11 @@ import assert from 'assert' import * as sinon from 'sinon' import * as path from 'path' -import * as vscode from 'vscode' import * as http from 'http' import { ToolkitError } from '../../shared/errors' import { Result } from '../../shared/utilities/result' import { ChildProcess, ChildProcessResult } from '../../shared/utilities/processUtils' import { SshConfig, ensureConnectScript, sshLogFileLocation } from '../../shared/sshConfig' -import { FakeExtensionContext } from '../fakeExtensionContext' import { fileExists, makeTemporaryToolkitFolder } from '../../shared/filesystemUtilities' import { DevEnvironmentId, @@ -20,8 +18,9 @@ import { getCodeCatalystSsmEnv, } from '../../codecatalyst/model' import { StartDevEnvironmentSessionRequest } from 'aws-sdk/clients/codecatalyst' -import { mkdir, readFile, writeFile } from 'fs/promises' +import { mkdir, readFile } from 'fs/promises' import fs from '../../shared/fs/fs' +import { globals } from '../../shared' class MockSshConfig extends SshConfig { // State variables to track logic flow. @@ -181,22 +180,15 @@ describe('VscodeRemoteSshConfig', async function () { }) describe('CodeCatalyst Connect Script', function () { - let context: FakeExtensionContext - function isWithin(path1: string, path2: string): boolean { const rel = path.relative(path1, path2) return !path.isAbsolute(rel) && !rel.startsWith('..') && !!rel } - beforeEach(async function () { - context = await FakeExtensionContext.create() - context.globalStorageUri = vscode.Uri.file(await makeTemporaryToolkitFolder()) - }) - it('can get a connect script path, adding a copy to global storage', async function () { - const script = (await ensureConnectScript(connectScriptPrefix, context)).unwrap().fsPath + const script = (await ensureConnectScript(connectScriptPrefix, globals.context)).unwrap().fsPath assert.ok(await fileExists(script)) - assert.ok(isWithin(context.globalStorageUri.fsPath, script)) + assert.ok(isWithin(globals.context.globalStorageUri.fsPath, script)) }) function createFakeServer(testDevEnv: DevEnvironmentId) { @@ -247,8 +239,8 @@ describe('CodeCatalyst Connect Script', function () { server.listen({ host: 'localhost', port: 28142 }, () => resolve(`http://localhost:28142`)) }) - await writeFile(bearerTokenCacheLocation(testDevEnv.id), 'token') - const script = (await ensureConnectScript(connectScriptPrefix, context)).unwrap().fsPath + await fs.writeFile(bearerTokenCacheLocation(testDevEnv.id), 'token') + const script = (await ensureConnectScript(connectScriptPrefix, globals.context)).unwrap().fsPath const env = getCodeCatalystSsmEnv('us-weast-1', 'echo', testDevEnv) env.CODECATALYST_ENDPOINT = address @@ -280,12 +272,12 @@ describe('CodeCatalyst Connect Script', function () { }) it('works if the .ssh directory is missing', async function () { - ;(await ensureConnectScript(connectScriptPrefix, context)).unwrap() + ;(await ensureConnectScript(connectScriptPrefix, globals.context)).unwrap() }) it('works if the .ssh directory exists but has different perms', async function () { await mkdir(path.join(tmpDir, '.ssh'), 0o777) - ;(await ensureConnectScript(connectScriptPrefix, context)).unwrap() + ;(await ensureConnectScript(connectScriptPrefix, globals.context)).unwrap() }) }) })