Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 6 additions & 9 deletions packages/amazonq/src/inlineChat/provider/inlineChatProvider.ts
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ import type { InlineChatEvent } from 'aws-core-vscode/codewhisperer'
import { InlineTask } from '../controller/inlineTask'
import { extractAuthFollowUp } from 'aws-core-vscode/amazonq'
import { InlineChatParams, InlineChatResult } from '@aws/language-server-runtimes-types'
import { decodeRequest, encryptRequest } from '../../lsp/encryption'
import { decryptResponse, encryptRequest } from '../../lsp/encryption'
import { getCursorState } from '../../lsp/utils'

export class InlineChatProvider {
Expand Down Expand Up @@ -72,16 +72,13 @@ export class InlineChatProvider {
// TODO: handle partial responses.
getLogger().info('Making inline chat request with message %O', message)
const params = this.getCurrentEditorParams(message.message ?? '')

const inlineChatRequest = await encryptRequest<InlineChatParams>(params, this.encryptionKey)
const response = await this.client.sendRequest(inlineChatRequestType.method, inlineChatRequest)
const decryptedMessage =
typeof response === 'string' && this.encryptionKey
? await decodeRequest(response, this.encryptionKey)
: response
const result: InlineChatResult = decryptedMessage as InlineChatResult
this.client.info(`Logging response for inline chat ${JSON.stringify(decryptedMessage)}`)

return result
const inlineChatResponse = await decryptResponse<InlineChatResult>(response, this.encryptionKey)
this.client.info(`Logging response for inline chat ${JSON.stringify(inlineChatResponse)}`)

return inlineChatResponse
}

// TODO: remove in favor of LSP implementation.
Expand Down
31 changes: 10 additions & 21 deletions packages/amazonq/src/lsp/chat/messages.ts
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ import {
} from 'aws-core-vscode/amazonq'
import { telemetry, TelemetryBase } from 'aws-core-vscode/telemetry'
import { isValidResponseError } from './error'
import { decodeRequest, encryptRequest } from '../encryption'
import { decryptResponse, encryptRequest } from '../encryption'
import { getCursorState } from '../utils'

export function registerLanguageServerEventListener(languageClient: LanguageClient, provider: AmazonQChatViewProvider) {
Expand Down Expand Up @@ -205,21 +205,12 @@ export function registerMessageListeners(
const cancellationToken = new CancellationTokenSource()
chatStreamTokens.set(chatParams.tabId, cancellationToken)

const chatDisposable = languageClient.onProgress(
chatRequestType,
partialResultToken,
(partialResult) => {
// Store the latest partial result
if (typeof partialResult === 'string' && encryptionKey) {
void decodeRequest<ChatResult>(partialResult, encryptionKey).then(
(decoded) => (lastPartialResult = decoded)
)
} else {
lastPartialResult = partialResult as ChatResult
const chatDisposable = languageClient.onProgress(chatRequestType, partialResultToken, (partialResult) =>
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this refactor is to avoid decrypting partial responses twice

handlePartialResult<ChatResult>(partialResult, encryptionKey, provider, chatParams.tabId).then(
(result) => {
lastPartialResult = result
}

void handlePartialResult<ChatResult>(partialResult, encryptionKey, provider, chatParams.tabId)
}
)
)

const editor =
Expand Down Expand Up @@ -482,10 +473,7 @@ async function handlePartialResult<T extends ChatResult>(
provider: AmazonQChatViewProvider,
tabId: string
) {
const decryptedMessage =
typeof partialResult === 'string' && encryptionKey
? await decodeRequest<T>(partialResult, encryptionKey)
: (partialResult as T)
const decryptedMessage = await decryptResponse<T>(partialResult, encryptionKey)

if (decryptedMessage.body !== undefined) {
void provider.webview?.postMessage({
Expand All @@ -495,6 +483,7 @@ async function handlePartialResult<T extends ChatResult>(
tabId: tabId,
})
}
return decryptedMessage
}

/**
Expand All @@ -508,8 +497,8 @@ async function handleCompleteResult<T extends ChatResult>(
tabId: string,
disposable: Disposable
) {
const decryptedMessage =
typeof result === 'string' && encryptionKey ? await decodeRequest<T>(result, encryptionKey) : (result as T)
const decryptedMessage = await decryptResponse<T>(result, encryptionKey)

void provider.webview?.postMessage({
command: chatRequestType.method,
params: decryptedMessage,
Expand Down
10 changes: 8 additions & 2 deletions packages/amazonq/src/lsp/encryption.ts
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,14 @@ export async function encryptRequest<T>(params: T, encryptionKey: Buffer): Promi
return { message: encryptedMessage }
}

export async function decodeRequest<T>(request: string, key: Buffer): Promise<T> {
const result = await jose.jwtDecrypt(request, key, {
export async function decryptResponse<T>(response: unknown, key: Buffer | undefined) {
// Note that casts are required since language client requests return 'unknown' type.
// If we can't decrypt, return original response casted.
if (typeof response !== 'string' || key === undefined) {
return response as T
}

const result = await jose.jwtDecrypt(response, key, {
clockTolerance: 60, // Allow up to 60 seconds to account for clock differences
contentEncryptionAlgorithms: ['A256GCM'],
keyManagementAlgorithms: ['dir'],
Expand Down
27 changes: 27 additions & 0 deletions packages/amazonq/test/unit/amazonq/lsp/encryption.test.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
/*!
* Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
* SPDX-License-Identifier: Apache-2.0
*/

import * as assert from 'assert'
import { decryptResponse, encryptRequest } from '../../../../src/lsp/encryption'
import { encryptionKey } from '../../../../src/lsp/auth'

describe('LSP encryption', function () {
it('encrypt and decrypt invert eachother with same key', async function () {
const key = encryptionKey
const request = {
id: 0,
name: 'my Request',
isRealRequest: false,
metadata: {
tags: ['tag1', 'tag2'],
},
}
const encryptedPayload = await encryptRequest<typeof request>(request, key)
const message = (encryptedPayload as { message: string }).message
const decrypted = await decryptResponse<typeof request>(message, key)

assert.deepStrictEqual(decrypted, request)
})
})
Loading