diff --git a/packages/amazonq/src/inlineChat/provider/inlineChatProvider.ts b/packages/amazonq/src/inlineChat/provider/inlineChatProvider.ts index 86fe0ac2ade..cfa3798945c 100644 --- a/packages/amazonq/src/inlineChat/provider/inlineChatProvider.ts +++ b/packages/amazonq/src/inlineChat/provider/inlineChatProvider.ts @@ -28,7 +28,7 @@ import type { InlineChatEvent } from 'aws-core-vscode/codewhisperer' import { InlineTask } from '../controller/inlineTask' import { extractAuthFollowUp } from 'aws-core-vscode/amazonq' import { InlineChatParams, InlineChatResult } from '@aws/language-server-runtimes-types' -import { decodeRequest, encryptRequest } from '../../lsp/encryption' +import { decryptResponse, encryptRequest } from '../../lsp/encryption' import { getCursorState } from '../../lsp/utils' export class InlineChatProvider { @@ -72,16 +72,13 @@ export class InlineChatProvider { // TODO: handle partial responses. getLogger().info('Making inline chat request with message %O', message) const params = this.getCurrentEditorParams(message.message ?? '') + const inlineChatRequest = await encryptRequest(params, this.encryptionKey) const response = await this.client.sendRequest(inlineChatRequestType.method, inlineChatRequest) - const decryptedMessage = - typeof response === 'string' && this.encryptionKey - ? await decodeRequest(response, this.encryptionKey) - : response - const result: InlineChatResult = decryptedMessage as InlineChatResult - this.client.info(`Logging response for inline chat ${JSON.stringify(decryptedMessage)}`) - - return result + const inlineChatResponse = await decryptResponse(response, this.encryptionKey) + this.client.info(`Logging response for inline chat ${JSON.stringify(inlineChatResponse)}`) + + return inlineChatResponse } // TODO: remove in favor of LSP implementation. diff --git a/packages/amazonq/src/lsp/chat/messages.ts b/packages/amazonq/src/lsp/chat/messages.ts index 93ede65fc9a..32390d52006 100644 --- a/packages/amazonq/src/lsp/chat/messages.ts +++ b/packages/amazonq/src/lsp/chat/messages.ts @@ -67,7 +67,7 @@ import { } from 'aws-core-vscode/amazonq' import { telemetry, TelemetryBase } from 'aws-core-vscode/telemetry' import { isValidResponseError } from './error' -import { decodeRequest, encryptRequest } from '../encryption' +import { decryptResponse, encryptRequest } from '../encryption' import { getCursorState } from '../utils' export function registerLanguageServerEventListener(languageClient: LanguageClient, provider: AmazonQChatViewProvider) { @@ -205,21 +205,12 @@ export function registerMessageListeners( const cancellationToken = new CancellationTokenSource() chatStreamTokens.set(chatParams.tabId, cancellationToken) - const chatDisposable = languageClient.onProgress( - chatRequestType, - partialResultToken, - (partialResult) => { - // Store the latest partial result - if (typeof partialResult === 'string' && encryptionKey) { - void decodeRequest(partialResult, encryptionKey).then( - (decoded) => (lastPartialResult = decoded) - ) - } else { - lastPartialResult = partialResult as ChatResult + const chatDisposable = languageClient.onProgress(chatRequestType, partialResultToken, (partialResult) => + handlePartialResult(partialResult, encryptionKey, provider, chatParams.tabId).then( + (result) => { + lastPartialResult = result } - - void handlePartialResult(partialResult, encryptionKey, provider, chatParams.tabId) - } + ) ) const editor = @@ -482,10 +473,7 @@ async function handlePartialResult( provider: AmazonQChatViewProvider, tabId: string ) { - const decryptedMessage = - typeof partialResult === 'string' && encryptionKey - ? await decodeRequest(partialResult, encryptionKey) - : (partialResult as T) + const decryptedMessage = await decryptResponse(partialResult, encryptionKey) if (decryptedMessage.body !== undefined) { void provider.webview?.postMessage({ @@ -495,6 +483,7 @@ async function handlePartialResult( tabId: tabId, }) } + return decryptedMessage } /** @@ -508,8 +497,8 @@ async function handleCompleteResult( tabId: string, disposable: Disposable ) { - const decryptedMessage = - typeof result === 'string' && encryptionKey ? await decodeRequest(result, encryptionKey) : (result as T) + const decryptedMessage = await decryptResponse(result, encryptionKey) + void provider.webview?.postMessage({ command: chatRequestType.method, params: decryptedMessage, diff --git a/packages/amazonq/src/lsp/encryption.ts b/packages/amazonq/src/lsp/encryption.ts index 213ee3c1553..246c64f476b 100644 --- a/packages/amazonq/src/lsp/encryption.ts +++ b/packages/amazonq/src/lsp/encryption.ts @@ -14,8 +14,14 @@ export async function encryptRequest(params: T, encryptionKey: Buffer): Promi return { message: encryptedMessage } } -export async function decodeRequest(request: string, key: Buffer): Promise { - const result = await jose.jwtDecrypt(request, key, { +export async function decryptResponse(response: unknown, key: Buffer | undefined) { + // Note that casts are required since language client requests return 'unknown' type. + // If we can't decrypt, return original response casted. + if (typeof response !== 'string' || key === undefined) { + return response as T + } + + const result = await jose.jwtDecrypt(response, key, { clockTolerance: 60, // Allow up to 60 seconds to account for clock differences contentEncryptionAlgorithms: ['A256GCM'], keyManagementAlgorithms: ['dir'], diff --git a/packages/amazonq/test/unit/amazonq/lsp/encryption.test.ts b/packages/amazonq/test/unit/amazonq/lsp/encryption.test.ts new file mode 100644 index 00000000000..06a901edde6 --- /dev/null +++ b/packages/amazonq/test/unit/amazonq/lsp/encryption.test.ts @@ -0,0 +1,27 @@ +/*! + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +import * as assert from 'assert' +import { decryptResponse, encryptRequest } from '../../../../src/lsp/encryption' +import { encryptionKey } from '../../../../src/lsp/auth' + +describe('LSP encryption', function () { + it('encrypt and decrypt invert eachother with same key', async function () { + const key = encryptionKey + const request = { + id: 0, + name: 'my Request', + isRealRequest: false, + metadata: { + tags: ['tag1', 'tag2'], + }, + } + const encryptedPayload = await encryptRequest(request, key) + const message = (encryptedPayload as { message: string }).message + const decrypted = await decryptResponse(message, key) + + assert.deepStrictEqual(decrypted, request) + }) +})