|
| 1 | +/*! |
| 2 | + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. |
| 3 | + * SPDX-License-Identifier: Apache-2.0 |
| 4 | + */ |
| 5 | + |
| 6 | +import { |
| 7 | + isValidAuthFollowUpType, |
| 8 | + INSERT_TO_CURSOR_POSITION, |
| 9 | + AUTH_FOLLOW_UP_CLICKED, |
| 10 | + CHAT_OPTIONS, |
| 11 | + COPY_TO_CLIPBOARD, |
| 12 | +} from '@aws/chat-client-ui-types' |
| 13 | +import { |
| 14 | + ChatResult, |
| 15 | + chatRequestType, |
| 16 | + ChatParams, |
| 17 | + followUpClickNotificationType, |
| 18 | + quickActionRequestType, |
| 19 | + QuickActionResult, |
| 20 | + QuickActionParams, |
| 21 | + insertToCursorPositionNotificationType, |
| 22 | +} from '@aws/language-server-runtimes/protocol' |
| 23 | +import { v4 as uuidv4 } from 'uuid' |
| 24 | +import { window } from 'vscode' |
| 25 | +import { Disposable, LanguageClient, Position, State, TextDocumentIdentifier } from 'vscode-languageclient' |
| 26 | +import * as jose from 'jose' |
| 27 | +import { AmazonQChatViewProvider } from './webviewProvider' |
| 28 | + |
| 29 | +export function registerLanguageServerEventListener(languageClient: LanguageClient, provider: AmazonQChatViewProvider) { |
| 30 | + languageClient.onDidChangeState(({ oldState, newState }) => { |
| 31 | + if (oldState === State.Starting && newState === State.Running) { |
| 32 | + languageClient.info( |
| 33 | + 'Language client received initializeResult from server:', |
| 34 | + JSON.stringify(languageClient.initializeResult) |
| 35 | + ) |
| 36 | + |
| 37 | + const chatOptions = languageClient.initializeResult?.awsServerCapabilities?.chatOptions |
| 38 | + |
| 39 | + void provider.webview?.postMessage({ |
| 40 | + command: CHAT_OPTIONS, |
| 41 | + params: chatOptions, |
| 42 | + }) |
| 43 | + } |
| 44 | + }) |
| 45 | + |
| 46 | + languageClient.onTelemetry((e) => { |
| 47 | + languageClient.info(`[VSCode Client] Received telemetry event from server ${JSON.stringify(e)}`) |
| 48 | + }) |
| 49 | +} |
| 50 | + |
| 51 | +export function registerMessageListeners( |
| 52 | + languageClient: LanguageClient, |
| 53 | + provider: AmazonQChatViewProvider, |
| 54 | + encryptionKey: Buffer |
| 55 | +) { |
| 56 | + provider.webview?.onDidReceiveMessage(async (message) => { |
| 57 | + languageClient.info(`[VSCode Client] Received ${JSON.stringify(message)} from chat`) |
| 58 | + |
| 59 | + switch (message.command) { |
| 60 | + case COPY_TO_CLIPBOARD: |
| 61 | + // TODO see what we need to hook this up |
| 62 | + languageClient.info('[VSCode Client] Copy to clipboard event received') |
| 63 | + break |
| 64 | + case INSERT_TO_CURSOR_POSITION: { |
| 65 | + const editor = window.activeTextEditor |
| 66 | + let textDocument: TextDocumentIdentifier | undefined = undefined |
| 67 | + let cursorPosition: Position | undefined = undefined |
| 68 | + if (editor) { |
| 69 | + cursorPosition = editor.selection.active |
| 70 | + textDocument = { uri: editor.document.uri.toString() } |
| 71 | + } |
| 72 | + |
| 73 | + languageClient.sendNotification(insertToCursorPositionNotificationType.method, { |
| 74 | + ...message.params, |
| 75 | + cursorPosition, |
| 76 | + textDocument, |
| 77 | + }) |
| 78 | + break |
| 79 | + } |
| 80 | + case AUTH_FOLLOW_UP_CLICKED: |
| 81 | + // TODO hook this into auth |
| 82 | + languageClient.info('[VSCode Client] AuthFollowUp clicked') |
| 83 | + break |
| 84 | + case chatRequestType.method: { |
| 85 | + const partialResultToken = uuidv4() |
| 86 | + const chatDisposable = languageClient.onProgress(chatRequestType, partialResultToken, (partialResult) => |
| 87 | + handlePartialResult<ChatResult>(partialResult, encryptionKey, provider, message.params.tabId) |
| 88 | + ) |
| 89 | + |
| 90 | + const editor = |
| 91 | + window.activeTextEditor || |
| 92 | + window.visibleTextEditors.find((editor) => editor.document.languageId !== 'Log') |
| 93 | + if (editor) { |
| 94 | + message.params.cursorPosition = [editor.selection.active] |
| 95 | + message.params.textDocument = { uri: editor.document.uri.toString() } |
| 96 | + } |
| 97 | + |
| 98 | + const chatRequest = await encryptRequest<ChatParams>(message.params, encryptionKey) |
| 99 | + const chatResult = (await languageClient.sendRequest(chatRequestType.method, { |
| 100 | + ...chatRequest, |
| 101 | + partialResultToken, |
| 102 | + })) as string | ChatResult |
| 103 | + void handleCompleteResult<ChatResult>( |
| 104 | + chatResult, |
| 105 | + encryptionKey, |
| 106 | + provider, |
| 107 | + message.params.tabId, |
| 108 | + chatDisposable |
| 109 | + ) |
| 110 | + break |
| 111 | + } |
| 112 | + case quickActionRequestType.method: { |
| 113 | + const quickActionPartialResultToken = uuidv4() |
| 114 | + const quickActionDisposable = languageClient.onProgress( |
| 115 | + quickActionRequestType, |
| 116 | + quickActionPartialResultToken, |
| 117 | + (partialResult) => |
| 118 | + handlePartialResult<QuickActionResult>( |
| 119 | + partialResult, |
| 120 | + encryptionKey, |
| 121 | + provider, |
| 122 | + message.params.tabId |
| 123 | + ) |
| 124 | + ) |
| 125 | + |
| 126 | + const quickActionRequest = await encryptRequest<QuickActionParams>(message.params, encryptionKey) |
| 127 | + const quickActionResult = (await languageClient.sendRequest(quickActionRequestType.method, { |
| 128 | + ...quickActionRequest, |
| 129 | + partialResultToken: quickActionPartialResultToken, |
| 130 | + })) as string | ChatResult |
| 131 | + void handleCompleteResult<ChatResult>( |
| 132 | + quickActionResult, |
| 133 | + encryptionKey, |
| 134 | + provider, |
| 135 | + message.params.tabId, |
| 136 | + quickActionDisposable |
| 137 | + ) |
| 138 | + break |
| 139 | + } |
| 140 | + case followUpClickNotificationType.method: |
| 141 | + if (!isValidAuthFollowUpType(message.params.followUp.type)) { |
| 142 | + languageClient.sendNotification(followUpClickNotificationType.method, message.params) |
| 143 | + } |
| 144 | + break |
| 145 | + default: |
| 146 | + if (isServerEvent(message.command)) { |
| 147 | + languageClient.sendNotification(message.command, message.params) |
| 148 | + } |
| 149 | + break |
| 150 | + } |
| 151 | + }, undefined) |
| 152 | +} |
| 153 | + |
| 154 | +function isServerEvent(command: string) { |
| 155 | + return command.startsWith('aws/chat/') || command === 'telemetry/event' |
| 156 | +} |
| 157 | + |
| 158 | +async function encryptRequest<T>(params: T, encryptionKey: Buffer): Promise<{ message: string } | T> { |
| 159 | + const payload = new TextEncoder().encode(JSON.stringify(params)) |
| 160 | + |
| 161 | + const encryptedMessage = await new jose.CompactEncrypt(payload) |
| 162 | + .setProtectedHeader({ alg: 'dir', enc: 'A256GCM' }) |
| 163 | + .encrypt(encryptionKey) |
| 164 | + |
| 165 | + return { message: encryptedMessage } |
| 166 | +} |
| 167 | + |
| 168 | +async function decodeRequest<T>(request: string, key: Buffer): Promise<T> { |
| 169 | + const result = await jose.jwtDecrypt(request, key, { |
| 170 | + clockTolerance: 60, // Allow up to 60 seconds to account for clock differences |
| 171 | + contentEncryptionAlgorithms: ['A256GCM'], |
| 172 | + keyManagementAlgorithms: ['dir'], |
| 173 | + }) |
| 174 | + |
| 175 | + if (!result.payload) { |
| 176 | + throw new Error('JWT payload not found') |
| 177 | + } |
| 178 | + return result.payload as T |
| 179 | +} |
| 180 | + |
| 181 | +/** |
| 182 | + * Decodes partial chat responses from the language server before sending them to mynah UI |
| 183 | + */ |
| 184 | +async function handlePartialResult<T extends ChatResult>( |
| 185 | + partialResult: string | T, |
| 186 | + encryptionKey: Buffer | undefined, |
| 187 | + provider: AmazonQChatViewProvider, |
| 188 | + tabId: string |
| 189 | +) { |
| 190 | + const decryptedMessage = |
| 191 | + typeof partialResult === 'string' && encryptionKey |
| 192 | + ? await decodeRequest<T>(partialResult, encryptionKey) |
| 193 | + : (partialResult as T) |
| 194 | + |
| 195 | + if (decryptedMessage.body) { |
| 196 | + void provider.webview?.postMessage({ |
| 197 | + command: chatRequestType.method, |
| 198 | + params: decryptedMessage, |
| 199 | + isPartialResult: true, |
| 200 | + tabId: tabId, |
| 201 | + }) |
| 202 | + } |
| 203 | +} |
| 204 | + |
| 205 | +/** |
| 206 | + * Decodes the final chat responses from the language server before sending it to mynah UI. |
| 207 | + * Once this is called the answer response is finished |
| 208 | + */ |
| 209 | +async function handleCompleteResult<T>( |
| 210 | + result: string | T, |
| 211 | + encryptionKey: Buffer | undefined, |
| 212 | + provider: AmazonQChatViewProvider, |
| 213 | + tabId: string, |
| 214 | + disposable: Disposable |
| 215 | +) { |
| 216 | + const decryptedMessage = |
| 217 | + typeof result === 'string' && encryptionKey ? await decodeRequest(result, encryptionKey) : result |
| 218 | + |
| 219 | + void provider.webview?.postMessage({ |
| 220 | + command: chatRequestType.method, |
| 221 | + params: decryptedMessage, |
| 222 | + tabId: tabId, |
| 223 | + }) |
| 224 | + disposable.dispose() |
| 225 | +} |
0 commit comments