diff --git a/src/core/task/Task.ts b/src/core/task/Task.ts index 2103dacb274..b8e2ba00f80 100644 --- a/src/core/task/Task.ts +++ b/src/core/task/Task.ts @@ -741,7 +741,21 @@ export class Task extends EventEmitter implements TaskLike { this.handleWebviewAskResponse("messageResponse", text, images) } - handleWebviewAskResponse(askResponse: ClineAskResponse, text?: string, images?: string[]) { + async handleWebviewAskResponse(askResponse: ClineAskResponse, text?: string, images?: string[]) { + // Save checkpoint immediately when user submits a message + // This allows users to easily revert to the state right before they typed their message + if (this.enableCheckpoints && askResponse === "messageResponse") { + try { + await this.checkpointSave(true) + } catch (error) { + console.error( + `[Task#handleWebviewAskResponse] Error saving checkpoint before user message: ${error.message}`, + error, + ) + // Don't block the user message if checkpoint fails + } + } + this.askResponse = askResponse this.askResponseText = text this.askResponseImages = images @@ -1533,6 +1547,10 @@ export class Task extends EventEmitter implements TaskLike { // results. const finalUserContent = [...parsedUserContent, { type: "text" as const, text: environmentDetails }] + // Note: Checkpoint is now saved in handleWebviewAskResponse when user submits a message, + // not here before the API request. This allows users to easily revert to the state + // right before they typed their message. + await this.addToApiConversationHistory({ role: "user", content: finalUserContent }) TelemetryService.instance.captureConversationMessage(this.taskId, "user") diff --git a/src/core/task/__tests__/Task.spec.ts b/src/core/task/__tests__/Task.spec.ts index 39df433814b..c3e90f7a104 100644 --- a/src/core/task/__tests__/Task.spec.ts +++ b/src/core/task/__tests__/Task.spec.ts @@ -1613,5 +1613,153 @@ describe("Cline", () => { consoleErrorSpy.mockRestore() }) }) + + describe("Checkpoint before user messages", () => { + it("should save checkpoint when user submits a message via handleWebviewAskResponse", async () => { + const task = new Task({ + provider: mockProvider, + apiConfiguration: mockApiConfig, + task: "test task", + enableCheckpoints: true, + startTask: false, + }) + + // Mock checkpointSave method + const checkpointSaveSpy = vi.spyOn(task, "checkpointSave").mockResolvedValue(undefined) + + // Call handleWebviewAskResponse with a user message + await task.handleWebviewAskResponse("messageResponse", "test user message", ["image.png"]) + + // Verify checkpoint was saved + expect(checkpointSaveSpy).toHaveBeenCalledWith(true) + + // Verify the response was set + expect(task["askResponse"]).toBe("messageResponse") + expect(task["askResponseText"]).toBe("test user message") + expect(task["askResponseImages"]).toEqual(["image.png"]) + }) + + it("should not save checkpoint for non-messageResponse ask responses", async () => { + const task = new Task({ + provider: mockProvider, + apiConfiguration: mockApiConfig, + task: "test task", + enableCheckpoints: true, + startTask: false, + }) + + // Mock checkpointSave method + const checkpointSaveSpy = vi.spyOn(task, "checkpointSave").mockResolvedValue(undefined) + + // Call handleWebviewAskResponse with a non-message response + await task.handleWebviewAskResponse("yesButtonClicked", undefined, undefined) + + // Verify checkpoint was NOT saved + expect(checkpointSaveSpy).not.toHaveBeenCalled() + + // Verify the response was set + expect(task["askResponse"]).toBe("yesButtonClicked") + }) + + it("should handle checkpoint save errors gracefully in handleWebviewAskResponse", async () => { + const task = new Task({ + provider: mockProvider, + apiConfiguration: mockApiConfig, + task: "test task", + enableCheckpoints: true, + startTask: false, + }) + + // Mock checkpointSave to throw an error + const checkpointError = new Error("Checkpoint save failed") + const checkpointSaveSpy = vi.spyOn(task, "checkpointSave").mockRejectedValue(checkpointError) + + // Mock console.error to verify error logging + const consoleErrorSpy = vi.spyOn(console, "error").mockImplementation(() => {}) + + // Call handleWebviewAskResponse + await task.handleWebviewAskResponse("messageResponse", "test user message", ["image.png"]) + + // Verify checkpoint save was attempted + expect(checkpointSaveSpy).toHaveBeenCalledWith(true) + + // Verify error was logged + expect(consoleErrorSpy).toHaveBeenCalledWith( + expect.stringContaining("Error saving checkpoint before user message"), + checkpointError, + ) + + // Verify the response was still set despite checkpoint error + expect(task["askResponse"]).toBe("messageResponse") + expect(task["askResponseText"]).toBe("test user message") + expect(task["askResponseImages"]).toEqual(["image.png"]) + + // Restore console.error + consoleErrorSpy.mockRestore() + }) + + it("should not save checkpoint when checkpoints are disabled", async () => { + const task = new Task({ + provider: mockProvider, + apiConfiguration: mockApiConfig, + task: "test task", + enableCheckpoints: false, + startTask: false, + }) + + // Mock checkpointSave method + const checkpointSaveSpy = vi.spyOn(task, "checkpointSave").mockResolvedValue(undefined) + + // Call handleWebviewAskResponse with a user message + await task.handleWebviewAskResponse("messageResponse", "test user message", ["image.png"]) + + // Verify checkpoint was NOT saved + expect(checkpointSaveSpy).not.toHaveBeenCalled() + + // Verify the response was still set + expect(task["askResponse"]).toBe("messageResponse") + expect(task["askResponseText"]).toBe("test user message") + expect(task["askResponseImages"]).toEqual(["image.png"]) + }) + + it("should not save checkpoint in recursivelyMakeClineRequests anymore", async () => { + const task = new Task({ + provider: mockProvider, + apiConfiguration: mockApiConfig, + task: "test task", + enableCheckpoints: true, + startTask: false, + }) + + // Mock checkpointSave method + const checkpointSaveSpy = vi.spyOn(task, "checkpointSave").mockResolvedValue(undefined) + + // Mock other required methods + vi.spyOn(task as any, "addToApiConversationHistory").mockResolvedValue(undefined) + vi.spyOn(task as any, "saveClineMessages").mockResolvedValue(undefined) + vi.spyOn(task.api, "createMessage").mockReturnValue({ + async *[Symbol.asyncIterator]() { + yield { type: "text", text: "response" } + }, + } as any) + + // Mock clineMessages + task.clineMessages = [ + { + ts: Date.now(), + type: "say", + say: "api_req_started", + text: JSON.stringify({ request: "test" }), + }, + ] + + // Call recursivelyMakeClineRequests + await task.recursivelyMakeClineRequests([{ type: "text", text: "test user message" }]) + + // Verify checkpoint was NOT saved in recursivelyMakeClineRequests + // (it should only be saved in handleWebviewAskResponse now) + expect(checkpointSaveSpy).not.toHaveBeenCalled() + }) + }) }) }) diff --git a/src/core/webview/webviewMessageHandler.ts b/src/core/webview/webviewMessageHandler.ts index 7e25ae14dcd..76f3ae09c57 100644 --- a/src/core/webview/webviewMessageHandler.ts +++ b/src/core/webview/webviewMessageHandler.ts @@ -346,7 +346,10 @@ export const webviewMessageHandler = async ( await provider.postStateToWebview() break case "askResponse": - provider.getCurrentCline()?.handleWebviewAskResponse(message.askResponse!, message.text, message.images) + // handleWebviewAskResponse is now async to support checkpoint saving + await provider + .getCurrentCline() + ?.handleWebviewAskResponse(message.askResponse!, message.text, message.images) break case "autoCondenseContext": await updateGlobalState("autoCondenseContext", message.bool)