diff --git a/packages/core/src/codewhispererChat/clients/chat/v0/chat.ts b/packages/core/src/codewhispererChat/clients/chat/v0/chat.ts index 3cf030b9b8e..540135268e0 100644 --- a/packages/core/src/codewhispererChat/clients/chat/v0/chat.ts +++ b/packages/core/src/codewhispererChat/clients/chat/v0/chat.ts @@ -60,10 +60,6 @@ export class ChatSession { async chatSso(chatRequest: GenerateAssistantResponseRequest): Promise { const client = await createCodeWhispererChatStreamingClient() - if (this.sessionId !== undefined && chatRequest.conversationState !== undefined) { - chatRequest.conversationState.conversationId = this.sessionId - } - const response = await client.generateAssistantResponse(chatRequest) if (!response.generateAssistantResponseResponse) { throw new ToolkitError( diff --git a/packages/core/src/codewhispererChat/controllers/chat/chatRequest/converter.ts b/packages/core/src/codewhispererChat/controllers/chat/chatRequest/converter.ts index 0a34463058e..dd6edfbeff0 100644 --- a/packages/core/src/codewhispererChat/controllers/chat/chatRequest/converter.ts +++ b/packages/core/src/codewhispererChat/controllers/chat/chatRequest/converter.ts @@ -123,6 +123,7 @@ export function triggerPayloadToChatRequest(triggerPayload: TriggerPayload): { c }, chatTriggerType, customizationArn: customizationArn, + history: triggerPayload.chatHistory, }, } } diff --git a/packages/core/src/codewhispererChat/controllers/chat/controller.ts b/packages/core/src/codewhispererChat/controllers/chat/controller.ts index 846f3c6e445..667c3d418f6 100644 --- a/packages/core/src/codewhispererChat/controllers/chat/controller.ts +++ b/packages/core/src/codewhispererChat/controllers/chat/controller.ts @@ -80,6 +80,7 @@ import { contextMaxLength, } from '../../constants' import { ChatSession } from '../../clients/chat/v0/chat' +import { ChatHistoryManager } from '../../storages/chatHistory' export interface ChatControllerMessagePublishers { readonly processPromptChatMessage: MessagePublisher @@ -141,6 +142,7 @@ export class ChatController { private readonly userIntentRecognizer: UserIntentRecognizer private readonly telemetryHelper: CWCTelemetryHelper private userPromptsWatcher: vscode.FileSystemWatcher | undefined + private readonly chatHistoryManager: ChatHistoryManager public constructor( private readonly chatControllerMessageListeners: ChatControllerMessageListeners, @@ -158,6 +160,7 @@ export class ChatController { this.editorContentController = new EditorContentController() this.promptGenerator = new PromptsGenerator() this.userIntentRecognizer = new UserIntentRecognizer() + this.chatHistoryManager = new ChatHistoryManager() onDidChangeAmazonQVisibility((visible) => { if (visible) { @@ -395,6 +398,7 @@ export class ChatController { private async processTabCloseMessage(message: TabClosedMessage) { this.sessionStorage.deleteSession(message.tabID) + this.chatHistoryManager.clear() this.triggerEventsStorage.removeTabEvents(message.tabID) // this.telemetryHelper.recordCloseChat(message.tabID) } @@ -654,6 +658,7 @@ export class ChatController { getLogger().error(`error: ${errorMessage} tabID: ${tabID} requestID: ${requestID}`) this.sessionStorage.deleteSession(tabID) + this.chatHistoryManager.clear() } private async processContextMenuCommand(command: EditorContextCommand) { @@ -714,6 +719,13 @@ export class ChatController { command, }) + this.chatHistoryManager.appendUserMessage({ + userInputMessage: { + content: prompt, + userIntent: this.userIntentRecognizer.getFromContextMenuCommand(command), + }, + }) + return this.generateResponse( { message: prompt, @@ -727,6 +739,7 @@ export class ChatController { codeQuery: context?.focusAreaContext?.names, userIntent: this.userIntentRecognizer.getFromContextMenuCommand(command), customization: getSelectedCustomization(), + chatHistory: this.chatHistoryManager.getHistory(), }, triggerID ) @@ -766,6 +779,7 @@ export class ChatController { switch (message.command) { case 'clear': this.sessionStorage.deleteSession(message.tabID) + this.chatHistoryManager.clear() this.triggerEventsStorage.removeTabEvents(message.tabID) recordTelemetryChatRunCommand('clear') return @@ -791,6 +805,13 @@ export class ChatController { context: lastTriggerEvent.context, }) + this.chatHistoryManager.appendUserMessage({ + userInputMessage: { + content: message.message, + userIntent: message.userIntent, + }, + }) + return this.generateResponse( { message: message.message, @@ -804,6 +825,7 @@ export class ChatController { codeQuery: lastTriggerEvent.context?.focusAreaContext?.names, userIntent: message.userIntent, customization: getSelectedCustomization(), + chatHistory: this.chatHistoryManager.getHistory(), }, triggerID ) @@ -824,6 +846,12 @@ export class ChatController { type: 'chat_message', context, }) + this.chatHistoryManager.appendUserMessage({ + userInputMessage: { + content: message.message, + userIntent: message.userIntent, + }, + }) return this.generateResponse( { message: message.message, @@ -838,6 +866,7 @@ export class ChatController { userIntent: this.userIntentRecognizer.getFromPromptChatMessage(message), customization: getSelectedCustomization(), context: message.context, + chatHistory: this.chatHistoryManager.getHistory(), }, triggerID ) @@ -1104,6 +1133,15 @@ export class ChatController { triggerPayload.documentReferences = this.mergeRelevantTextDocuments(triggerPayload.relevantTextDocuments || []) const request = triggerPayloadToChatRequest(triggerPayload) + if ( + this.chatHistoryManager.getConversationId() !== undefined && + this.chatHistoryManager.getConversationId() !== '' + ) { + request.conversationState.conversationId = this.chatHistoryManager.getConversationId() + } else { + this.chatHistoryManager.setConversationId(randomUUID()) + request.conversationState.conversationId = this.chatHistoryManager.getConversationId() + } if (triggerPayload.documentReferences !== undefined) { const relativePathsOfMergedRelevantDocuments = triggerPayload.documentReferences.map( @@ -1157,7 +1195,14 @@ export class ChatController { response.$metadata.requestId } metadata: ${inspect(response.$metadata, { depth: 12 })}` ) - await this.messenger.sendAIResponse(response, session, tabID, triggerID, triggerPayload) + await this.messenger.sendAIResponse( + response, + session, + tabID, + triggerID, + triggerPayload, + this.chatHistoryManager + ) } catch (e: any) { this.telemetryHelper.recordMessageResponseError(triggerPayload, tabID, getHttpStatusCode(e) ?? 0) // clears session, record telemetry before this call diff --git a/packages/core/src/codewhispererChat/controllers/chat/messenger/messenger.ts b/packages/core/src/codewhispererChat/controllers/chat/messenger/messenger.ts index dd80676cf8b..3c5f181a3fb 100644 --- a/packages/core/src/codewhispererChat/controllers/chat/messenger/messenger.ts +++ b/packages/core/src/codewhispererChat/controllers/chat/messenger/messenger.ts @@ -38,6 +38,7 @@ import { extractCodeBlockLanguage } from '../../../../shared/markdown' import { extractAuthFollowUp } from '../../../../amazonq/util/authUtils' import { helpMessage } from '../../../../amazonq/webview/ui/texts/constants' import { ChatItemButton, ChatItemFormItem, MynahUIDataModel } from '@aws/mynah-ui' +import { ChatHistoryManager } from '../../../storages/chatHistory' export type StaticTextResponseType = 'quick-action-help' | 'onboarding-help' | 'transform' | 'help' @@ -121,7 +122,8 @@ export class Messenger { session: ChatSession, tabID: string, triggerID: string, - triggerPayload: TriggerPayload + triggerPayload: TriggerPayload, + chatHistoryManager: ChatHistoryManager ) { let message = '' const messageID = response.$metadata.requestId ?? '' @@ -331,6 +333,15 @@ export class Messenger { ) ) + chatHistoryManager.pushAssistantMessage({ + assistantResponseMessage: { + messageId: messageID, + content: message, + references: codeReference, + // TODO: Add tools data and follow up prompt details + }, + }) + getLogger().info( `All events received. requestId=%s counts=%s`, response.$metadata.requestId, diff --git a/packages/core/src/codewhispererChat/controllers/chat/model.ts b/packages/core/src/codewhispererChat/controllers/chat/model.ts index ae0c6b61063..62666316166 100644 --- a/packages/core/src/codewhispererChat/controllers/chat/model.ts +++ b/packages/core/src/codewhispererChat/controllers/chat/model.ts @@ -4,7 +4,7 @@ */ import * as vscode from 'vscode' -import { AdditionalContentEntry, RelevantTextDocument, UserIntent } from '@amzn/codewhisperer-streaming' +import { AdditionalContentEntry, ChatMessage, RelevantTextDocument, UserIntent } from '@amzn/codewhisperer-streaming' import { MatchPolicy, CodeQuery } from '../../clients/chat/v0/model' import { Selection } from 'vscode' import { TabOpenType } from '../../../amazonq/webview/ui/storages/tabsStorage' @@ -197,6 +197,7 @@ export interface TriggerPayload { additionalContextLengths?: AdditionalContextLengths truncatedAdditionalContextLengths?: AdditionalContextLengths workspaceRulesCount?: number + chatHistory?: ChatMessage[] } export type AdditionalContextLengths = { diff --git a/packages/core/src/codewhispererChat/storages/chatHistory.ts b/packages/core/src/codewhispererChat/storages/chatHistory.ts new file mode 100644 index 00000000000..5387a36ed09 --- /dev/null +++ b/packages/core/src/codewhispererChat/storages/chatHistory.ts @@ -0,0 +1,118 @@ +/*! + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ +import { ChatMessage } from '@amzn/codewhisperer-streaming' +import { randomUUID } from '../../shared/crypto' +import { getLogger } from '../../shared/logger/logger' + +// Maximum number of messages to keep in history +const MaxConversationHistoryLength = 100 + +/** + * ChatHistoryManager handles the storage and manipulation of chat history + * for CodeWhisperer Chat sessions. + */ +export class ChatHistoryManager { + private conversationId: string + private history: ChatMessage[] = [] + private logger = getLogger() + private lastUserMessage?: ChatMessage + + constructor() { + this.conversationId = randomUUID() + this.logger.info(`Generated new conversation id: ${this.conversationId}`) + } + + /** + * Get the conversation ID + */ + public getConversationId(): string { + return this.conversationId + } + + public setConversationId(conversationId: string) { + this.conversationId = conversationId + } + + /** + * Get the full chat history + */ + public getHistory(): ChatMessage[] { + return [...this.history] + } + + /** + * Clear the conversation history + */ + public clear(): void { + this.history = [] + this.conversationId = '' + } + + /** + * Append a new user message to be sent + */ + public appendUserMessage(newMessage: ChatMessage): void { + this.fixHistory() + if (!newMessage.userInputMessage?.content || newMessage.userInputMessage?.content.trim() === '') { + this.logger.warn('input must not be empty when adding new messages') + // const emptyMessage: ChatMessage = { + // ...newMessage, + // userInputMessage: { + // ...newMessage.userInputMessage, + // content: 'Empty user input', + // }, + // } + // this.history.push(emptyMessage) + } + this.lastUserMessage = newMessage + this.history.push(newMessage) + } + + /** + * Push an assistant message to the history + */ + public pushAssistantMessage(newMessage: ChatMessage): void { + if (newMessage !== undefined && this.lastUserMessage !== undefined) { + this.logger.warn('last Message should not be defined when pushing an assistant message') + } + this.history.push(newMessage) + } + + /** + * Fixes the history to maintain the following invariants: + * 1. The history length is <= MAX_CONVERSATION_HISTORY_LENGTH. Oldest messages are dropped. + * 2. The first message is from the user. Oldest messages are dropped if needed. + * 3. The last message is from the assistant. The last message is dropped if it is from the user. + * 4. If the last message is from the assistant and it contains tool uses, and a next user + * message is set without tool results, then the user message will have cancelled tool results. + */ + public fixHistory(): void { + // Trim the conversation history if it exceeds the maximum length + if (this.history.length > MaxConversationHistoryLength) { + // Find the second oldest user message to be the new starting point + const secondUserMessageIndex = this.history + .slice(1) // Skip the first message which might be from the user + .findIndex((msg) => !msg.userInputMessage?.content || msg.userInputMessage?.content.trim() === '') + + if (secondUserMessageIndex !== -1) { + // +1 because we sliced off the first element + this.logger.debug(`Removing the first ${secondUserMessageIndex + 1} elements in the history`) + this.history = this.history.slice(secondUserMessageIndex + 1) + } else { + this.logger.debug('No valid starting user message found in the history, clearing') + this.history = [] + } + } + + // Ensure the last message is from the assistant + + if (this.history.length > 0 && this.history[this.history.length - 1].userInputMessage !== undefined) { + this.logger.debug('Last message in history is from the user, dropping') + this.history.pop() + } + + // TODO: If the last message from the assistant contains tool uses, ensure the next user message contains tool results + } +}