Skip to content

Commit adc4d0b

Browse files
committed
fix: enforce mode restrictions even with auto-approval enabled (#5448)
- Remove special handling for FileRestrictionError when auto-approval is enabled - Ensure all validation errors block tool execution consistently - Add tests to verify mode restrictions are enforced with auto-approval - Fixes issue where Architect mode could write any file type when auto-approval was on
1 parent 0e0da80 commit adc4d0b

File tree

2 files changed

+301
-2
lines changed

2 files changed

+301
-2
lines changed
Lines changed: 295 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,295 @@
1+
import { describe, it, expect, vi, beforeEach } from "vitest"
2+
import { Task } from "../../task/Task"
3+
import { presentAssistantMessage } from "../presentAssistantMessage"
4+
import { FileRestrictionError } from "../../../shared/modes"
5+
import type { ClineProvider } from "../../../core/webview/ClineProvider"
6+
7+
// Mock all the tool modules
8+
vi.mock("../../tools/writeToFileTool", () => ({
9+
writeToFileTool: vi
10+
.fn()
11+
.mockImplementation(async (cline, block, askApproval, handleError, pushToolResult, removeClosingTag) => {
12+
// Call askApproval to simulate the tool asking for approval
13+
await askApproval("tool", JSON.stringify({ tool: "write_to_file", path: block.params.path }))
14+
}),
15+
}))
16+
17+
vi.mock("../../tools/validateToolUse", () => ({
18+
validateToolUse: vi.fn(),
19+
}))
20+
21+
vi.mock("../../checkpoints", () => ({
22+
checkpointSave: vi.fn(),
23+
}))
24+
25+
// Import mocked functions
26+
import { validateToolUse } from "../../tools/validateToolUse"
27+
28+
describe("presentAssistantMessage - auto-approval with file restrictions", () => {
29+
let mockTask: Task
30+
let mockProvider: Partial<ClineProvider>
31+
let mockProviderRef: { deref: () => ClineProvider | undefined }
32+
33+
beforeEach(() => {
34+
vi.clearAllMocks()
35+
36+
// Create mock provider
37+
mockProvider = {
38+
getState: vi.fn().mockResolvedValue({
39+
mode: "architect",
40+
customModes: [],
41+
autoApprovalEnabled: true,
42+
alwaysAllowWrite: true,
43+
}),
44+
}
45+
46+
// Create provider ref
47+
mockProviderRef = {
48+
deref: () => mockProvider as ClineProvider,
49+
}
50+
51+
// Create mock task
52+
mockTask = {
53+
taskId: "test-task",
54+
instanceId: "test-instance",
55+
abort: false,
56+
presentAssistantMessageLocked: false,
57+
presentAssistantMessageHasPendingUpdates: false,
58+
currentStreamingContentIndex: 0,
59+
assistantMessageContent: [],
60+
didCompleteReadingStream: false,
61+
userMessageContentReady: false,
62+
didRejectTool: false,
63+
didAlreadyUseTool: false,
64+
userMessageContent: [],
65+
consecutiveMistakeCount: 0,
66+
providerRef: mockProviderRef,
67+
diffEnabled: false,
68+
fileContextTracker: {
69+
getAndClearCheckpointPossibleFile: vi.fn().mockReturnValue([]),
70+
},
71+
say: vi.fn(),
72+
ask: vi.fn().mockResolvedValue({
73+
response: "yesButtonClicked",
74+
text: undefined,
75+
images: undefined,
76+
}),
77+
recordToolUsage: vi.fn(),
78+
toolRepetitionDetector: {
79+
check: vi.fn().mockReturnValue({ allowExecution: true }),
80+
},
81+
browserSession: {
82+
closeBrowser: vi.fn(),
83+
},
84+
} as any
85+
86+
// Mock TelemetryService
87+
vi.mock("@roo-code/telemetry", () => ({
88+
TelemetryService: {
89+
instance: {
90+
captureToolUsage: vi.fn(),
91+
captureConsecutiveMistakeError: vi.fn(),
92+
},
93+
},
94+
}))
95+
})
96+
97+
it("should block file restriction errors even when auto-approval is enabled", async () => {
98+
// Setup: Architect mode trying to write a non-markdown file
99+
const mockValidateToolUse = vi.mocked(validateToolUse)
100+
mockValidateToolUse.mockImplementation(() => {
101+
throw new FileRestrictionError(
102+
"architect",
103+
"\\.md$",
104+
"Markdown files only",
105+
"src/index.ts",
106+
"write_to_file",
107+
)
108+
})
109+
110+
// Add a write_to_file tool use that would normally be blocked
111+
mockTask.assistantMessageContent = [
112+
{
113+
type: "tool_use",
114+
name: "write_to_file",
115+
params: {
116+
path: "src/index.ts",
117+
content: "console.log('hello')",
118+
},
119+
partial: false,
120+
},
121+
]
122+
123+
// Execute
124+
await presentAssistantMessage(mockTask)
125+
126+
// Verify validateToolUse was called
127+
expect(mockValidateToolUse).toHaveBeenCalledWith(
128+
"write_to_file",
129+
"architect",
130+
[],
131+
{ apply_diff: false },
132+
{
133+
path: "src/index.ts",
134+
content: "console.log('hello')",
135+
},
136+
)
137+
138+
// Verify the error was handled (auto-approval should not bypass mode restrictions)
139+
expect(mockTask.consecutiveMistakeCount).toBe(1)
140+
expect(mockTask.userMessageContent).toHaveLength(2) // Error message added
141+
expect(mockTask.userMessageContent[0]).toEqual({
142+
type: "text",
143+
text: "[write_to_file for 'src/index.ts'] Result:",
144+
})
145+
expect(mockTask.userMessageContent[1]).toEqual({
146+
type: "text",
147+
text: expect.stringContaining("can only edit files matching pattern"),
148+
})
149+
150+
// Verify ask was not called (tool was blocked)
151+
expect(mockTask.ask).not.toHaveBeenCalled()
152+
})
153+
154+
it("should still block file restriction errors when auto-approval is disabled", async () => {
155+
// Disable auto-approval
156+
mockProvider.getState = vi.fn().mockResolvedValue({
157+
mode: "architect",
158+
customModes: [],
159+
autoApprovalEnabled: false, // Disabled
160+
alwaysAllowWrite: true,
161+
})
162+
163+
const mockValidateToolUse = vi.mocked(validateToolUse)
164+
mockValidateToolUse.mockImplementation(() => {
165+
throw new FileRestrictionError(
166+
"architect",
167+
"\\.md$",
168+
"Markdown files only",
169+
"src/index.ts",
170+
"write_to_file",
171+
)
172+
})
173+
174+
// Add a write_to_file tool use
175+
mockTask.assistantMessageContent = [
176+
{
177+
type: "tool_use",
178+
name: "write_to_file",
179+
params: {
180+
path: "src/index.ts",
181+
content: "console.log('hello')",
182+
},
183+
partial: false,
184+
},
185+
]
186+
187+
// Execute
188+
await presentAssistantMessage(mockTask)
189+
190+
// Verify the error was handled
191+
expect(mockTask.consecutiveMistakeCount).toBe(1)
192+
expect(mockTask.userMessageContent).toHaveLength(2) // Error message added
193+
expect(mockTask.userMessageContent[0]).toEqual({
194+
type: "text",
195+
text: "[write_to_file for 'src/index.ts'] Result:",
196+
})
197+
expect(mockTask.userMessageContent[1]).toEqual({
198+
type: "text",
199+
text: expect.stringContaining("can only edit files matching pattern"),
200+
})
201+
})
202+
203+
it("should still block non-FileRestrictionError errors regardless of auto-approval", async () => {
204+
// Enable auto-approval
205+
mockProvider.getState = vi.fn().mockResolvedValue({
206+
mode: "code",
207+
customModes: [],
208+
autoApprovalEnabled: true,
209+
alwaysAllowWrite: true,
210+
})
211+
212+
const mockValidateToolUse = vi.mocked(validateToolUse)
213+
mockValidateToolUse.mockImplementation(() => {
214+
throw new Error("Some other validation error")
215+
})
216+
217+
// Add a write_to_file tool use
218+
mockTask.assistantMessageContent = [
219+
{
220+
type: "tool_use",
221+
name: "write_to_file",
222+
params: {
223+
path: "src/index.ts",
224+
content: "console.log('hello')",
225+
},
226+
partial: false,
227+
},
228+
]
229+
230+
// Execute
231+
await presentAssistantMessage(mockTask)
232+
233+
// Verify the error was handled
234+
expect(mockTask.consecutiveMistakeCount).toBe(1)
235+
expect(mockTask.userMessageContent).toHaveLength(2) // Error message added
236+
expect(mockTask.userMessageContent[1].type).toBe("text")
237+
if (mockTask.userMessageContent[1].type === "text") {
238+
expect(mockTask.userMessageContent[1].text).toContain("Some other validation error")
239+
}
240+
241+
// Verify ask was not called (tool was blocked)
242+
expect(mockTask.ask).not.toHaveBeenCalled()
243+
})
244+
245+
it("should allow auto-approved tools that pass validation", async () => {
246+
// Enable auto-approval for a valid operation
247+
mockProvider.getState = vi.fn().mockResolvedValue({
248+
mode: "code", // Code mode has no file restrictions
249+
customModes: [],
250+
autoApprovalEnabled: true,
251+
alwaysAllowWrite: true,
252+
})
253+
254+
const mockValidateToolUse = vi.mocked(validateToolUse)
255+
mockValidateToolUse.mockImplementation(() => {
256+
// No error - validation passes
257+
})
258+
259+
// Add a write_to_file tool use
260+
mockTask.assistantMessageContent = [
261+
{
262+
type: "tool_use",
263+
name: "write_to_file",
264+
params: {
265+
path: "src/index.ts",
266+
content: "console.log('hello')",
267+
},
268+
partial: false,
269+
},
270+
]
271+
272+
// Execute
273+
await presentAssistantMessage(mockTask)
274+
275+
// Verify validateToolUse was called
276+
expect(mockValidateToolUse).toHaveBeenCalledWith(
277+
"write_to_file",
278+
"code",
279+
[],
280+
{ apply_diff: false },
281+
{
282+
path: "src/index.ts",
283+
content: "console.log('hello')",
284+
},
285+
)
286+
287+
// Since validation passed, the tool should proceed to ask for approval
288+
expect(mockTask.ask).toHaveBeenCalledWith("tool", expect.any(String), false, undefined, false)
289+
290+
// No errors should be recorded
291+
expect(mockTask.didRejectTool).toBe(false)
292+
expect(mockTask.consecutiveMistakeCount).toBe(0)
293+
expect(mockTask.userMessageContent).toHaveLength(0)
294+
})
295+
})

src/core/assistant-message/presentAssistantMessage.ts

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -351,9 +351,10 @@ export async function presentAssistantMessage(cline: Task) {
351351
TelemetryService.instance.captureToolUsage(cline.taskId, block.name)
352352
}
353353

354-
// Validate tool use before execution.
354+
// Get the provider and state
355355
const { mode, customModes } = (await cline.providerRef.deref()?.getState()) ?? {}
356356

357+
// Always validate tool use
357358
try {
358359
validateToolUse(
359360
block.name as ToolName,
@@ -363,8 +364,11 @@ export async function presentAssistantMessage(cline: Task) {
363364
block.params,
364365
)
365366
} catch (error) {
367+
// All validation errors should be treated as errors, including FileRestrictionError
368+
// Auto-approval should not bypass mode restrictions
369+
const validationError = error as Error
366370
cline.consecutiveMistakeCount++
367-
pushToolResult(formatResponse.toolError(error.message))
371+
pushToolResult(formatResponse.toolError(validationError.message))
368372
break
369373
}
370374

0 commit comments

Comments
 (0)