Skip to content
Merged
31 changes: 22 additions & 9 deletions src/core/assistant-message/presentAssistantMessage.ts
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@ import { switchModeTool } from "../tools/switchModeTool"
import { attemptCompletionTool } from "../tools/attemptCompletionTool"
import { newTaskTool } from "../tools/newTaskTool"

import { checkpointSave } from "../checkpoints"
import { updateTodoListTool } from "../tools/updateTodoListTool"

import { formatResponse } from "../prompts/responses"
Expand Down Expand Up @@ -411,6 +410,7 @@ export async function presentAssistantMessage(cline: Task) {

switch (block.name) {
case "write_to_file":
await checkpointSaveAndMark(cline)
await writeToFileTool(cline, block, askApproval, handleError, pushToolResult, removeClosingTag)
break
case "update_todo_list":
Expand All @@ -430,8 +430,10 @@ export async function presentAssistantMessage(cline: Task) {
}

if (isMultiFileApplyDiffEnabled) {
await checkpointSaveAndMark(cline)
await applyDiffTool(cline, block, askApproval, handleError, pushToolResult, removeClosingTag)
} else {
await checkpointSaveAndMark(cline)
await applyDiffToolLegacy(
cline,
block,
Expand All @@ -444,9 +446,11 @@ export async function presentAssistantMessage(cline: Task) {
break
}
case "insert_content":
await checkpointSaveAndMark(cline)
await insertContentTool(cline, block, askApproval, handleError, pushToolResult, removeClosingTag)
break
case "search_and_replace":
await checkpointSaveAndMark(cline)
await searchAndReplaceTool(cline, block, askApproval, handleError, pushToolResult, removeClosingTag)
break
case "read_file":
Expand Down Expand Up @@ -527,14 +531,6 @@ export async function presentAssistantMessage(cline: Task) {
break
}

const recentlyModifiedFiles = cline.fileContextTracker.getAndClearCheckpointPossibleFile()

if (recentlyModifiedFiles.length > 0) {
// TODO: We can track what file changes were made and only
// checkpoint those files, this will be save storage.
await checkpointSave(cline)
}

// Seeing out of bounds is fine, it means that the next too call is being
// built up and ready to add to assistantMessageContent to present.
// When you see the UI inactive during this, it means that a tool is
Expand Down Expand Up @@ -583,3 +579,20 @@ export async function presentAssistantMessage(cline: Task) {
presentAssistantMessage(cline)
}
}

/**
* save checkpoint and mark done in the current streaming task.
* @param task The Task instance to checkpoint save and mark.
* @returns
*/
async function checkpointSaveAndMark(task: Task) {
if (task.currentStreamingDidCheckpoint) {
return
}
try {
await task.checkpointSave(true)
task.currentStreamingDidCheckpoint = true
} catch (error) {
console.error(`[Task#presentAssistantMessage] Error saving checkpoint: ${error.message}`, error)
}
}
113 changes: 46 additions & 67 deletions src/core/checkpoints/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -16,18 +16,29 @@ import { DIFF_VIEW_URI_SCHEME } from "../../integrations/editor/DiffViewProvider

import { CheckpointServiceOptions, RepoPerTaskCheckpointService } from "../../services/checkpoints"

export function getCheckpointService(cline: Task) {
export async function getCheckpointService(
cline: Task,
{ interval = 250, timeout = 15_000 }: { interval?: number; timeout?: number } = {},
) {
if (!cline.enableCheckpoints) {
return undefined
}

if (cline.checkpointService) {
return cline.checkpointService
}

if (cline.checkpointServiceInitializing) {
console.log("[Task#getCheckpointService] checkpoint service is still initializing")
return undefined
if (cline.checkpointServiceInitializing) {
console.log("[Task#getCheckpointService] checkpoint service is still initializing")
const service = cline.checkpointService
await pWaitFor(
() => {
console.log("[Task#getCheckpointService] waiting for service to initialize")
return service.isInitialized
},
{ interval, timeout },
)
return service.isInitialized ? cline.checkpointService : undefined
} else {
return cline.checkpointService
}
}

const provider = cline.providerRef.deref()
Expand Down Expand Up @@ -69,15 +80,20 @@ export function getCheckpointService(cline: Task) {
}

const service = RepoPerTaskCheckpointService.create(options)

cline.checkpointServiceInitializing = true

// Check if Git is installed before initializing the service
// Note: This is intentionally fire-and-forget to match the original IIFE pattern
// The service is returned immediately while Git check happens asynchronously
checkGitInstallation(cline, service, log, provider)

return service
// Only assign the service after successful initialization
try {
await checkGitInstallation(cline, service, log, provider)
cline.checkpointService = service
return service
} catch (err) {
// Clean up on failure
cline.checkpointServiceInitializing = false
cline.enableCheckpoints = false
throw err
}
} catch (err) {
log(`[Task#getCheckpointService] ${err.message}`)
cline.enableCheckpoints = false
Expand Down Expand Up @@ -115,22 +131,7 @@ async function checkGitInstallation(
// Git is installed, proceed with initialization
service.on("initialize", () => {
log("[Task#getCheckpointService] service initialized")

try {
const isCheckpointNeeded =
typeof cline.clineMessages.find(({ say }) => say === "checkpoint_saved") === "undefined"

cline.checkpointService = service
cline.checkpointServiceInitializing = false

if (isCheckpointNeeded) {
log("[Task#getCheckpointService] no checkpoints found, saving initial checkpoint")
checkpointSave(cline)
}
} catch (err) {
log("[Task#getCheckpointService] caught error in on('initialize'), disabling checkpoints")
cline.enableCheckpoints = false
}
cline.checkpointServiceInitializing = false
})

service.on("checkpoint", ({ isFirst, fromHash: from, toHash: to }) => {
Expand All @@ -153,11 +154,12 @@ async function checkGitInstallation(
})

log("[Task#getCheckpointService] initializing shadow git")

service.initShadowGit().catch((err) => {
try {
await service.initShadowGit()
} catch (err) {
log(`[Task#getCheckpointService] initShadowGit -> ${err.message}`)
cline.enableCheckpoints = false
})
}
} catch (err) {
log(`[Task#getCheckpointService] Unexpected error during Git check: ${err.message}`)
console.error("Git check error:", err)
Expand All @@ -166,33 +168,8 @@ async function checkGitInstallation(
}
}

async function getInitializedCheckpointService(
cline: Task,
{ interval = 250, timeout = 15_000 }: { interval?: number; timeout?: number } = {},
) {
const service = getCheckpointService(cline)

if (!service || service.isInitialized) {
return service
}

try {
await pWaitFor(
() => {
console.log("[Task#getCheckpointService] waiting for service to initialize")
return service.isInitialized
},
{ interval, timeout },
)

return service
} catch (err) {
return undefined
}
}

export async function checkpointSave(cline: Task, force = false) {
const service = getCheckpointService(cline)
const service = await getCheckpointService(cline)

if (!service) {
return
Expand Down Expand Up @@ -221,7 +198,7 @@ export type CheckpointRestoreOptions = {
}

export async function checkpointRestore(cline: Task, { ts, commitHash, mode }: CheckpointRestoreOptions) {
const service = await getInitializedCheckpointService(cline)
const service = await getCheckpointService(cline)

if (!service) {
return
Expand Down Expand Up @@ -289,25 +266,27 @@ export type CheckpointDiffOptions = {
}

export async function checkpointDiff(cline: Task, { ts, previousCommitHash, commitHash, mode }: CheckpointDiffOptions) {
const service = await getInitializedCheckpointService(cline)
const service = await getCheckpointService(cline)

if (!service) {
return
}

TelemetryService.instance.captureCheckpointDiffed(cline.taskId)

if (!previousCommitHash && mode === "checkpoint") {
const previousCheckpoint = cline.clineMessages
.filter(({ say }) => say === "checkpoint_saved")
.sort((a, b) => b.ts - a.ts)
.find((message) => message.ts < ts)
let prevHash = commitHash
let nextHash: string | undefined

previousCommitHash = previousCheckpoint?.text
const checkpoints = typeof service.getCheckpoints === "function" ? service.getCheckpoints() : []
const idx = checkpoints.indexOf(commitHash)
if (idx !== -1 && idx < checkpoints.length - 1) {
nextHash = checkpoints[idx + 1]
} else {
nextHash = undefined
}

try {
const changes = await service.getDiff({ from: previousCommitHash, to: commitHash })
const changes = await service.getDiff({ from: prevHash, to: nextHash })

if (!changes?.length) {
vscode.window.showInformationMessage("No changes found.")
Expand Down
2 changes: 2 additions & 0 deletions src/core/task/Task.ts
Original file line number Diff line number Diff line change
Expand Up @@ -247,6 +247,7 @@ export class Task extends EventEmitter<ClineEvents> {
isWaitingForFirstChunk = false
isStreaming = false
currentStreamingContentIndex = 0
currentStreamingDidCheckpoint = false
assistantMessageContent: AssistantMessageContent[] = []
presentAssistantMessageLocked = false
presentAssistantMessageHasPendingUpdates = false
Expand Down Expand Up @@ -1523,6 +1524,7 @@ export class Task extends EventEmitter<ClineEvents> {

// Reset streaming state.
this.currentStreamingContentIndex = 0
this.currentStreamingDidCheckpoint = false
this.assistantMessageContent = []
this.didCompleteReadingStream = false
this.userMessageContent = []
Expand Down
4 changes: 4 additions & 0 deletions src/services/checkpoints/ShadowCheckpointService.ts
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,10 @@ export abstract class ShadowCheckpointService extends EventEmitter {
return !!this.git
}

public getCheckpoints(): string[] {
return this._checkpoints.slice()
}

constructor(taskId: string, checkpointsDir: string, workspaceDir: string, log: (message: string) => void) {
super()

Expand Down
Loading
Loading