Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
70 changes: 31 additions & 39 deletions packages/core/src/codewhispererChat/controllers/chat/controller.ts
Original file line number Diff line number Diff line change
Expand Up @@ -87,11 +87,11 @@ import {
defaultContextLengths,
} from '../../constants'
import { ChatSession } from '../../clients/chat/v0/chat'
import { ChatHistoryManager } from '../../storages/chatHistory'
import { amazonQTabSuffix } from '../../../shared/constants'
import { OutputKind } from '../../tools/toolShared'
import { ToolUtils, Tool, ToolType } from '../../tools/toolUtils'
import { ChatStream } from '../../tools/chatStream'
import { ChatHistoryStorage } from '../../storages/chatHistoryStorage'
import { FsWrite, FsWriteParams } from '../../tools/fsWrite'
import { tempDirPath } from '../../../shared/filesystemUtilities'

Expand Down Expand Up @@ -155,7 +155,7 @@ export class ChatController {
private readonly userIntentRecognizer: UserIntentRecognizer
private readonly telemetryHelper: CWCTelemetryHelper
private userPromptsWatcher: vscode.FileSystemWatcher | undefined
private readonly chatHistoryManager: ChatHistoryManager
private readonly chatHistoryStorage: ChatHistoryStorage

public constructor(
private readonly chatControllerMessageListeners: ChatControllerMessageListeners,
Expand All @@ -173,7 +173,7 @@ export class ChatController {
this.editorContentController = new EditorContentController()
this.promptGenerator = new PromptsGenerator()
this.userIntentRecognizer = new UserIntentRecognizer()
this.chatHistoryManager = new ChatHistoryManager()
this.chatHistoryStorage = new ChatHistoryStorage()

onDidChangeAmazonQVisibility((visible) => {
if (visible) {
Expand Down Expand Up @@ -424,7 +424,7 @@ export class ChatController {

private async processTabCloseMessage(message: TabClosedMessage) {
this.sessionStorage.deleteSession(message.tabID)
this.chatHistoryManager.clear()
this.chatHistoryStorage.deleteHistory(message.tabID)
this.triggerEventsStorage.removeTabEvents(message.tabID)
// this.telemetryHelper.recordCloseChat(message.tabID)
}
Expand Down Expand Up @@ -710,7 +710,7 @@ export class ChatController {
customization: getSelectedCustomization(),
toolResults: toolResults,
origin: Origin.IDE,
chatHistory: this.chatHistoryManager.getHistory(),
chatHistory: this.chatHistoryStorage.getHistory(tabID).getHistory(),
context: session.context ?? [],
relevantTextDocuments: [],
additionalContents: [],
Expand Down Expand Up @@ -890,7 +890,6 @@ export class ChatController {
getLogger().error(`error: ${errorMessage} tabID: ${tabID} requestID: ${requestID}`)

this.sessionStorage.deleteSession(tabID)
this.chatHistoryManager.clear()
}

private async processContextMenuCommand(command: EditorContextCommand) {
Expand Down Expand Up @@ -964,7 +963,6 @@ export class ChatController {
codeQuery: context?.focusAreaContext?.names,
userIntent: this.userIntentRecognizer.getFromContextMenuCommand(command),
customization: getSelectedCustomization(),
chatHistory: this.chatHistoryManager.getHistory(),
additionalContents: [],
relevantTextDocuments: [],
documentReferences: [],
Expand Down Expand Up @@ -1012,7 +1010,7 @@ export class ChatController {
switch (message.command) {
case 'clear':
this.sessionStorage.deleteSession(message.tabID)
this.chatHistoryManager.clear()
this.chatHistoryStorage.getHistory(message.tabID).clear()
this.triggerEventsStorage.removeTabEvents(message.tabID)
recordTelemetryChatRunCommand('clear')
return
Expand Down Expand Up @@ -1051,7 +1049,7 @@ export class ChatController {
codeQuery: lastTriggerEvent.context?.focusAreaContext?.names,
userIntent: message.userIntent,
customization: getSelectedCustomization(),
chatHistory: this.chatHistoryManager.getHistory(),
chatHistory: this.chatHistoryStorage.getHistory(message.tabID).getHistory(),
contextLengths: {
...defaultContextLengths,
},
Expand Down Expand Up @@ -1100,7 +1098,7 @@ export class ChatController {
codeQuery: context?.focusAreaContext?.names,
userIntent: this.userIntentRecognizer.getFromPromptChatMessage(message),
customization: getSelectedCustomization(),
chatHistory: this.chatHistoryManager.getHistory(),
chatHistory: this.chatHistoryStorage.getHistory(message.tabID).getHistory(),
origin: Origin.IDE,
context: message.context ?? [],
relevantTextDocuments: [],
Expand Down Expand Up @@ -1327,16 +1325,28 @@ export class ChatController {

triggerPayload.contextLengths.userInputContextLength = triggerPayload.message.length
triggerPayload.contextLengths.focusFileContextLength = triggerPayload.fileText.length
const request = triggerPayloadToChatRequest(triggerPayload)
if (
this.chatHistoryManager.getConversationId() !== undefined &&
this.chatHistoryManager.getConversationId() !== ''
) {
request.conversationState.conversationId = this.chatHistoryManager.getConversationId()
} else {
this.chatHistoryManager.setConversationId(randomUUID())
request.conversationState.conversationId = this.chatHistoryManager.getConversationId()

const chatHistory = this.chatHistoryStorage.getHistory(tabID)
const newUserMessage = {
userInputMessage: {
content: triggerPayload.message,
userIntent: triggerPayload.userIntent,
...(triggerPayload.origin && { origin: triggerPayload.origin }),
userInputMessageContext: {
tools: tools,
...(triggerPayload.toolResults && { toolResults: triggerPayload.toolResults }),
},
},
}
const fixedHistoryMessage = chatHistory.fixHistory(newUserMessage)
if (fixedHistoryMessage.userInputMessage?.userInputMessageContext) {
triggerPayload.toolResults = fixedHistoryMessage.userInputMessage.userInputMessageContext.toolResults
}
Comment on lines +1342 to 1344
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why can't we do this inside fixHistory?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can but i'm not sending triggerPayload to fixHistory. fixHistory could be used for other operations, so i set it up to take in only ChatMessage

const request = triggerPayloadToChatRequest(triggerPayload)
const conversationId = chatHistory.getConversationId() || randomUUID()
chatHistory.setConversationId(conversationId)
request.conversationState.conversationId = conversationId

triggerPayload.documentReferences = this.mergeRelevantTextDocuments(triggerPayload.relevantTextDocuments)

// Update context transparency after it's truncated dynamically to show users only the context sent.
Expand Down Expand Up @@ -1386,32 +1396,14 @@ export class ChatController {
}
this.telemetryHelper.recordEnterFocusConversation(triggerEvent.tabID)
this.telemetryHelper.recordStartConversation(triggerEvent, triggerPayload)

this.chatHistoryManager.appendUserMessage({
userInputMessage: {
content: triggerPayload.message,
userIntent: triggerPayload.userIntent,
...(triggerPayload.origin && { origin: triggerPayload.origin }),
userInputMessageContext: {
tools: tools,
...(triggerPayload.toolResults && { toolResults: triggerPayload.toolResults }),
},
},
})
chatHistory.appendUserMessage(fixedHistoryMessage)

getLogger().info(
`response to tab: ${tabID} conversationID: ${session.sessionIdentifier} requestID: ${
response.$metadata.requestId
} metadata: ${inspect(response.$metadata, { depth: 12 })}`
)
await this.messenger.sendAIResponse(
response,
session,
tabID,
triggerID,
triggerPayload,
this.chatHistoryManager
)
await this.messenger.sendAIResponse(response, session, tabID, triggerID, triggerPayload, chatHistory)
} catch (e: any) {
this.telemetryHelper.recordMessageResponseError(triggerPayload, tabID, getHttpStatusCode(e) ?? 0)
// clears session, record telemetry before this call
Expand Down
92 changes: 82 additions & 10 deletions packages/core/src/codewhispererChat/storages/chatHistory.ts
Original file line number Diff line number Diff line change
Expand Up @@ -23,14 +23,15 @@ const MaxConversationHistoryLength = 100
*/
export class ChatHistoryManager {
private conversationId: string
private tabId: string
private history: ChatMessage[] = []
private logger = getLogger()
private lastUserMessage?: ChatMessage
private tools: Tool[] = []

constructor() {
constructor(tabId?: string) {
this.conversationId = randomUUID()
this.logger.info(`Generated new conversation id: ${this.conversationId}`)
this.tabId = tabId ?? randomUUID()
this.tools = tools
}

Expand All @@ -45,6 +46,20 @@ export class ChatHistoryManager {
this.conversationId = conversationId
}

/**
* Get the tab ID
*/
public getTabId(): string {
return this.tabId
}

/**
* Set the tab ID
*/
public setTabId(tabId: string) {
this.tabId = tabId
}

/**
* Get the full chat history
*/
Expand All @@ -65,7 +80,6 @@ export class ChatHistoryManager {
*/
public appendUserMessage(newMessage: ChatMessage): void {
this.lastUserMessage = newMessage
this.fixHistory()
if (!newMessage.userInputMessage?.content || newMessage.userInputMessage?.content.trim() === '') {
this.logger.warn('input must not be empty when adding new messages')
}
Expand All @@ -90,7 +104,7 @@ export class ChatHistoryManager {
* 4. If the last message is from the assistant and it contains tool uses, and a next user
* message is set without tool results, then the user message will have cancelled tool results.
*/
public fixHistory(): void {
public fixHistory(newUserMessage: ChatMessage): ChatMessage {
// Trim the conversation history if it exceeds the maximum length
if (this.history.length > MaxConversationHistoryLength) {
// Find the second oldest user message without tool results
Expand Down Expand Up @@ -123,22 +137,22 @@ export class ChatHistoryManager {
this.history.pop()
}

// TODO: If the last message from the assistant contains tool uses, ensure the next user message contains tool results
// If the last message from the assistant contains tool uses, ensure the next user message contains tool results

const lastHistoryMessage = this.history[this.history.length - 1]

if (
lastHistoryMessage &&
(lastHistoryMessage.assistantResponseMessage ||
lastHistoryMessage.assistantResponseMessage !== undefined) &&
this.lastUserMessage
newUserMessage
) {
const toolUses = lastHistoryMessage.assistantResponseMessage.toolUses

if (toolUses && toolUses.length > 0) {
if (this.lastUserMessage.userInputMessage) {
if (this.lastUserMessage.userInputMessage.userInputMessageContext) {
const ctx = this.lastUserMessage.userInputMessage.userInputMessageContext
if (newUserMessage.userInputMessage) {
if (newUserMessage.userInputMessage.userInputMessageContext) {
const ctx = newUserMessage.userInputMessage.userInputMessageContext

if (!ctx.toolResults || ctx.toolResults.length === 0) {
ctx.toolResults = toolUses.map((toolUse) => ({
Expand All @@ -164,16 +178,21 @@ export class ChatHistoryManager {
status: ToolResultStatus.ERROR,
}))

this.lastUserMessage.userInputMessage.userInputMessageContext = {
newUserMessage.userInputMessage.userInputMessageContext = {
shellState: undefined,
envState: undefined,
toolResults: toolResults,
tools: this.tools.length === 0 ? undefined : [...this.tools],
}

return newUserMessage
}
}
}
}

// Always return the message to fix the TypeScript error
return newUserMessage
}

/**
Expand All @@ -197,6 +216,59 @@ export class ChatHistoryManager {
}
}

/**
* Checks if the latest message in history is an Assistant Message.
* If it is and doesn't have toolUse, it will be removed.
* If it has toolUse, an assistantResponse message with cancelled tool status will be added.
*/
public checkLatestAssistantMessage(): void {
if (this.history.length === 0) {
return
}

const lastMessage = this.history[this.history.length - 1]

if (lastMessage.assistantResponseMessage) {
const toolUses = lastMessage.assistantResponseMessage.toolUses

if (!toolUses || toolUses.length === 0) {
// If there are no tool uses, remove the assistant message
this.logger.debug('Removing assistant message without tool uses')
this.history.pop()
} else {
// If there are tool uses, add cancelled tool results
const toolResults = toolUses.map((toolUse) => ({
toolUseId: toolUse.toolUseId,
content: [
{
type: 'Text',
text: 'Tool use was cancelled by the user',
},
],
status: ToolResultStatus.ERROR,
}))

// Create a new user message with cancelled tool results
const userInputMessageContext: UserInputMessageContext = {
shellState: undefined,
envState: undefined,
toolResults: toolResults,
tools: this.tools.length === 0 ? undefined : [...this.tools],
}

const userMessage: ChatMessage = {
userInputMessage: {
content: '',
userInputMessageContext: userInputMessageContext,
},
}

this.history.push(this.formatChatHistoryMessage(userMessage))
this.logger.debug('Added user message with cancelled tool results')
}
}
}

private formatChatHistoryMessage(message: ChatMessage): ChatMessage {
if (message.userInputMessage !== undefined) {
return {
Expand Down
43 changes: 43 additions & 0 deletions packages/core/src/codewhispererChat/storages/chatHistoryStorage.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
/*!
* Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
* SPDX-License-Identifier: Apache-2.0
*/

import { ChatHistoryManager } from './chatHistory'

/**
* ChatHistoryStorage manages ChatHistoryManager instances for multiple tabs.
* Each tab has its own ChatHistoryManager to maintain separate chat histories.
*/
export class ChatHistoryStorage {
private histories: Map<string, ChatHistoryManager> = new Map()

/**
* Gets the ChatHistoryManager for a specific tab.
* If no history exists for the tab, creates a new one.
*
* @param tabId The ID of the tab
* @returns The ChatHistoryManager for the specified tab
*/
public getHistory(tabId: string): ChatHistoryManager {
const historyFromStorage = this.histories.get(tabId)
if (historyFromStorage !== undefined) {
return historyFromStorage
}

// Create a new ChatHistoryManager with the tabId
const newHistory = new ChatHistoryManager(tabId)
this.histories.set(tabId, newHistory)

return newHistory
}

/**
* Deletes the ChatHistoryManager for a specific tab.
*
* @param tabId The ID of the tab
*/
public deleteHistory(tabId: string) {
this.histories.delete(tabId)
}
}