diff --git a/packages/types/src/global-settings.ts b/packages/types/src/global-settings.ts index 6a0dda1cf2..a820f56c87 100644 --- a/packages/types/src/global-settings.ts +++ b/packages/types/src/global-settings.ts @@ -109,6 +109,7 @@ export const globalSettingsSchema = z.object({ hasOpenedModeSelector: z.boolean().optional(), lastModeExportPath: z.string().optional(), lastModeImportPath: z.string().optional(), + chatTextDrafts: z.record(z.string(), z.string()).optional(), }) export type GlobalSettings = z.infer diff --git a/src/core/webview/webviewMessageHandler.ts b/src/core/webview/webviewMessageHandler.ts index 0ec14ca27e..e4148cc4d5 100644 --- a/src/core/webview/webviewMessageHandler.ts +++ b/src/core/webview/webviewMessageHandler.ts @@ -58,6 +58,40 @@ export const webviewMessageHandler = async ( await provider.contextProxy.setValue(key, value) switch (message.type) { + case "updateChatTextDraft": { + // Update or remove the chat text draft in globalState under the fixed key "chatTextDraft" + const text = typeof message.text === "string" ? message.text : "" + const drafts = (await getGlobalState("chatTextDrafts")) ?? {} + if (text && text.trim()) { + // Set/Update + const updated = { ...drafts, chatTextDraft: text } + await updateGlobalState("chatTextDrafts", updated) + } else if (drafts.chatTextDraft) { + // Remove if empty + const { chatTextDraft: _, ...rest } = drafts + await updateGlobalState("chatTextDrafts", rest) + } + break + } + case "getChatTextDraft": { + // Return the chat text draft for the fixed key "chatTextDraft" + const drafts = (await getGlobalState("chatTextDrafts")) ?? {} + const text = drafts.chatTextDraft ?? "" + await provider.postMessageToWebview({ + type: "chatTextDraftValue", + text, + }) + break + } + case "clearChatTextDraft": { + // Remove the chat text draft from globalState under the fixed key "chatTextDraft" + const drafts = (await getGlobalState("chatTextDrafts")) ?? {} + if (drafts.chatTextDraft) { + const { chatTextDraft: _, ...rest } = drafts + await updateGlobalState("chatTextDrafts", rest) + } + break + } case "webviewDidLaunch": // Load custom modes first const customModes = await provider.customModesManager.getCustomModes() diff --git a/src/shared/ExtensionMessage.ts b/src/shared/ExtensionMessage.ts index 9db0889c88..66043cc4c4 100644 --- a/src/shared/ExtensionMessage.ts +++ b/src/shared/ExtensionMessage.ts @@ -105,6 +105,7 @@ export interface ExtensionMessage { | "shareTaskSuccess" | "codeIndexSettingsSaved" | "codeIndexSecretStatus" + | "chatTextDraftValue" text?: string payload?: any // Add a generic payload for now, can refine later action?: diff --git a/src/shared/WebviewMessage.ts b/src/shared/WebviewMessage.ts index a50e30b67e..679ccdca63 100644 --- a/src/shared/WebviewMessage.ts +++ b/src/shared/WebviewMessage.ts @@ -185,6 +185,9 @@ export interface WebviewMessage { | "checkRulesDirectoryResult" | "saveCodeIndexSettingsAtomic" | "requestCodeIndexSecretStatus" + | "getChatTextDraft" + | "updateChatTextDraft" + | "clearChatTextDraft" text?: string tab?: "settings" | "history" | "mcp" | "modes" | "chat" | "marketplace" | "account" disabled?: boolean diff --git a/webview-ui/src/components/chat/ChatTextArea.tsx b/webview-ui/src/components/chat/ChatTextArea.tsx index 51279062d2..71b3682871 100644 --- a/webview-ui/src/components/chat/ChatTextArea.tsx +++ b/webview-ui/src/components/chat/ChatTextArea.tsx @@ -9,6 +9,7 @@ import { ExtensionMessage } from "@roo/ExtensionMessage" import { vscode } from "@/utils/vscode" import { useExtensionState } from "@/context/ExtensionStateContext" +import { useChatTextDraft } from "./hooks/useChatTextDraft" import { useAppTranslation } from "@/i18n/TranslationContext" import { ContextMenuOptionType, @@ -69,6 +70,10 @@ const ChatTextArea = forwardRef( ref, ) => { const { t } = useAppTranslation() + + // Chat draft persistence + const { handleSendAndClearDraft } = useChatTextDraft(inputValue, setInputValue, onSend) + const { filePaths, openedTabs, @@ -389,7 +394,7 @@ const ChatTextArea = forwardRef( if (!sendingDisabled) { // Reset history navigation state when sending resetHistoryNavigation() - onSend() + handleSendAndClearDraft() } } @@ -438,22 +443,22 @@ const ChatTextArea = forwardRef( } }, [ - sendingDisabled, - onSend, showContextMenu, - searchQuery, + handleHistoryNavigation, selectedMenuIndex, - handleMentionSelect, - selectedType, + searchQuery, inputValue, - cursorPosition, - setInputValue, - justDeletedSpaceAfterMention, + selectedType, queryItems, - allModes, fileSearchResults, - handleHistoryNavigation, + allModes, + handleMentionSelect, + sendingDisabled, resetHistoryNavigation, + handleSendAndClearDraft, + cursorPosition, + justDeletedSpaceAfterMention, + setInputValue, ], ) @@ -1151,7 +1156,7 @@ const ChatTextArea = forwardRef( iconClass="codicon-send" title={t("chat:sendMessage")} disabled={sendingDisabled} - onClick={onSend} + onClick={handleSendAndClearDraft} /> diff --git a/webview-ui/src/components/chat/hooks/__tests__/useChatTextDraft.spec.ts b/webview-ui/src/components/chat/hooks/__tests__/useChatTextDraft.spec.ts new file mode 100644 index 0000000000..e705a18141 --- /dev/null +++ b/webview-ui/src/components/chat/hooks/__tests__/useChatTextDraft.spec.ts @@ -0,0 +1,146 @@ +// npx vitest webview-ui/src/components/chat/hooks/__tests__/useChatTextDraft.spec.ts + +import { renderHook, act } from "@testing-library/react" +import { useChatTextDraft } from "../useChatTextDraft" +import { vi } from "vitest" +import { vscode } from "@src/utils/vscode" + +describe("useChatTextDraft (postMessage version)", () => { + let setInputValue: (v: string) => void + let onSend: () => void + let postMessageMock: ReturnType + let addEventListenerMock: ReturnType + let removeEventListenerMock: ReturnType + let eventListener: ((event: MessageEvent) => void) | undefined + + beforeEach(() => { + setInputValue = vi.fn((_: string) => {}) + onSend = vi.fn() + postMessageMock = vi.fn() + addEventListenerMock = vi.fn((type, cb) => { + if (type === "message") eventListener = cb + }) + removeEventListenerMock = vi.fn((type, cb) => { + if (type === "message" && eventListener === cb) eventListener = undefined + }) + + global.window.addEventListener = addEventListenerMock + global.window.removeEventListener = removeEventListenerMock + // mock vscode.postMessage + vi.resetModules() + vi.clearAllMocks() + vscode.postMessage = postMessageMock + + vi.useFakeTimers() + }) + + afterEach(() => { + vi.clearAllTimers() + vi.useRealTimers() + vi.restoreAllMocks() + eventListener = undefined + }) + + it("should send getChatTextDraft on mount and set input value when chatTextDraftValue received", () => { + renderHook(() => useChatTextDraft("", setInputValue, onSend)) + expect(postMessageMock).toHaveBeenCalledWith({ type: "getChatTextDraft" }) + expect(setInputValue).not.toHaveBeenCalled() + // Simulate extension host response + act(() => { + eventListener?.({ data: { type: "chatTextDraftValue", text: "restored draft" } } as MessageEvent) + }) + expect(setInputValue).toHaveBeenCalledWith("restored draft") + }) + + it("should not set input value if inputValue is not empty when chatTextDraftValue received", () => { + renderHook(() => useChatTextDraft("already typed", setInputValue, onSend)) + act(() => { + eventListener?.({ data: { type: "chatTextDraftValue", text: "restored draft" } } as MessageEvent) + }) + expect(setInputValue).not.toHaveBeenCalled() + }) + + it("should debounce and send updateChatTextDraft with text after 2s if inputValue is non-empty", () => { + renderHook(({ value }) => useChatTextDraft(value, setInputValue, onSend), { + initialProps: { value: "hello world" }, + }) + expect(postMessageMock).toHaveBeenCalledWith({ type: "getChatTextDraft" }) + postMessageMock.mockClear() + act(() => { + vi.advanceTimersByTime(1999) + }) + expect(postMessageMock).not.toHaveBeenCalled() + act(() => { + vi.advanceTimersByTime(1) + }) + expect(postMessageMock).toHaveBeenCalledWith({ type: "updateChatTextDraft", text: "hello world" }) + }) + + it("should reset debounce timer when inputValue changes before debounce delay", () => { + const { rerender } = renderHook(({ value }) => useChatTextDraft(value, setInputValue, onSend), { + initialProps: { value: "foo" }, + }) + act(() => { + vi.advanceTimersByTime(1000) + }) + postMessageMock.mockClear() + rerender({ value: "bar" }) + act(() => { + vi.advanceTimersByTime(1999) + }) + expect(postMessageMock).not.toHaveBeenCalled() + act(() => { + vi.advanceTimersByTime(1) + }) + expect(postMessageMock).toHaveBeenCalledWith({ type: "updateChatTextDraft", text: "bar" }) + }) + + it("should send clearChatTextDraft if inputValue is empty after user has input", () => { + const { rerender } = renderHook(({ value }) => useChatTextDraft(value, setInputValue, onSend), { + initialProps: { value: "foo" }, + }) + act(() => { + vi.advanceTimersByTime(2000) + }) + postMessageMock.mockClear() + rerender({ value: "" }) + expect(postMessageMock).toHaveBeenCalledWith({ type: "clearChatTextDraft" }) + }) + + it("should send clearChatTextDraft and call onSend when handleSendAndClearDraft is called", () => { + const { result } = renderHook(() => useChatTextDraft("msg", setInputValue, onSend)) + postMessageMock.mockClear() + act(() => { + result.current.handleSendAndClearDraft() + }) + expect(postMessageMock).toHaveBeenCalledWith({ type: "clearChatTextDraft" }) + expect(onSend).toHaveBeenCalled() + }) + + it("should not send updateChatTextDraft and should warn if inputValue exceeds 100KB (ASCII)", () => { + const MAX_DRAFT_BYTES = 102400 + const largeStr = "a".repeat(MAX_DRAFT_BYTES + 5000) + const warnMock = vi.spyOn(console, "warn").mockImplementation(() => {}) + renderHook(() => useChatTextDraft(largeStr, setInputValue, onSend)) + act(() => { + vi.advanceTimersByTime(3000) + }) + expect(postMessageMock).not.toHaveBeenCalledWith({ type: "updateChatTextDraft", text: largeStr }) + expect(warnMock).toHaveBeenCalledWith(expect.stringContaining("exceeds 100KB")) + warnMock.mockRestore() + }) + + it("should not send updateChatTextDraft and should warn if inputValue exceeds 100KB (UTF-8 multi-byte)", () => { + const emoji = "😀" + const hanzi = "汉" + const utf8Str = emoji.repeat(20000) + hanzi.repeat(10000) + "abc" + const warnMock = vi.spyOn(console, "warn").mockImplementation(() => {}) + renderHook(() => useChatTextDraft(utf8Str, setInputValue, onSend)) + act(() => { + vi.advanceTimersByTime(3000) + }) + expect(postMessageMock).not.toHaveBeenCalledWith({ type: "updateChatTextDraft", text: utf8Str }) + expect(warnMock).toHaveBeenCalledWith(expect.stringContaining("exceeds 100KB")) + warnMock.mockRestore() + }) +}) diff --git a/webview-ui/src/components/chat/hooks/useChatTextDraft.ts b/webview-ui/src/components/chat/hooks/useChatTextDraft.ts new file mode 100644 index 0000000000..465a1b077c --- /dev/null +++ b/webview-ui/src/components/chat/hooks/useChatTextDraft.ts @@ -0,0 +1,90 @@ +import { useCallback, useEffect, useRef } from "react" +import { vscode } from "@src/utils/vscode" + +export const CHAT_DRAFT_SAVE_DEBOUNCE_MS = 2000 + +/** + * Hook for chat textarea draft persistence (extension globalState). + * Handles auto-save, restore on mount, and clear on send via postMessage. + * @param inputValue current textarea value + * @param setInputValue setter for textarea value + * @param onSend send callback + */ +export function useChatTextDraft(inputValue: string, setInputValue: (value: string) => void, onSend: () => void) { + // Restore draft from extension host on mount + useEffect(() => { + const handleDraftValue = (event: MessageEvent) => { + const msg = event.data + if (msg && msg.type === "chatTextDraftValue") { + if (typeof msg.text === "string" && msg.text && !inputValue) { + setInputValue(msg.text) + } + } + } + window.addEventListener("message", handleDraftValue) + // Request draft from extension host + vscode.postMessage({ type: "getChatTextDraft" }) + return () => { + window.removeEventListener("message", handleDraftValue) + } + // eslint-disable-next-line react-hooks/exhaustive-deps + }, []) + + // Debounced save draft to extension host + const debounceTimerRef = useRef(null) + + const hasHadUserInput = useRef(false) + + useEffect(() => { + if (debounceTimerRef.current) { + clearTimeout(debounceTimerRef.current) + } + const MAX_DRAFT_BYTES = 102400 + if (inputValue && inputValue.trim()) { + hasHadUserInput.current = true + debounceTimerRef.current = setTimeout(() => { + try { + // Fast pre-check: if character count is much greater than max bytes, skip encoding + if (inputValue.length > MAX_DRAFT_BYTES * 2) { + console.warn(`[useChatTextDraft] Draft is too long (chars=${inputValue.length}), not saving.`) + return + } + const encoder = new TextEncoder() + const bytes = encoder.encode(inputValue) + if (bytes.length > MAX_DRAFT_BYTES) { + console.warn(`[useChatTextDraft] Draft exceeds 100KB, not saving.`) + return + } + vscode.postMessage({ type: "updateChatTextDraft", text: inputValue }) + } catch (err) { + console.warn(`[useChatTextDraft] Failed to save draft:`, err) + } + }, CHAT_DRAFT_SAVE_DEBOUNCE_MS) + } else { + if (hasHadUserInput.current) { + try { + vscode.postMessage({ type: "clearChatTextDraft" }) + } catch (err) { + console.warn(`[useChatTextDraft] Failed to clear draft:`, err) + } + } + } + return () => { + if (debounceTimerRef.current) { + clearTimeout(debounceTimerRef.current) + } + } + }, [inputValue]) + + // Clear draft after send + const handleSendAndClearDraft = useCallback(() => { + try { + vscode.postMessage({ type: "clearChatTextDraft" }) + } catch (err) { + console.warn(`[useChatTextDraft] Failed to clear draft:`, err) + } + onSend() + }, [onSend]) + + return { handleSendAndClearDraft } +}