diff --git a/index.ts b/index.ts index f30b041..9a34d60 100644 --- a/index.ts +++ b/index.ts @@ -2,8 +2,8 @@ import type { Plugin } from "@opencode-ai/plugin" import { getConfig } from "./lib/config" import { Logger } from "./lib/logger" import { createSessionState } from "./lib/state" -import { createPruneTool } from "./lib/strategies/prune-tool" -import { createChatMessageTransformHandler } from "./lib/hooks" +import { createPruneTool } from "./lib/strategies" +import { createChatMessageTransformHandler, createEventHandler } from "./lib/hooks" const plugin: Plugin = (async (ctx) => { const config = getConfig(ctx) @@ -54,6 +54,7 @@ const plugin: Plugin = (async (ctx) => { logger.info("Added 'prune' to experimental.primary_tools via config mutation") } }, + event: createEventHandler(ctx.client, config, state, logger, ctx.directory), } }) satisfies Plugin diff --git a/lib/hooks.ts b/lib/hooks.ts index 88acd88..a48183e 100644 --- a/lib/hooks.ts +++ b/lib/hooks.ts @@ -5,6 +5,7 @@ import { syncToolCache } from "./state/tool-cache" import { deduplicate } from "./strategies" import { prune, insertPruneToolContext } from "./messages" import { checkSession } from "./state" +import { runOnIdle } from "./strategies/on-idle" export function createChatMessageTransformHandler( @@ -24,11 +25,48 @@ export function createChatMessageTransformHandler( syncToolCache(state, config, logger, output.messages); - - deduplicate(client, state, logger, config, output.messages) + deduplicate(state, logger, config, output.messages) prune(state, logger, config, output.messages) insertPruneToolContext(state, config, logger, output.messages) } } + +export function createEventHandler( + client: any, + config: PluginConfig, + state: SessionState, + logger: Logger, + workingDirectory?: string +) { + return async ( + { event }: { event: any } + ) => { + if (state.sessionId === null || state.isSubAgent) { + return + } + + if (event.type === "session.status" && event.properties.status.type === "idle") { + if (!config.strategies.onIdle.enabled) { + return + } + if (state.lastToolPrune) { + logger.info("Skipping OnIdle pruning - last tool was prune") + return + } + + try { + await runOnIdle( + client, + state, + logger, + config, + workingDirectory + ) + } catch (err: any) { + logger.error("OnIdle pruning failed", { error: err.message }) + } + } + } +} diff --git a/lib/messages/prune.ts b/lib/messages/prune.ts index cca70d1..cb0ba18 100644 --- a/lib/messages/prune.ts +++ b/lib/messages/prune.ts @@ -40,6 +40,10 @@ export const insertPruneToolContext = ( logger: Logger, messages: WithParts[] ): void => { + if (!config.strategies.pruneTool.enabled) { + return + } + const lastUserMessage = getLastUserMessage(messages) if (!lastUserMessage || lastUserMessage.info.role !== 'user') { return @@ -48,7 +52,7 @@ export const insertPruneToolContext = ( const prunableToolsList = buildPrunableToolsList(state, config, logger, messages) let nudgeString = "" - if (config.strategies.pruneTool.nudge.enabled && state.nudgeCounter >= config.strategies.pruneTool.nudge.frequency) { + if (state.nudgeCounter >= config.strategies.pruneTool.nudge.frequency) { logger.info("Inserting prune nudge message") nudgeString = "\n" + NUDGE_STRING } diff --git a/lib/state/state.ts b/lib/state/state.ts index 8edbf52..fc69883 100644 --- a/lib/state/state.ts +++ b/lib/state/state.ts @@ -43,7 +43,8 @@ export function createSessionState(): SessionState { totalPruneTokens: 0, }, toolParameters: new Map(), - nudgeCounter: 0 + nudgeCounter: 0, + lastToolPrune: false } } @@ -59,6 +60,7 @@ export function resetSessionState(state: SessionState): void { } state.toolParameters.clear() state.nudgeCounter = 0 + state.lastToolPrune = false } export async function ensureSessionInitialized( diff --git a/lib/state/tool-cache.ts b/lib/state/tool-cache.ts index d86005f..854aaaa 100644 --- a/lib/state/tool-cache.ts +++ b/lib/state/tool-cache.ts @@ -17,6 +17,7 @@ export async function syncToolCache( ): Promise { try { logger.info("Syncing tool parameters from OpenCode messages") + for (const msg of messages) { for (const part of msg.parts) { if (part.type !== "tool" || !part.callID || state.toolParameters.has(part.callID)) { @@ -36,6 +37,9 @@ export async function syncToolCache( if (!config.strategies.pruneTool.protectedTools.includes(part.tool)) { state.nudgeCounter++ } + + state.lastToolPrune = part.tool === "prune" + logger.info("lastToolPrune=" + String(state.lastToolPrune)) } } diff --git a/lib/state/types.ts b/lib/state/types.ts index f7353e0..89fc8e7 100644 --- a/lib/state/types.ts +++ b/lib/state/types.ts @@ -30,4 +30,5 @@ export interface SessionState { stats: SessionStats toolParameters: Map nudgeCounter: number + lastToolPrune: boolean } diff --git a/lib/strategies/deduplication.ts b/lib/strategies/deduplication.ts index f0887c2..f58a13a 100644 --- a/lib/strategies/deduplication.ts +++ b/lib/strategies/deduplication.ts @@ -9,7 +9,6 @@ import { calculateTokensSaved } from "../utils" * Modifies the session state in place to add pruned tool call IDs. */ export const deduplicate = ( - client: any, state: SessionState, logger: Logger, config: PluginConfig, diff --git a/lib/strategies/index.ts b/lib/strategies/index.ts index 0bd83ff..105d9c8 100644 --- a/lib/strategies/index.ts +++ b/lib/strategies/index.ts @@ -1,2 +1,3 @@ export { deduplicate } from "./deduplication" - +export { runOnIdle } from "./on-idle" +export { createPruneTool } from "./prune-tool" diff --git a/lib/strategies/on-idle.ts b/lib/strategies/on-idle.ts index e69de29..a0f07d0 100644 --- a/lib/strategies/on-idle.ts +++ b/lib/strategies/on-idle.ts @@ -0,0 +1,317 @@ +import { z } from "zod" +import type { SessionState, WithParts, ToolParameterEntry } from "../state" +import type { Logger } from "../logger" +import type { PluginConfig } from "../config" +import { buildAnalysisPrompt } from "../prompt" +import { selectModel, extractModelFromSession, ModelInfo } from "../model-selector" +import { calculateTokensSaved, findCurrentAgent } from "../utils" +import { saveSessionState } from "../state/persistence" +import { sendUnifiedNotification } from "../ui/notification" + +export interface OnIdleResult { + prunedCount: number + tokensSaved: number + prunedIds: string[] +} + +/** + * Parse messages to extract tool information. + */ +function parseMessages( + messages: WithParts[], + toolParametersCache: Map +): { + toolCallIds: string[] + toolMetadata: Map +} { + const toolCallIds: string[] = [] + const toolMetadata = new Map() + + for (const msg of messages) { + if (msg.parts) { + for (const part of msg.parts) { + if (part.type === "tool" && part.callID) { + toolCallIds.push(part.callID) + + const cachedData = toolParametersCache.get(part.callID) + const parameters = cachedData?.parameters ?? part.state?.input ?? {} + + toolMetadata.set(part.callID, { + tool: part.tool, + parameters: parameters, + status: part.state?.status, + error: part.state?.status === "error" ? part.state.error : undefined + }) + } + } + } + } + + return { toolCallIds, toolMetadata } +} + +/** + * Replace pruned tool outputs in messages for LLM analysis. + */ +function replacePrunedToolOutputs(messages: WithParts[], prunedIds: string[]): WithParts[] { + if (prunedIds.length === 0) return messages + + const prunedIdsSet = new Set(prunedIds) + + return messages.map(msg => { + if (!msg.parts) return msg + + return { + ...msg, + parts: msg.parts.map((part: any) => { + if (part.type === 'tool' && + part.callID && + prunedIdsSet.has(part.callID) && + part.state?.output) { + return { + ...part, + state: { + ...part.state, + output: '[Output removed to save context - information superseded or no longer needed]' + } + } + } + return part + }) + } + }) as WithParts[] +} + +/** + * Run LLM analysis to determine which tool calls can be pruned. + */ +async function runLlmAnalysis( + client: any, + state: SessionState, + logger: Logger, + config: PluginConfig, + messages: WithParts[], + unprunedToolCallIds: string[], + alreadyPrunedIds: string[], + toolMetadata: Map, + workingDirectory?: string +): Promise { + const protectedToolCallIds: string[] = [] + const prunableToolCallIds = unprunedToolCallIds.filter(id => { + const metadata = toolMetadata.get(id) + if (metadata && config.strategies.onIdle.protectedTools.includes(metadata.tool)) { + protectedToolCallIds.push(id) + return false + } + return true + }) + + if (prunableToolCallIds.length === 0) { + return [] + } + + // Get model info from messages + let validModelInfo: ModelInfo | undefined = undefined + if (messages.length > 0) { + const lastMessage = messages[messages.length - 1] + const model = (lastMessage.info as any)?.model + if (model?.providerID && model?.modelID) { + validModelInfo = { + providerID: model.providerID, + modelID: model.modelID + } + } + } + + const modelSelection = await selectModel( + validModelInfo, + logger, + config.strategies.onIdle.model, + workingDirectory + ) + + logger.info(`OnIdle Model: ${modelSelection.modelInfo.providerID}/${modelSelection.modelInfo.modelID}`, { + source: modelSelection.source + }) + + if (modelSelection.failedModel && config.strategies.onIdle.showModelErrorToasts) { + const skipAi = modelSelection.source === 'fallback' && config.strategies.onIdle.strictModelSelection + try { + await client.tui.showToast({ + body: { + title: skipAi ? "DCP: AI analysis skipped" : "DCP: Model fallback", + message: skipAi + ? `${modelSelection.failedModel.providerID}/${modelSelection.failedModel.modelID} failed\nAI analysis skipped (strictModelSelection enabled)` + : `${modelSelection.failedModel.providerID}/${modelSelection.failedModel.modelID} failed\nUsing ${modelSelection.modelInfo.providerID}/${modelSelection.modelInfo.modelID}`, + variant: "info", + duration: 5000 + } + }) + } catch { + // Ignore toast errors + } + } + + if (modelSelection.source === 'fallback' && config.strategies.onIdle.strictModelSelection) { + logger.info("Skipping AI analysis (fallback model, strictModelSelection enabled)") + return [] + } + + const { generateObject } = await import('ai') + + const sanitizedMessages = replacePrunedToolOutputs(messages, alreadyPrunedIds) + + const analysisPrompt = buildAnalysisPrompt( + prunableToolCallIds, + sanitizedMessages, + alreadyPrunedIds, + protectedToolCallIds + ) + + const result = await generateObject({ + model: modelSelection.model, + schema: z.object({ + pruned_tool_call_ids: z.array(z.string()), + reasoning: z.string(), + }), + prompt: analysisPrompt + }) + + const rawLlmPrunedIds = result.object.pruned_tool_call_ids + const llmPrunedIds = rawLlmPrunedIds.filter(id => + prunableToolCallIds.includes(id) + ) + + // Always log LLM output as debug + const reasoning = result.object.reasoning.replace(/\n+/g, ' ').replace(/\s+/g, ' ').trim() + logger.debug(`OnIdle LLM output`, { + pruned_tool_call_ids: rawLlmPrunedIds, + reasoning: reasoning + }) + + return llmPrunedIds +} + +/** + * Run the onIdle pruning strategy. + * This is called when the session transitions to idle state. + */ +export async function runOnIdle( + client: any, + state: SessionState, + logger: Logger, + config: PluginConfig, + workingDirectory?: string +): Promise { + try { + if (!state.sessionId) { + return null + } + + const sessionId = state.sessionId + + // Fetch session info and messages + const [sessionInfoResponse, messagesResponse] = await Promise.all([ + client.session.get({ path: { id: sessionId } }), + client.session.messages({ path: { id: sessionId }}) + ]) + + const sessionInfo = sessionInfoResponse.data + const messages: WithParts[] = messagesResponse.data || messagesResponse + + if (!messages || messages.length < 3) { + return null + } + + const currentAgent = findCurrentAgent(messages) + const { toolCallIds, toolMetadata } = parseMessages(messages, state.toolParameters) + + const alreadyPrunedIds = state.prune.toolIds + const unprunedToolCallIds = toolCallIds.filter(id => !alreadyPrunedIds.includes(id)) + + if (unprunedToolCallIds.length === 0) { + return null + } + + // Count prunable tools (excluding protected) + const candidateCount = unprunedToolCallIds.filter(id => { + const metadata = toolMetadata.get(id) + return !metadata || !config.strategies.onIdle.protectedTools.includes(metadata.tool) + }).length + + if (candidateCount === 0) { + return null + } + + // Run LLM analysis + const llmPrunedIds = await runLlmAnalysis( + client, + state, + logger, + config, + messages, + unprunedToolCallIds, + alreadyPrunedIds, + toolMetadata, + workingDirectory + ) + + const newlyPrunedIds = llmPrunedIds.filter(id => !alreadyPrunedIds.includes(id)) + + if (newlyPrunedIds.length === 0) { + return null + } + + // Log the tool IDs being pruned with their tool names + for (const id of newlyPrunedIds) { + const metadata = toolMetadata.get(id) + const toolName = metadata?.tool || 'unknown' + logger.info(`OnIdle pruning tool: ${toolName}`, { callID: id }) + } + + // Update state + const allPrunedIds = [...new Set([...alreadyPrunedIds, ...newlyPrunedIds])] + state.prune.toolIds = allPrunedIds + + state.stats.pruneTokenCounter += calculateTokensSaved(messages, newlyPrunedIds) + + // Build tool metadata map for notification + const prunedToolMetadata = new Map() + for (const id of newlyPrunedIds) { + const metadata = toolMetadata.get(id) + if (metadata) { + prunedToolMetadata.set(id, metadata) + } + } + + // Send notification + await sendUnifiedNotification( + client, + logger, + config, + state, + sessionId, + newlyPrunedIds, + prunedToolMetadata, + undefined, // reason + currentAgent, + workingDirectory || "" + ) + + state.stats.totalPruneTokens += state.stats.pruneTokenCounter + state.stats.pruneTokenCounter = 0 + state.nudgeCounter = 0 + state.lastToolPrune = true + + // Persist state + const sessionName = sessionInfo?.title + saveSessionState(state, logger, sessionName).catch(err => { + logger.error("Failed to persist state", { error: err.message }) + }) + + logger.info(`OnIdle: Pruned ${newlyPrunedIds.length}/${candidateCount} tools`) + } catch (error: any) { + logger.error("OnIdle analysis failed", { error: error.message }) + return null + } +}