Skip to content

Commit b8ec914

Browse files
committed
refactor: pull out all encrypt and decrypt logic, and add tests
1 parent 0b751ac commit b8ec914

File tree

4 files changed

+48
-25
lines changed

4 files changed

+48
-25
lines changed

packages/amazonq/src/inlineChat/provider/inlineChatProvider.ts

Lines changed: 6 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ import type { InlineChatEvent } from 'aws-core-vscode/codewhisperer'
2828
import { InlineTask } from '../controller/inlineTask'
2929
import { extractAuthFollowUp } from 'aws-core-vscode/amazonq'
3030
import { InlineChatParams, InlineChatResult } from '@aws/language-server-runtimes-types'
31-
import { decodeRequest, encryptRequest } from '../../lsp/encryption'
31+
import { decryptResponse, encryptRequest } from '../../lsp/encryption'
3232
import { getCursorState } from '../../lsp/utils'
3333

3434
export class InlineChatProvider {
@@ -72,16 +72,13 @@ export class InlineChatProvider {
7272
// TODO: handle partial responses.
7373
getLogger().info('Making inline chat request with message %O', message)
7474
const params = this.getCurrentEditorParams(message.message ?? '')
75+
7576
const inlineChatRequest = await encryptRequest<InlineChatParams>(params, this.encryptionKey)
7677
const response = await this.client.sendRequest(inlineChatRequestType.method, inlineChatRequest)
77-
const decryptedMessage =
78-
typeof response === 'string' && this.encryptionKey
79-
? await decodeRequest(response, this.encryptionKey)
80-
: response
81-
const result: InlineChatResult = decryptedMessage as InlineChatResult
82-
this.client.info(`Logging response for inline chat ${JSON.stringify(decryptedMessage)}`)
83-
84-
return result
78+
const inlineChatResponse = await decryptResponse<InlineChatResult>(response, this.encryptionKey)
79+
this.client.info(`Logging response for inline chat ${JSON.stringify(inlineChatResponse)}`)
80+
81+
return inlineChatResponse
8582
}
8683

8784
// TODO: remove in favor of LSP implementation.

packages/amazonq/src/lsp/chat/messages.ts

Lines changed: 7 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@ import {
6767
} from 'aws-core-vscode/amazonq'
6868
import { telemetry, TelemetryBase } from 'aws-core-vscode/telemetry'
6969
import { isValidResponseError } from './error'
70-
import { decodeRequest, encryptRequest } from '../encryption'
70+
import { decryptResponse, encryptRequest } from '../encryption'
7171
import { getCursorState } from '../utils'
7272

7373
export function registerLanguageServerEventListener(languageClient: LanguageClient, provider: AmazonQChatViewProvider) {
@@ -210,13 +210,9 @@ export function registerMessageListeners(
210210
partialResultToken,
211211
(partialResult) => {
212212
// Store the latest partial result
213-
if (typeof partialResult === 'string' && encryptionKey) {
214-
void decodeRequest<ChatResult>(partialResult, encryptionKey).then(
215-
(decoded) => (lastPartialResult = decoded)
216-
)
217-
} else {
218-
lastPartialResult = partialResult as ChatResult
219-
}
213+
decryptResponse<ChatResult>(partialResult, encryptionKey).then((result) => {
214+
lastPartialResult = result
215+
})
220216

221217
void handlePartialResult<ChatResult>(partialResult, encryptionKey, provider, chatParams.tabId)
222218
}
@@ -482,10 +478,7 @@ async function handlePartialResult<T extends ChatResult>(
482478
provider: AmazonQChatViewProvider,
483479
tabId: string
484480
) {
485-
const decryptedMessage =
486-
typeof partialResult === 'string' && encryptionKey
487-
? await decodeRequest<T>(partialResult, encryptionKey)
488-
: (partialResult as T)
481+
const decryptedMessage = await decryptResponse<T>(partialResult, encryptionKey)
489482

490483
if (decryptedMessage.body !== undefined) {
491484
void provider.webview?.postMessage({
@@ -508,8 +501,8 @@ async function handleCompleteResult<T extends ChatResult>(
508501
tabId: string,
509502
disposable: Disposable
510503
) {
511-
const decryptedMessage =
512-
typeof result === 'string' && encryptionKey ? await decodeRequest<T>(result, encryptionKey) : (result as T)
504+
const decryptedMessage = await decryptResponse<T>(result, encryptionKey)
505+
513506
void provider.webview?.postMessage({
514507
command: chatRequestType.method,
515508
params: decryptedMessage,

packages/amazonq/src/lsp/encryption.ts

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,14 @@ export async function encryptRequest<T>(params: T, encryptionKey: Buffer): Promi
1414
return { message: encryptedMessage }
1515
}
1616

17-
export async function decodeRequest<T>(request: string, key: Buffer): Promise<T> {
18-
const result = await jose.jwtDecrypt(request, key, {
17+
export async function decryptResponse<T>(response: unknown, key: Buffer | undefined) {
18+
// Note that casts are required since language client requests return 'unknown' type.
19+
// If we can't decrypt, return original response casted.
20+
if (typeof response !== 'string' || key === undefined) {
21+
return response as T
22+
}
23+
24+
const result = await jose.jwtDecrypt(response, key, {
1925
clockTolerance: 60, // Allow up to 60 seconds to account for clock differences
2026
contentEncryptionAlgorithms: ['A256GCM'],
2127
keyManagementAlgorithms: ['dir'],
Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
/*!
2+
* Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
3+
* SPDX-License-Identifier: Apache-2.0
4+
*/
5+
6+
import * as assert from 'assert'
7+
import { decryptResponse, encryptRequest } from '../../../../src/lsp/encryption'
8+
import { encryptionKey } from '../../../../src/lsp/auth'
9+
10+
describe('LSP encryption', function () {
11+
it('encrypt and decrypt invert eachother with same key', async function () {
12+
const key = encryptionKey
13+
const request = {
14+
id: 0,
15+
name: 'my Request',
16+
isRealRequest: false,
17+
metadata: {
18+
tags: ['tag1', 'tag2'],
19+
},
20+
}
21+
const encryptedPayload = await encryptRequest<typeof request>(request, key)
22+
const message = (encryptedPayload as { message: string }).message
23+
const decrypted = await decryptResponse<typeof request>(message, key)
24+
25+
assert.deepStrictEqual(decrypted, request)
26+
})
27+
})

0 commit comments

Comments
 (0)