|
| 1 | +/*! |
| 2 | + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. |
| 3 | + * SPDX-License-Identifier: Apache-2.0 |
| 4 | + */ |
| 5 | +import { |
| 6 | + isValidAuthFollowUpType, |
| 7 | + INSERT_TO_CURSOR_POSITION, |
| 8 | + AUTH_FOLLOW_UP_CLICKED, |
| 9 | + CHAT_OPTIONS, |
| 10 | + COPY_TO_CLIPBOARD, |
| 11 | +} from '@aws/chat-client-ui-types' |
| 12 | +import { |
| 13 | + ChatResult, |
| 14 | + chatRequestType, |
| 15 | + ChatParams, |
| 16 | + followUpClickNotificationType, |
| 17 | + quickActionRequestType, |
| 18 | + QuickActionResult, |
| 19 | + QuickActionParams, |
| 20 | + insertToCursorPositionNotificationType, |
| 21 | +} from '@aws/language-server-runtimes/protocol' |
| 22 | +import { v4 as uuidv4 } from 'uuid' |
| 23 | +import { Webview, window } from 'vscode' |
| 24 | +import { Disposable, LanguageClient, Position, State, TextDocumentIdentifier } from 'vscode-languageclient' |
| 25 | +import * as jose from 'jose' |
| 26 | +import { encryptionKey } from '../lsp/auth' |
| 27 | +import { Commands } from 'aws-core-vscode/shared' |
| 28 | + |
| 29 | +export function handle(client: LanguageClient, webview: Webview) { |
| 30 | + // Listen for Initialize handshake from LSP server to register quick actions dynamically |
| 31 | + client.onDidChangeState(({ oldState, newState }) => { |
| 32 | + if (oldState === State.Starting && newState === State.Running) { |
| 33 | + client.info( |
| 34 | + 'Language client received initializeResult from server:', |
| 35 | + JSON.stringify(client.initializeResult) |
| 36 | + ) |
| 37 | + |
| 38 | + const chatOptions = client.initializeResult?.awsServerCapabilities?.chatOptions |
| 39 | + |
| 40 | + void webview.postMessage({ |
| 41 | + command: CHAT_OPTIONS, |
| 42 | + params: chatOptions, |
| 43 | + }) |
| 44 | + } |
| 45 | + }) |
| 46 | + |
| 47 | + client.onTelemetry((e) => { |
| 48 | + client.info(`[VSCode Client] Received telemetry event from server ${JSON.stringify(e)}`) |
| 49 | + }) |
| 50 | + |
| 51 | + webview.onDidReceiveMessage(async (message) => { |
| 52 | + client.info(`[VSCode Client] Received ${JSON.stringify(message)} from chat`) |
| 53 | + |
| 54 | + switch (message.command) { |
| 55 | + case COPY_TO_CLIPBOARD: |
| 56 | + client.info('[VSCode Client] Copy to clipboard event received') |
| 57 | + break |
| 58 | + case INSERT_TO_CURSOR_POSITION: { |
| 59 | + const editor = window.activeTextEditor |
| 60 | + let textDocument: TextDocumentIdentifier | undefined = undefined |
| 61 | + let cursorPosition: Position | undefined = undefined |
| 62 | + if (editor) { |
| 63 | + cursorPosition = editor.selection.active |
| 64 | + textDocument = { uri: editor.document.uri.toString() } |
| 65 | + } |
| 66 | + |
| 67 | + client.sendNotification(insertToCursorPositionNotificationType.method, { |
| 68 | + ...message.params, |
| 69 | + cursorPosition, |
| 70 | + textDocument, |
| 71 | + }) |
| 72 | + break |
| 73 | + } |
| 74 | + case AUTH_FOLLOW_UP_CLICKED: |
| 75 | + client.info('[VSCode Client] AuthFollowUp clicked') |
| 76 | + break |
| 77 | + case chatRequestType.method: { |
| 78 | + const partialResultToken = uuidv4() |
| 79 | + const chatDisposable = client.onProgress(chatRequestType, partialResultToken, (partialResult) => |
| 80 | + handlePartialResult<ChatResult>(partialResult, encryptionKey, message.params.tabId, webview) |
| 81 | + ) |
| 82 | + |
| 83 | + const editor = |
| 84 | + window.activeTextEditor || |
| 85 | + window.visibleTextEditors.find((editor) => editor.document.languageId !== 'Log') |
| 86 | + if (editor) { |
| 87 | + message.params.cursorPosition = [editor.selection.active] |
| 88 | + message.params.textDocument = { uri: editor.document.uri.toString() } |
| 89 | + } |
| 90 | + |
| 91 | + const chatRequest = await encryptRequest<ChatParams>(message.params, encryptionKey) |
| 92 | + const chatResult = (await client.sendRequest(chatRequestType.method, { |
| 93 | + ...chatRequest, |
| 94 | + partialResultToken, |
| 95 | + })) as string | ChatResult |
| 96 | + void handleCompleteResult<ChatResult>( |
| 97 | + chatResult, |
| 98 | + encryptionKey, |
| 99 | + message.params.tabId, |
| 100 | + chatDisposable, |
| 101 | + webview |
| 102 | + ) |
| 103 | + break |
| 104 | + } |
| 105 | + case quickActionRequestType.method: { |
| 106 | + const quickActionPartialResultToken = uuidv4() |
| 107 | + const quickActionDisposable = client.onProgress( |
| 108 | + quickActionRequestType, |
| 109 | + quickActionPartialResultToken, |
| 110 | + (partialResult) => |
| 111 | + handlePartialResult<QuickActionResult>( |
| 112 | + partialResult, |
| 113 | + encryptionKey, |
| 114 | + message.params.tabId, |
| 115 | + webview |
| 116 | + ) |
| 117 | + ) |
| 118 | + |
| 119 | + const quickActionRequest = await encryptRequest<QuickActionParams>(message.params, encryptionKey) |
| 120 | + const quickActionResult = (await client.sendRequest(quickActionRequestType.method, { |
| 121 | + ...quickActionRequest, |
| 122 | + partialResultToken: quickActionPartialResultToken, |
| 123 | + })) as string | ChatResult |
| 124 | + void handleCompleteResult<ChatResult>( |
| 125 | + quickActionResult, |
| 126 | + encryptionKey, |
| 127 | + message.params.tabId, |
| 128 | + quickActionDisposable, |
| 129 | + webview |
| 130 | + ) |
| 131 | + break |
| 132 | + } |
| 133 | + case followUpClickNotificationType.method: |
| 134 | + if (!isValidAuthFollowUpType(message.params.followUp.type)) { |
| 135 | + client.sendNotification(followUpClickNotificationType.method, message.params) |
| 136 | + } |
| 137 | + break |
| 138 | + default: |
| 139 | + if (isServerEvent(message.command)) { |
| 140 | + client.sendNotification(message.command, message.params) |
| 141 | + } |
| 142 | + break |
| 143 | + } |
| 144 | + }, undefined) |
| 145 | + |
| 146 | + registerGenericCommand('aws.amazonq.explainCode', 'Explain', webview) |
| 147 | + registerGenericCommand('aws.amazonq.refactorCode', 'Refactor', webview) |
| 148 | + registerGenericCommand('aws.amazonq.fixCode', 'Fix', webview) |
| 149 | + registerGenericCommand('aws.amazonq.optimizeCode', 'Optimize', webview) |
| 150 | + |
| 151 | + Commands.register('aws.amazonq.sendToPrompt', (data) => { |
| 152 | + const triggerType = getCommandTriggerType(data) |
| 153 | + const selection = getSelectedText() |
| 154 | + |
| 155 | + void webview.postMessage({ |
| 156 | + command: 'sendToPrompt', |
| 157 | + params: { selection: selection, triggerType }, |
| 158 | + }) |
| 159 | + }) |
| 160 | +} |
| 161 | + |
| 162 | +function getSelectedText(): string { |
| 163 | + const editor = window.activeTextEditor |
| 164 | + if (editor) { |
| 165 | + const selection = editor.selection |
| 166 | + const selectedText = editor.document.getText(selection) |
| 167 | + return selectedText |
| 168 | + } |
| 169 | + |
| 170 | + return ' ' |
| 171 | +} |
| 172 | + |
| 173 | +function getCommandTriggerType(data: any): string { |
| 174 | + // data is undefined when commands triggered from keybinding or command palette. Currently no |
| 175 | + // way to differentiate keybinding and command palette, so both interactions are recorded as keybinding |
| 176 | + return data === undefined ? 'hotkeys' : 'contextMenu' |
| 177 | +} |
| 178 | + |
| 179 | +function registerGenericCommand(commandName: string, genericCommand: string, webview?: Webview) { |
| 180 | + Commands.register(commandName, (data) => { |
| 181 | + const triggerType = getCommandTriggerType(data) |
| 182 | + const selection = getSelectedText() |
| 183 | + |
| 184 | + void webview?.postMessage({ |
| 185 | + command: 'genericCommand', |
| 186 | + params: { genericCommand, selection, triggerType }, |
| 187 | + }) |
| 188 | + }) |
| 189 | +} |
| 190 | + |
| 191 | +function isServerEvent(command: string) { |
| 192 | + return command.startsWith('aws/chat/') || command === 'telemetry/event' |
| 193 | +} |
| 194 | + |
| 195 | +// Encrypt the provided request if encryption key exists otherwise do nothing |
| 196 | +async function encryptRequest<T>(params: T, encryptionKey: Buffer | undefined): Promise<{ message: string } | T> { |
| 197 | + if (!encryptionKey) { |
| 198 | + return params |
| 199 | + } |
| 200 | + |
| 201 | + const payload = new TextEncoder().encode(JSON.stringify(params)) |
| 202 | + |
| 203 | + const encryptedMessage = await new jose.CompactEncrypt(payload) |
| 204 | + .setProtectedHeader({ alg: 'dir', enc: 'A256GCM' }) |
| 205 | + .encrypt(encryptionKey) |
| 206 | + |
| 207 | + return { message: encryptedMessage } |
| 208 | +} |
| 209 | + |
| 210 | +async function decodeRequest<T>(request: string, key: Buffer): Promise<T> { |
| 211 | + const result = await jose.jwtDecrypt(request, key, { |
| 212 | + clockTolerance: 60, // Allow up to 60 seconds to account for clock differences |
| 213 | + contentEncryptionAlgorithms: ['A256GCM'], |
| 214 | + keyManagementAlgorithms: ['dir'], |
| 215 | + }) |
| 216 | + |
| 217 | + if (!result.payload) { |
| 218 | + throw new Error('JWT payload not found') |
| 219 | + } |
| 220 | + return result.payload as T |
| 221 | +} |
| 222 | + |
| 223 | +async function handlePartialResult<T extends ChatResult>( |
| 224 | + partialResult: string | T, |
| 225 | + encryptionKey: Buffer | undefined, |
| 226 | + tabId: string, |
| 227 | + webview: Webview |
| 228 | +) { |
| 229 | + const decryptedMessage = |
| 230 | + typeof partialResult === 'string' && encryptionKey |
| 231 | + ? await decodeRequest<T>(partialResult, encryptionKey) |
| 232 | + : (partialResult as T) |
| 233 | + |
| 234 | + if (decryptedMessage.body) { |
| 235 | + void webview?.postMessage({ |
| 236 | + command: chatRequestType.method, |
| 237 | + params: decryptedMessage, |
| 238 | + isPartialResult: true, |
| 239 | + tabId: tabId, |
| 240 | + }) |
| 241 | + } |
| 242 | +} |
| 243 | + |
| 244 | +async function handleCompleteResult<T>( |
| 245 | + result: string | T, |
| 246 | + encryptionKey: Buffer | undefined, |
| 247 | + tabId: string, |
| 248 | + disposable: Disposable, |
| 249 | + webview: Webview |
| 250 | +) { |
| 251 | + const decryptedMessage = |
| 252 | + typeof result === 'string' && encryptionKey ? await decodeRequest(result, encryptionKey) : result |
| 253 | + |
| 254 | + void webview?.postMessage({ |
| 255 | + command: chatRequestType.method, |
| 256 | + params: decryptedMessage, |
| 257 | + tabId: tabId, |
| 258 | + }) |
| 259 | + disposable.dispose() |
| 260 | +} |
0 commit comments