diff --git a/packages/types/src/global-settings.ts b/packages/types/src/global-settings.ts index 81c6ae6dfe..a39949f875 100644 --- a/packages/types/src/global-settings.ts +++ b/packages/types/src/global-settings.ts @@ -41,6 +41,7 @@ export const globalSettingsSchema = z.object({ lastShownAnnouncementId: z.string().optional(), customInstructions: z.string().optional(), taskHistory: z.array(historyItemSchema).optional(), + currentActiveTaskId: z.string().optional(), // Image generation settings (experimental) - flattened for simplicity openRouterImageApiKey: z.string().optional(), diff --git a/src/core/webview/ClineProvider.ts b/src/core/webview/ClineProvider.ts index 544922a187..379d142e4e 100644 --- a/src/core/webview/ClineProvider.ts +++ b/src/core/webview/ClineProvider.ts @@ -353,6 +353,9 @@ export class ClineProvider this.clineStack.push(task) task.emit(RooCodeEventName.TaskFocused) + // Persist the current active task ID + await this.updateGlobalState("currentActiveTaskId", task.taskId) + // Perform special setup provider specific tasks. await this.performPreparationTasks(task) @@ -417,6 +420,15 @@ export class ClineProvider // garbage collected. task = undefined } + + // Clear the current active task ID if no tasks remain + if (this.clineStack.length === 0) { + await this.updateGlobalState("currentActiveTaskId", undefined) + } else { + // Update to the new top task + const newCurrentTask = this.clineStack[this.clineStack.length - 1] + await this.updateGlobalState("currentActiveTaskId", newCurrentTask.taskId) + } } getTaskStackSize(): number { @@ -725,6 +737,22 @@ export class ClineProvider // If the extension is starting a new session, clear previous task state. await this.removeClineFromStack() + + // Attempt to restore the last active task if one was persisted + const lastActiveTaskId = this.getGlobalState("currentActiveTaskId") + if (lastActiveTaskId && typeof lastActiveTaskId === "string") { + try { + const { historyItem } = await this.getTaskWithId(lastActiveTaskId) + if (historyItem) { + await this.createTaskWithHistoryItem(historyItem) + this.log(`Restored last active task: ${lastActiveTaskId}`) + } + } catch (error) { + // Task may have been deleted or corrupted, clear the saved ID + this.log(`Failed to restore last active task ${lastActiveTaskId}: ${error}`) + await this.updateGlobalState("currentActiveTaskId", undefined) + } + } } public async createTaskWithHistoryItem(historyItem: HistoryItem & { rootTask?: Task; parentTask?: Task }) { diff --git a/src/core/webview/__tests__/ClineProvider.spec.ts b/src/core/webview/__tests__/ClineProvider.spec.ts index 400ce50468..696edcf3b2 100644 --- a/src/core/webview/__tests__/ClineProvider.spec.ts +++ b/src/core/webview/__tests__/ClineProvider.spec.ts @@ -359,9 +359,10 @@ describe("ClineProvider", () => { extensionUri: {} as vscode.Uri, globalState: { get: vi.fn().mockImplementation((key: string) => globalState[key]), - update: vi - .fn() - .mockImplementation((key: string, value: string | undefined) => (globalState[key] = value)), + update: vi.fn().mockImplementation((key: string, value: string | undefined) => { + globalState[key] = value + return Promise.resolve() + }), keys: vi.fn().mockImplementation(() => Object.keys(globalState)), }, secrets: { @@ -2262,7 +2263,17 @@ describe("Project MCP Settings", () => { onDidChangeVisibility: vi.fn(), } as unknown as vscode.WebviewView - provider = new ClineProvider(mockContext, mockOutputChannel, "sidebar", new ContextProxy(mockContext)) + // Create a mock ContextProxy with proper getValue/setValue methods + const mockContextProxy = new ContextProxy(mockContext) + // Mock the getValue method to use globalState + vi.spyOn(mockContextProxy, "getValue").mockImplementation((key: string) => { + return mockContext.globalState.get(key) + }) + vi.spyOn(mockContextProxy, "setValue").mockImplementation(async (key: string, value: any) => { + await mockContext.globalState.update(key, value) + }) + + provider = new ClineProvider(mockContext, mockOutputChannel, "sidebar", mockContextProxy) }) test.skip("handles openProjectMcpSettings message", async () => { @@ -3808,3 +3819,323 @@ describe("ClineProvider - Comprehensive Edit/Delete Edge Cases", () => { }) }) }) + +describe("ClineProvider - Task Persistence and Restoration", () => { + let provider: ClineProvider + let mockContext: vscode.ExtensionContext + let mockOutputChannel: vscode.OutputChannel + let mockWebviewView: vscode.WebviewView + let mockPostMessage: any + let defaultTaskOptions: TaskOptions + + beforeEach(() => { + vi.clearAllMocks() + + if (!TelemetryService.hasInstance()) { + TelemetryService.createInstance([]) + } + + const globalState: Record = { + mode: "code", + currentApiConfigName: "current-config", + } + + const secrets: Record = {} + + mockContext = { + extensionPath: "/test/path", + extensionUri: {} as vscode.Uri, + globalState: { + get: vi.fn().mockImplementation((key: string) => globalState[key]), + update: vi.fn().mockImplementation((key: string, value: string | undefined) => { + globalState[key] = value + return Promise.resolve() + }), + keys: vi.fn().mockImplementation(() => Object.keys(globalState)), + }, + secrets: { + get: vi.fn().mockImplementation((key: string) => secrets[key]), + store: vi.fn().mockImplementation((key: string, value: string | undefined) => (secrets[key] = value)), + delete: vi.fn().mockImplementation((key: string) => delete secrets[key]), + }, + subscriptions: [], + extension: { + packageJSON: { version: "1.0.0" }, + }, + globalStorageUri: { + fsPath: "/test/storage/path", + }, + } as unknown as vscode.ExtensionContext + + mockOutputChannel = { + appendLine: vi.fn(), + clear: vi.fn(), + dispose: vi.fn(), + } as unknown as vscode.OutputChannel + + mockPostMessage = vi.fn() + + mockWebviewView = { + webview: { + postMessage: mockPostMessage, + html: "", + options: {}, + onDidReceiveMessage: vi.fn(), + asWebviewUri: vi.fn(), + cspSource: "vscode-webview://test-csp-source", + }, + visible: true, + onDidDispose: vi.fn().mockImplementation((callback) => { + callback() + return { dispose: vi.fn() } + }), + onDidChangeVisibility: vi.fn().mockImplementation(() => ({ dispose: vi.fn() })), + } as unknown as vscode.WebviewView + + provider = new ClineProvider(mockContext, mockOutputChannel, "sidebar", new ContextProxy(mockContext)) + + defaultTaskOptions = { + provider, + apiConfiguration: { + apiProvider: "openrouter", + }, + } + + // Mock getMcpHub method + provider.getMcpHub = vi.fn().mockReturnValue({ + listTools: vi.fn().mockResolvedValue([]), + callTool: vi.fn().mockResolvedValue({ content: [] }), + listResources: vi.fn().mockResolvedValue([]), + readResource: vi.fn().mockResolvedValue({ contents: [] }), + getAllServers: vi.fn().mockReturnValue([]), + }) + }) + + describe("Task ID Persistence", () => { + test("persists current task ID when adding task to stack", async () => { + const mockTask = new Task(defaultTaskOptions) + Object.defineProperty(mockTask, "taskId", { value: "test-task-123", writable: true }) + + await provider.addClineToStack(mockTask) + + // Verify that currentActiveTaskId was saved to global state + expect(mockContext.globalState.update).toHaveBeenCalledWith("currentActiveTaskId", "test-task-123") + }) + + test("clears current task ID when removing last task from stack", async () => { + const mockTask = new Task(defaultTaskOptions) + Object.defineProperty(mockTask, "taskId", { value: "test-task-456", writable: true }) + + await provider.addClineToStack(mockTask) + await provider.removeClineFromStack() + + // Verify that currentActiveTaskId was cleared + expect(mockContext.globalState.update).toHaveBeenCalledWith("currentActiveTaskId", undefined) + }) + + test("updates current task ID when removing task from stack with remaining tasks", async () => { + const mockTask1 = new Task(defaultTaskOptions) + const mockTask2 = new Task(defaultTaskOptions) + Object.defineProperty(mockTask1, "taskId", { value: "task-1", writable: true }) + Object.defineProperty(mockTask2, "taskId", { value: "task-2", writable: true }) + + await provider.addClineToStack(mockTask1) + await provider.addClineToStack(mockTask2) + + // Remove the top task (task-2) + await provider.removeClineFromStack() + + // Verify that currentActiveTaskId was updated to the new top task + expect(mockContext.globalState.update).toHaveBeenLastCalledWith("currentActiveTaskId", "task-1") + }) + }) + + describe("Task Restoration on Startup", () => { + test("restores last active task when resolving webview", async () => { + // Set up a saved task ID in global state + const savedTaskId = "saved-task-789" + const globalState: Record = { + currentActiveTaskId: savedTaskId, + } + ;(mockContext.globalState.get as any).mockImplementation((key: string) => globalState[key]) + + // Also mock the ContextProxy getValue to return the saved task ID + const mockContextProxy = new ContextProxy(mockContext) + vi.spyOn(mockContextProxy, "getValue").mockImplementation((key: string) => { + if (key === "currentActiveTaskId") return savedTaskId + return mockContext.globalState.get(key) + }) + vi.spyOn(mockContextProxy, "setValue").mockImplementation(async (key: string, value: any) => { + await mockContext.globalState.update(key, value) + }) + + // Create provider with mocked context proxy + provider = new ClineProvider(mockContext, mockOutputChannel, "sidebar", mockContextProxy) + + // Mock getTaskWithId to return a valid history item + const mockHistoryItem = { + id: savedTaskId, + ts: Date.now(), + task: "Test task", + mode: "code", + number: 1, + tokensIn: 0, + tokensOut: 0, + totalCost: 0, + } + + ;(provider as any).getTaskWithId = vi.fn().mockResolvedValue({ + historyItem: mockHistoryItem, + taskDirPath: "/test/task/path", + apiConversationHistoryFilePath: "/test/api/history", + uiMessagesFilePath: "/test/ui/messages", + apiConversationHistory: [], + }) + + // Mock createTaskWithHistoryItem + const createTaskSpy = vi + .spyOn(provider, "createTaskWithHistoryItem") + .mockResolvedValue(new Task(defaultTaskOptions)) + + // Mock log method to verify logging + const logSpy = vi.spyOn(provider, "log") + + await provider.resolveWebviewView(mockWebviewView) + + // Verify that the task was restored + expect(provider.getTaskWithId).toHaveBeenCalledWith(savedTaskId) + expect(createTaskSpy).toHaveBeenCalledWith(mockHistoryItem) + expect(logSpy).toHaveBeenCalledWith(`Restored last active task: ${savedTaskId}`) + }) + + test("handles missing task gracefully when restoring", async () => { + // Set up a saved task ID that doesn't exist + const missingTaskId = "missing-task-999" + const globalState: Record = { + currentActiveTaskId: missingTaskId, + } + ;(mockContext.globalState.get as any).mockImplementation((key: string) => globalState[key]) + + // Also mock the ContextProxy getValue to return the missing task ID + const mockContextProxy = new ContextProxy(mockContext) + vi.spyOn(mockContextProxy, "getValue").mockImplementation((key: string) => { + if (key === "currentActiveTaskId") return missingTaskId + return mockContext.globalState.get(key) + }) + vi.spyOn(mockContextProxy, "setValue").mockImplementation(async (key: string, value: any) => { + await mockContext.globalState.update(key, value) + }) + + // Create provider with mocked context proxy + provider = new ClineProvider(mockContext, mockOutputChannel, "sidebar", mockContextProxy) + + // Mock getTaskWithId to throw an error (task not found) + ;(provider as any).getTaskWithId = vi.fn().mockRejectedValue(new Error("Task not found")) + + // Mock log method to verify error logging + const logSpy = vi.spyOn(provider, "log") + + await provider.resolveWebviewView(mockWebviewView) + + // Verify that the error was handled and logged + expect(provider.getTaskWithId).toHaveBeenCalledWith(missingTaskId) + expect(logSpy).toHaveBeenCalledWith( + expect.stringContaining(`Failed to restore last active task ${missingTaskId}`), + ) + + // Verify that the saved ID was cleared + expect(mockContext.globalState.update).toHaveBeenCalledWith("currentActiveTaskId", undefined) + }) + + test("does not attempt restoration when no saved task ID exists", async () => { + // No saved task ID in global state + const globalState: Record = {} + ;(mockContext.globalState.get as any).mockImplementation((key: string) => globalState[key]) + + // Mock getTaskWithId to ensure it's not called + ;(provider as any).getTaskWithId = vi.fn() + + // Mock createTaskWithHistoryItem to ensure it's not called + const createTaskSpy = vi.spyOn(provider, "createTaskWithHistoryItem") + + await provider.resolveWebviewView(mockWebviewView) + + // Verify that no restoration was attempted + expect(provider.getTaskWithId).not.toHaveBeenCalled() + expect(createTaskSpy).not.toHaveBeenCalled() + }) + + test("handles invalid task ID type gracefully", async () => { + // Set up an invalid task ID type (object instead of string) + const globalState: Record = { + currentActiveTaskId: { invalid: "object" }, + } + ;(mockContext.globalState.get as any).mockImplementation((key: string) => globalState[key]) + + // Mock getTaskWithId to ensure it's not called with invalid data + ;(provider as any).getTaskWithId = vi.fn() + + await provider.resolveWebviewView(mockWebviewView) + + // Verify that no restoration was attempted with invalid data + expect(provider.getTaskWithId).not.toHaveBeenCalled() + }) + }) + + describe("Task Stack Management with Persistence", () => { + test("maintains correct task ID persistence through multiple stack operations", async () => { + const task1 = new Task(defaultTaskOptions) + const task2 = new Task(defaultTaskOptions) + const task3 = new Task(defaultTaskOptions) + + Object.defineProperty(task1, "taskId", { value: "task-1", writable: true }) + Object.defineProperty(task2, "taskId", { value: "task-2", writable: true }) + Object.defineProperty(task3, "taskId", { value: "task-3", writable: true }) + + // Add tasks to stack + await provider.addClineToStack(task1) + expect(mockContext.globalState.update).toHaveBeenCalledWith("currentActiveTaskId", "task-1") + + await provider.addClineToStack(task2) + expect(mockContext.globalState.update).toHaveBeenCalledWith("currentActiveTaskId", "task-2") + + await provider.addClineToStack(task3) + expect(mockContext.globalState.update).toHaveBeenCalledWith("currentActiveTaskId", "task-3") + + // Remove tasks and verify ID updates + await provider.removeClineFromStack() // Remove task-3 + expect(mockContext.globalState.update).toHaveBeenCalledWith("currentActiveTaskId", "task-2") + + await provider.removeClineFromStack() // Remove task-2 + expect(mockContext.globalState.update).toHaveBeenCalledWith("currentActiveTaskId", "task-1") + + await provider.removeClineFromStack() // Remove task-1 + expect(mockContext.globalState.update).toHaveBeenCalledWith("currentActiveTaskId", undefined) + }) + + test("handles finishSubTask with task ID persistence", async () => { + const parentTask = new Task(defaultTaskOptions) + const childTask = new Task(defaultTaskOptions) + + Object.defineProperty(parentTask, "taskId", { value: "parent-task", writable: true }) + Object.defineProperty(childTask, "taskId", { value: "child-task", writable: true }) + + // Set up parent-child relationship + ;(childTask as any).parentTask = parentTask + ;(parentTask as any).completeSubtask = vi.fn() + + // Add tasks to stack + await provider.addClineToStack(parentTask) + await provider.addClineToStack(childTask) + + // Finish the subtask + await provider.finishSubTask("Subtask completed") + + // Verify that the current task ID was updated to parent + expect(mockContext.globalState.update).toHaveBeenLastCalledWith("currentActiveTaskId", "parent-task") + + // Verify parent's completeSubtask was called + expect(parentTask.completeSubtask).toHaveBeenCalledWith("Subtask completed") + }) + }) +})