Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,6 @@ export function triggerPayloadToChatRequest(triggerPayload: TriggerPayload): { c
},
chatTriggerType,
customizationArn: customizationArn,
history: triggerPayload.chatHistory,
},
}
}
Expand Down
28 changes: 9 additions & 19 deletions packages/core/src/codewhispererChat/controllers/chat/controller.ts
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,6 @@ import {
createSavedPromptCommandId,
aditionalContentNameLimit,
additionalContentInnerContextLimit,
tools,
workspaceChunkMaxSize,
defaultContextLengths,
} from '../../constants'
Expand Down Expand Up @@ -1355,24 +1354,15 @@ export class ChatController {
triggerPayload.contextLengths.userInputContextLength = triggerPayload.message.length
triggerPayload.contextLengths.focusFileContextLength = triggerPayload.fileText.length

const request = triggerPayloadToChatRequest(triggerPayload)

const chatHistory = this.chatHistoryStorage.getTabHistory(tabID)
const newUserMessage = {
userInputMessage: {
content: triggerPayload.message,
userIntent: triggerPayload.userIntent,
...(triggerPayload.origin && { origin: triggerPayload.origin }),
userInputMessageContext: {
tools: tools,
...(triggerPayload.toolResults && { toolResults: triggerPayload.toolResults }),
},
},
const currentMessage = request.conversationState.currentMessage
if (currentMessage) {
chatHistory.fixHistory(currentMessage)
}
const fixedHistoryMessage = chatHistory.fixHistory(newUserMessage)
if (fixedHistoryMessage.userInputMessage?.userInputMessageContext) {
triggerPayload.toolResults = fixedHistoryMessage.userInputMessage.userInputMessageContext.toolResults
}
triggerPayload.chatHistory = chatHistory.getHistory()
const request = triggerPayloadToChatRequest(triggerPayload)
request.conversationState.history = chatHistory.getHistory()

const conversationId = chatHistory.getConversationId() || randomUUID()
chatHistory.setConversationId(conversationId)
request.conversationState.conversationId = conversationId
Expand Down Expand Up @@ -1426,8 +1416,8 @@ export class ChatController {
}
this.telemetryHelper.recordEnterFocusConversation(triggerEvent.tabID)
this.telemetryHelper.recordStartConversation(triggerEvent, triggerPayload)
if (request.conversationState.currentMessage) {
chatHistory.appendUserMessage(request.conversationState.currentMessage)
if (currentMessage) {
chatHistory.appendUserMessage(currentMessage)
}

getLogger().info(
Expand Down
2 changes: 0 additions & 2 deletions packages/core/src/codewhispererChat/controllers/chat/model.ts
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
import * as vscode from 'vscode'
import {
AdditionalContentEntry,
ChatMessage,
Origin,
RelevantTextDocument,
ToolResult,
Expand Down Expand Up @@ -203,7 +202,6 @@ export interface TriggerPayload {
traceId?: string
contextLengths: ContextLengths
workspaceRulesCount?: number
chatHistory?: ChatMessage[]
toolResults?: ToolResult[]
origin?: Origin
}
Expand Down
50 changes: 16 additions & 34 deletions packages/core/src/codewhispererChat/storages/chatHistory.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,9 @@
* Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
* SPDX-License-Identifier: Apache-2.0
*/
import { ChatMessage, Tool, ToolResult, ToolResultStatus, ToolUse } from '@amzn/codewhisperer-streaming'
import { ChatMessage, ToolResult, ToolResultStatus, ToolUse } from '@amzn/codewhisperer-streaming'
import { randomUUID } from '../../shared/crypto'
import { getLogger } from '../../shared/logger/logger'
import { tools } from '../constants'

// Maximum number of messages to keep in history
const MaxConversationHistoryLength = 100
Expand All @@ -20,12 +19,10 @@ export class ChatHistoryManager {
private history: ChatMessage[] = []
private logger = getLogger()
private lastUserMessage?: ChatMessage
private tools: Tool[] = []

constructor(tabId?: string) {
this.conversationId = randomUUID()
this.tabId = tabId ?? randomUUID()
this.tools = tools
}

/**
Expand Down Expand Up @@ -97,10 +94,10 @@ export class ChatHistoryManager {
* 4. If the last message is from the assistant and it contains tool uses, and a next user
* message is set without tool results, then the user message will have cancelled tool results.
*/
public fixHistory(newUserMessage: ChatMessage): ChatMessage {
public fixHistory(newUserMessage: ChatMessage): void {
this.trimConversationHistory()
this.ensureLastMessageFromAssistant()
return this.handleToolUses(newUserMessage)
this.ensureCurrentMessageIsValid(newUserMessage)
}

private trimConversationHistory(): void {
Expand Down Expand Up @@ -145,42 +142,27 @@ export class ChatHistoryManager {
}
}

private handleToolUses(newUserMessage: ChatMessage): ChatMessage {
private ensureCurrentMessageIsValid(newUserMessage: ChatMessage): void {
const lastHistoryMessage = this.history[this.history.length - 1]
if (!lastHistoryMessage || !lastHistoryMessage.assistantResponseMessage || !newUserMessage) {
return newUserMessage
}

const toolUses = lastHistoryMessage.assistantResponseMessage.toolUses
if (!toolUses || toolUses.length === 0) {
return newUserMessage
}

return this.addToolResultsToUserMessage(newUserMessage, toolUses)
}

private addToolResultsToUserMessage(newUserMessage: ChatMessage, toolUses: ToolUse[]): ChatMessage {
if (!newUserMessage.userInputMessage) {
return newUserMessage
if (!lastHistoryMessage) {
return
}

const toolResults = this.createToolResults(toolUses)
if (lastHistoryMessage.assistantResponseMessage?.toolUses?.length) {
const toolResults = newUserMessage.userInputMessage?.userInputMessageContext?.toolResults
if (!toolResults || toolResults.length === 0) {
const abandonedToolResults = this.createAbandonedToolResults(
lastHistoryMessage.assistantResponseMessage.toolUses
)

if (newUserMessage.userInputMessage.userInputMessageContext) {
newUserMessage.userInputMessage.userInputMessageContext.toolResults = toolResults
} else {
newUserMessage.userInputMessage.userInputMessageContext = {
shellState: undefined,
envState: undefined,
toolResults: toolResults,
tools: this.tools.length === 0 ? undefined : [...this.tools],
if (newUserMessage.userInputMessage?.userInputMessageContext) {
newUserMessage.userInputMessage.userInputMessageContext.toolResults = abandonedToolResults
}
}
}

return newUserMessage
}

private createToolResults(toolUses: ToolUse[]): ToolResult[] {
private createAbandonedToolResults(toolUses: ToolUse[]): ToolResult[] {
return toolUses.map((toolUse) => ({
toolUseId: toolUse.toolUseId,
content: [
Expand Down
Loading