Skip to content

Commit 9789925

Browse files
committed
feat: add checkpoints before file edits for increased checkpoint frequency
- Save checkpoints BEFORE file editing operations (write_to_file, apply_diff, insert_content, search_and_replace) - Maintain existing checkpoints AFTER file edits - Update checkpointSaveAndMark function to handle before/after timing - Add comprehensive tests for new checkpoint behavior This addresses user feedback requesting checkpoints before edits to allow reverting to the state just before making changes.
1 parent 67e6a22 commit 9789925

File tree

2 files changed

+121
-15
lines changed

2 files changed

+121
-15
lines changed

src/core/assistant-message/presentAssistantMessage.ts

Lines changed: 26 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -410,9 +410,11 @@ export async function presentAssistantMessage(cline: Task) {
410410

411411
switch (block.name) {
412412
case "write_to_file":
413+
// Save checkpoint BEFORE file edit
414+
await checkpointSaveAndMark(cline, "before")
413415
await writeToFileTool(cline, block, askApproval, handleError, pushToolResult, removeClosingTag)
414416
// Save checkpoint AFTER file edit
415-
await checkpointSaveAndMark(cline)
417+
await checkpointSaveAndMark(cline, "after")
416418
break
417419
case "update_todo_list":
418420
await updateTodoListTool(cline, block, askApproval, handleError, pushToolResult, removeClosingTag)
@@ -430,6 +432,9 @@ export async function presentAssistantMessage(cline: Task) {
430432
)
431433
}
432434

435+
// Save checkpoint BEFORE file edit
436+
await checkpointSaveAndMark(cline, "before")
437+
433438
if (isMultiFileApplyDiffEnabled) {
434439
await applyDiffTool(cline, block, askApproval, handleError, pushToolResult, removeClosingTag)
435440
} else {
@@ -443,18 +448,22 @@ export async function presentAssistantMessage(cline: Task) {
443448
)
444449
}
445450
// Save checkpoint AFTER file edit
446-
await checkpointSaveAndMark(cline)
451+
await checkpointSaveAndMark(cline, "after")
447452
break
448453
}
449454
case "insert_content":
455+
// Save checkpoint BEFORE file edit
456+
await checkpointSaveAndMark(cline, "before")
450457
await insertContentTool(cline, block, askApproval, handleError, pushToolResult, removeClosingTag)
451458
// Save checkpoint AFTER file edit
452-
await checkpointSaveAndMark(cline)
459+
await checkpointSaveAndMark(cline, "after")
453460
break
454461
case "search_and_replace":
462+
// Save checkpoint BEFORE file edit
463+
await checkpointSaveAndMark(cline, "before")
455464
await searchAndReplaceTool(cline, block, askApproval, handleError, pushToolResult, removeClosingTag)
456465
// Save checkpoint AFTER file edit
457-
await checkpointSaveAndMark(cline)
466+
await checkpointSaveAndMark(cline, "after")
458467
break
459468
case "read_file":
460469
await readFileTool(cline, block, askApproval, handleError, pushToolResult, removeClosingTag)
@@ -586,16 +595,25 @@ export async function presentAssistantMessage(cline: Task) {
586595
/**
587596
* save checkpoint and mark done in the current streaming task.
588597
* @param task The Task instance to checkpoint save and mark.
598+
* @param timing Whether this is a "before" or "after" checkpoint for file edits
589599
* @returns
590600
*/
591-
async function checkpointSaveAndMark(task: Task) {
592-
if (task.currentStreamingDidCheckpoint) {
601+
async function checkpointSaveAndMark(task: Task, timing?: "before" | "after") {
602+
// For "before" checkpoints, always save regardless of currentStreamingDidCheckpoint
603+
// For "after" checkpoints or no timing specified, use the original logic
604+
if (timing !== "before" && task.currentStreamingDidCheckpoint) {
593605
return
594606
}
595607
try {
596608
await task.checkpointSave(true)
597-
task.currentStreamingDidCheckpoint = true
609+
// Only mark as done for "after" checkpoints or when no timing is specified
610+
if (timing !== "before") {
611+
task.currentStreamingDidCheckpoint = true
612+
}
598613
} catch (error) {
599-
console.error(`[Task#presentAssistantMessage] Error saving checkpoint: ${error.message}`, error)
614+
console.error(
615+
`[Task#presentAssistantMessage] Error saving checkpoint (${timing || "default"}): ${error.message}`,
616+
error,
617+
)
600618
}
601619
}

src/core/checkpoints/__tests__/checkpoint-timing.spec.ts

Lines changed: 95 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -216,8 +216,8 @@ describe("Checkpoint Timing", () => {
216216
vi.clearAllMocks()
217217
})
218218

219-
describe("Checkpoint after file edits", () => {
220-
it("should save checkpoint AFTER write_to_file tool execution", async () => {
219+
describe("Checkpoint before and after file edits", () => {
220+
it("should save checkpoint BEFORE and AFTER write_to_file tool execution", async () => {
221221
// Setup assistant message content with write_to_file tool
222222
mockTask.assistantMessageContent = [
223223
{
@@ -241,8 +241,10 @@ describe("Checkpoint Timing", () => {
241241
// Execute presentAssistantMessage
242242
await presentAssistantMessage(mockTask)
243243

244-
// Verify checkpoint was saved after the tool execution
244+
// Verify checkpoint was saved twice (before and after the tool execution)
245+
expect(mockTask.checkpointSave).toHaveBeenCalledTimes(2)
245246
expect(mockTask.checkpointSave).toHaveBeenCalledWith(true)
247+
// Note: currentStreamingDidCheckpoint is only set to true after the "after" checkpoint
246248
expect(mockTask.currentStreamingDidCheckpoint).toBe(true)
247249
})
248250

@@ -251,7 +253,7 @@ describe("Checkpoint Timing", () => {
251253
// through the other file editing tools (write_to_file, insert_content, search_and_replace)
252254
// which all follow the same pattern of saving checkpoints after file edits.
253255

254-
it("should save checkpoint AFTER insert_content tool execution", async () => {
256+
it("should save checkpoint BEFORE and AFTER insert_content tool execution", async () => {
255257
// Setup assistant message content with insert_content tool
256258
mockTask.assistantMessageContent = [
257259
{
@@ -276,12 +278,14 @@ describe("Checkpoint Timing", () => {
276278
// Execute presentAssistantMessage
277279
await presentAssistantMessage(mockTask)
278280

279-
// Verify checkpoint was saved after the tool execution
281+
// Verify checkpoint was saved twice (before and after the tool execution)
282+
expect(mockTask.checkpointSave).toHaveBeenCalledTimes(2)
280283
expect(mockTask.checkpointSave).toHaveBeenCalledWith(true)
284+
// Note: currentStreamingDidCheckpoint is only set to true after the "after" checkpoint
281285
expect(mockTask.currentStreamingDidCheckpoint).toBe(true)
282286
})
283287

284-
it("should save checkpoint AFTER search_and_replace tool execution", async () => {
288+
it("should save checkpoint BEFORE and AFTER search_and_replace tool execution", async () => {
285289
// Setup assistant message content with search_and_replace tool
286290
mockTask.assistantMessageContent = [
287291
{
@@ -306,10 +310,94 @@ describe("Checkpoint Timing", () => {
306310
// Execute presentAssistantMessage
307311
await presentAssistantMessage(mockTask)
308312

309-
// Verify checkpoint was saved after the tool execution
313+
// Verify checkpoint was saved twice (before and after the tool execution)
314+
expect(mockTask.checkpointSave).toHaveBeenCalledTimes(2)
310315
expect(mockTask.checkpointSave).toHaveBeenCalledWith(true)
316+
// Note: currentStreamingDidCheckpoint is only set to true after the "after" checkpoint
311317
expect(mockTask.currentStreamingDidCheckpoint).toBe(true)
312318
})
319+
320+
it("should handle checkpoint errors gracefully for file edit tools", async () => {
321+
// Setup assistant message content with write_to_file tool
322+
mockTask.assistantMessageContent = [
323+
{
324+
type: "tool_use",
325+
name: "write_to_file",
326+
params: {
327+
path: "test.txt",
328+
content: "test content",
329+
},
330+
partial: false,
331+
},
332+
]
333+
334+
// Mock the write_to_file tool execution
335+
const writeToFileModule = await import("../../tools/writeToFileTool")
336+
vi.spyOn(writeToFileModule, "writeToFileTool").mockImplementation(async () => {
337+
// Simulate tool execution
338+
return undefined
339+
})
340+
341+
// Mock checkpointSave to fail on first call (before) and succeed on second (after)
342+
mockTask.checkpointSave
343+
.mockRejectedValueOnce(new Error("Checkpoint before failed"))
344+
.mockResolvedValueOnce(undefined)
345+
346+
// Mock console.error to verify error logging
347+
const consoleErrorSpy = vi.spyOn(console, "error").mockImplementation(() => {})
348+
349+
// Execute presentAssistantMessage
350+
await presentAssistantMessage(mockTask)
351+
352+
// Verify checkpoint was attempted twice
353+
expect(mockTask.checkpointSave).toHaveBeenCalledTimes(2)
354+
355+
// Verify error was logged for the "before" checkpoint failure
356+
expect(consoleErrorSpy).toHaveBeenCalledWith(
357+
expect.stringContaining("Error saving checkpoint (before)"),
358+
expect.any(Error),
359+
)
360+
361+
// Clean up
362+
consoleErrorSpy.mockRestore()
363+
})
364+
365+
it("should not set currentStreamingDidCheckpoint for 'before' checkpoints", async () => {
366+
// Setup assistant message content with write_to_file tool
367+
mockTask.assistantMessageContent = [
368+
{
369+
type: "tool_use",
370+
name: "write_to_file",
371+
params: {
372+
path: "test.txt",
373+
content: "test content",
374+
},
375+
partial: false,
376+
},
377+
]
378+
379+
// Mock the write_to_file tool execution to track when it's called
380+
const writeToFileModule = await import("../../tools/writeToFileTool")
381+
let toolExecuted = false
382+
vi.spyOn(writeToFileModule, "writeToFileTool").mockImplementation(async () => {
383+
// At this point, the "before" checkpoint should have been saved
384+
// but currentStreamingDidCheckpoint should still be false
385+
expect(mockTask.checkpointSave).toHaveBeenCalledTimes(1)
386+
expect(mockTask.currentStreamingDidCheckpoint).toBe(false)
387+
toolExecuted = true
388+
return undefined
389+
})
390+
391+
// Execute presentAssistantMessage
392+
await presentAssistantMessage(mockTask)
393+
394+
// Verify tool was executed
395+
expect(toolExecuted).toBe(true)
396+
397+
// After execution, currentStreamingDidCheckpoint should be true (from "after" checkpoint)
398+
expect(mockTask.currentStreamingDidCheckpoint).toBe(true)
399+
expect(mockTask.checkpointSave).toHaveBeenCalledTimes(2)
400+
})
313401
})
314402

315403
describe("Checkpoint before new prompts", () => {

0 commit comments

Comments
 (0)