diff --git a/index.ts b/index.ts index baee7a7..cf66fc5 100644 --- a/index.ts +++ b/index.ts @@ -28,12 +28,9 @@ const plugin: Plugin = (async (ctx) => { const janitor = new Janitor( ctx.client, - state.prunedIds, - state.stats, + state, logger, - state.toolParameters, config.protectedTools, - state.model, config.model, config.showModelErrorToasts, config.strictModelSelection, diff --git a/lib/fetch-wrapper/gemini.ts b/lib/fetch-wrapper/gemini.ts index f2697a8..6e18884 100644 --- a/lib/fetch-wrapper/gemini.ts +++ b/lib/fetch-wrapper/gemini.ts @@ -53,7 +53,7 @@ export async function handleGemini( return { modified, body } } - const { allSessions, allPrunedIds } = await getAllPrunedIds(ctx.client, ctx.state) + const { allSessions, allPrunedIds } = await getAllPrunedIds(ctx.client, ctx.state, ctx.logger) if (allPrunedIds.size === 0) { return { modified, body } diff --git a/lib/fetch-wrapper/openai-chat.ts b/lib/fetch-wrapper/openai-chat.ts index 2483baf..14214ad 100644 --- a/lib/fetch-wrapper/openai-chat.ts +++ b/lib/fetch-wrapper/openai-chat.ts @@ -61,7 +61,7 @@ export async function handleOpenAIChatAndAnthropic( return false }) - const { allSessions, allPrunedIds } = await getAllPrunedIds(ctx.client, ctx.state) + const { allSessions, allPrunedIds } = await getAllPrunedIds(ctx.client, ctx.state, ctx.logger) if (toolMessages.length === 0 || allPrunedIds.size === 0) { return { modified, body } diff --git a/lib/fetch-wrapper/openai-responses.ts b/lib/fetch-wrapper/openai-responses.ts index 0725d22..1674712 100644 --- a/lib/fetch-wrapper/openai-responses.ts +++ b/lib/fetch-wrapper/openai-responses.ts @@ -55,7 +55,7 @@ export async function handleOpenAIResponses( return { modified, body } } - const { allSessions, allPrunedIds } = await getAllPrunedIds(ctx.client, ctx.state) + const { allSessions, allPrunedIds } = await getAllPrunedIds(ctx.client, ctx.state, ctx.logger) if (allPrunedIds.size === 0) { return { modified, body } diff --git a/lib/fetch-wrapper/types.ts b/lib/fetch-wrapper/types.ts index 91182d0..f23baf9 100644 --- a/lib/fetch-wrapper/types.ts +++ b/lib/fetch-wrapper/types.ts @@ -1,4 +1,4 @@ -import type { PluginState } from "../state" +import { type PluginState, ensureSessionRestored } from "../state" import type { Logger } from "../logger" import type { ToolTracker } from "../synth-instruction" import type { PluginConfig } from "../config" @@ -36,22 +36,19 @@ export interface PrunedIdData { allPrunedIds: Set } -/** - * Get all pruned IDs across all non-subagent sessions. - */ export async function getAllPrunedIds( client: any, - state: PluginState + state: PluginState, + logger?: Logger ): Promise { const allSessions = await client.session.list() const allPrunedIds = new Set() - if (allSessions.data) { - for (const session of allSessions.data) { - if (session.parentID) continue - const prunedIds = state.prunedIds.get(session.id) ?? [] - prunedIds.forEach((id: string) => allPrunedIds.add(id)) - } + const currentSession = getMostRecentActiveSession(allSessions) + if (currentSession) { + await ensureSessionRestored(state, currentSession.id, logger) + const prunedIds = state.prunedIds.get(currentSession.id) ?? [] + prunedIds.forEach((id: string) => allPrunedIds.add(id)) } return { allSessions, allPrunedIds } diff --git a/lib/janitor.ts b/lib/janitor.ts index 5a2b2f0..643beae 100644 --- a/lib/janitor.ts +++ b/lib/janitor.ts @@ -1,11 +1,14 @@ import { z } from "zod" import type { Logger } from "./logger" import type { PruningStrategy } from "./config" +import type { PluginState } from "./state" import { buildAnalysisPrompt } from "./prompt" import { selectModel, extractModelFromSession } from "./model-selector" import { estimateTokensBatch, formatTokenCount } from "./tokenizer" import { detectDuplicates } from "./deduplicator" import { extractParameterKey } from "./display-utils" +import { saveSessionState } from "./state-persistence" +import { ensureSessionRestored } from "./state" export interface SessionStats { totalToolsPruned: number @@ -29,20 +32,28 @@ export interface PruningOptions { } export class Janitor { + private prunedIdsState: Map + private statsState: Map + private toolParametersCache: Map + private modelCache: Map + constructor( private client: any, - private prunedIdsState: Map, - private statsState: Map, + private state: PluginState, private logger: Logger, - private toolParametersCache: Map, private protectedTools: string[], - private modelCache: Map, private configModel?: string, private showModelErrorToasts: boolean = true, private strictModelSelection: boolean = false, private pruningSummary: "off" | "minimal" | "detailed" = "detailed", private workingDirectory?: string - ) { } + ) { + // Bind state references for convenience + this.prunedIdsState = state.prunedIds + this.statsState = state.stats + this.toolParametersCache = state.toolParameters + this.modelCache = state.model + } private async sendIgnoredMessage(sessionID: string, text: string, agent?: string) { try { @@ -85,6 +96,9 @@ export class Janitor { return null } + // Ensure persisted state is restored before processing + await ensureSessionRestored(this.state, sessionID, this.logger) + const [sessionInfoResponse, messagesResponse] = await Promise.all([ this.client.session.get({ path: { id: sessionID } }), this.client.session.messages({ path: { id: sessionID }, query: { limit: 100 } }) @@ -97,8 +111,6 @@ export class Janitor { return null } - // Extract the current agent from the last user message to preserve agent context - // Following the same pattern as OpenCode's server.ts let currentAgent: string | undefined = undefined for (let i = messages.length - 1; i >= 0; i--) { const msg = messages[i] @@ -330,6 +342,11 @@ export class Janitor { const allPrunedIds = [...new Set([...alreadyPrunedIds, ...finalPrunedIds])] this.prunedIdsState.set(sessionID, allPrunedIds) + const sessionName = sessionInfo?.title + saveSessionState(sessionID, new Set(allPrunedIds), sessionStats, this.logger, sessionName).catch(err => { + this.logger.error("janitor", "Failed to persist state", { error: err.message }) + }) + const prunedCount = finalNewlyPrunedIds.length const keptCount = candidateCount - prunedCount const hasBoth = deduplicatedIds.length > 0 && llmPrunedIds.length > 0 diff --git a/lib/state-persistence.ts b/lib/state-persistence.ts new file mode 100644 index 0000000..384e610 --- /dev/null +++ b/lib/state-persistence.ts @@ -0,0 +1,110 @@ +/** + * State persistence module for DCP plugin. + * Persists pruned tool IDs across sessions so they survive OpenCode restarts. + * Storage location: ~/.local/share/opencode/storage/plugin/dcp/{sessionId}.json + */ + +import * as fs from "fs/promises"; +import { existsSync } from "fs"; +import { homedir } from "os"; +import { join } from "path"; +import type { SessionStats } from "./janitor"; +import type { Logger } from "./logger"; + +export interface PersistedSessionState { + sessionName?: string; + prunedIds: string[]; + stats: SessionStats; + lastUpdated: string; +} + +const STORAGE_DIR = join( + homedir(), + ".local", + "share", + "opencode", + "storage", + "plugin", + "dcp" +); + +async function ensureStorageDir(): Promise { + if (!existsSync(STORAGE_DIR)) { + await fs.mkdir(STORAGE_DIR, { recursive: true }); + } +} + +function getSessionFilePath(sessionId: string): string { + return join(STORAGE_DIR, `${sessionId}.json`); +} + +export async function saveSessionState( + sessionId: string, + prunedIds: Set, + stats: SessionStats, + logger?: Logger, + sessionName?: string +): Promise { + try { + await ensureStorageDir(); + + const state: PersistedSessionState = { + ...(sessionName && { sessionName }), + prunedIds: Array.from(prunedIds), + stats, + lastUpdated: new Date().toISOString(), + }; + + const filePath = getSessionFilePath(sessionId); + const content = JSON.stringify(state, null, 2); + await fs.writeFile(filePath, content, "utf-8"); + + logger?.info("persist", "Saved session state to disk", { + sessionId: sessionId.slice(0, 8), + prunedIds: prunedIds.size, + totalTokensSaved: stats.totalTokensSaved, + }); + } catch (error: any) { + logger?.error("persist", "Failed to save session state", { + sessionId: sessionId.slice(0, 8), + error: error?.message, + }); + } +} + +export async function loadSessionState( + sessionId: string, + logger?: Logger +): Promise { + try { + const filePath = getSessionFilePath(sessionId); + + if (!existsSync(filePath)) { + return null; + } + + const content = await fs.readFile(filePath, "utf-8"); + const state = JSON.parse(content) as PersistedSessionState; + + if (!state || !Array.isArray(state.prunedIds) || !state.stats) { + logger?.warn("persist", "Invalid session state file, ignoring", { + sessionId: sessionId.slice(0, 8), + }); + return null; + } + + logger?.info("persist", "Loaded session state from disk", { + sessionId: sessionId.slice(0, 8), + prunedIds: state.prunedIds.length, + totalTokensSaved: state.stats.totalTokensSaved, + }); + + return state; + } catch (error: any) { + logger?.warn("persist", "Failed to load session state", { + sessionId: sessionId.slice(0, 8), + error: error?.message, + }); + return null; + } +} diff --git a/lib/state.ts b/lib/state.ts index 9bad263..5145eae 100644 --- a/lib/state.ts +++ b/lib/state.ts @@ -1,4 +1,6 @@ import type { SessionStats } from "./janitor" +import type { Logger } from "./logger" +import { loadSessionState } from "./state-persistence" /** * Centralized state management for the DCP plugin. @@ -18,6 +20,8 @@ export interface PluginState { * Key: sessionID, Value: Map where positionKey is "toolName:index" */ googleToolCallMapping: Map> + /** Set of session IDs that have been restored from disk */ + restoredSessions: Set } export interface ToolParameterEntry { @@ -40,5 +44,32 @@ export function createPluginState(): PluginState { toolParameters: new Map(), model: new Map(), googleToolCallMapping: new Map(), + restoredSessions: new Set(), + } +} + +export async function ensureSessionRestored( + state: PluginState, + sessionId: string, + logger?: Logger +): Promise { + if (state.restoredSessions.has(sessionId)) { + return + } + + state.restoredSessions.add(sessionId) + + const persisted = await loadSessionState(sessionId, logger) + if (persisted) { + if (!state.prunedIds.has(sessionId)) { + state.prunedIds.set(sessionId, persisted.prunedIds) + logger?.info("persist", "Restored prunedIds from disk", { + sessionId: sessionId.slice(0, 8), + count: persisted.prunedIds.length, + }) + } + if (!state.stats.has(sessionId)) { + state.stats.set(sessionId, persisted.stats) + } } }