diff --git a/src/core/webview/__tests__/ClineProvider.test.ts b/src/core/webview/__tests__/ClineProvider.test.ts index 6ced4989a4..266857c01b 100644 --- a/src/core/webview/__tests__/ClineProvider.test.ts +++ b/src/core/webview/__tests__/ClineProvider.test.ts @@ -152,6 +152,9 @@ jest.mock("vscode", () => ({ window: { showInformationMessage: jest.fn(), showErrorMessage: jest.fn(), + createTextEditorDecorationType: jest.fn().mockReturnValue({ + dispose: jest.fn(), + }), }, workspace: { getConfiguration: jest.fn().mockReturnValue({ diff --git a/src/core/webview/__tests__/webviewMessageHandler.test.ts b/src/core/webview/__tests__/webviewMessageHandler.test.ts index 7f3bc49654..24d786a5ea 100644 --- a/src/core/webview/__tests__/webviewMessageHandler.test.ts +++ b/src/core/webview/__tests__/webviewMessageHandler.test.ts @@ -2,9 +2,20 @@ import { webviewMessageHandler } from "../webviewMessageHandler" import { ClineProvider } from "../ClineProvider" import { getModels } from "../../../api/providers/fetchers/modelCache" import { ModelRecord } from "../../../shared/api" +import type { ClineMessage } from "@roo-code/types" +import * as vscode from "vscode" // Mock dependencies jest.mock("../../../api/providers/fetchers/modelCache") +jest.mock("vscode", () => ({ + window: { + showWarningMessage: jest.fn(), + }, +})) +jest.mock("../../checkpoints", () => ({ + checkpointRestore: jest.fn(), +})) + const mockGetModels = getModels as jest.MockedFunction // Mock ClineProvider @@ -272,3 +283,93 @@ describe("webviewMessageHandler - requestRouterModels", () => { }) }) }) + +describe("webviewMessageHandler - editMessage", () => { + let mockCline: any + + beforeEach(() => { + jest.clearAllMocks() + + // Mock Cline instance + mockCline = { + taskId: "test-task-id", + clineMessages: [ + { ts: 1000, type: "say", say: "user_feedback", text: "First message" }, + { ts: 2000, type: "say", say: "user_feedback", text: "Second message" }, + { ts: 3000, type: "say", say: "checkpoint_saved", text: "Checkpoint saved" }, + { ts: 4000, type: "say", say: "user_feedback", text: "Third message" }, + ] as ClineMessage[], + apiConversationHistory: [ + { ts: 1000, role: "user", content: "First message" }, + { ts: 2000, role: "user", content: "Second message" }, + { ts: 4000, role: "user", content: "Third message" }, + ], + overwriteClineMessages: jest.fn(), + overwriteApiConversationHistory: jest.fn(), + } + + mockClineProvider.getCurrentCline = jest.fn().mockReturnValue(mockCline) + mockClineProvider.getState = jest.fn().mockResolvedValue({ enableCheckpoints: true }) + mockClineProvider.getTaskWithId = jest.fn().mockResolvedValue({ + historyItem: { clineMessages: mockCline.clineMessages }, + }) + mockClineProvider.postStateToWebview = jest.fn() + mockClineProvider.initClineWithHistoryItem = jest.fn() + }) + + it("handles basic message editing without confirmation", async () => { + // Mock no subsequent messages and no checkpoints + mockCline.clineMessages = [{ ts: 1000, type: "say", say: "user_feedback", text: "Only message" }] + mockClineProvider.getState = jest.fn().mockResolvedValue({ enableCheckpoints: false }) + + await webviewMessageHandler(mockClineProvider, { + type: "editMessage", + value: 1000, + text: "Edited message", + }) + + expect(mockClineProvider.initClineWithHistoryItem).toHaveBeenCalled() + }) + + it("shows confirmation dialog when editing affects subsequent messages", async () => { + const mockShowWarning = vscode.window.showWarningMessage as jest.Mock + mockShowWarning.mockResolvedValue("Edit Message") + + await webviewMessageHandler(mockClineProvider, { + type: "editMessage", + value: 2000, // Edit second message, affecting third message + text: "Edited second message", + }) + + expect(mockShowWarning).toHaveBeenCalledWith( + "Edit and delete subsequent messages?\n\n• 1 checkpoint(s) will be removed", + { modal: true }, + "Edit Message", + ) + expect(mockClineProvider.initClineWithHistoryItem).toHaveBeenCalled() + }) + + it("cancels edit when user declines confirmation", async () => { + const mockShowWarning = vscode.window.showWarningMessage as jest.Mock + mockShowWarning.mockResolvedValue(undefined) // User cancelled + + await webviewMessageHandler(mockClineProvider, { + type: "editMessage", + value: 2000, + text: "This edit should be cancelled", + }) + + expect(mockClineProvider.postStateToWebview).toHaveBeenCalled() + expect(mockClineProvider.initClineWithHistoryItem).not.toHaveBeenCalled() + }) + + it("handles invalid message parameters gracefully", async () => { + await webviewMessageHandler(mockClineProvider, { + type: "editMessage", + value: undefined, // Invalid value + text: "Should not process", + }) + + expect(mockClineProvider.initClineWithHistoryItem).not.toHaveBeenCalled() + }) +}) diff --git a/src/core/webview/webviewMessageHandler.ts b/src/core/webview/webviewMessageHandler.ts index a4d9dafecf..87a0638560 100644 --- a/src/core/webview/webviewMessageHandler.ts +++ b/src/core/webview/webviewMessageHandler.ts @@ -3,7 +3,13 @@ import fs from "fs/promises" import pWaitFor from "p-wait-for" import * as vscode from "vscode" -import { type Language, type ProviderSettings, type GlobalState, TelemetryEventName } from "@roo-code/types" +import { + type Language, + type ProviderSettings, + type GlobalState, + TelemetryEventName, + type ClineMessage, +} from "@roo-code/types" import { CloudService } from "@roo-code/cloud" import { TelemetryService } from "@roo-code/telemetry" @@ -28,6 +34,7 @@ import { playTts, setTtsEnabled, setTtsSpeed, stopTts } from "../../utils/tts" import { singleCompletionHandler } from "../../utils/single-completion-handler" import { searchCommits } from "../../utils/git" import { exportSettings, importSettings } from "../config/importExport" +import { checkpointRestore } from "../checkpoints" import { getOpenAiModels } from "../../api/providers/openai" import { getOllamaModels } from "../../api/providers/ollama" import { getVsCodeLmModels } from "../../api/providers/vscode-lm" @@ -959,6 +966,179 @@ export const webviewMessageHandler = async ( } break } + case "editMessage": { + if ( + provider.getCurrentCline() && + typeof message.value === "number" && + message.value && + message.text !== undefined + ) { + const timeCutoff = message.value - 1000 // 1 second buffer before the message to edit + + const messageIndex = provider + .getCurrentCline()! + .clineMessages.findIndex((msg) => msg.ts && msg.ts >= timeCutoff) + + const apiConversationHistoryIndex = + provider + .getCurrentCline() + ?.apiConversationHistory.findIndex((msg) => msg.ts && msg.ts >= timeCutoff) ?? -1 + + if (messageIndex !== -1) { + // Check if there are subsequent messages that will be deleted + const totalMessages = provider.getCurrentCline()!.clineMessages.length + const hasSubsequentMessages = messageIndex < totalMessages - 1 + + // Check for checkpoints if enabled + const checkpointsEnabled = (await provider.getState()).enableCheckpoints + let affectedCheckpointsCount = 0 + let closestPreviousCheckpoint: ClineMessage | undefined + + if (checkpointsEnabled) { + const editMessageTimestamp = message.value + const checkpointMessages = provider + .getCurrentCline()! + .clineMessages.filter((msg) => msg.say === "checkpoint_saved") + .sort((a, b) => a.ts - b.ts) + + // Find checkpoints that will be affected (those after the edited message) + affectedCheckpointsCount = checkpointMessages.filter( + (cp) => cp.ts > editMessageTimestamp, + ).length + + // Find the closest checkpoint before the edited message + closestPreviousCheckpoint = checkpointMessages + .reverse() + .find((cp) => cp.ts < editMessageTimestamp) + } + + // Build confirmation message + let confirmationMessage = "Edit and delete subsequent messages?" + + if (checkpointsEnabled && affectedCheckpointsCount > 0) { + confirmationMessage += `\n\n• ${affectedCheckpointsCount} checkpoint(s) will be removed` + + if (closestPreviousCheckpoint) { + confirmationMessage += "\n• Files will restore to previous checkpoint" + } + } + + // Show confirmation dialog if there are subsequent messages or affected checkpoints + if (hasSubsequentMessages || affectedCheckpointsCount > 0) { + const confirmation = await vscode.window.showWarningMessage( + confirmationMessage, + { modal: true }, + "Edit Message", + ) + + if (confirmation !== "Edit Message") { + // User cancelled, update the webview to show the original state + await provider.postStateToWebview() + break + } + } + + const { historyItem } = await provider.getTaskWithId(provider.getCurrentCline()!.taskId) + + // Get messages up to and including the edited message + const updatedClineMessages = [ + ...provider.getCurrentCline()!.clineMessages.slice(0, messageIndex + 1), + ] + const messageToEdit = updatedClineMessages[messageIndex] + + if (messageToEdit && messageToEdit.type === "say" && messageToEdit.say === "user_feedback") { + // Update the text content + messageToEdit.text = message.text + + // Update images if provided + if (message.images) { + messageToEdit.images = message.images + } + + // Overwrite with only messages up to and including the edited one + await provider.getCurrentCline()!.overwriteClineMessages(updatedClineMessages) + + // Handle checkpoint restoration if checkpoints are enabled + if (checkpointsEnabled && closestPreviousCheckpoint) { + // Restore to the closest checkpoint before the edited message + const commitHash = closestPreviousCheckpoint.text // The commit hash is stored in the text field + if (commitHash) { + // Use "preview" mode to only restore files without affecting messages + // (we've already handled message cleanup above) + await checkpointRestore(provider.getCurrentCline()!, { + ts: closestPreviousCheckpoint.ts, + commitHash: commitHash, + mode: "preview", + }) + } + } + + // Update API conversation history if needed + if (apiConversationHistoryIndex !== -1) { + const updatedApiHistory = [ + ...provider + .getCurrentCline()! + .apiConversationHistory.slice(0, apiConversationHistoryIndex + 1), + ] + const apiMessage = updatedApiHistory[apiConversationHistoryIndex] + + if (apiMessage && apiMessage.role === "user") { + // Update the content in API history + if (typeof apiMessage.content === "string") { + apiMessage.content = message.text + } else if (Array.isArray(apiMessage.content)) { + // Find and update text content blocks + apiMessage.content = apiMessage.content.map((block: any) => { + if (block.type === "text") { + return { ...block, text: message.text } + } + return block + }) + + // Handle image updates if provided + if (message.images) { + // Remove existing image blocks + apiMessage.content = apiMessage.content.filter( + (block: any) => block.type !== "image", + ) + + // Add new image blocks + const imageBlocks = message.images.map((image) => ({ + type: "image" as const, + source: { + type: "base64" as const, + media_type: (image.startsWith("data:image/png") + ? "image/png" + : "image/jpeg") as + | "image/png" + | "image/jpeg" + | "image/gif" + | "image/webp", + data: image.split(",")[1] || image, + }, + })) + + // Add image blocks after text + apiMessage.content.push(...imageBlocks) + } + } + + // Overwrite with only API messages up to and including the edited one + await provider.getCurrentCline()!.overwriteApiConversationHistory(updatedApiHistory) + } + } + + await provider.initClineWithHistoryItem(historyItem) + // Force a state update to ensure the webview reflects the changes + await provider.postStateToWebview() + + // Note: Removed auto-resume logic to prevent duplicate messages. + // The user will manually send the edited message when ready. + } + } + } + break + } case "screenshotQuality": await updateGlobalState("screenshotQuality", message.value) await provider.postStateToWebview() diff --git a/src/shared/WebviewMessage.ts b/src/shared/WebviewMessage.ts index 5186c716b9..56cd2e9b1c 100644 --- a/src/shared/WebviewMessage.ts +++ b/src/shared/WebviewMessage.ts @@ -99,6 +99,7 @@ export interface WebviewMessage { | "enhancedPrompt" | "draggedImages" | "deleteMessage" + | "editMessage" | "terminalOutputLineLimit" | "terminalShellIntegrationTimeout" | "terminalShellIntegrationDisabled" diff --git a/webview-ui/src/components/chat/ChatRow.tsx b/webview-ui/src/components/chat/ChatRow.tsx index 43824c5902..e1b911d18a 100644 --- a/webview-ui/src/components/chat/ChatRow.tsx +++ b/webview-ui/src/components/chat/ChatRow.tsx @@ -107,11 +107,41 @@ export const ChatRowContent = ({ const [showCopySuccess, setShowCopySuccess] = useState(false) const { copyWithFeedback } = useCopyToClipboard() + // Edit mode state + const [isEditing, setIsEditing] = useState(false) + const [editValue, setEditValue] = useState(message.text || "") + const [editImages, setEditImages] = useState(message.images || []) + // Memoized callback to prevent re-renders caused by inline arrow functions const handleToggleExpand = useCallback(() => { onToggleExpand(message.ts) }, [onToggleExpand, message.ts]) + // Edit mode handlers + const handleEditSave = useCallback(() => { + if (editValue.trim() || editImages.length > 0) { + vscode.postMessage({ + type: "editMessage", + value: message.ts, + text: editValue.trim(), + images: editImages.length > 0 ? editImages : undefined, + }) + setIsEditing(false) + } + }, [editValue, editImages, message.ts]) + + const handleEditCancel = useCallback(() => { + setEditValue(message.text || "") + setEditImages(message.images || []) + setIsEditing(false) + }, [message.text, message.images]) + + const handleStartEdit = useCallback(() => { + setEditValue(message.text || "") + setEditImages(message.images || []) + setIsEditing(true) + }, [message.text, message.images]) + const [cost, apiReqCancelReason, apiReqStreamingFailedMessage] = useMemo(() => { if (message.text !== null && message.text !== undefined && message.say === "api_req_started") { const info = safeJsonParse(message.text) @@ -978,24 +1008,107 @@ export const ChatRowContent = ({ ) case "user_feedback": - return ( -
-
-
- + if (isEditing) { + return ( +
+
+