Skip to content
Merged
Show file tree
Hide file tree
Changes from 16 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 1 addition & 4 deletions packages/core/src/awsService/ec2/model.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
* SPDX-License-Identifier: Apache-2.0
*/
import * as vscode from 'vscode'
import * as path from 'path'
import { Session } from 'aws-sdk/clients/ssm'
import { IAM, SSM } from 'aws-sdk'
import { Ec2Selection } from './prompter'
Expand All @@ -28,7 +27,6 @@ import { CancellationError, Timeout } from '../../shared/utilities/timeoutUtils'
import { showMessageWithCancel } from '../../shared/utilities/messages'
import { SshConfig, sshLogFileLocation } from '../../shared/sshConfig'
import { SshKeyPair } from './sshKeyPair'
import globals from '../../shared/extensionGlobals'

export type Ec2ConnectErrorCode = 'EC2SSMStatus' | 'EC2SSMPermission' | 'EC2SSMConnect' | 'EC2SSMAgentStatus'

Expand Down Expand Up @@ -235,8 +233,7 @@ export class Ec2ConnectionManager {
}

public async configureSshKeys(selection: Ec2Selection, remoteUser: string): Promise<SshKeyPair> {
const keyPath = path.join(globals.context.globalStorageUri.fsPath, `aws-ec2-key`)
const keyPair = await SshKeyPair.getSshKeyPair(keyPath, 30000)
const keyPair = await SshKeyPair.getSshKeyPair(`aws-ec2-key`, 30000)
await this.sendSshKeyToInstance(selection, keyPair, remoteUser)
return keyPair
}
Expand Down
25 changes: 23 additions & 2 deletions packages/core/src/awsService/ec2/sshKeyPair.ts
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import { tryRun } from '../../shared/utilities/pathFind'
import { Timeout } from '../../shared/utilities/timeoutUtils'
import { findAsync } from '../../shared/utilities/collectionUtils'
import { RunParameterContext } from '../../shared/utilities/processUtils'
import path from 'path'

type sshKeyType = 'rsa' | 'ed25519'

Expand All @@ -20,19 +21,35 @@ export class SshKeyPair {
private readonly keyPath: string,
lifetime: number
) {
this.publicKeyPath = `${keyPath}.pub`
this.publicKeyPath = `${this.keyPath}.pub`
this.lifeTimeout = new Timeout(lifetime)

this.lifeTimeout.onCompletion(async () => {
await this.delete()
})
}

public static async getSshKeyPair(keyPath: string, lifetime: number) {
private static getKeypath(keyName: string): string {
return path.join(globals.context.globalStorageUri.fsPath, keyName)
}

public static async getSshKeyPair(keyName: string, lifetime: number) {
const keyPath = SshKeyPair.getKeypath(keyName)
await SshKeyPair.generateSshKeyPair(keyPath)
return new SshKeyPair(keyPath, lifetime)
}

private static isValidKeyPath(keyPath: string): boolean {
const relative = path.relative(globals.context.globalStorageUri.fsPath, keyPath)
return relative !== undefined && !relative.startsWith('..') && !path.isAbsolute(relative) && keyPath.length > 4
}

private static assertValidKeypath(keyPath: string, message: string): void | never {
if (!SshKeyPair.isValidKeyPath(keyPath)) {
throw new ToolkitError(message)
}
}

private static async assertGenerated(keyPath: string, keyGenerated: boolean): Promise<never | void> {
if (!keyGenerated) {
throw new ToolkitError('ec2: Unable to generate ssh key pair with either ed25519 or rsa')
Expand Down Expand Up @@ -85,6 +102,10 @@ export class SshKeyPair {
}

public async delete(): Promise<void> {
SshKeyPair.assertValidKeypath(
this.keyPath,
`ec2: keyPath became invalid after creation, not deleting key at ${this.keyPath}`
)
await fs.delete(this.keyPath, { force: true })
await fs.delete(this.publicKeyPath, { force: true })

Expand Down
14 changes: 5 additions & 9 deletions packages/core/src/test/awsService/ec2/model.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@ import { SshKeyPair } from '../../../awsService/ec2/sshKeyPair'
import { DefaultIamClient } from '../../../shared/clients/iamClient'
import { assertNoTelemetryMatch, createTestWorkspaceFolder } from '../../testUtil'
import { fs } from '../../../shared'
import path from 'path'

describe('Ec2ConnectClient', function () {
let client: Ec2ConnectionManager
Expand Down Expand Up @@ -135,15 +134,13 @@ describe('Ec2ConnectClient', function () {
it('calls the sdk with the proper parameters', async function () {
const sendCommandStub = sinon.stub(SsmClient.prototype, 'sendCommandAndWait')

sinon.stub(SshKeyPair, 'generateSshKeyPair')
sinon.stub(SshKeyPair.prototype, 'getPublicKey').resolves('test-key')

const testSelection = {
instanceId: 'test-id',
region: 'test-region',
}
const mockKeys = await SshKeyPair.getSshKeyPair('fakeDir', 30000)
await client.sendSshKeyToInstance(testSelection, mockKeys, 'test-user')

const keys = await SshKeyPair.getSshKeyPair('key', 30000)
await client.sendSshKeyToInstance(testSelection, keys, 'test-user')
sinon.assert.calledWith(sendCommandStub, testSelection.instanceId, 'AWS-RunShellScript')
sinon.restore()
})
Expand All @@ -156,10 +153,9 @@ describe('Ec2ConnectClient', function () {
region: 'test-region',
}
const testWorkspaceFolder = await createTestWorkspaceFolder()
const keyPath = path.join(testWorkspaceFolder.uri.fsPath, 'key')
const keys = await SshKeyPair.getSshKeyPair(keyPath, 60000)
const keys = await SshKeyPair.getSshKeyPair('key', 60000)
await client.sendSshKeyToInstance(testSelection, keys, 'test-user')
const privKey = await fs.readFileText(keyPath)
const privKey = await fs.readFileText(keys.getPrivateKeyPath())
assertNoTelemetryMatch(privKey)
sinon.restore()

Expand Down
49 changes: 24 additions & 25 deletions packages/core/src/test/awsService/ec2/sshKeyPair.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -5,50 +5,47 @@
import assert from 'assert'
import nodefs from 'fs' // eslint-disable-line no-restricted-imports
import * as sinon from 'sinon'
import * as path from 'path'
import * as os from 'os'
import { SshKeyPair } from '../../../awsService/ec2/sshKeyPair'
import { createTestWorkspaceFolder, installFakeClock } from '../../testUtil'
import { installFakeClock } from '../../testUtil'
import { InstalledClock } from '@sinonjs/fake-timers'
import { ChildProcess } from '../../../shared/utilities/processUtils'
import { fs, globals } from '../../../shared'

describe('SshKeyPair', async function () {
let temporaryDirectory: string
let keyPath: string
let keyPair: SshKeyPair
let clock: InstalledClock
let keyPair: SshKeyPair
let keyName: string

before(async function () {
temporaryDirectory = (await createTestWorkspaceFolder()).uri.fsPath
keyPath = path.join(temporaryDirectory, 'testKeyPair')
clock = installFakeClock()
})

beforeEach(async function () {
keyPair = await SshKeyPair.getSshKeyPair(keyPath, 30000)
keyName = 'testKeyPair'
keyPair = await SshKeyPair.getSshKeyPair(keyName, 30000)
})

afterEach(async function () {
await keyPair.delete()
})

after(async function () {
await fs.delete(temporaryDirectory, { recursive: true })
clock.uninstall()
sinon.restore()
})

it('generates key in target file', async function () {
const contents = await fs.readFileBytes(keyPath)
const contents = await fs.readFileBytes(keyPair.getPrivateKeyPath())
assert.notStrictEqual(contents.length, 0)
})

it('generates unique key each time', async function () {
const beforeContent = await fs.readFileBytes(keyPath)
keyPair = await SshKeyPair.getSshKeyPair(keyPath, 30000)
const afterContent = await fs.readFileBytes(keyPath)
assert.notStrictEqual(beforeContent, afterContent)
const keyPair2 = await SshKeyPair.getSshKeyPair(`${keyName}2`, 30000)
const content1 = await fs.readFileBytes(keyPair2.getPrivateKeyPath())
const content2 = await fs.readFileBytes(keyPair.getPrivateKeyPath())
assert.notStrictEqual(content1, content2)
await keyPair2.delete()
})

it('sets permission of the file to read/write owner', async function () {
Expand All @@ -59,7 +56,7 @@ describe('SshKeyPair', async function () {
})

it('defaults to ed25519 key type', async function () {
const process = new ChildProcess(`ssh-keygen`, ['-vvv', '-l', '-f', keyPath])
const process = new ChildProcess(`ssh-keygen`, ['-vvv', '-l', '-f', keyPair.getPrivateKeyPath()])
const result = await process.run()
// Check private key header for algorithm name
assert.strictEqual(result.stdout.includes('[ED25519 256]'), true)
Expand All @@ -70,29 +67,25 @@ describe('SshKeyPair', async function () {
const stub = sinon.stub(SshKeyPair, 'tryKeyGen')
stub.onFirstCall().resolves(false)
stub.callThrough()
keyPair = await SshKeyPair.getSshKeyPair(keyPath, 30000)
const process = new ChildProcess(`ssh-keygen`, ['-vvv', '-l', '-f', keyPath])
const rsaKey = await SshKeyPair.getSshKeyPair('rsa', 30000)
const process = new ChildProcess(`ssh-keygen`, ['-vvv', '-l', '-f', rsaKey.getPrivateKeyPath()])
const result = await process.run()
// Check private key header for algorithm name
assert.strictEqual(result.stdout.includes('[RSA'), true)
stub.restore()
})

it('properly names the public key', function () {
assert.strictEqual(keyPair.getPublicKeyPath(), `${keyPath}.pub`)
})

it('reads in public ssh key that is non-empty', async function () {
const key = await keyPair.getPublicKey()
assert.notStrictEqual(key.length, 0)
})

it('does overwrite existing keys on get call', async function () {
const generateStub = sinon.spy(SshKeyPair, 'generateSshKeyPair')
const keyBefore = await fs.readFileBytes(keyPath)
keyPair = await SshKeyPair.getSshKeyPair(keyPath, 30000)
const keyBefore = await fs.readFileBytes(keyPair.getPrivateKeyPath())
keyPair = await SshKeyPair.getSshKeyPair(keyName, 30000)

const keyAfter = await fs.readFileBytes(keyPath)
const keyAfter = await fs.readFileBytes(keyPair.getPrivateKeyPath())
sinon.assert.calledOnce(generateStub)

assert.notStrictEqual(keyBefore, keyAfter)
Expand All @@ -118,7 +111,7 @@ describe('SshKeyPair', async function () {
sinon.stub(SshKeyPair, 'generateSshKeyPair')
const deleteStub = sinon.stub(SshKeyPair.prototype, 'delete')

keyPair = await SshKeyPair.getSshKeyPair(keyPath, 50)
keyPair = await SshKeyPair.getSshKeyPair(keyName, 50)
await clock.tickAsync(10)
sinon.assert.notCalled(deleteStub)
await clock.tickAsync(100)
Expand All @@ -133,6 +126,12 @@ describe('SshKeyPair', async function () {
assert(keyPair.isDeleted())
})

it('does not allow writing keys to non-global storage', async function () {
await assert.rejects(async () => await SshKeyPair.getSshKeyPair('~/.ssh/someKey', 2000))

await assert.rejects(async () => await SshKeyPair.getSshKeyPair('/a/path/that/isnt/real/key', 2000))
})

describe('isDeleted', async function () {
it('returns false if key files exist', async function () {
assert.strictEqual(await keyPair.isDeleted(), false)
Expand Down
3 changes: 3 additions & 0 deletions packages/core/src/test/globalSetup.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -101,10 +101,13 @@ export const mochaHooks = {
globals.telemetry.clearRecords()
globals.telemetry.logger.clear()
TelemetryDebounceInfo.instance.clear()

// mochaGlobalSetup() set this to a fake, so it's safe to clear it here.
await globals.globalState.clear()

await testUtil.closeAllEditors()
await fs.delete(globals.context.globalStorageUri.fsPath, { recursive: true, force: true })
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should be fine even for local development, because mochaGlobalSetup does:

fakeContext.globalStorageUri = (await testUtil.createTestWorkspaceFolder('globalStoragePath')).uri

await fs.mkdir(globals.context.globalStorageUri.fsPath)
},
async afterEach(this: Mocha.Context) {
if (openExternalStub.called && openExternalStub.returned(sinon.match.typeOf('undefined'))) {
Expand Down
24 changes: 8 additions & 16 deletions packages/core/src/test/shared/sshConfig.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,11 @@
import assert from 'assert'
import * as sinon from 'sinon'
import * as path from 'path'
import * as vscode from 'vscode'
import * as http from 'http'
import { ToolkitError } from '../../shared/errors'
import { Result } from '../../shared/utilities/result'
import { ChildProcess, ChildProcessResult } from '../../shared/utilities/processUtils'
import { SshConfig, ensureConnectScript, sshLogFileLocation } from '../../shared/sshConfig'
import { FakeExtensionContext } from '../fakeExtensionContext'
import { fileExists, makeTemporaryToolkitFolder } from '../../shared/filesystemUtilities'
import {
DevEnvironmentId,
Expand All @@ -20,8 +18,9 @@ import {
getCodeCatalystSsmEnv,
} from '../../codecatalyst/model'
import { StartDevEnvironmentSessionRequest } from 'aws-sdk/clients/codecatalyst'
import { mkdir, readFile, writeFile } from 'fs/promises'
import { mkdir, readFile } from 'fs/promises'
import fs from '../../shared/fs/fs'
import { globals } from '../../shared'

class MockSshConfig extends SshConfig {
// State variables to track logic flow.
Expand Down Expand Up @@ -181,22 +180,15 @@ describe('VscodeRemoteSshConfig', async function () {
})

describe('CodeCatalyst Connect Script', function () {
let context: FakeExtensionContext

function isWithin(path1: string, path2: string): boolean {
const rel = path.relative(path1, path2)
return !path.isAbsolute(rel) && !rel.startsWith('..') && !!rel
}

beforeEach(async function () {
context = await FakeExtensionContext.create()
context.globalStorageUri = vscode.Uri.file(await makeTemporaryToolkitFolder())
})

it('can get a connect script path, adding a copy to global storage', async function () {
const script = (await ensureConnectScript(connectScriptPrefix, context)).unwrap().fsPath
const script = (await ensureConnectScript(connectScriptPrefix, globals.context)).unwrap().fsPath
assert.ok(await fileExists(script))
assert.ok(isWithin(context.globalStorageUri.fsPath, script))
assert.ok(isWithin(globals.context.globalStorageUri.fsPath, script))
})

function createFakeServer(testDevEnv: DevEnvironmentId) {
Expand Down Expand Up @@ -247,8 +239,8 @@ describe('CodeCatalyst Connect Script', function () {
server.listen({ host: 'localhost', port: 28142 }, () => resolve(`http://localhost:28142`))
})

await writeFile(bearerTokenCacheLocation(testDevEnv.id), 'token')
const script = (await ensureConnectScript(connectScriptPrefix, context)).unwrap().fsPath
await fs.writeFile(bearerTokenCacheLocation(testDevEnv.id), 'token')
const script = (await ensureConnectScript(connectScriptPrefix, globals.context)).unwrap().fsPath
const env = getCodeCatalystSsmEnv('us-weast-1', 'echo', testDevEnv)
env.CODECATALYST_ENDPOINT = address

Expand Down Expand Up @@ -280,12 +272,12 @@ describe('CodeCatalyst Connect Script', function () {
})

it('works if the .ssh directory is missing', async function () {
;(await ensureConnectScript(connectScriptPrefix, context)).unwrap()
;(await ensureConnectScript(connectScriptPrefix, globals.context)).unwrap()
})

it('works if the .ssh directory exists but has different perms', async function () {
await mkdir(path.join(tmpDir, '.ssh'), 0o777)
;(await ensureConnectScript(connectScriptPrefix, context)).unwrap()
;(await ensureConnectScript(connectScriptPrefix, globals.context)).unwrap()
})
})
})
Loading