Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
113 changes: 113 additions & 0 deletions src/core/tools/__tests__/writeToFileTool.spec.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import * as path from "path"
import * as fs from "fs/promises"

import type { MockedFunction } from "vitest"

Expand All @@ -10,6 +11,7 @@ import { unescapeHtmlEntities } from "../../../utils/text-normalization"
import { everyLineHasLineNumbers, stripLineNumbers } from "../../../integrations/misc/extract-text"
import { ToolUse, ToolResponse } from "../../../shared/tools"
import { writeToFileTool } from "../writeToFileTool"
import { experiments, EXPERIMENT_IDS } from "../../../shared/experiments"

vi.mock("path", async () => {
const originalPath = await vi.importActual("path")
Expand All @@ -27,6 +29,19 @@ vi.mock("delay", () => ({
default: vi.fn(),
}))

vi.mock("fs/promises", () => ({
readFile: vi.fn().mockResolvedValue("original content"),
}))

vi.mock("../../../shared/experiments", () => ({
experiments: {
isEnabled: vi.fn().mockReturnValue(false), // Default to disabled
},
EXPERIMENT_IDS: {
PREVENT_FOCUS_DISRUPTION: "preventFocusDisruption",
},
}))

vi.mock("../../../utils/fs", () => ({
fileExistsAtPath: vi.fn().mockResolvedValue(false),
}))
Expand Down Expand Up @@ -108,6 +123,8 @@ describe("writeToFileTool", () => {
const mockedEveryLineHasLineNumbers = everyLineHasLineNumbers as MockedFunction<typeof everyLineHasLineNumbers>
const mockedStripLineNumbers = stripLineNumbers as MockedFunction<typeof stripLineNumbers>
const mockedPathResolve = path.resolve as MockedFunction<typeof path.resolve>
const mockedExperimentsIsEnabled = experiments.isEnabled as MockedFunction<typeof experiments.isEnabled>
const mockedFsReadFile = fs.readFile as MockedFunction<typeof fs.readFile>

const mockCline: any = {}
let mockAskApproval: ReturnType<typeof vi.fn>
Expand All @@ -127,6 +144,8 @@ describe("writeToFileTool", () => {
mockedUnescapeHtmlEntities.mockImplementation((content) => content)
mockedEveryLineHasLineNumbers.mockReturnValue(false)
mockedStripLineNumbers.mockImplementation((content) => content)
mockedExperimentsIsEnabled.mockReturnValue(false) // Default to disabled
mockedFsReadFile.mockResolvedValue("original content")

mockCline.cwd = "/"
mockCline.consecutiveMistakeCount = 0
Expand Down Expand Up @@ -416,4 +435,98 @@ describe("writeToFileTool", () => {
expect(mockCline.diffViewProvider.reset).toHaveBeenCalled()
})
})

describe("PREVENT_FOCUS_DISRUPTION experiment", () => {
beforeEach(() => {
// Reset the experiments mock for these tests
mockedExperimentsIsEnabled.mockReset()
})

it("should NOT save file before user approval when experiment is enabled", async () => {
// Enable the PREVENT_FOCUS_DISRUPTION experiment
mockedExperimentsIsEnabled.mockReturnValue(true)

mockCline.providerRef.deref().getState.mockResolvedValue({
diagnosticsEnabled: true,
writeDelayMs: 1000,
experiments: {
preventFocusDisruption: true,
},
})

// Mock saveDirectly to track when it's called
const saveDirectlySpy = vi.fn().mockResolvedValue({
newProblemsMessage: "",
userEdits: undefined,
finalContent: testContent,
})
mockCline.diffViewProvider.saveDirectly = saveDirectlySpy

// User rejects the approval
mockAskApproval.mockResolvedValue(false)

await executeWriteFileTool({}, { fileExists: false })

// Verify that askApproval was called
expect(mockAskApproval).toHaveBeenCalled()

// Verify that saveDirectly was NOT called since user rejected
expect(saveDirectlySpy).not.toHaveBeenCalled()

// Verify that the diffViewProvider state was reset
expect(mockCline.diffViewProvider.editType).toBe(undefined)
expect(mockCline.diffViewProvider.originalContent).toBe(undefined)
})

it("should save file AFTER user approval when experiment is enabled", async () => {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great test coverage for the happy path! Consider adding a test case for when fs.readFile fails to ensure we handle that error scenario gracefully.

// Enable the PREVENT_FOCUS_DISRUPTION experiment
mockedExperimentsIsEnabled.mockReturnValue(true)

mockCline.providerRef.deref().getState.mockResolvedValue({
diagnosticsEnabled: true,
writeDelayMs: 1000,
experiments: {
preventFocusDisruption: true,
},
})

// Mock saveDirectly to track when it's called
const saveDirectlySpy = vi.fn().mockResolvedValue({
newProblemsMessage: "",
userEdits: undefined,
finalContent: testContent,
})
mockCline.diffViewProvider.saveDirectly = saveDirectlySpy

// Mock pushToolWriteResult
mockCline.diffViewProvider.pushToolWriteResult = vi.fn().mockResolvedValue("Tool result message")

// User approves
mockAskApproval.mockResolvedValue(true)

// Track the order of calls
const callOrder: string[] = []
mockAskApproval.mockImplementation(async () => {
callOrder.push("askApproval")
return true
})
saveDirectlySpy.mockImplementation(async () => {
callOrder.push("saveDirectly")
return {
newProblemsMessage: "",
userEdits: undefined,
finalContent: testContent,
}
})

await executeWriteFileTool({}, { fileExists: false })

// Verify that askApproval was called BEFORE saveDirectly
expect(callOrder).toEqual(["askApproval", "saveDirectly"])

// Verify both were called
expect(mockAskApproval).toHaveBeenCalled()
expect(saveDirectlySpy).toHaveBeenCalledWith(testFilePath, testContent, false, true, 1000)
})
})
})
24 changes: 14 additions & 10 deletions src/core/tools/writeToFileTool.ts
Original file line number Diff line number Diff line change
Expand Up @@ -202,6 +202,16 @@ export async function writeToFileTool(
}
}

// Set up diffViewProvider properties needed for saveDirectly BEFORE asking for approval
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The comment explains WHAT we're doing but not WHY it's critical for security. Could we make this more explicit? Something like:

Suggested change
// Set up diffViewProvider properties needed for saveDirectly BEFORE asking for approval
// Set up diffViewProvider properties needed for saveDirectly BEFORE asking for approval
// This ensures we have the original content for comparison but critically prevents any
// file writes from occurring until the user explicitly approves the operation

// This ensures we have the original content for comparison but don't write anything yet
cline.diffViewProvider.editType = fileExists ? "modify" : "create"
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I notice this block (lines 207-213) duplicates similar logic from the non-experiment path. Could we extract this into a helper function to reduce duplication? Something like prepareDiffViewState(fileExists, relPath)?

if (fileExists) {
const absolutePath = path.resolve(cline.cwd, relPath)
cline.diffViewProvider.originalContent = await fs.readFile(absolutePath, "utf-8")
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we add error handling here in case fs.readFile fails? Currently if reading the original file fails, the error would bubble up and might not be handled gracefully.

} else {
cline.diffViewProvider.originalContent = ""
}

const completeMessage = JSON.stringify({
...sharedMessageProps,
content: newContent,
Expand All @@ -210,19 +220,13 @@ export async function writeToFileTool(
const didApprove = await askApproval("tool", completeMessage, undefined, isWriteProtected)

if (!didApprove) {
// Reset the diffViewProvider state since we're not proceeding
cline.diffViewProvider.editType = undefined
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good defensive programming! Properly resetting the state when the user rejects ensures no lingering state that could cause issues.

cline.diffViewProvider.originalContent = undefined
return
}

// Set up diffViewProvider properties needed for saveDirectly
cline.diffViewProvider.editType = fileExists ? "modify" : "create"
if (fileExists) {
const absolutePath = path.resolve(cline.cwd, relPath)
cline.diffViewProvider.originalContent = await fs.readFile(absolutePath, "utf-8")
} else {
cline.diffViewProvider.originalContent = ""
}

// Save directly without showing diff view or opening the file
// Only save directly AFTER user approval
await cline.diffViewProvider.saveDirectly(relPath, newContent, false, diagnosticsEnabled, writeDelayMs)
} else {
// Original behavior with diff view
Expand Down
Loading