Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 25 additions & 3 deletions packages/core/src/codewhispererChat/controllers/chat/controller.ts
Original file line number Diff line number Diff line change
Expand Up @@ -368,6 +368,7 @@ export class ChatController {
private async processStopResponseMessage(message: StopResponseMessage) {
const session = this.sessionStorage.getSession(message.tabID)
session.tokenSource.cancel()
this.chatHistoryStorage.getTabHistory(message.tabID).clearRecentHistory()
}

private async processTriggerTabIDReceived(message: TriggerTabIDReceived) {
Expand Down Expand Up @@ -650,6 +651,8 @@ export class ChatController {
const session = this.sessionStorage.getSession(tabID)
const toolUse = session.toolUse
if (!toolUse || !toolUse.input) {
// Turn off AgentLoop flag if there's no tool use
this.sessionStorage.setAgentLoopInProgress(tabID, false)
return
}
session.setToolUse(undefined)
Expand Down Expand Up @@ -711,7 +714,6 @@ export class ChatController {
customization: getSelectedCustomization(),
toolResults: toolResults,
origin: Origin.IDE,
chatHistory: this.chatHistoryStorage.getTabHistory(tabID).getHistory(),
context: session.context ?? [],
relevantTextDocuments: [],
additionalContents: [],
Expand Down Expand Up @@ -887,10 +889,16 @@ export class ChatController {
errorMessage = e.message
}

// Turn off AgentLoop flag in case of exception
if (tabID) {
this.sessionStorage.setAgentLoopInProgress(tabID, false)
}

this.messenger.sendErrorMessage(errorMessage, tabID, requestID)
getLogger().error(`error: ${errorMessage} tabID: ${tabID} requestID: ${requestID}`)

this.sessionStorage.deleteSession(tabID)
this.chatHistoryStorage.getTabHistory(tabID).clearRecentHistory()
}

private async processContextMenuCommand(command: EditorContextCommand) {
Expand Down Expand Up @@ -1050,7 +1058,6 @@ export class ChatController {
codeQuery: lastTriggerEvent.context?.focusAreaContext?.names,
userIntent: message.userIntent,
customization: getSelectedCustomization(),
chatHistory: this.chatHistoryStorage.getTabHistory(message.tabID).getHistory(),
contextLengths: {
...defaultContextLengths,
},
Expand Down Expand Up @@ -1099,7 +1106,6 @@ export class ChatController {
codeQuery: context?.focusAreaContext?.names,
userIntent: this.userIntentRecognizer.getFromPromptChatMessage(message),
customization: getSelectedCustomization(),
chatHistory: this.chatHistoryStorage.getTabHistory(message.tabID).getHistory(),
origin: Origin.IDE,
context: message.context ?? [],
relevantTextDocuments: [],
Expand Down Expand Up @@ -1281,6 +1287,16 @@ export class ChatController {
}

const tabID = triggerEvent.tabID
if (this.sessionStorage.isAgentLoopInProgress(tabID)) {
// If a response is already in progress, stop it first
const stopResponseMessage: StopResponseMessage = {
tabID: tabID,
}
await this.processStopResponseMessage(stopResponseMessage)
}

// Ensure AgentLoop flag is set to true during response generation
this.sessionStorage.setAgentLoopInProgress(tabID, true)

const credentialsState = await AuthUtil.instance.getChatAuthState()

Expand Down Expand Up @@ -1343,6 +1359,7 @@ export class ChatController {
if (fixedHistoryMessage.userInputMessage?.userInputMessageContext) {
triggerPayload.toolResults = fixedHistoryMessage.userInputMessage.userInputMessageContext.toolResults
}
triggerPayload.chatHistory = chatHistory.getHistory()
const request = triggerPayloadToChatRequest(triggerPayload)
const conversationId = chatHistory.getConversationId() || randomUUID()
chatHistory.setConversationId(conversationId)
Expand Down Expand Up @@ -1405,8 +1422,13 @@ export class ChatController {
} metadata: ${inspect(response.$metadata, { depth: 12 })}`
)
await this.messenger.sendAIResponse(response, session, tabID, triggerID, triggerPayload, chatHistory)

// Turn off AgentLoop flag after sending the AI response
this.sessionStorage.setAgentLoopInProgress(tabID, false)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Isn't this going to keep setting/resetting the flag within a single agentic loop?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

currently yes...i'll expand it to tool use next to stop tool execution.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is a bit limited right now

} catch (e: any) {
this.telemetryHelper.recordMessageResponseError(triggerPayload, tabID, getHttpStatusCode(e) ?? 0)
// Turn off AgentLoop flag in case of exception
this.sessionStorage.setAgentLoopInProgress(tabID, false)
// clears session, record telemetry before this call
this.processException(e, tabID)
}
Expand Down
98 changes: 45 additions & 53 deletions packages/core/src/codewhispererChat/storages/chatHistory.ts
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,37 @@ export class ChatHistoryManager {
return newUserMessage
}
}
} else {
if (
newUserMessage.userInputMessage?.userInputMessageContext?.toolResults &&
newUserMessage.userInputMessage?.userInputMessageContext?.toolResults.length > 0
) {
// correct toolUse section of lastAssistantResponse in case of empty toolUse for user message with tool Results.
const toolResults = newUserMessage.userInputMessage.userInputMessageContext.toolResults
const updatedToolUses = toolResults.map((toolResult) => ({
toolUseId: toolResult.toolUseId,
name: lastHistoryMessage.assistantResponseMessage.toolUses?.find(
(tu) => tu.toolUseId === toolResult.toolUseId
)?.name,
input: lastHistoryMessage.assistantResponseMessage.toolUses?.find(
(tu) => tu.toolUseId === toolResult.toolUseId
)?.input,
}))

// Create a new assistant response message with updated toolUses
const updatedAssistantResponseMessage = {
...lastHistoryMessage.assistantResponseMessage,
toolUses: updatedToolUses,
}

// Create a new chat message with the updated assistant response
const updatedChatMessage: ChatMessage = {
assistantResponseMessage: updatedAssistantResponseMessage,
}

// Replace the last message in history
this.history[this.history.length - 1] = updatedChatMessage
}
}
}

Expand Down Expand Up @@ -216,59 +247,6 @@ export class ChatHistoryManager {
}
}

/**
* Checks if the latest message in history is an Assistant Message.
* If it is and doesn't have toolUse, it will be removed.
* If it has toolUse, an assistantResponse message with cancelled tool status will be added.
*/
public checkLatestAssistantMessage(): void {
if (this.history.length === 0) {
return
}

const lastMessage = this.history[this.history.length - 1]

if (lastMessage.assistantResponseMessage) {
const toolUses = lastMessage.assistantResponseMessage.toolUses

if (!toolUses || toolUses.length === 0) {
// If there are no tool uses, remove the assistant message
this.logger.debug('Removing assistant message without tool uses')
this.history.pop()
} else {
// If there are tool uses, add cancelled tool results
const toolResults = toolUses.map((toolUse) => ({
toolUseId: toolUse.toolUseId,
content: [
{
type: 'Text',
text: 'Tool use was cancelled by the user',
},
],
status: ToolResultStatus.ERROR,
}))

// Create a new user message with cancelled tool results
const userInputMessageContext: UserInputMessageContext = {
shellState: undefined,
envState: undefined,
toolResults: toolResults,
tools: this.tools.length === 0 ? undefined : [...this.tools],
}

const userMessage: ChatMessage = {
userInputMessage: {
content: '',
userInputMessageContext: userInputMessageContext,
},
}

this.history.push(this.formatChatHistoryMessage(userMessage))
this.logger.debug('Added user message with cancelled tool results')
}
}
}

private formatChatHistoryMessage(message: ChatMessage): ChatMessage {
if (message.userInputMessage !== undefined) {
return {
Expand All @@ -283,4 +261,18 @@ export class ChatHistoryManager {
}
return message
}

public clearRecentHistory(): void {
if (this.history.length === 0) {
return
}

const lastHistoryMessage = this.history[this.history.length - 1]

if (lastHistoryMessage.userInputMessage?.userInputMessageContext) {
this.history.pop()
} else if (lastHistoryMessage.assistantResponseMessage) {
this.history.splice(-2)
}
}
}
20 changes: 20 additions & 0 deletions packages/core/src/codewhispererChat/storages/chatSession.ts
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import { ChatSession } from '../clients/chat/v0/chat'

export class ChatSessionStorage {
private sessions: Map<string, ChatSession> = new Map()
private agentLoopInProgress: Map<string, boolean> = new Map()

public getSession(tabID: string): ChatSession {
const sessionFromStorage = this.sessions.get(tabID)
Expand All @@ -22,5 +23,24 @@ export class ChatSessionStorage {

public deleteSession(tabID: string) {
this.sessions.delete(tabID)
this.agentLoopInProgress.delete(tabID)
}

/**
* Check if agent loop is in progress for a specific tab
* @param tabID The tab ID to check
* @returns True if agent loop is in progress, false otherwise
*/
public isAgentLoopInProgress(tabID: string): boolean {
return this.agentLoopInProgress.get(tabID) === true
}

/**
* Set agent loop in progress state for a specific tab
* @param tabID The tab ID to set state for
* @param inProgress Whether the agent loop is in progress
*/
public setAgentLoopInProgress(tabID: string, inProgress: boolean): void {
this.agentLoopInProgress.set(tabID, inProgress)
}
}
Loading