Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 0 additions & 4 deletions packages/core/src/codewhispererChat/clients/chat/v0/chat.ts
Original file line number Diff line number Diff line change
Expand Up @@ -60,10 +60,6 @@ export class ChatSession {
async chatSso(chatRequest: GenerateAssistantResponseRequest): Promise<GenerateAssistantResponseCommandOutput> {
const client = await createCodeWhispererChatStreamingClient()

if (this.sessionId !== undefined && chatRequest.conversationState !== undefined) {
chatRequest.conversationState.conversationId = this.sessionId
}

const response = await client.generateAssistantResponse(chatRequest)
if (!response.generateAssistantResponseResponse) {
throw new ToolkitError(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,7 @@ export function triggerPayloadToChatRequest(triggerPayload: TriggerPayload): { c
},
chatTriggerType,
customizationArn: customizationArn,
history: triggerPayload.chatHistory,
},
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@ import {
contextMaxLength,
} from '../../constants'
import { ChatSession } from '../../clients/chat/v0/chat'
import { ChatHistoryManager } from '../../storages/chatHistory'

export interface ChatControllerMessagePublishers {
readonly processPromptChatMessage: MessagePublisher<PromptMessage>
Expand Down Expand Up @@ -141,6 +142,7 @@ export class ChatController {
private readonly userIntentRecognizer: UserIntentRecognizer
private readonly telemetryHelper: CWCTelemetryHelper
private userPromptsWatcher: vscode.FileSystemWatcher | undefined
private readonly chatHistoryManager: ChatHistoryManager

public constructor(
private readonly chatControllerMessageListeners: ChatControllerMessageListeners,
Expand All @@ -158,6 +160,7 @@ export class ChatController {
this.editorContentController = new EditorContentController()
this.promptGenerator = new PromptsGenerator()
this.userIntentRecognizer = new UserIntentRecognizer()
this.chatHistoryManager = new ChatHistoryManager()

onDidChangeAmazonQVisibility((visible) => {
if (visible) {
Expand Down Expand Up @@ -395,6 +398,7 @@ export class ChatController {

private async processTabCloseMessage(message: TabClosedMessage) {
this.sessionStorage.deleteSession(message.tabID)
this.chatHistoryManager.clear()
this.triggerEventsStorage.removeTabEvents(message.tabID)
// this.telemetryHelper.recordCloseChat(message.tabID)
}
Expand Down Expand Up @@ -654,6 +658,7 @@ export class ChatController {
getLogger().error(`error: ${errorMessage} tabID: ${tabID} requestID: ${requestID}`)

this.sessionStorage.deleteSession(tabID)
this.chatHistoryManager.clear()
}

private async processContextMenuCommand(command: EditorContextCommand) {
Expand Down Expand Up @@ -714,6 +719,13 @@ export class ChatController {
command,
})

this.chatHistoryManager.appendUserMessage({
userInputMessage: {
content: prompt,
userIntent: this.userIntentRecognizer.getFromContextMenuCommand(command),
},
})

return this.generateResponse(
{
message: prompt,
Expand All @@ -727,6 +739,7 @@ export class ChatController {
codeQuery: context?.focusAreaContext?.names,
userIntent: this.userIntentRecognizer.getFromContextMenuCommand(command),
customization: getSelectedCustomization(),
chatHistory: this.chatHistoryManager.getHistory(),
},
triggerID
)
Expand Down Expand Up @@ -766,6 +779,7 @@ export class ChatController {
switch (message.command) {
case 'clear':
this.sessionStorage.deleteSession(message.tabID)
this.chatHistoryManager.clear()
this.triggerEventsStorage.removeTabEvents(message.tabID)
recordTelemetryChatRunCommand('clear')
return
Expand All @@ -791,6 +805,13 @@ export class ChatController {
context: lastTriggerEvent.context,
})

this.chatHistoryManager.appendUserMessage({
userInputMessage: {
content: message.message,
userIntent: message.userIntent,
},
})

return this.generateResponse(
{
message: message.message,
Expand All @@ -804,6 +825,7 @@ export class ChatController {
codeQuery: lastTriggerEvent.context?.focusAreaContext?.names,
userIntent: message.userIntent,
customization: getSelectedCustomization(),
chatHistory: this.chatHistoryManager.getHistory(),
},
triggerID
)
Expand All @@ -824,6 +846,12 @@ export class ChatController {
type: 'chat_message',
context,
})
this.chatHistoryManager.appendUserMessage({
userInputMessage: {
content: message.message,
userIntent: message.userIntent,
},
})
return this.generateResponse(
{
message: message.message,
Expand All @@ -838,6 +866,7 @@ export class ChatController {
userIntent: this.userIntentRecognizer.getFromPromptChatMessage(message),
customization: getSelectedCustomization(),
context: message.context,
chatHistory: this.chatHistoryManager.getHistory(),
},
triggerID
)
Expand Down Expand Up @@ -1104,6 +1133,15 @@ export class ChatController {
triggerPayload.documentReferences = this.mergeRelevantTextDocuments(triggerPayload.relevantTextDocuments || [])

const request = triggerPayloadToChatRequest(triggerPayload)
if (
this.chatHistoryManager.getConversationId() !== undefined &&
this.chatHistoryManager.getConversationId() !== ''
) {
request.conversationState.conversationId = this.chatHistoryManager.getConversationId()
} else {
this.chatHistoryManager.setConversationId(randomUUID())
request.conversationState.conversationId = this.chatHistoryManager.getConversationId()
}

if (triggerPayload.documentReferences !== undefined) {
const relativePathsOfMergedRelevantDocuments = triggerPayload.documentReferences.map(
Expand Down Expand Up @@ -1157,7 +1195,14 @@ export class ChatController {
response.$metadata.requestId
} metadata: ${inspect(response.$metadata, { depth: 12 })}`
)
await this.messenger.sendAIResponse(response, session, tabID, triggerID, triggerPayload)
await this.messenger.sendAIResponse(
response,
session,
tabID,
triggerID,
triggerPayload,
this.chatHistoryManager
)
} catch (e: any) {
this.telemetryHelper.recordMessageResponseError(triggerPayload, tabID, getHttpStatusCode(e) ?? 0)
// clears session, record telemetry before this call
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ import { extractCodeBlockLanguage } from '../../../../shared/markdown'
import { extractAuthFollowUp } from '../../../../amazonq/util/authUtils'
import { helpMessage } from '../../../../amazonq/webview/ui/texts/constants'
import { ChatItemButton, ChatItemFormItem, MynahUIDataModel } from '@aws/mynah-ui'
import { ChatHistoryManager } from '../../../storages/chatHistory'

export type StaticTextResponseType = 'quick-action-help' | 'onboarding-help' | 'transform' | 'help'

Expand Down Expand Up @@ -121,7 +122,8 @@ export class Messenger {
session: ChatSession,
tabID: string,
triggerID: string,
triggerPayload: TriggerPayload
triggerPayload: TriggerPayload,
chatHistoryManager: ChatHistoryManager
) {
let message = ''
const messageID = response.$metadata.requestId ?? ''
Expand Down Expand Up @@ -331,6 +333,15 @@ export class Messenger {
)
)

chatHistoryManager.pushAssistantMessage({
assistantResponseMessage: {
messageId: messageID,
content: message,
references: codeReference,
// TODO: Add tools data and follow up prompt details
},
})

getLogger().info(
`All events received. requestId=%s counts=%s`,
response.$metadata.requestId,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
*/

import * as vscode from 'vscode'
import { AdditionalContentEntry, RelevantTextDocument, UserIntent } from '@amzn/codewhisperer-streaming'
import { AdditionalContentEntry, ChatMessage, RelevantTextDocument, UserIntent } from '@amzn/codewhisperer-streaming'
import { MatchPolicy, CodeQuery } from '../../clients/chat/v0/model'
import { Selection } from 'vscode'
import { TabOpenType } from '../../../amazonq/webview/ui/storages/tabsStorage'
Expand Down Expand Up @@ -197,6 +197,7 @@ export interface TriggerPayload {
additionalContextLengths?: AdditionalContextLengths
truncatedAdditionalContextLengths?: AdditionalContextLengths
workspaceRulesCount?: number
chatHistory?: ChatMessage[]
}

export type AdditionalContextLengths = {
Expand Down
118 changes: 118 additions & 0 deletions packages/core/src/codewhispererChat/storages/chatHistory.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
/*!
* Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
* SPDX-License-Identifier: Apache-2.0
*/
import { ChatMessage } from '@amzn/codewhisperer-streaming'
import { randomUUID } from '../../shared/crypto'
import { getLogger } from '../../shared/logger/logger'

// Maximum number of messages to keep in history
const MaxConversationHistoryLength = 100

/**
* ChatHistoryManager handles the storage and manipulation of chat history
* for CodeWhisperer Chat sessions.
*/
export class ChatHistoryManager {
private conversationId: string
private history: ChatMessage[] = []
private logger = getLogger()
private lastUserMessage?: ChatMessage

constructor() {
this.conversationId = randomUUID()
this.logger.info(`Generated new conversation id: ${this.conversationId}`)
}

/**
* Get the conversation ID
*/
public getConversationId(): string {
return this.conversationId
}

public setConversationId(conversationId: string) {
this.conversationId = conversationId
}

/**
* Get the full chat history
*/
public getHistory(): ChatMessage[] {
return [...this.history]
}

/**
* Clear the conversation history
*/
public clear(): void {
this.history = []
this.conversationId = ''
}

/**
* Append a new user message to be sent
*/
public appendUserMessage(newMessage: ChatMessage): void {
this.fixHistory()
if (!newMessage.userInputMessage?.content || newMessage.userInputMessage?.content.trim() === '') {
this.logger.warn('input must not be empty when adding new messages')
// const emptyMessage: ChatMessage = {
// ...newMessage,
// userInputMessage: {
// ...newMessage.userInputMessage,
// content: 'Empty user input',
// },
// }
// this.history.push(emptyMessage)
}
this.lastUserMessage = newMessage
this.history.push(newMessage)
}

/**
* Push an assistant message to the history
*/
public pushAssistantMessage(newMessage: ChatMessage): void {
if (newMessage !== undefined && this.lastUserMessage !== undefined) {
this.logger.warn('last Message should not be defined when pushing an assistant message')
}
this.history.push(newMessage)
}

/**
* Fixes the history to maintain the following invariants:
* 1. The history length is <= MAX_CONVERSATION_HISTORY_LENGTH. Oldest messages are dropped.
* 2. The first message is from the user. Oldest messages are dropped if needed.
* 3. The last message is from the assistant. The last message is dropped if it is from the user.
* 4. If the last message is from the assistant and it contains tool uses, and a next user
* message is set without tool results, then the user message will have cancelled tool results.
*/
public fixHistory(): void {
// Trim the conversation history if it exceeds the maximum length
if (this.history.length > MaxConversationHistoryLength) {
// Find the second oldest user message to be the new starting point
const secondUserMessageIndex = this.history
.slice(1) // Skip the first message which might be from the user
.findIndex((msg) => !msg.userInputMessage?.content || msg.userInputMessage?.content.trim() === '')

if (secondUserMessageIndex !== -1) {
// +1 because we sliced off the first element
this.logger.debug(`Removing the first ${secondUserMessageIndex + 1} elements in the history`)
this.history = this.history.slice(secondUserMessageIndex + 1)
} else {
this.logger.debug('No valid starting user message found in the history, clearing')
this.history = []
}
}

// Ensure the last message is from the assistant

if (this.history.length > 0 && this.history[this.history.length - 1].userInputMessage !== undefined) {
this.logger.debug('Last message in history is from the user, dropping')
this.history.pop()
}

// TODO: If the last message from the assistant contains tool uses, ensure the next user message contains tool results
}
}