diff --git a/packages/core/src/codewhispererChat/controllers/chat/chatRequest/converter.ts b/packages/core/src/codewhispererChat/controllers/chat/chatRequest/converter.ts index 82bb79242c3..93947bdfe78 100644 --- a/packages/core/src/codewhispererChat/controllers/chat/chatRequest/converter.ts +++ b/packages/core/src/codewhispererChat/controllers/chat/chatRequest/converter.ts @@ -175,7 +175,6 @@ export function triggerPayloadToChatRequest(triggerPayload: TriggerPayload): { c }, chatTriggerType, customizationArn: customizationArn, - history: triggerPayload.chatHistory, }, } } diff --git a/packages/core/src/codewhispererChat/controllers/chat/controller.ts b/packages/core/src/codewhispererChat/controllers/chat/controller.ts index ee72a3f94cd..0c5574a0a20 100644 --- a/packages/core/src/codewhispererChat/controllers/chat/controller.ts +++ b/packages/core/src/codewhispererChat/controllers/chat/controller.ts @@ -82,7 +82,6 @@ import { createSavedPromptCommandId, aditionalContentNameLimit, additionalContentInnerContextLimit, - tools, workspaceChunkMaxSize, defaultContextLengths, } from '../../constants' @@ -1355,24 +1354,15 @@ export class ChatController { triggerPayload.contextLengths.userInputContextLength = triggerPayload.message.length triggerPayload.contextLengths.focusFileContextLength = triggerPayload.fileText.length + const request = triggerPayloadToChatRequest(triggerPayload) + const chatHistory = this.chatHistoryStorage.getTabHistory(tabID) - const newUserMessage = { - userInputMessage: { - content: triggerPayload.message, - userIntent: triggerPayload.userIntent, - ...(triggerPayload.origin && { origin: triggerPayload.origin }), - userInputMessageContext: { - tools: tools, - ...(triggerPayload.toolResults && { toolResults: triggerPayload.toolResults }), - }, - }, + const currentMessage = request.conversationState.currentMessage + if (currentMessage) { + chatHistory.fixHistory(currentMessage) } - const fixedHistoryMessage = chatHistory.fixHistory(newUserMessage) - if (fixedHistoryMessage.userInputMessage?.userInputMessageContext) { - triggerPayload.toolResults = fixedHistoryMessage.userInputMessage.userInputMessageContext.toolResults - } - triggerPayload.chatHistory = chatHistory.getHistory() - const request = triggerPayloadToChatRequest(triggerPayload) + request.conversationState.history = chatHistory.getHistory() + const conversationId = chatHistory.getConversationId() || randomUUID() chatHistory.setConversationId(conversationId) request.conversationState.conversationId = conversationId @@ -1426,8 +1416,8 @@ export class ChatController { } this.telemetryHelper.recordEnterFocusConversation(triggerEvent.tabID) this.telemetryHelper.recordStartConversation(triggerEvent, triggerPayload) - if (request.conversationState.currentMessage) { - chatHistory.appendUserMessage(request.conversationState.currentMessage) + if (currentMessage) { + chatHistory.appendUserMessage(currentMessage) } getLogger().info( diff --git a/packages/core/src/codewhispererChat/controllers/chat/model.ts b/packages/core/src/codewhispererChat/controllers/chat/model.ts index b07cea0cc4a..1ec4d9f1666 100644 --- a/packages/core/src/codewhispererChat/controllers/chat/model.ts +++ b/packages/core/src/codewhispererChat/controllers/chat/model.ts @@ -6,7 +6,6 @@ import * as vscode from 'vscode' import { AdditionalContentEntry, - ChatMessage, Origin, RelevantTextDocument, ToolResult, @@ -203,7 +202,6 @@ export interface TriggerPayload { traceId?: string contextLengths: ContextLengths workspaceRulesCount?: number - chatHistory?: ChatMessage[] toolResults?: ToolResult[] origin?: Origin } diff --git a/packages/core/src/codewhispererChat/storages/chatHistory.ts b/packages/core/src/codewhispererChat/storages/chatHistory.ts index 79ff5d557f0..1029e2eeec5 100644 --- a/packages/core/src/codewhispererChat/storages/chatHistory.ts +++ b/packages/core/src/codewhispererChat/storages/chatHistory.ts @@ -2,10 +2,9 @@ * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. * SPDX-License-Identifier: Apache-2.0 */ -import { ChatMessage, Tool, ToolResult, ToolResultStatus, ToolUse } from '@amzn/codewhisperer-streaming' +import { ChatMessage, ToolResult, ToolResultStatus, ToolUse } from '@amzn/codewhisperer-streaming' import { randomUUID } from '../../shared/crypto' import { getLogger } from '../../shared/logger/logger' -import { tools } from '../constants' // Maximum number of messages to keep in history const MaxConversationHistoryLength = 100 @@ -20,12 +19,10 @@ export class ChatHistoryManager { private history: ChatMessage[] = [] private logger = getLogger() private lastUserMessage?: ChatMessage - private tools: Tool[] = [] constructor(tabId?: string) { this.conversationId = randomUUID() this.tabId = tabId ?? randomUUID() - this.tools = tools } /** @@ -97,10 +94,10 @@ export class ChatHistoryManager { * 4. If the last message is from the assistant and it contains tool uses, and a next user * message is set without tool results, then the user message will have cancelled tool results. */ - public fixHistory(newUserMessage: ChatMessage): ChatMessage { + public fixHistory(newUserMessage: ChatMessage): void { this.trimConversationHistory() this.ensureLastMessageFromAssistant() - return this.handleToolUses(newUserMessage) + this.ensureCurrentMessageIsValid(newUserMessage) } private trimConversationHistory(): void { @@ -145,42 +142,27 @@ export class ChatHistoryManager { } } - private handleToolUses(newUserMessage: ChatMessage): ChatMessage { + private ensureCurrentMessageIsValid(newUserMessage: ChatMessage): void { const lastHistoryMessage = this.history[this.history.length - 1] - if (!lastHistoryMessage || !lastHistoryMessage.assistantResponseMessage || !newUserMessage) { - return newUserMessage - } - - const toolUses = lastHistoryMessage.assistantResponseMessage.toolUses - if (!toolUses || toolUses.length === 0) { - return newUserMessage - } - - return this.addToolResultsToUserMessage(newUserMessage, toolUses) - } - - private addToolResultsToUserMessage(newUserMessage: ChatMessage, toolUses: ToolUse[]): ChatMessage { - if (!newUserMessage.userInputMessage) { - return newUserMessage + if (!lastHistoryMessage) { + return } - const toolResults = this.createToolResults(toolUses) + if (lastHistoryMessage.assistantResponseMessage?.toolUses?.length) { + const toolResults = newUserMessage.userInputMessage?.userInputMessageContext?.toolResults + if (!toolResults || toolResults.length === 0) { + const abandonedToolResults = this.createAbandonedToolResults( + lastHistoryMessage.assistantResponseMessage.toolUses + ) - if (newUserMessage.userInputMessage.userInputMessageContext) { - newUserMessage.userInputMessage.userInputMessageContext.toolResults = toolResults - } else { - newUserMessage.userInputMessage.userInputMessageContext = { - shellState: undefined, - envState: undefined, - toolResults: toolResults, - tools: this.tools.length === 0 ? undefined : [...this.tools], + if (newUserMessage.userInputMessage?.userInputMessageContext) { + newUserMessage.userInputMessage.userInputMessageContext.toolResults = abandonedToolResults + } } } - - return newUserMessage } - private createToolResults(toolUses: ToolUse[]): ToolResult[] { + private createAbandonedToolResults(toolUses: ToolUse[]): ToolResult[] { return toolUses.map((toolUse) => ({ toolUseId: toolUse.toolUseId, content: [