diff --git a/packages/core/src/codewhispererChat/clients/chat/v0/chat.ts b/packages/core/src/codewhispererChat/clients/chat/v0/chat.ts index 540135268e0..a89a203430d 100644 --- a/packages/core/src/codewhispererChat/clients/chat/v0/chat.ts +++ b/packages/core/src/codewhispererChat/clients/chat/v0/chat.ts @@ -4,7 +4,11 @@ */ import { SendMessageCommandOutput, SendMessageRequest } from '@amzn/amazon-q-developer-streaming-client' -import { GenerateAssistantResponseCommandOutput, GenerateAssistantResponseRequest } from '@amzn/codewhisperer-streaming' +import { + GenerateAssistantResponseCommandOutput, + GenerateAssistantResponseRequest, + ToolUse, +} from '@amzn/codewhisperer-streaming' import * as vscode from 'vscode' import { ToolkitError } from '../../../../shared/errors' import { createCodeWhispererChatStreamingClient } from '../../../../shared/clients/codewhispererChatClient' @@ -13,6 +17,7 @@ import { UserWrittenCodeTracker } from '../../../../codewhisperer/tracker/userWr export class ChatSession { private sessionId?: string + private _toolUse: ToolUse | undefined contexts: Map = new Map() // TODO: doesn't handle the edge case when two files share the same relativePath string but from different root @@ -22,6 +27,14 @@ export class ChatSession { return this.sessionId } + public get toolUse(): ToolUse | undefined { + return this._toolUse + } + + public setToolUse(toolUse: ToolUse | undefined) { + this._toolUse = toolUse + } + public tokenSource!: vscode.CancellationTokenSource constructor() { diff --git a/packages/core/src/codewhispererChat/constants.ts b/packages/core/src/codewhispererChat/constants.ts index 4566d14ec64..7f7c71435e2 100644 --- a/packages/core/src/codewhispererChat/constants.ts +++ b/packages/core/src/codewhispererChat/constants.ts @@ -4,6 +4,8 @@ */ import * as path from 'path' import fs from '../shared/fs/fs' +import { Tool } from '@amzn/codewhisperer-streaming' +import toolsJson from '../codewhispererChat/tools/tool_index.json' export const promptFileExtension = '.md' @@ -19,3 +21,10 @@ export const getUserPromptsDirectory = () => { } export const createSavedPromptCommandId = 'create-saved-prompt' + +export const tools: Tool[] = Object.entries(toolsJson).map(([, toolSpec]) => ({ + toolSpecification: { + ...toolSpec, + inputSchema: { json: toolSpec.inputSchema }, + }, +})) diff --git a/packages/core/src/codewhispererChat/controllers/chat/chatRequest/converter.ts b/packages/core/src/codewhispererChat/controllers/chat/chatRequest/converter.ts index dd6edfbeff0..d46dbf8565c 100644 --- a/packages/core/src/codewhispererChat/controllers/chat/chatRequest/converter.ts +++ b/packages/core/src/codewhispererChat/controllers/chat/chatRequest/converter.ts @@ -13,6 +13,7 @@ import { } from '@amzn/codewhisperer-streaming' import { ChatTriggerType, TriggerPayload } from '../model' import { undefinedIfEmpty } from '../../../../shared/utilities/textUtilities' +import { tools } from '../../../constants' const fqnNameSizeDownLimit = 1 const fqnNameSizeUpLimit = 256 @@ -115,10 +116,16 @@ export function triggerPayloadToChatRequest(triggerPayload: TriggerPayload): { c cursorState, relevantDocuments, useRelevantDocuments, + // TODO: Need workspace folders here after model update. }, additionalContext: triggerPayload.additionalContents, + tools, + ...(triggerPayload.toolResults !== undefined && + triggerPayload.toolResults !== null && { toolResults: triggerPayload.toolResults }), }, userIntent: triggerPayload.userIntent, + ...(triggerPayload.origin !== undefined && + triggerPayload.origin !== null && { origin: triggerPayload.origin }), }, }, chatTriggerType, diff --git a/packages/core/src/codewhispererChat/controllers/chat/controller.ts b/packages/core/src/codewhispererChat/controllers/chat/controller.ts index 667c3d418f6..b7c67f1ecc1 100644 --- a/packages/core/src/codewhispererChat/controllers/chat/controller.ts +++ b/packages/core/src/codewhispererChat/controllers/chat/controller.ts @@ -45,7 +45,7 @@ import { EditorContextCommand } from '../../commands/registerCommands' import { PromptsGenerator } from './prompts/promptsGenerator' import { TriggerEventsStorage } from '../../storages/triggerEvents' import { SendMessageRequest } from '@amzn/amazon-q-developer-streaming-client' -import { CodeWhispererStreamingServiceException } from '@amzn/codewhisperer-streaming' +import { CodeWhispererStreamingServiceException, Origin, ToolResult } from '@amzn/codewhisperer-streaming' import { UserIntentRecognizer } from './userIntent/userIntentRecognizer' import { CWCTelemetryHelper, recordTelemetryChatRunCommand } from './telemetryHelper' import { CodeWhispererTracker } from '../../../codewhisperer/tracker/codewhispererTracker' @@ -81,6 +81,7 @@ import { } from '../../constants' import { ChatSession } from '../../clients/chat/v0/chat' import { ChatHistoryManager } from '../../storages/chatHistory' +import { FsRead, FsReadParams } from '../../tools/fsRead' export interface ChatControllerMessagePublishers { readonly processPromptChatMessage: MessagePublisher @@ -577,6 +578,8 @@ export class ChatController { const newFileDoc = await vscode.workspace.openTextDocument(newFilePath) await vscode.window.showTextDocument(newFileDoc) telemetry.ui_click.emit({ elementId: 'amazonq_createSavedPrompt' }) + } else if (message.action.id === 'confirm-tool-use') { + await this.processToolUseMessage(message) } } @@ -834,10 +837,108 @@ export class ChatController { } } + private async processToolUseMessage(message: CustomFormActionMessage) { + const tabID = message.tabID + if (!tabID) { + return + } + this.editorContextExtractor + .extractContextForTrigger('ChatMessage') + .then(async (context) => { + const triggerID = randomUUID() + this.triggerEventsStorage.addTriggerEvent({ + id: triggerID, + tabID: message.tabID, + message: undefined, + type: 'chat_message', + context, + }) + const session = this.sessionStorage.getSession(tabID) + const toolUse = session.toolUse + if (!toolUse || !toolUse.input) { + return + } + session.setToolUse(undefined) + + let result: any + const toolResults: ToolResult[] = [] + try { + switch (toolUse.name) { + // case 'execute_bash': { + // const executeBash = new ExecuteBash(toolUse.input as unknown as ExecuteBashParams) + // await executeBash.validate() + // result = await executeBash.invoke(process.stdout) + // break + // } + case 'fs_read': { + const fsRead = new FsRead(toolUse.input as unknown as FsReadParams) + await fsRead.validate() + result = await fsRead.invoke() + break + } + // case 'fs_write': { + // const fsWrite = new FsWrite(toolUse.input as unknown as FsWriteParams) + // const ctx = new DefaultContext() + // result = await fsWrite.invoke(ctx, process.stdout) + // break + // } + // case 'open_file': { + // result = await openFile(toolUse.input as unknown as OpenFileParams) + // break + // } + default: + break + } + toolResults.push({ + content: [ + result.output.kind === 'text' + ? { text: result.output.content } + : { json: result.output.content }, + ], + toolUseId: toolUse.toolUseId, + status: 'success', + }) + } catch (e: any) { + toolResults.push({ content: [{ text: e.message }], toolUseId: toolUse.toolUseId, status: 'error' }) + } + + this.chatHistoryManager.appendUserMessage({ + userInputMessage: { + content: 'Tool Results', + userIntent: undefined, + origin: Origin.IDE, + }, + }) + + await this.generateResponse( + { + message: 'Tool Results', + trigger: ChatTriggerType.ChatMessage, + query: undefined, + codeSelection: context?.focusAreaContext?.selectionInsideExtendedCodeBlock, + fileText: context?.focusAreaContext?.extendedCodeBlock, + fileLanguage: context?.activeFileContext?.fileLanguage, + filePath: context?.activeFileContext?.filePath, + matchPolicy: context?.activeFileContext?.matchPolicy, + codeQuery: context?.focusAreaContext?.names, + userIntent: undefined, + customization: getSelectedCustomization(), + context: undefined, + toolResults: toolResults, + origin: Origin.IDE, + }, + triggerID + ) + }) + .catch((e) => { + this.processException(e, tabID) + }) + } + private async processPromptMessageAsNewThread(message: PromptMessage) { this.editorContextExtractor .extractContextForTrigger('ChatMessage') - .then((context) => { + .then(async (context) => { const triggerID = randomUUID() this.triggerEventsStorage.addTriggerEvent({ id: triggerID, @@ -850,9 +951,10 @@ export class ChatController { userInputMessage: { content: message.message, userIntent: message.userIntent, + origin: Origin.IDE, }, }) - return this.generateResponse( + await this.generateResponse( { message: message.message, trigger: ChatTriggerType.ChatMessage, @@ -867,6 +969,7 @@ export class ChatController { customization: getSelectedCustomization(), context: message.context, chatHistory: this.chatHistoryManager.getHistory(), + origin: Origin.IDE, }, triggerID ) diff --git a/packages/core/src/codewhispererChat/controllers/chat/messenger/messenger.ts b/packages/core/src/codewhispererChat/controllers/chat/messenger/messenger.ts index 3c5f181a3fb..24ca24e0d59 100644 --- a/packages/core/src/codewhispererChat/controllers/chat/messenger/messenger.ts +++ b/packages/core/src/codewhispererChat/controllers/chat/messenger/messenger.ts @@ -20,6 +20,7 @@ import { ChatResponseStream as cwChatResponseStream, CodeWhispererStreamingServiceException, SupplementaryWebLink, + ToolUse, } from '@amzn/codewhisperer-streaming' import { ChatMessage, ErrorMessage, FollowUp, Suggestion } from '../../../view/connector/connector' import { ChatSession } from '../../../clients/chat/v0/chat' @@ -131,6 +132,8 @@ export class Messenger { let followUps: FollowUp[] = [] let relatedSuggestions: Suggestion[] = [] let codeBlockLanguage: string = 'plaintext' + let toolUseInput = '' + const toolUse: ToolUse = { toolUseId: undefined, name: undefined, input: undefined } if (response.message === undefined) { throw new ToolkitError( @@ -158,7 +161,7 @@ export class Messenger { }) const eventCounts = new Map() - waitUntil( + await waitUntil( async () => { for await (const chatEvent of response.message!) { for (const key of keys(chatEvent)) { @@ -188,6 +191,53 @@ export class Messenger { ] } + const cwChatEvent: cwChatResponseStream = chatEvent + if ( + cwChatEvent.toolUseEvent?.input !== undefined && + cwChatEvent.toolUseEvent.input.length > 0 && + !cwChatEvent.toolUseEvent.stop + ) { + toolUseInput += cwChatEvent.toolUseEvent.input + } + + if (cwChatEvent.toolUseEvent?.stop) { + toolUse.input = JSON.parse(toolUseInput) + toolUse.toolUseId = cwChatEvent.toolUseEvent.toolUseId ?? '' + toolUse.name = cwChatEvent.toolUseEvent.name ?? '' + session.setToolUse(toolUse) + + const message = this.getToolUseMessage(toolUse) + // const isConfirmationRequired = this.getIsConfirmationRequired(toolUse) + + this.dispatcher.sendChatMessage( + new ChatMessage( + { + message, + messageType: 'answer', + followUps: undefined, + followUpsHeader: undefined, + relatedSuggestions: undefined, + codeReference, + triggerID, + messageID: toolUse.toolUseId, + userIntent: triggerPayload.userIntent, + codeBlockLanguage: codeBlockLanguage, + contextList: undefined, + // TODO: confirmation buttons + }, + tabID + ) + ) + // TODO: setup permission action + // if (!isConfirmationRequired) { + // this.dispatcher.sendCustomFormActionMessage( + // new CustomFormActionMessage(tabID, { + // id: 'confirm-tool-use', + // }) + // ) + // } + } + if ( chatEvent.assistantResponseEvent?.content !== undefined && chatEvent.assistantResponseEvent.content.length > 0 @@ -338,7 +388,7 @@ export class Messenger { messageId: messageID, content: message, references: codeReference, - // TODO: Add tools data and follow up prompt details + toolUses: [{ ...toolUse }], }, }) @@ -533,4 +583,67 @@ export class Messenger { new ShowCustomFormMessage(tabID, formItems, buttons, title, description) ) } + + // TODO: Make this cleaner + // private getIsConfirmationRequired(toolUse: ToolUse) { + // if (toolUse.name === 'execute_bash') { + // const executeBash = new ExecuteBash(toolUse.input as unknown as ExecuteBashParams) + // return executeBash.requiresAcceptance() + // } + // return toolUse.name === 'fs_write' + // } + private getToolUseMessage(toolUse: ToolUse) { + if (toolUse.name === 'fs_read') { + return `Reading the file at \`${(toolUse.input as any)?.path}\` using the \`fs_read\` tool.` + } + // if (toolUse.name === 'execute_bash') { + // const input = toolUse.input as unknown as ExecuteBashParams + // return `Executing the bash command + // \`\`\`bash + // ${input.command} + // \`\`\` + // using the \`execute_bash\` tool.` + // } + // if (toolUse.name === 'fs_write') { + // const input = toolUse.input as unknown as FsWriteParams + // switch (input.command) { + // case 'create': { + // return `Writing + // \`\`\` + // ${input.file_text} + // \`\`\` + // into the file at \`${input.path}\` using the \`fs_write\` tool.` + // } + // case 'str_replace': { + // return `Replacing + // \`\`\` + // ${input.old_str} + // \`\`\` + // with + // \`\`\` + // ${input.new_str} + // \`\`\` + // at \`${input.path}\` using the \`fs_write\` tool.` + // } + // case 'insert': { + // return `Inserting + // \`\`\` + // ${input.new_str} + // \`\`\` + // at line + // \`\`\` + // ${input.insert_line} + // \`\`\` + // at \`${input.path}\` using the \`fs_write\` tool.` + // } + // case 'append': { + // return `Appending + // \`\`\` + // ${input.new_str} + // \`\`\` + // at \`${input.path}\` using the \`fs_write\` tool.` + // } + // } + // } + } } diff --git a/packages/core/src/codewhispererChat/controllers/chat/model.ts b/packages/core/src/codewhispererChat/controllers/chat/model.ts index 62666316166..71837d10c76 100644 --- a/packages/core/src/codewhispererChat/controllers/chat/model.ts +++ b/packages/core/src/codewhispererChat/controllers/chat/model.ts @@ -4,7 +4,14 @@ */ import * as vscode from 'vscode' -import { AdditionalContentEntry, ChatMessage, RelevantTextDocument, UserIntent } from '@amzn/codewhisperer-streaming' +import { + AdditionalContentEntry, + ChatMessage, + Origin, + RelevantTextDocument, + ToolResult, + UserIntent, +} from '@amzn/codewhisperer-streaming' import { MatchPolicy, CodeQuery } from '../../clients/chat/v0/model' import { Selection } from 'vscode' import { TabOpenType } from '../../../amazonq/webview/ui/storages/tabsStorage' @@ -198,6 +205,8 @@ export interface TriggerPayload { truncatedAdditionalContextLengths?: AdditionalContextLengths workspaceRulesCount?: number chatHistory?: ChatMessage[] + toolResults?: ToolResult[] + origin?: Origin } export type AdditionalContextLengths = { diff --git a/packages/core/src/codewhispererChat/storages/chatHistory.ts b/packages/core/src/codewhispererChat/storages/chatHistory.ts index 5387a36ed09..e808fd77b2a 100644 --- a/packages/core/src/codewhispererChat/storages/chatHistory.ts +++ b/packages/core/src/codewhispererChat/storages/chatHistory.ts @@ -2,9 +2,17 @@ * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. * SPDX-License-Identifier: Apache-2.0 */ -import { ChatMessage } from '@amzn/codewhisperer-streaming' +import { + ChatMessage, + Tool, + ToolResult, + ToolResultStatus, + UserInputMessage, + UserInputMessageContext, +} from '@amzn/codewhisperer-streaming' import { randomUUID } from '../../shared/crypto' import { getLogger } from '../../shared/logger/logger' +import { tools } from '../constants' // Maximum number of messages to keep in history const MaxConversationHistoryLength = 100 @@ -18,10 +26,12 @@ export class ChatHistoryManager { private history: ChatMessage[] = [] private logger = getLogger() private lastUserMessage?: ChatMessage + private tools: Tool[] = [] constructor() { this.conversationId = randomUUID() this.logger.info(`Generated new conversation id: ${this.conversationId}`) + this.tools = tools } /** @@ -54,20 +64,12 @@ export class ChatHistoryManager { * Append a new user message to be sent */ public appendUserMessage(newMessage: ChatMessage): void { + this.lastUserMessage = newMessage this.fixHistory() if (!newMessage.userInputMessage?.content || newMessage.userInputMessage?.content.trim() === '') { this.logger.warn('input must not be empty when adding new messages') - // const emptyMessage: ChatMessage = { - // ...newMessage, - // userInputMessage: { - // ...newMessage.userInputMessage, - // content: 'Empty user input', - // }, - // } - // this.history.push(emptyMessage) } - this.lastUserMessage = newMessage - this.history.push(newMessage) + this.history.push(this.lastUserMessage) } /** @@ -91,15 +93,24 @@ export class ChatHistoryManager { public fixHistory(): void { // Trim the conversation history if it exceeds the maximum length if (this.history.length > MaxConversationHistoryLength) { - // Find the second oldest user message to be the new starting point - const secondUserMessageIndex = this.history - .slice(1) // Skip the first message which might be from the user - .findIndex((msg) => !msg.userInputMessage?.content || msg.userInputMessage?.content.trim() === '') - - if (secondUserMessageIndex !== -1) { - // +1 because we sliced off the first element - this.logger.debug(`Removing the first ${secondUserMessageIndex + 1} elements in the history`) - this.history = this.history.slice(secondUserMessageIndex + 1) + // Find the second oldest user message without tool results + let indexToTrim: number | undefined + + for (let i = 1; i < this.history.length; i++) { + const message = this.history[i] + if (message.userInputMessage) { + const userMessage = message.userInputMessage + const ctx = userMessage.userInputMessageContext + const hasNoToolResults = ctx && (!ctx.toolResults || ctx.toolResults.length === 0) + if (hasNoToolResults && userMessage.content !== '') { + indexToTrim = i + break + } + } + } + if (indexToTrim !== undefined) { + this.logger.debug(`Removing the first ${indexToTrim} elements in the history`) + this.history.splice(0, indexToTrim) } else { this.logger.debug('No valid starting user message found in the history, clearing') this.history = [] @@ -107,12 +118,82 @@ export class ChatHistoryManager { } // Ensure the last message is from the assistant - if (this.history.length > 0 && this.history[this.history.length - 1].userInputMessage !== undefined) { this.logger.debug('Last message in history is from the user, dropping') this.history.pop() } // TODO: If the last message from the assistant contains tool uses, ensure the next user message contains tool results + + const lastHistoryMessage = this.history[this.history.length - 1] + + if ( + lastHistoryMessage && + (lastHistoryMessage.assistantResponseMessage || + lastHistoryMessage.assistantResponseMessage !== undefined) && + this.lastUserMessage + ) { + const toolUses = lastHistoryMessage.assistantResponseMessage.toolUses + + if (toolUses && toolUses.length > 0) { + if (this.lastUserMessage.userInputMessage) { + if (this.lastUserMessage.userInputMessage.userInputMessageContext) { + const ctx = this.lastUserMessage.userInputMessage.userInputMessageContext + + if (!ctx.toolResults || ctx.toolResults.length === 0) { + ctx.toolResults = toolUses.map((toolUse) => ({ + toolUseId: toolUse.toolUseId, + content: [ + { + type: 'Text', + text: 'Tool use was cancelled by the user', + }, + ], + status: ToolResultStatus.ERROR, + })) + } + } else { + const toolResults = toolUses.map((toolUse) => ({ + toolUseId: toolUse.toolUseId, + content: [ + { + type: 'Text', + text: 'Tool use was cancelled by the user', + }, + ], + status: ToolResultStatus.ERROR, + })) + + this.lastUserMessage.userInputMessage.userInputMessageContext = { + shellState: undefined, + envState: undefined, + toolResults: toolResults, + tools: this.tools.length === 0 ? undefined : [...this.tools], + } + } + } + } + } + } + + /** + * Adds tool results to the conversation. + */ + addToolResults(toolResults: ToolResult[]): void { + const userInputMessageContext: UserInputMessageContext = { + shellState: undefined, + envState: undefined, + toolResults: toolResults, + tools: this.tools.length === 0 ? undefined : [...this.tools], + } + + const msg: UserInputMessage = { + content: '', + userInputMessageContext: userInputMessageContext, + } + + if (this.lastUserMessage?.userInputMessage) { + this.lastUserMessage.userInputMessage = msg + } } } diff --git a/packages/core/src/codewhispererChat/tools/tool_index.json b/packages/core/src/codewhispererChat/tools/tool_index.json new file mode 100644 index 00000000000..b88d03c34f9 --- /dev/null +++ b/packages/core/src/codewhispererChat/tools/tool_index.json @@ -0,0 +1,23 @@ +{ + "fsRead": { + "name": "fsRead", + "description": "A tool for reading files (e.g. `cat -n`), or listing files/directories (e.g. `ls -la` or `find . -maxdepth 2). The behavior of this tool is determined by the `path` parameter pointing to a file or directory.\n* If `path` is a file, this tool returns the result of running `cat -n`, and the optional `readRange` determines what range of lines will be read from the specified file.\n* If `path` is a directory, this tool returns the listed files and directories of the specified path, as if running `ls -la`. If the `readRange` parameter is provided, the tool acts like the `find . -maxdepth `, where `readRange` is the number of subdirectories deep to search, e.g. [2] will run `find . -maxdepth 2`.", + "inputSchema": { + "type": "object", + "properties": { + "path": { + "description": "Absolute path to file or directory, e.g. `/repo/file.py` or `/repo`.", + "type": "string" + }, + "readRange": { + "description": "Optional parameter when reading either files or directories.\n* When `path` is a file, if none is given, the full file is shown. If provided, the file will be shown in the indicated line number range, e.g. [11, 12] will show lines 11 and 12. Indexing at 1 to start. Setting `[startLine, -1]` shows all lines from `startLine` to the end of the file.\n* When `path` is a directory, if none is given, the results of `ls -l` are given. If provided, the current directory and indicated number of subdirectories will be shown, e.g. [2] will show the current directory and directories two levels deep.", + "items": { + "type": "integer" + }, + "type": "array" + } + }, + "required": ["path"] + } + } +} diff --git a/packages/core/src/codewhispererChat/view/connector/connector.ts b/packages/core/src/codewhispererChat/view/connector/connector.ts index 0b2b29498c4..b37f5610611 100644 --- a/packages/core/src/codewhispererChat/view/connector/connector.ts +++ b/packages/core/src/codewhispererChat/view/connector/connector.ts @@ -318,4 +318,8 @@ export class AppToWebViewMessageDispatcher { public sendShowCustomFormMessage(message: ShowCustomFormMessage) { this.appsToWebViewMessagePublisher.publish(message) } + + public sendCustomFormActionMessage(message: CustomFormActionMessage) { + this.appsToWebViewMessagePublisher.publish(message) + } }