Skip to content

Commit 4446831

Browse files
authored
feat(chat): Adding state Management to stop the response for new user input prompt (aws#6934)
## Problem Responses need to stop for new user prompt and fix user history ## Solution - featureflag to check the response state to stop the responses. - Fixing history incase of exception or stopped responses. --- - Treat all work as PUBLIC. Private `feature/x` branches will not be squash-merged at release time. - Your code changes must meet the guidelines in [CONTRIBUTING.md](https://github.com/aws/aws-toolkit-vscode/blob/master/CONTRIBUTING.md#guidelines). - License: I confirm that my contribution is made under the terms of the Apache 2.0 license.
1 parent 093d5bc commit 4446831

File tree

4 files changed

+132
-152
lines changed

4 files changed

+132
-152
lines changed

packages/core/src/codewhispererChat/controllers/chat/controller.ts

Lines changed: 25 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -368,6 +368,7 @@ export class ChatController {
368368
private async processStopResponseMessage(message: StopResponseMessage) {
369369
const session = this.sessionStorage.getSession(message.tabID)
370370
session.tokenSource.cancel()
371+
this.chatHistoryStorage.getTabHistory(message.tabID).clearRecentHistory()
371372
}
372373

373374
private async processTriggerTabIDReceived(message: TriggerTabIDReceived) {
@@ -650,6 +651,8 @@ export class ChatController {
650651
const session = this.sessionStorage.getSession(tabID)
651652
const toolUse = session.toolUse
652653
if (!toolUse || !toolUse.input) {
654+
// Turn off AgentLoop flag if there's no tool use
655+
this.sessionStorage.setAgentLoopInProgress(tabID, false)
653656
return
654657
}
655658
session.setToolUse(undefined)
@@ -723,7 +726,6 @@ export class ChatController {
723726
customization: getSelectedCustomization(),
724727
toolResults: toolResults,
725728
origin: Origin.IDE,
726-
chatHistory: this.chatHistoryStorage.getTabHistory(tabID).getHistory(),
727729
context: session.context ?? [],
728730
relevantTextDocuments: [],
729731
additionalContents: [],
@@ -899,10 +901,16 @@ export class ChatController {
899901
errorMessage = e.message
900902
}
901903

904+
// Turn off AgentLoop flag in case of exception
905+
if (tabID) {
906+
this.sessionStorage.setAgentLoopInProgress(tabID, false)
907+
}
908+
902909
this.messenger.sendErrorMessage(errorMessage, tabID, requestID)
903910
getLogger().error(`error: ${errorMessage} tabID: ${tabID} requestID: ${requestID}`)
904911

905912
this.sessionStorage.deleteSession(tabID)
913+
this.chatHistoryStorage.getTabHistory(tabID).clearRecentHistory()
906914
}
907915

908916
private async processContextMenuCommand(command: EditorContextCommand) {
@@ -1062,7 +1070,6 @@ export class ChatController {
10621070
codeQuery: lastTriggerEvent.context?.focusAreaContext?.names,
10631071
userIntent: message.userIntent,
10641072
customization: getSelectedCustomization(),
1065-
chatHistory: this.chatHistoryStorage.getTabHistory(message.tabID).getHistory(),
10661073
contextLengths: {
10671074
...defaultContextLengths,
10681075
},
@@ -1111,7 +1118,6 @@ export class ChatController {
11111118
codeQuery: context?.focusAreaContext?.names,
11121119
userIntent: undefined,
11131120
customization: getSelectedCustomization(),
1114-
chatHistory: this.chatHistoryStorage.getTabHistory(message.tabID).getHistory(),
11151121
origin: Origin.IDE,
11161122
context: message.context ?? [],
11171123
relevantTextDocuments: [],
@@ -1293,6 +1299,16 @@ export class ChatController {
12931299
}
12941300

12951301
const tabID = triggerEvent.tabID
1302+
if (this.sessionStorage.isAgentLoopInProgress(tabID)) {
1303+
// If a response is already in progress, stop it first
1304+
const stopResponseMessage: StopResponseMessage = {
1305+
tabID: tabID,
1306+
}
1307+
await this.processStopResponseMessage(stopResponseMessage)
1308+
}
1309+
1310+
// Ensure AgentLoop flag is set to true during response generation
1311+
this.sessionStorage.setAgentLoopInProgress(tabID, true)
12961312

12971313
const credentialsState = await AuthUtil.instance.getChatAuthState()
12981314

@@ -1355,6 +1371,7 @@ export class ChatController {
13551371
if (fixedHistoryMessage.userInputMessage?.userInputMessageContext) {
13561372
triggerPayload.toolResults = fixedHistoryMessage.userInputMessage.userInputMessageContext.toolResults
13571373
}
1374+
triggerPayload.chatHistory = chatHistory.getHistory()
13581375
const request = triggerPayloadToChatRequest(triggerPayload)
13591376
const conversationId = chatHistory.getConversationId() || randomUUID()
13601377
chatHistory.setConversationId(conversationId)
@@ -1417,8 +1434,13 @@ export class ChatController {
14171434
} metadata: ${inspect(response.$metadata, { depth: 12 })}`
14181435
)
14191436
await this.messenger.sendAIResponse(response, session, tabID, triggerID, triggerPayload, chatHistory)
1437+
1438+
// Turn off AgentLoop flag after sending the AI response
1439+
this.sessionStorage.setAgentLoopInProgress(tabID, false)
14201440
} catch (e: any) {
14211441
this.telemetryHelper.recordMessageResponseError(triggerPayload, tabID, getHttpStatusCode(e) ?? 0)
1442+
// Turn off AgentLoop flag in case of exception
1443+
this.sessionStorage.setAgentLoopInProgress(tabID, false)
14221444
// clears session, record telemetry before this call
14231445
this.processException(e, tabID)
14241446
}

packages/core/src/codewhispererChat/controllers/chat/messenger/messenger.ts

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -320,7 +320,7 @@ export class Messenger {
320320
}
321321
return true
322322
},
323-
{ timeout: 60000, truthy: true }
323+
{ timeout: 600000, truthy: true }
324324
)
325325
.catch((error: any) => {
326326
let errorMessage = 'Error reading chat stream.'

packages/core/src/codewhispererChat/storages/chatHistory.ts

Lines changed: 86 additions & 148 deletions
Original file line numberDiff line numberDiff line change
@@ -2,14 +2,7 @@
22
* Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
33
* SPDX-License-Identifier: Apache-2.0
44
*/
5-
import {
6-
ChatMessage,
7-
Tool,
8-
ToolResult,
9-
ToolResultStatus,
10-
UserInputMessage,
11-
UserInputMessageContext,
12-
} from '@amzn/codewhisperer-streaming'
5+
import { ChatMessage, Tool, ToolResult, ToolResultStatus, ToolUse } from '@amzn/codewhisperer-streaming'
136
import { randomUUID } from '../../shared/crypto'
147
import { getLogger } from '../../shared/logger/logger'
158
import { tools } from '../constants'
@@ -105,168 +98,99 @@ export class ChatHistoryManager {
10598
* message is set without tool results, then the user message will have cancelled tool results.
10699
*/
107100
public fixHistory(newUserMessage: ChatMessage): ChatMessage {
108-
// Trim the conversation history if it exceeds the maximum length
109-
if (this.history.length > MaxConversationHistoryLength) {
110-
// Find the second oldest user message without tool results
111-
let indexToTrim: number | undefined
101+
this.trimConversationHistory()
102+
this.ensureLastMessageFromAssistant()
103+
return this.handleToolUses(newUserMessage)
104+
}
112105

113-
for (let i = 1; i < this.history.length; i++) {
114-
const message = this.history[i]
115-
if (message.userInputMessage) {
116-
const userMessage = message.userInputMessage
117-
const ctx = userMessage.userInputMessageContext
118-
const hasNoToolResults = ctx && (!ctx.toolResults || ctx.toolResults.length === 0)
119-
if (hasNoToolResults && userMessage.content !== '') {
120-
indexToTrim = i
121-
break
122-
}
123-
}
124-
}
125-
if (indexToTrim !== undefined) {
126-
this.logger.debug(`Removing the first ${indexToTrim} elements in the history`)
127-
this.history.splice(0, indexToTrim)
128-
} else {
129-
this.logger.debug('No valid starting user message found in the history, clearing')
130-
this.history = []
131-
}
106+
private trimConversationHistory(): void {
107+
if (this.history.length <= MaxConversationHistoryLength) {
108+
return
132109
}
133110

134-
// Ensure the last message is from the assistant
135-
if (this.history.length > 0 && this.history[this.history.length - 1].userInputMessage !== undefined) {
136-
this.logger.debug('Last message in history is from the user, dropping')
137-
this.history.pop()
111+
const indexToTrim = this.findIndexToTrim()
112+
if (indexToTrim !== undefined) {
113+
this.logger.debug(`Removing the first ${indexToTrim} elements in the history`)
114+
this.history.splice(0, indexToTrim)
115+
} else {
116+
this.logger.debug('No valid starting user message found in the history, clearing')
117+
this.history = []
138118
}
119+
}
139120

140-
// If the last message from the assistant contains tool uses, ensure the next user message contains tool results
141-
142-
const lastHistoryMessage = this.history[this.history.length - 1]
143-
144-
if (
145-
lastHistoryMessage &&
146-
(lastHistoryMessage.assistantResponseMessage ||
147-
lastHistoryMessage.assistantResponseMessage !== undefined) &&
148-
newUserMessage
149-
) {
150-
const toolUses = lastHistoryMessage.assistantResponseMessage.toolUses
151-
152-
if (toolUses && toolUses.length > 0) {
153-
if (newUserMessage.userInputMessage) {
154-
if (newUserMessage.userInputMessage.userInputMessageContext) {
155-
const ctx = newUserMessage.userInputMessage.userInputMessageContext
156-
157-
if (!ctx.toolResults || ctx.toolResults.length === 0) {
158-
ctx.toolResults = toolUses.map((toolUse) => ({
159-
toolUseId: toolUse.toolUseId,
160-
content: [
161-
{
162-
type: 'Text',
163-
text: 'Tool use was cancelled by the user',
164-
},
165-
],
166-
status: ToolResultStatus.ERROR,
167-
}))
168-
}
169-
} else {
170-
const toolResults = toolUses.map((toolUse) => ({
171-
toolUseId: toolUse.toolUseId,
172-
content: [
173-
{
174-
type: 'Text',
175-
text: 'Tool use was cancelled by the user',
176-
},
177-
],
178-
status: ToolResultStatus.ERROR,
179-
}))
180-
181-
newUserMessage.userInputMessage.userInputMessageContext = {
182-
shellState: undefined,
183-
envState: undefined,
184-
toolResults: toolResults,
185-
tools: this.tools.length === 0 ? undefined : [...this.tools],
186-
}
187-
188-
return newUserMessage
189-
}
190-
}
121+
private findIndexToTrim(): number | undefined {
122+
for (let i = 1; i < this.history.length; i++) {
123+
const message = this.history[i]
124+
if (this.isValidUserMessageWithoutToolResults(message)) {
125+
return i
191126
}
192127
}
193-
194-
// Always return the message to fix the TypeScript error
195-
return newUserMessage
128+
return undefined
196129
}
197130

198-
/**
199-
* Adds tool results to the conversation.
200-
*/
201-
addToolResults(toolResults: ToolResult[]): void {
202-
const userInputMessageContext: UserInputMessageContext = {
203-
shellState: undefined,
204-
envState: undefined,
205-
toolResults: toolResults,
206-
tools: this.tools.length === 0 ? undefined : [...this.tools],
207-
}
208-
209-
const msg: UserInputMessage = {
210-
content: '',
211-
userInputMessageContext: userInputMessageContext,
131+
private isValidUserMessageWithoutToolResults(message: ChatMessage): boolean {
132+
if (!message.userInputMessage) {
133+
return false
212134
}
135+
const ctx = message.userInputMessage.userInputMessageContext
136+
return Boolean(
137+
ctx && (!ctx.toolResults || ctx.toolResults.length === 0) && message.userInputMessage.content !== ''
138+
)
139+
}
213140

214-
if (this.lastUserMessage?.userInputMessage) {
215-
this.lastUserMessage.userInputMessage = msg
141+
private ensureLastMessageFromAssistant(): void {
142+
if (this.history.length > 0 && this.history[this.history.length - 1].userInputMessage !== undefined) {
143+
this.logger.debug('Last message in history is from the user, dropping')
144+
this.history.pop()
216145
}
217146
}
218147

219-
/**
220-
* Checks if the latest message in history is an Assistant Message.
221-
* If it is and doesn't have toolUse, it will be removed.
222-
* If it has toolUse, an assistantResponse message with cancelled tool status will be added.
223-
*/
224-
public checkLatestAssistantMessage(): void {
225-
if (this.history.length === 0) {
226-
return
148+
private handleToolUses(newUserMessage: ChatMessage): ChatMessage {
149+
const lastHistoryMessage = this.history[this.history.length - 1]
150+
if (!lastHistoryMessage || !lastHistoryMessage.assistantResponseMessage || !newUserMessage) {
151+
return newUserMessage
227152
}
228153

229-
const lastMessage = this.history[this.history.length - 1]
230-
231-
if (lastMessage.assistantResponseMessage) {
232-
const toolUses = lastMessage.assistantResponseMessage.toolUses
154+
const toolUses = lastHistoryMessage.assistantResponseMessage.toolUses
155+
if (!toolUses || toolUses.length === 0) {
156+
return newUserMessage
157+
}
233158

234-
if (!toolUses || toolUses.length === 0) {
235-
// If there are no tool uses, remove the assistant message
236-
this.logger.debug('Removing assistant message without tool uses')
237-
this.history.pop()
238-
} else {
239-
// If there are tool uses, add cancelled tool results
240-
const toolResults = toolUses.map((toolUse) => ({
241-
toolUseId: toolUse.toolUseId,
242-
content: [
243-
{
244-
type: 'Text',
245-
text: 'Tool use was cancelled by the user',
246-
},
247-
],
248-
status: ToolResultStatus.ERROR,
249-
}))
159+
return this.addToolResultsToUserMessage(newUserMessage, toolUses)
160+
}
250161

251-
// Create a new user message with cancelled tool results
252-
const userInputMessageContext: UserInputMessageContext = {
253-
shellState: undefined,
254-
envState: undefined,
255-
toolResults: toolResults,
256-
tools: this.tools.length === 0 ? undefined : [...this.tools],
257-
}
162+
private addToolResultsToUserMessage(newUserMessage: ChatMessage, toolUses: ToolUse[]): ChatMessage {
163+
if (!newUserMessage.userInputMessage) {
164+
return newUserMessage
165+
}
258166

259-
const userMessage: ChatMessage = {
260-
userInputMessage: {
261-
content: '',
262-
userInputMessageContext: userInputMessageContext,
263-
},
264-
}
167+
const toolResults = this.createToolResults(toolUses)
265168

266-
this.history.push(this.formatChatHistoryMessage(userMessage))
267-
this.logger.debug('Added user message with cancelled tool results')
169+
if (newUserMessage.userInputMessage.userInputMessageContext) {
170+
newUserMessage.userInputMessage.userInputMessageContext.toolResults = toolResults
171+
} else {
172+
newUserMessage.userInputMessage.userInputMessageContext = {
173+
shellState: undefined,
174+
envState: undefined,
175+
toolResults: toolResults,
176+
tools: this.tools.length === 0 ? undefined : [...this.tools],
268177
}
269178
}
179+
180+
return newUserMessage
181+
}
182+
183+
private createToolResults(toolUses: ToolUse[]): ToolResult[] {
184+
return toolUses.map((toolUse) => ({
185+
toolUseId: toolUse.toolUseId,
186+
content: [
187+
{
188+
type: 'Text',
189+
text: 'Tool use was cancelled by the user',
190+
},
191+
],
192+
status: ToolResultStatus.ERROR,
193+
}))
270194
}
271195

272196
private formatChatHistoryMessage(message: ChatMessage): ChatMessage {
@@ -283,4 +207,18 @@ export class ChatHistoryManager {
283207
}
284208
return message
285209
}
210+
211+
public clearRecentHistory(): void {
212+
if (this.history.length === 0) {
213+
return
214+
}
215+
216+
const lastHistoryMessage = this.history[this.history.length - 1]
217+
218+
if (lastHistoryMessage.userInputMessage?.userInputMessageContext) {
219+
this.history.pop()
220+
} else if (lastHistoryMessage.assistantResponseMessage) {
221+
this.history.splice(-2)
222+
}
223+
}
286224
}

0 commit comments

Comments
 (0)