diff --git a/src/core/tools/__tests__/writeToFileTool.spec.ts b/src/core/tools/__tests__/writeToFileTool.spec.ts index 78e60cbaa5..c7ace0a916 100644 --- a/src/core/tools/__tests__/writeToFileTool.spec.ts +++ b/src/core/tools/__tests__/writeToFileTool.spec.ts @@ -1,4 +1,5 @@ import * as path from "path" +import * as fs from "fs/promises" import type { MockedFunction } from "vitest" @@ -10,6 +11,7 @@ import { unescapeHtmlEntities } from "../../../utils/text-normalization" import { everyLineHasLineNumbers, stripLineNumbers } from "../../../integrations/misc/extract-text" import { ToolUse, ToolResponse } from "../../../shared/tools" import { writeToFileTool } from "../writeToFileTool" +import { experiments, EXPERIMENT_IDS } from "../../../shared/experiments" vi.mock("path", async () => { const originalPath = await vi.importActual("path") @@ -27,6 +29,19 @@ vi.mock("delay", () => ({ default: vi.fn(), })) +vi.mock("fs/promises", () => ({ + readFile: vi.fn().mockResolvedValue("original content"), +})) + +vi.mock("../../../shared/experiments", () => ({ + experiments: { + isEnabled: vi.fn().mockReturnValue(false), // Default to disabled + }, + EXPERIMENT_IDS: { + PREVENT_FOCUS_DISRUPTION: "preventFocusDisruption", + }, +})) + vi.mock("../../../utils/fs", () => ({ fileExistsAtPath: vi.fn().mockResolvedValue(false), })) @@ -108,6 +123,8 @@ describe("writeToFileTool", () => { const mockedEveryLineHasLineNumbers = everyLineHasLineNumbers as MockedFunction const mockedStripLineNumbers = stripLineNumbers as MockedFunction const mockedPathResolve = path.resolve as MockedFunction + const mockedExperimentsIsEnabled = experiments.isEnabled as MockedFunction + const mockedFsReadFile = fs.readFile as MockedFunction const mockCline: any = {} let mockAskApproval: ReturnType @@ -127,6 +144,8 @@ describe("writeToFileTool", () => { mockedUnescapeHtmlEntities.mockImplementation((content) => content) mockedEveryLineHasLineNumbers.mockReturnValue(false) mockedStripLineNumbers.mockImplementation((content) => content) + mockedExperimentsIsEnabled.mockReturnValue(false) // Default to disabled + mockedFsReadFile.mockResolvedValue("original content") mockCline.cwd = "/" mockCline.consecutiveMistakeCount = 0 @@ -416,4 +435,98 @@ describe("writeToFileTool", () => { expect(mockCline.diffViewProvider.reset).toHaveBeenCalled() }) }) + + describe("PREVENT_FOCUS_DISRUPTION experiment", () => { + beforeEach(() => { + // Reset the experiments mock for these tests + mockedExperimentsIsEnabled.mockReset() + }) + + it("should NOT save file before user approval when experiment is enabled", async () => { + // Enable the PREVENT_FOCUS_DISRUPTION experiment + mockedExperimentsIsEnabled.mockReturnValue(true) + + mockCline.providerRef.deref().getState.mockResolvedValue({ + diagnosticsEnabled: true, + writeDelayMs: 1000, + experiments: { + preventFocusDisruption: true, + }, + }) + + // Mock saveDirectly to track when it's called + const saveDirectlySpy = vi.fn().mockResolvedValue({ + newProblemsMessage: "", + userEdits: undefined, + finalContent: testContent, + }) + mockCline.diffViewProvider.saveDirectly = saveDirectlySpy + + // User rejects the approval + mockAskApproval.mockResolvedValue(false) + + await executeWriteFileTool({}, { fileExists: false }) + + // Verify that askApproval was called + expect(mockAskApproval).toHaveBeenCalled() + + // Verify that saveDirectly was NOT called since user rejected + expect(saveDirectlySpy).not.toHaveBeenCalled() + + // Verify that the diffViewProvider state was reset + expect(mockCline.diffViewProvider.editType).toBe(undefined) + expect(mockCline.diffViewProvider.originalContent).toBe(undefined) + }) + + it("should save file AFTER user approval when experiment is enabled", async () => { + // Enable the PREVENT_FOCUS_DISRUPTION experiment + mockedExperimentsIsEnabled.mockReturnValue(true) + + mockCline.providerRef.deref().getState.mockResolvedValue({ + diagnosticsEnabled: true, + writeDelayMs: 1000, + experiments: { + preventFocusDisruption: true, + }, + }) + + // Mock saveDirectly to track when it's called + const saveDirectlySpy = vi.fn().mockResolvedValue({ + newProblemsMessage: "", + userEdits: undefined, + finalContent: testContent, + }) + mockCline.diffViewProvider.saveDirectly = saveDirectlySpy + + // Mock pushToolWriteResult + mockCline.diffViewProvider.pushToolWriteResult = vi.fn().mockResolvedValue("Tool result message") + + // User approves + mockAskApproval.mockResolvedValue(true) + + // Track the order of calls + const callOrder: string[] = [] + mockAskApproval.mockImplementation(async () => { + callOrder.push("askApproval") + return true + }) + saveDirectlySpy.mockImplementation(async () => { + callOrder.push("saveDirectly") + return { + newProblemsMessage: "", + userEdits: undefined, + finalContent: testContent, + } + }) + + await executeWriteFileTool({}, { fileExists: false }) + + // Verify that askApproval was called BEFORE saveDirectly + expect(callOrder).toEqual(["askApproval", "saveDirectly"]) + + // Verify both were called + expect(mockAskApproval).toHaveBeenCalled() + expect(saveDirectlySpy).toHaveBeenCalledWith(testFilePath, testContent, false, true, 1000) + }) + }) }) diff --git a/src/core/tools/writeToFileTool.ts b/src/core/tools/writeToFileTool.ts index e82eab92bc..575085c011 100644 --- a/src/core/tools/writeToFileTool.ts +++ b/src/core/tools/writeToFileTool.ts @@ -202,6 +202,16 @@ export async function writeToFileTool( } } + // Set up diffViewProvider properties needed for saveDirectly BEFORE asking for approval + // This ensures we have the original content for comparison but don't write anything yet + cline.diffViewProvider.editType = fileExists ? "modify" : "create" + if (fileExists) { + const absolutePath = path.resolve(cline.cwd, relPath) + cline.diffViewProvider.originalContent = await fs.readFile(absolutePath, "utf-8") + } else { + cline.diffViewProvider.originalContent = "" + } + const completeMessage = JSON.stringify({ ...sharedMessageProps, content: newContent, @@ -210,19 +220,13 @@ export async function writeToFileTool( const didApprove = await askApproval("tool", completeMessage, undefined, isWriteProtected) if (!didApprove) { + // Reset the diffViewProvider state since we're not proceeding + cline.diffViewProvider.editType = undefined + cline.diffViewProvider.originalContent = undefined return } - // Set up diffViewProvider properties needed for saveDirectly - cline.diffViewProvider.editType = fileExists ? "modify" : "create" - if (fileExists) { - const absolutePath = path.resolve(cline.cwd, relPath) - cline.diffViewProvider.originalContent = await fs.readFile(absolutePath, "utf-8") - } else { - cline.diffViewProvider.originalContent = "" - } - - // Save directly without showing diff view or opening the file + // Only save directly AFTER user approval await cline.diffViewProvider.saveDirectly(relPath, newContent, false, diagnosticsEnabled, writeDelayMs) } else { // Original behavior with diff view