diff --git a/packages/types/src/provider-settings.ts b/packages/types/src/provider-settings.ts
index ea7089a81ea..ac472d647e6 100644
--- a/packages/types/src/provider-settings.ts
+++ b/packages/types/src/provider-settings.ts
@@ -63,6 +63,7 @@ export const DEFAULT_CONSECUTIVE_MISTAKE_LIMIT = 3
const baseProviderSettingsSchema = z.object({
includeMaxTokens: z.boolean().optional(),
diffEnabled: z.boolean().optional(),
+ applyEnabled: z.boolean().optional(),
todoListEnabled: z.boolean().optional(),
fuzzyMatchThreshold: z.number().optional(),
modelTemperature: z.number().nullish(),
diff --git a/packages/types/src/tool.ts b/packages/types/src/tool.ts
index 7a3fd211999..1fd14d245ea 100644
--- a/packages/types/src/tool.ts
+++ b/packages/types/src/tool.ts
@@ -19,6 +19,7 @@ export const toolNames = [
"read_file",
"write_to_file",
"apply_diff",
+ "apply_code",
"insert_content",
"search_and_replace",
"search_files",
diff --git a/src/core/assistant-message/presentAssistantMessage.ts b/src/core/assistant-message/presentAssistantMessage.ts
index ee3fa148b41..56eaa5dafc8 100644
--- a/src/core/assistant-message/presentAssistantMessage.ts
+++ b/src/core/assistant-message/presentAssistantMessage.ts
@@ -34,6 +34,7 @@ import { Task } from "../task/Task"
import { codebaseSearchTool } from "../tools/codebaseSearchTool"
import { experiments, EXPERIMENT_IDS } from "../../shared/experiments"
import { applyDiffToolLegacy } from "../tools/applyDiffTool"
+import { applyCodeTool } from "../tools/applyCodeTool"
/**
* Processes and presents assistant message content to the user interface.
@@ -214,6 +215,8 @@ export async function presentAssistantMessage(cline: Task) {
const modeName = getModeBySlug(mode, customModes)?.name ?? mode
return `[${block.name} in ${modeName} mode: '${message}']`
}
+ case "apply_code":
+ return `[${block.name} for '${block.params.path}' with instruction: '${block.params.instruction}']`
}
}
@@ -522,6 +525,9 @@ export async function presentAssistantMessage(cline: Task) {
askFinishSubTaskApproval,
)
break
+ case "apply_code":
+ await applyCodeTool(cline, block, askApproval, handleError, pushToolResult, removeClosingTag)
+ break
}
break
diff --git a/src/core/prompts/tools/apply-code.ts b/src/core/prompts/tools/apply-code.ts
new file mode 100644
index 00000000000..c5936eb95a6
--- /dev/null
+++ b/src/core/prompts/tools/apply-code.ts
@@ -0,0 +1,34 @@
+import { ToolArgs } from "./types"
+
+export function getApplyCodeDescription(args: ToolArgs): string {
+ return `## apply_code
+Description: Request to apply code changes using a two-stage approach for improved reliability. This tool first generates code based on your instruction, then creates an accurate diff to integrate it into the existing file. This approach separates creative code generation from technical diff creation, resulting in more reliable code modifications.
+
+Parameters:
+- path: (required) The path of the file to modify (relative to the current workspace directory ${args.cwd})
+- instruction: (required) Clear instruction describing what code changes to make
+
+Usage:
+
+File path here
+Your instruction for code changes
+
+
+Example: Adding a new function to an existing file
+
+src/utils.ts
+Add a function called calculateAverage that takes an array of numbers and returns their average
+
+
+Example: Modifying existing code
+
+src/api/handler.ts
+Update the error handling in the fetchData function to include retry logic with exponential backoff
+
+
+Benefits over apply_diff:
+- More reliable: Separates code generation from diff creation
+- Cleaner context: Each stage has focused, minimal context
+- Better success rate: Reduces failures due to inaccurate diffs
+- Natural instructions: Use plain language instead of crafting diffs`
+}
diff --git a/src/core/prompts/tools/index.ts b/src/core/prompts/tools/index.ts
index 9f4af7f312c..cf7507bcd75 100644
--- a/src/core/prompts/tools/index.ts
+++ b/src/core/prompts/tools/index.ts
@@ -23,6 +23,7 @@ import { getSwitchModeDescription } from "./switch-mode"
import { getNewTaskDescription } from "./new-task"
import { getCodebaseSearchDescription } from "./codebase-search"
import { getUpdateTodoListDescription } from "./update-todo-list"
+import { getApplyCodeDescription } from "./apply-code"
import { CodeIndexManager } from "../../../services/code-index/manager"
// Map of tool names to their description functions
@@ -46,6 +47,7 @@ const toolDescriptionMap: Record string | undefined>
search_and_replace: (args) => getSearchAndReplaceDescription(args),
apply_diff: (args) =>
args.diffStrategy ? args.diffStrategy.getToolDescription({ cwd: args.cwd, toolOptions: args.toolOptions }) : "",
+ apply_code: (args) => getApplyCodeDescription(args),
update_todo_list: (args) => getUpdateTodoListDescription(args),
}
@@ -114,6 +116,11 @@ export function getToolDescriptionsForMode(
tools.delete("update_todo_list")
}
+ // Conditionally exclude apply_code if disabled in settings
+ if (settings?.applyEnabled === false) {
+ tools.delete("apply_code")
+ }
+
// Map tool descriptions for allowed tools
const descriptions = Array.from(tools).map((toolName) => {
const descriptionFn = toolDescriptionMap[toolName]
@@ -148,4 +155,5 @@ export {
getInsertContentDescription,
getSearchAndReplaceDescription,
getCodebaseSearchDescription,
+ getApplyCodeDescription,
}
diff --git a/src/core/tools/__tests__/applyCodeTool.spec.ts b/src/core/tools/__tests__/applyCodeTool.spec.ts
new file mode 100644
index 00000000000..deedc8c3627
--- /dev/null
+++ b/src/core/tools/__tests__/applyCodeTool.spec.ts
@@ -0,0 +1,626 @@
+import { vi, describe, it, expect, beforeEach } from "vitest"
+import type { MockedFunction } from "vitest"
+import { applyCodeTool } from "../applyCodeTool"
+import { ToolUse, ToolResponse } from "../../../shared/tools"
+import { fileExistsAtPath } from "../../../utils/fs"
+import { getReadablePath } from "../../../utils/path"
+import * as path from "path"
+
+// Mock fs/promises before any imports
+vi.mock("fs/promises", () => ({
+ default: {
+ readFile: vi.fn(),
+ },
+ readFile: vi.fn(),
+}))
+
+// Mock dependencies
+vi.mock("path", async () => {
+ const originalPath = await vi.importActual("path")
+ return {
+ ...originalPath,
+ resolve: vi.fn().mockImplementation((...args) => {
+ const separator = process.platform === "win32" ? "\\" : "/"
+ return args.join(separator)
+ }),
+ }
+})
+
+vi.mock("../../../utils/fs", () => ({
+ fileExistsAtPath: vi.fn().mockResolvedValue(true),
+}))
+
+vi.mock("../../../utils/path", () => ({
+ getReadablePath: vi.fn().mockReturnValue("test/file.ts"),
+}))
+
+vi.mock("../../prompts/responses", () => ({
+ formatResponse: {
+ toolError: vi.fn((msg) => `Error: ${msg}`),
+ rooIgnoreError: vi.fn((path) => `Access denied: ${path}`),
+ createPrettyPatch: vi.fn(() => "mock-diff"),
+ },
+}))
+
+vi.mock("../../../api", () => ({
+ buildApiHandler: vi.fn().mockReturnValue({
+ createMessage: vi.fn(),
+ getModel: vi.fn().mockReturnValue({
+ id: "claude-3",
+ info: { contextWindow: 200000 },
+ }),
+ }),
+}))
+
+// Import after mocking to get the mocked version
+import fs from "fs/promises"
+import { buildApiHandler } from "../../../api"
+
+describe("applyCodeTool", () => {
+ // Test data
+ const testFilePath = "test/file.ts"
+ const absoluteFilePath = process.platform === "win32" ? "C:\\test\\file.ts" : "/test/file.ts"
+ const testInstruction = "Add error handling to the function"
+ const originalContent = `function getData() {
+ return fetch('/api/data').then(res => res.json());
+}`
+
+ // Mocked functions
+ const mockedFileExistsAtPath = fileExistsAtPath as MockedFunction
+ const mockedGetReadablePath = getReadablePath as MockedFunction
+ const mockedPathResolve = path.resolve as MockedFunction
+ const mockedReadFile = fs.readFile as MockedFunction
+ const mockedBuildApiHandler = buildApiHandler as MockedFunction
+
+ const mockCline: any = {}
+ let mockAskApproval: ReturnType
+ let mockHandleError: ReturnType
+ let mockPushToolResult: ReturnType
+ let mockRemoveClosingTag: ReturnType
+ let toolResult: ToolResponse | undefined
+
+ beforeEach(() => {
+ vi.clearAllMocks()
+
+ mockedPathResolve.mockReturnValue(absoluteFilePath)
+ mockedFileExistsAtPath.mockResolvedValue(true)
+ mockedGetReadablePath.mockReturnValue(testFilePath)
+
+ mockCline.cwd = "/"
+ mockCline.consecutiveMistakeCount = 0
+ mockCline.taskId = "test-task-id"
+ mockCline.apiConfiguration = { apiProvider: "anthropic", apiKey: "test-key" }
+ mockCline.api = {
+ createMessage: vi.fn(),
+ getModel: vi.fn().mockReturnValue({
+ id: "claude-3",
+ info: { contextWindow: 200000 },
+ }),
+ }
+ mockCline.providerRef = {
+ deref: vi.fn().mockReturnValue({
+ getState: vi.fn().mockResolvedValue({
+ applyEnabled: true,
+ diagnosticsEnabled: true,
+ writeDelayMs: 0,
+ }),
+ }),
+ }
+ mockCline.rooIgnoreController = {
+ validateAccess: vi.fn().mockReturnValue(true),
+ }
+ mockCline.rooProtectedController = {
+ isWriteProtected: vi.fn().mockReturnValue(false),
+ }
+ mockCline.diffStrategy = {
+ applyDiff: vi.fn().mockResolvedValue({
+ success: true,
+ content: "modified content",
+ }),
+ }
+ mockCline.diffViewProvider = {
+ reset: vi.fn().mockResolvedValue(undefined),
+ editType: undefined,
+ open: vi.fn().mockResolvedValue(undefined),
+ update: vi.fn().mockResolvedValue(undefined),
+ scrollToFirstDiff: vi.fn(),
+ revertChanges: vi.fn().mockResolvedValue(undefined),
+ saveChanges: vi.fn().mockResolvedValue(undefined),
+ pushToolWriteResult: vi.fn().mockResolvedValue("File updated successfully"),
+ }
+ mockCline.fileContextTracker = {
+ trackFileContext: vi.fn().mockResolvedValue(undefined),
+ }
+ mockCline.say = vi.fn().mockResolvedValue(undefined)
+ mockCline.ask = vi.fn().mockResolvedValue(undefined)
+ mockCline.recordToolError = vi.fn()
+ mockCline.recordToolUsage = vi.fn()
+ mockCline.sayAndCreateMissingParamError = vi.fn().mockResolvedValue("Missing param error")
+
+ mockAskApproval = vi.fn().mockResolvedValue(true)
+ mockHandleError = vi.fn().mockResolvedValue(undefined)
+ mockRemoveClosingTag = vi.fn((tag, content) => content)
+ mockPushToolResult = vi.fn()
+
+ toolResult = undefined
+ })
+
+ /**
+ * Helper function to execute the apply code tool
+ */
+ async function executeApplyCodeTool(
+ params: Partial = {},
+ options: {
+ fileExists?: boolean
+ isPartial?: boolean
+ accessAllowed?: boolean
+ applyEnabled?: boolean
+ } = {},
+ ): Promise {
+ const fileExists = options.fileExists ?? true
+ const isPartial = options.isPartial ?? false
+ const accessAllowed = options.accessAllowed ?? true
+ const applyEnabled = options.applyEnabled ?? true
+
+ mockedFileExistsAtPath.mockResolvedValue(fileExists)
+ mockCline.rooIgnoreController.validateAccess.mockReturnValue(accessAllowed)
+ mockCline.providerRef.deref().getState.mockResolvedValue({ applyEnabled })
+
+ const toolUse: ToolUse = {
+ type: "tool_use",
+ name: "apply_code",
+ params: {
+ path: testFilePath,
+ instruction: testInstruction,
+ ...params,
+ },
+ partial: isPartial,
+ }
+
+ await applyCodeTool(
+ mockCline,
+ toolUse,
+ mockAskApproval,
+ mockHandleError,
+ (result: ToolResponse) => {
+ toolResult = result
+ mockPushToolResult(result)
+ },
+ mockRemoveClosingTag,
+ )
+
+ return toolResult
+ }
+
+ describe("parameter validation", () => {
+ it("handles missing path parameter", async () => {
+ await executeApplyCodeTool({ path: undefined })
+
+ expect(mockCline.sayAndCreateMissingParamError).toHaveBeenCalledWith("apply_code", "path")
+ expect(mockPushToolResult).toHaveBeenCalledWith("Missing param error")
+ })
+
+ it("handles missing instruction parameter", async () => {
+ await executeApplyCodeTool({ instruction: undefined })
+
+ expect(mockCline.sayAndCreateMissingParamError).toHaveBeenCalledWith("apply_code", "instruction")
+ expect(mockPushToolResult).toHaveBeenCalledWith("Missing param error")
+ })
+ })
+
+ describe("file validation", () => {
+ it("validates access with rooIgnoreController", async () => {
+ await executeApplyCodeTool({}, { accessAllowed: false })
+
+ expect(mockCline.rooIgnoreController.validateAccess).toHaveBeenCalledWith(testFilePath)
+ expect(mockPushToolResult).toHaveBeenCalledWith(expect.stringContaining("Access denied"))
+ })
+ })
+
+ describe("two-stage API workflow", () => {
+ it("makes two API calls with correct prompts and isolated context", async () => {
+ // Mock file read
+ mockedReadFile.mockResolvedValue(originalContent)
+
+ // Mock successful API responses
+ const generatedCode = `try {
+ return fetch('/api/data').then(res => {
+ if (!res.ok) throw new Error('Failed to fetch');
+ return res.json();
+ });
+} catch (error) {
+ console.error('Error fetching data:', error);
+ throw error;
+}`
+
+ // First API call response (code generation) - returns JSON
+ const mockStream1 = {
+ [Symbol.asyncIterator]: vi.fn().mockReturnValue({
+ next: vi
+ .fn()
+ .mockResolvedValueOnce({
+ value: {
+ type: "text",
+ text: JSON.stringify({
+ file: testFilePath,
+ type: "snippet",
+ code: generatedCode,
+ }),
+ },
+ done: false,
+ })
+ .mockResolvedValueOnce({ done: true }),
+ }),
+ }
+
+ // Second API call response (diff generation)
+ const mockStream2 = {
+ [Symbol.asyncIterator]: vi.fn().mockReturnValue({
+ next: vi
+ .fn()
+ .mockResolvedValueOnce({
+ value: {
+ type: "text",
+ text: `<<<<<<< SEARCH
+function getData() {
+ return fetch('/api/data').then(res => res.json());
+}
+=======
+function getData() {
+ try {
+ return fetch('/api/data').then(res => {
+ if (!res.ok) throw new Error('Failed to fetch');
+ return res.json();
+ });
+ } catch (error) {
+ console.error('Error fetching data:', error);
+ throw error;
+ }
+}
+>>>>>>> REPLACE`,
+ },
+ done: false,
+ })
+ .mockResolvedValueOnce({ done: true }),
+ }),
+ }
+
+ // Mock the isolated API handler
+ const mockIsolatedApiHandler = {
+ createMessage: vi.fn().mockReturnValue(mockStream2),
+ getModel: vi.fn().mockReturnValue({
+ id: "claude-3",
+ info: { contextWindow: 200000 },
+ }),
+ countTokens: vi.fn().mockResolvedValue(100),
+ }
+
+ mockCline.api.createMessage.mockReturnValueOnce(mockStream1)
+ mockedBuildApiHandler.mockReturnValue(mockIsolatedApiHandler)
+
+ await executeApplyCodeTool()
+
+ // Verify first API call (code generation) uses main API
+ expect(mockCline.api.createMessage).toHaveBeenCalledTimes(1)
+ const firstCall = mockCline.api.createMessage.mock.calls[0]
+ expect(firstCall[0]).toContain("code generation expert")
+ expect(firstCall[1][0].content[0].text).toContain(testInstruction)
+
+ // Verify second API call uses isolated handler
+ expect(mockedBuildApiHandler).toHaveBeenCalledWith(mockCline.apiConfiguration)
+ expect(mockIsolatedApiHandler.createMessage).toHaveBeenCalledTimes(1)
+
+ // Verify the isolated call has the hardcoded system prompt
+ const secondCall = mockIsolatedApiHandler.createMessage.mock.calls[0]
+ expect(secondCall[0]).toContain("specialized diff generation model")
+ expect(secondCall[0]).toContain("Your ONLY task is to generate accurate diff patches")
+
+ // Verify the isolated call has clean context (no conversation history)
+ expect(secondCall[1]).toHaveLength(1)
+ expect(secondCall[1][0].role).toBe("user")
+ expect(secondCall[1][0].content[0].text).toContain("Original file content:")
+ expect(secondCall[1][0].content[0].text).toContain("New code to integrate:")
+ })
+
+ it("applies the generated diff using diffStrategy", async () => {
+ // Mock file read
+ mockedReadFile.mockResolvedValue(originalContent)
+
+ // Mock API responses
+ const mockStream1 = {
+ [Symbol.asyncIterator]: vi.fn().mockReturnValue({
+ next: vi
+ .fn()
+ .mockResolvedValueOnce({
+ value: {
+ type: "text",
+ text: JSON.stringify({
+ file: testFilePath,
+ type: "snippet",
+ code: "generated code",
+ }),
+ },
+ done: false,
+ })
+ .mockResolvedValueOnce({ done: true }),
+ }),
+ }
+
+ const mockStream2 = {
+ [Symbol.asyncIterator]: vi.fn().mockReturnValue({
+ next: vi
+ .fn()
+ .mockResolvedValueOnce({
+ value: {
+ type: "text",
+ text: "<<<<<<< SEARCH\noriginal\n=======\nmodified\n>>>>>>> REPLACE",
+ },
+ done: false,
+ })
+ .mockResolvedValueOnce({ done: true }),
+ }),
+ }
+
+ const mockIsolatedApiHandler = {
+ createMessage: vi.fn().mockReturnValue(mockStream2),
+ getModel: vi.fn().mockReturnValue({
+ id: "claude-3",
+ info: { contextWindow: 200000 },
+ }),
+ countTokens: vi.fn().mockResolvedValue(100),
+ }
+
+ mockCline.api.createMessage.mockReturnValueOnce(mockStream1)
+ mockedBuildApiHandler.mockReturnValue(mockIsolatedApiHandler)
+
+ await executeApplyCodeTool()
+
+ // Verify diffStrategy.applyDiff was called with the generated diff
+ expect(mockCline.diffStrategy.applyDiff).toHaveBeenCalledWith(
+ originalContent,
+ "<<<<<<< SEARCH\noriginal\n=======\nmodified\n>>>>>>> REPLACE",
+ )
+
+ // Verify the diff view was updated
+ expect(mockCline.diffViewProvider.update).toHaveBeenCalledWith("modified content", true)
+ expect(mockPushToolResult).toHaveBeenCalledWith("File updated successfully")
+ })
+
+ it("handles new file creation", async () => {
+ // Mock file doesn't exist
+ mockedFileExistsAtPath.mockResolvedValue(false)
+
+ // Mock API response for new file
+ const newFileContent = `export function newFunction() {
+ return "Hello, World!";
+}`
+
+ const mockStream = {
+ [Symbol.asyncIterator]: vi.fn().mockReturnValue({
+ next: vi
+ .fn()
+ .mockResolvedValueOnce({
+ value: {
+ type: "text",
+ text: JSON.stringify({
+ file: testFilePath,
+ type: "full_file",
+ code: newFileContent,
+ }),
+ },
+ done: false,
+ })
+ .mockResolvedValueOnce({ done: true }),
+ }),
+ }
+
+ mockCline.api.createMessage.mockReturnValue(mockStream)
+
+ await executeApplyCodeTool({}, { fileExists: false })
+
+ // Verify only one API call was made (no diff generation for new files)
+ expect(mockCline.api.createMessage).toHaveBeenCalledTimes(1)
+ expect(mockedBuildApiHandler).not.toHaveBeenCalled()
+
+ // Verify the file was created
+ expect(mockCline.diffViewProvider.editType).toBe("create")
+ expect(mockCline.diffViewProvider.update).toHaveBeenCalledWith(newFileContent, true)
+ })
+ })
+
+ describe("error handling", () => {
+ it("handles API errors in first stage", async () => {
+ // Mock file read
+ mockedReadFile.mockResolvedValue(originalContent)
+
+ // Mock API error
+ const mockStream = {
+ [Symbol.asyncIterator]: vi.fn().mockImplementation(() => ({
+ next: vi.fn().mockRejectedValue(new Error("API error")),
+ })),
+ }
+ mockCline.api.createMessage.mockReturnValue(mockStream)
+
+ await executeApplyCodeTool()
+
+ expect(mockHandleError).toHaveBeenCalledWith("applying code", expect.any(Error))
+ })
+
+ it("handles file read errors", async () => {
+ // Mock file read error
+ mockedReadFile.mockRejectedValue(new Error("File read error"))
+
+ await executeApplyCodeTool()
+
+ expect(mockHandleError).toHaveBeenCalledWith("applying code", expect.any(Error))
+ })
+
+ it("handles malformed JSON responses", async () => {
+ // Mock file read
+ mockedReadFile.mockResolvedValue(originalContent)
+
+ // Mock malformed response (invalid JSON)
+ const mockStream = {
+ [Symbol.asyncIterator]: vi.fn().mockReturnValue({
+ next: vi
+ .fn()
+ .mockResolvedValueOnce({
+ value: { type: "text", text: "Just some text without valid JSON" },
+ done: false,
+ })
+ .mockResolvedValueOnce({ done: true }),
+ }),
+ }
+
+ mockCline.api.createMessage.mockReturnValue(mockStream)
+
+ await executeApplyCodeTool()
+
+ expect(mockCline.say).toHaveBeenCalledWith(
+ "error",
+ expect.stringContaining("Failed to parse code generation response"),
+ )
+ })
+
+ it("handles diff application failures", async () => {
+ // Mock file read
+ mockedReadFile.mockResolvedValue(originalContent)
+
+ // Mock successful code generation
+ const mockStream1 = {
+ [Symbol.asyncIterator]: vi.fn().mockReturnValue({
+ next: vi
+ .fn()
+ .mockResolvedValueOnce({
+ value: {
+ type: "text",
+ text: JSON.stringify({
+ file: testFilePath,
+ type: "snippet",
+ code: "generated code",
+ }),
+ },
+ done: false,
+ })
+ .mockResolvedValueOnce({ done: true }),
+ }),
+ }
+
+ // Mock successful diff generation
+ const mockStream2 = {
+ [Symbol.asyncIterator]: vi.fn().mockReturnValue({
+ next: vi
+ .fn()
+ .mockResolvedValueOnce({
+ value: {
+ type: "text",
+ text: "<<<<<<< SEARCH\nwrong content\n=======\nmodified\n>>>>>>> REPLACE",
+ },
+ done: false,
+ })
+ .mockResolvedValueOnce({ done: true }),
+ }),
+ }
+
+ const mockIsolatedApiHandler = {
+ createMessage: vi.fn().mockReturnValue(mockStream2),
+ getModel: vi.fn().mockReturnValue({
+ id: "claude-3",
+ info: { contextWindow: 200000 },
+ }),
+ countTokens: vi.fn().mockResolvedValue(100),
+ }
+
+ mockCline.api.createMessage.mockReturnValueOnce(mockStream1)
+ mockedBuildApiHandler.mockReturnValue(mockIsolatedApiHandler)
+
+ // Mock diff application failure
+ mockCline.diffStrategy.applyDiff.mockResolvedValue({
+ success: false,
+ error: "Could not find search content",
+ })
+
+ await executeApplyCodeTool()
+
+ expect(mockCline.say).toHaveBeenCalledWith(
+ "error",
+ "Failed to apply generated diff: Could not find search content",
+ )
+ expect(mockPushToolResult).toHaveBeenCalledWith(
+ "Failed to apply generated diff: Could not find search content",
+ )
+ })
+ })
+
+ describe("partial execution", () => {
+ it("returns early for partial blocks", async () => {
+ await executeApplyCodeTool({}, { isPartial: true })
+
+ expect(mockCline.api.createMessage).not.toHaveBeenCalled()
+ expect(mockPushToolResult).not.toHaveBeenCalled()
+ })
+ })
+
+ describe("user approval", () => {
+ it("reverts changes when user denies approval", async () => {
+ // Mock file read
+ mockedReadFile.mockResolvedValue(originalContent)
+
+ // Mock successful API responses
+ const mockStream1 = {
+ [Symbol.asyncIterator]: vi.fn().mockReturnValue({
+ next: vi
+ .fn()
+ .mockResolvedValueOnce({
+ value: {
+ type: "text",
+ text: JSON.stringify({
+ file: testFilePath,
+ type: "snippet",
+ code: "generated code",
+ }),
+ },
+ done: false,
+ })
+ .mockResolvedValueOnce({ done: true }),
+ }),
+ }
+
+ const mockStream2 = {
+ [Symbol.asyncIterator]: vi.fn().mockReturnValue({
+ next: vi
+ .fn()
+ .mockResolvedValueOnce({
+ value: {
+ type: "text",
+ text: "<<<<<<< SEARCH\noriginal\n=======\nmodified\n>>>>>>> REPLACE",
+ },
+ done: false,
+ })
+ .mockResolvedValueOnce({ done: true }),
+ }),
+ }
+
+ const mockIsolatedApiHandler = {
+ createMessage: vi.fn().mockReturnValue(mockStream2),
+ getModel: vi.fn().mockReturnValue({
+ id: "claude-3",
+ info: { contextWindow: 200000 },
+ }),
+ countTokens: vi.fn().mockResolvedValue(100),
+ }
+
+ mockCline.api.createMessage.mockReturnValueOnce(mockStream1)
+ mockedBuildApiHandler.mockReturnValue(mockIsolatedApiHandler)
+
+ // User denies approval
+ mockAskApproval.mockResolvedValue(false)
+
+ await executeApplyCodeTool()
+
+ expect(mockCline.diffViewProvider.revertChanges).toHaveBeenCalled()
+ expect(mockCline.diffViewProvider.saveChanges).not.toHaveBeenCalled()
+ expect(mockPushToolResult).not.toHaveBeenCalled()
+ })
+ })
+})
diff --git a/src/core/tools/applyCodeTool.ts b/src/core/tools/applyCodeTool.ts
new file mode 100644
index 00000000000..3a01b7db734
--- /dev/null
+++ b/src/core/tools/applyCodeTool.ts
@@ -0,0 +1,295 @@
+import path from "path"
+import fs from "fs/promises"
+
+import { TelemetryService } from "@roo-code/telemetry"
+import { DEFAULT_WRITE_DELAY_MS } from "@roo-code/types"
+
+import { ClineSayTool } from "../../shared/ExtensionMessage"
+import { getReadablePath } from "../../utils/path"
+import { Task } from "../task/Task"
+import { ToolUse, RemoveClosingTag, AskApproval, HandleError, PushToolResult } from "../../shared/tools"
+import { formatResponse } from "../prompts/responses"
+import { fileExistsAtPath } from "../../utils/fs"
+import { RecordSource } from "../context-tracking/FileContextTrackerTypes"
+import { buildApiHandler } from "../../api"
+
+interface CodeGenerationResult {
+ file: string
+ type: "snippet" | "full_file"
+ code: string
+}
+
+export async function applyCodeTool(
+ cline: Task,
+ block: ToolUse,
+ askApproval: AskApproval,
+ handleError: HandleError,
+ pushToolResult: PushToolResult,
+ removeClosingTag: RemoveClosingTag,
+) {
+ const relPath: string | undefined = block.params.path
+ const instruction: string | undefined = block.params.instruction
+
+ const sharedMessageProps: ClineSayTool = {
+ tool: "applyCode",
+ path: getReadablePath(cline.cwd, removeClosingTag("path", relPath)),
+ instruction: removeClosingTag("instruction", instruction),
+ }
+
+ try {
+ if (block.partial) {
+ // Update GUI message
+ await cline.ask("tool", JSON.stringify(sharedMessageProps), block.partial).catch(() => {})
+ return
+ } else {
+ if (!relPath) {
+ cline.consecutiveMistakeCount++
+ cline.recordToolError("apply_code")
+ pushToolResult(await cline.sayAndCreateMissingParamError("apply_code", "path"))
+ return
+ }
+
+ if (!instruction) {
+ cline.consecutiveMistakeCount++
+ cline.recordToolError("apply_code")
+ pushToolResult(await cline.sayAndCreateMissingParamError("apply_code", "instruction"))
+ return
+ }
+
+ const accessAllowed = cline.rooIgnoreController?.validateAccess(relPath)
+
+ if (!accessAllowed) {
+ await cline.say("rooignore_error", relPath)
+ pushToolResult(formatResponse.toolError(formatResponse.rooIgnoreError(relPath)))
+ return
+ }
+
+ const absolutePath = path.resolve(cline.cwd, relPath)
+ const fileExists = await fileExistsAtPath(absolutePath)
+
+ // Read the original file content if it exists
+ let originalContent = ""
+ if (fileExists) {
+ originalContent = await fs.readFile(absolutePath, "utf-8")
+ }
+
+ // Stage 1: Creative Code Generation
+ const codeGenPrompt = `You are a code generation expert. Generate code based on the following instruction.
+
+File: ${relPath}
+${fileExists ? `Current content:\n\`\`\`\n${originalContent}\n\`\`\`` : "File does not exist yet."}
+
+Instruction: ${instruction}
+
+Respond with a JSON object in this exact format:
+{
+ "file": "${relPath}",
+ "type": "${fileExists ? "snippet" : "full_file"}",
+ "code": "your generated code here"
+}
+
+IMPORTANT:
+- For existing files, generate only the new/modified code snippet
+- For new files, generate the complete file content
+- Do not include any markdown code blocks in the "code" field
+- Ensure proper escaping of quotes and newlines in JSON`
+
+ // Make first API call for code generation
+ // This uses the existing API handler with full context
+ const codeGenMessages = [
+ {
+ role: "user" as const,
+ content: [{ type: "text" as const, text: codeGenPrompt }],
+ },
+ ]
+
+ const codeGenStream = cline.api.createMessage(
+ "You are a code generation expert. Generate code exactly as requested.",
+ codeGenMessages,
+ { taskId: cline.taskId, mode: "code_generation" },
+ )
+
+ let codeGenResponse = ""
+ for await (const chunk of codeGenStream) {
+ if (chunk.type === "text") {
+ codeGenResponse += chunk.text
+ }
+ }
+
+ // Parse the code generation result
+ let codeGenResult: CodeGenerationResult
+ try {
+ codeGenResult = JSON.parse(codeGenResponse)
+ } catch (error) {
+ cline.consecutiveMistakeCount++
+ cline.recordToolError("apply_code")
+ const formattedError = `Failed to parse code generation response: ${error.message}\n\nResponse: ${codeGenResponse}`
+ await cline.say("error", formattedError)
+ pushToolResult(formattedError)
+ return
+ }
+
+ // Stage 2: Focused Diff Generation with ISOLATED context
+ let diffContent = ""
+ if (fileExists && codeGenResult.type === "snippet") {
+ // HARDCODED OPTIMIZED PROMPT - This is the key difference
+ // This prompt is specifically designed for diff generation without any conversational noise
+ const DIFF_GENERATION_SYSTEM_PROMPT = `You are a specialized diff generation model. Your ONLY task is to generate accurate diff patches.
+
+RULES:
+1. You will receive EXACTLY two inputs: original file content and new code to integrate
+2. You must output ONLY the diff patch in the specified format
+3. Do NOT add any explanations, comments, or conversational text
+4. Focus ONLY on the mechanical task of creating the diff
+5. Ensure the SEARCH section matches the original content EXACTLY (including whitespace)
+6. Place the new code in the most logical location within the file
+
+OUTPUT FORMAT:
+<<<<<<< SEARCH
+[exact content from original file]
+=======
+[integrated content with new code]
+>>>>>>> REPLACE
+
+You may use multiple SEARCH/REPLACE blocks if needed.`
+
+ // Simplified prompt for isolated context - no conversational instructions
+ const diffGenPrompt = `Original file content:
+\`\`\`
+${originalContent}
+\`\`\`
+
+New code to integrate:
+\`\`\`
+${codeGenResult.code}
+\`\`\``
+
+ // Create a COMPLETELY ISOLATED API call
+ // This is a new, independent message array with NO conversation history
+ const isolatedDiffMessages = [
+ {
+ role: "user" as const,
+ content: [{ type: "text" as const, text: diffGenPrompt }],
+ },
+ ]
+
+ // Create a new API handler instance to ensure complete isolation
+ // This prevents any context bleeding from the main conversation
+ const isolatedApiHandler = buildApiHandler(cline.apiConfiguration)
+
+ // Make the isolated API call with the hardcoded system prompt
+ const diffGenStream = isolatedApiHandler.createMessage(
+ DIFF_GENERATION_SYSTEM_PROMPT,
+ isolatedDiffMessages,
+ { taskId: `${cline.taskId}-diff-gen`, mode: "diff_generation" },
+ )
+
+ for await (const chunk of diffGenStream) {
+ if (chunk.type === "text") {
+ diffContent += chunk.text
+ }
+ }
+
+ // Apply the diff using the existing diff strategy
+ const diffResult = (await cline.diffStrategy?.applyDiff(originalContent, diffContent)) ?? {
+ success: false,
+ error: "No diff strategy available",
+ }
+
+ if (!diffResult.success) {
+ cline.consecutiveMistakeCount++
+ cline.recordToolError("apply_code")
+ const formattedError = `Failed to apply generated diff: ${diffResult.error}`
+ await cline.say("error", formattedError)
+ pushToolResult(formattedError)
+ return
+ }
+
+ // Show diff view before asking for approval
+ cline.diffViewProvider.editType = "modify"
+ await cline.diffViewProvider.open(relPath)
+ await cline.diffViewProvider.update(diffResult.content, true)
+ cline.diffViewProvider.scrollToFirstDiff()
+
+ // Check if file is write-protected
+ const isWriteProtected = cline.rooProtectedController?.isWriteProtected(relPath) || false
+
+ const completeMessage = JSON.stringify({
+ ...sharedMessageProps,
+ diff: formatResponse.createPrettyPatch(relPath, originalContent, diffResult.content),
+ isProtected: isWriteProtected,
+ } satisfies ClineSayTool)
+
+ const didApprove = await askApproval("tool", completeMessage, undefined, isWriteProtected)
+
+ if (!didApprove) {
+ await cline.diffViewProvider.revertChanges()
+ return
+ }
+
+ // Save the changes
+ const provider = cline.providerRef.deref()
+ const state = await provider?.getState()
+ const diagnosticsEnabled = state?.diagnosticsEnabled ?? true
+ const writeDelayMs = state?.writeDelayMs ?? DEFAULT_WRITE_DELAY_MS
+ await cline.diffViewProvider.saveChanges(diagnosticsEnabled, writeDelayMs)
+ } else {
+ // For new files or full file replacements, use the generated code directly
+ cline.diffViewProvider.editType = fileExists ? "modify" : "create"
+ await cline.diffViewProvider.open(relPath)
+ await cline.diffViewProvider.update(codeGenResult.code, true)
+ cline.diffViewProvider.scrollToFirstDiff()
+
+ // Check if file is write-protected
+ const isWriteProtected = cline.rooProtectedController?.isWriteProtected(relPath) || false
+
+ const completeMessage = JSON.stringify({
+ ...sharedMessageProps,
+ content: fileExists ? undefined : codeGenResult.code,
+ diff: fileExists
+ ? formatResponse.createPrettyPatch(relPath, originalContent, codeGenResult.code)
+ : undefined,
+ isProtected: isWriteProtected,
+ } satisfies ClineSayTool)
+
+ const didApprove = await askApproval("tool", completeMessage, undefined, isWriteProtected)
+
+ if (!didApprove) {
+ await cline.diffViewProvider.revertChanges()
+ return
+ }
+
+ // Save the changes
+ const provider = cline.providerRef.deref()
+ const state = await provider?.getState()
+ const diagnosticsEnabled = state?.diagnosticsEnabled ?? true
+ const writeDelayMs = state?.writeDelayMs ?? DEFAULT_WRITE_DELAY_MS
+ await cline.diffViewProvider.saveChanges(diagnosticsEnabled, writeDelayMs)
+ }
+
+ // Track file edit operation
+ if (relPath) {
+ await cline.fileContextTracker.trackFileContext(relPath, "roo_edited" as RecordSource)
+ }
+
+ // Used to determine if we should wait for busy terminal to update before sending api request
+ cline.didEditFile = true
+
+ // Get the formatted response message
+ const message = await cline.diffViewProvider.pushToolWriteResult(cline, cline.cwd, !fileExists)
+
+ pushToolResult(message)
+
+ await cline.diffViewProvider.reset()
+
+ cline.consecutiveMistakeCount = 0
+ cline.recordToolUsage("apply_code")
+
+ return
+ }
+ } catch (error) {
+ await handleError("applying code", error)
+ await cline.diffViewProvider.reset()
+ return
+ }
+}
diff --git a/src/shared/ExtensionMessage.ts b/src/shared/ExtensionMessage.ts
index 000762e317a..49a9548fa15 100644
--- a/src/shared/ExtensionMessage.ts
+++ b/src/shared/ExtensionMessage.ts
@@ -315,6 +315,7 @@ export interface ClineSayTool {
tool:
| "editedExistingFile"
| "appliedDiff"
+ | "applyCode"
| "newFileCreated"
| "codebaseSearch"
| "readFile"
@@ -346,6 +347,7 @@ export interface ClineSayTool {
endLine?: number
lineNumber?: number
query?: string
+ instruction?: string
batchFiles?: Array<{
path: string
lineSnippet: string
diff --git a/src/shared/tools.ts b/src/shared/tools.ts
index 67972243fe7..69e4005c3c7 100644
--- a/src/shared/tools.ts
+++ b/src/shared/tools.ts
@@ -65,6 +65,7 @@ export const toolParamNames = [
"query",
"args",
"todos",
+ "instruction",
] as const
export type ToolParamName = (typeof toolParamNames)[number]
@@ -164,6 +165,11 @@ export interface SearchAndReplaceToolUse extends ToolUse {
Partial, "use_regex" | "ignore_case" | "start_line" | "end_line">>
}
+export interface ApplyCodeToolUse extends ToolUse {
+ name: "apply_code"
+ params: Partial, "path" | "instruction">>
+}
+
// Define tool group configuration
export type ToolGroupConfig = {
tools: readonly string[]
@@ -176,6 +182,7 @@ export const TOOL_DISPLAY_NAMES: Record = {
fetch_instructions: "fetch instructions",
write_to_file: "write files",
apply_diff: "apply changes",
+ apply_code: "apply code",
search_files: "search files",
list_files: "list files",
list_code_definition_names: "list definitions",
@@ -205,7 +212,7 @@ export const TOOL_GROUPS: Record = {
],
},
edit: {
- tools: ["apply_diff", "write_to_file", "insert_content", "search_and_replace"],
+ tools: ["apply_diff", "apply_code", "write_to_file", "insert_content", "search_and_replace"],
},
browser: {
tools: ["browser_action"],