|
7 | 7 | Tool, |
8 | 8 | ToolResult, |
9 | 9 | ToolResultStatus, |
| 10 | + ToolUse, |
10 | 11 | UserInputMessage, |
11 | 12 | UserInputMessageContext, |
12 | 13 | } from '@amzn/codewhisperer-streaming' |
@@ -105,115 +106,99 @@ export class ChatHistoryManager { |
105 | 106 | * message is set without tool results, then the user message will have cancelled tool results. |
106 | 107 | */ |
107 | 108 | public fixHistory(newUserMessage: ChatMessage): ChatMessage { |
108 | | - // Trim the conversation history if it exceeds the maximum length |
109 | | - if (this.history.length > MaxConversationHistoryLength) { |
110 | | - // Find the second oldest user message without tool results |
111 | | - let indexToTrim: number | undefined |
112 | | - |
113 | | - for (let i = 1; i < this.history.length; i++) { |
114 | | - const message = this.history[i] |
115 | | - if (message.userInputMessage) { |
116 | | - const userMessage = message.userInputMessage |
117 | | - const ctx = userMessage.userInputMessageContext |
118 | | - const hasNoToolResults = ctx && (!ctx.toolResults || ctx.toolResults.length === 0) |
119 | | - if (hasNoToolResults && userMessage.content !== '') { |
120 | | - indexToTrim = i |
121 | | - break |
122 | | - } |
123 | | - } |
124 | | - } |
125 | | - if (indexToTrim !== undefined) { |
126 | | - this.logger.debug(`Removing the first ${indexToTrim} elements in the history`) |
127 | | - this.history.splice(0, indexToTrim) |
128 | | - } else { |
129 | | - this.logger.debug('No valid starting user message found in the history, clearing') |
130 | | - this.history = [] |
| 109 | + this.trimConversationHistory() |
| 110 | + this.ensureLastMessageFromAssistant() |
| 111 | + return this.handleToolUses(newUserMessage) |
| 112 | + } |
| 113 | + |
| 114 | + private trimConversationHistory(): void { |
| 115 | + if (this.history.length <= MaxConversationHistoryLength) { |
| 116 | + return |
| 117 | + } |
| 118 | + |
| 119 | + const indexToTrim = this.findIndexToTrim() |
| 120 | + if (indexToTrim !== undefined) { |
| 121 | + this.logger.debug(`Removing the first ${indexToTrim} elements in the history`) |
| 122 | + this.history.splice(0, indexToTrim) |
| 123 | + } else { |
| 124 | + this.logger.debug('No valid starting user message found in the history, clearing') |
| 125 | + this.history = [] |
| 126 | + } |
| 127 | + } |
| 128 | + |
| 129 | + private findIndexToTrim(): number | undefined { |
| 130 | + for (let i = 1; i < this.history.length; i++) { |
| 131 | + const message = this.history[i] |
| 132 | + if (this.isValidUserMessageWithoutToolResults(message)) { |
| 133 | + return i |
131 | 134 | } |
132 | 135 | } |
| 136 | + return undefined |
| 137 | + } |
133 | 138 |
|
134 | | - // Ensure the last message is from the assistant |
| 139 | + private isValidUserMessageWithoutToolResults(message: ChatMessage): boolean { |
| 140 | + if (!message.userInputMessage) { |
| 141 | + return false |
| 142 | + } |
| 143 | + const ctx = message.userInputMessage.userInputMessageContext |
| 144 | + return Boolean( |
| 145 | + ctx && (!ctx.toolResults || ctx.toolResults.length === 0) && message.userInputMessage.content !== '' |
| 146 | + ) |
| 147 | + } |
| 148 | + |
| 149 | + private ensureLastMessageFromAssistant(): void { |
135 | 150 | if (this.history.length > 0 && this.history[this.history.length - 1].userInputMessage !== undefined) { |
136 | 151 | this.logger.debug('Last message in history is from the user, dropping') |
137 | 152 | this.history.pop() |
138 | 153 | } |
| 154 | + } |
139 | 155 |
|
140 | | - // If the last message from the assistant contains tool uses, ensure the next user message contains tool results |
141 | | - |
| 156 | + private handleToolUses(newUserMessage: ChatMessage): ChatMessage { |
142 | 157 | const lastHistoryMessage = this.history[this.history.length - 1] |
| 158 | + if (!lastHistoryMessage || !lastHistoryMessage.assistantResponseMessage || !newUserMessage) { |
| 159 | + return newUserMessage |
| 160 | + } |
143 | 161 |
|
144 | | - if ( |
145 | | - lastHistoryMessage && |
146 | | - (lastHistoryMessage.assistantResponseMessage || |
147 | | - lastHistoryMessage.assistantResponseMessage !== undefined) && |
148 | | - newUserMessage |
149 | | - ) { |
150 | | - const toolUses = lastHistoryMessage.assistantResponseMessage.toolUses |
151 | | - |
152 | | - if (toolUses && toolUses.length > 0) { |
153 | | - if (newUserMessage.userInputMessage) { |
154 | | - if (newUserMessage.userInputMessage.userInputMessageContext) { |
155 | | - const ctx = newUserMessage.userInputMessage.userInputMessageContext |
156 | | - |
157 | | - if (!ctx.toolResults || ctx.toolResults.length === 0) { |
158 | | - ctx.toolResults = toolUses.map((toolUse) => ({ |
159 | | - toolUseId: toolUse.toolUseId, |
160 | | - content: [ |
161 | | - { |
162 | | - type: 'Text', |
163 | | - text: 'Tool use was cancelled by the user', |
164 | | - }, |
165 | | - ], |
166 | | - status: ToolResultStatus.ERROR, |
167 | | - })) |
168 | | - } |
169 | | - } else { |
170 | | - const toolResults = toolUses.map((toolUse) => ({ |
171 | | - toolUseId: toolUse.toolUseId, |
172 | | - content: [ |
173 | | - { |
174 | | - type: 'Text', |
175 | | - text: 'Tool use was cancelled by the user', |
176 | | - }, |
177 | | - ], |
178 | | - status: ToolResultStatus.ERROR, |
179 | | - })) |
180 | | - |
181 | | - newUserMessage.userInputMessage.userInputMessageContext = { |
182 | | - shellState: undefined, |
183 | | - envState: undefined, |
184 | | - toolResults: toolResults, |
185 | | - tools: this.tools.length === 0 ? undefined : [...this.tools], |
186 | | - } |
187 | | - |
188 | | - return newUserMessage |
189 | | - } |
190 | | - } |
191 | | - } |
| 162 | + const toolUses = lastHistoryMessage.assistantResponseMessage.toolUses |
| 163 | + if (!toolUses || toolUses.length === 0) { |
| 164 | + return newUserMessage |
192 | 165 | } |
193 | 166 |
|
194 | | - // Always return the message to fix the TypeScript error |
195 | | - return newUserMessage |
| 167 | + return this.addToolResultsToUserMessage(newUserMessage, toolUses) |
196 | 168 | } |
197 | 169 |
|
198 | | - /** |
199 | | - * Adds tool results to the conversation. |
200 | | - */ |
201 | | - addToolResults(toolResults: ToolResult[]): void { |
202 | | - const userInputMessageContext: UserInputMessageContext = { |
203 | | - shellState: undefined, |
204 | | - envState: undefined, |
205 | | - toolResults: toolResults, |
206 | | - tools: this.tools.length === 0 ? undefined : [...this.tools], |
| 170 | + private addToolResultsToUserMessage(newUserMessage: ChatMessage, toolUses: ToolUse[]): ChatMessage { |
| 171 | + if (!newUserMessage.userInputMessage) { |
| 172 | + return newUserMessage |
207 | 173 | } |
208 | 174 |
|
209 | | - const msg: UserInputMessage = { |
210 | | - content: '', |
211 | | - userInputMessageContext: userInputMessageContext, |
212 | | - } |
| 175 | + const toolResults = this.createToolResults(toolUses) |
213 | 176 |
|
214 | | - if (this.lastUserMessage?.userInputMessage) { |
215 | | - this.lastUserMessage.userInputMessage = msg |
| 177 | + if (newUserMessage.userInputMessage.userInputMessageContext) { |
| 178 | + newUserMessage.userInputMessage.userInputMessageContext.toolResults = toolResults |
| 179 | + } else { |
| 180 | + newUserMessage.userInputMessage.userInputMessageContext = { |
| 181 | + shellState: undefined, |
| 182 | + envState: undefined, |
| 183 | + toolResults: toolResults, |
| 184 | + tools: this.tools.length === 0 ? undefined : [...this.tools], |
| 185 | + } |
216 | 186 | } |
| 187 | + |
| 188 | + return newUserMessage |
| 189 | + } |
| 190 | + |
| 191 | + private createToolResults(toolUses: ToolUse[]): ToolResult[] { |
| 192 | + return toolUses.map((toolUse) => ({ |
| 193 | + toolUseId: toolUse.toolUseId, |
| 194 | + content: [ |
| 195 | + { |
| 196 | + type: 'Text', |
| 197 | + text: 'Tool use was cancelled by the user', |
| 198 | + }, |
| 199 | + ], |
| 200 | + status: ToolResultStatus.ERROR, |
| 201 | + })) |
217 | 202 | } |
218 | 203 |
|
219 | 204 | private formatChatHistoryMessage(message: ChatMessage): ChatMessage { |
|
0 commit comments