diff --git a/lib/hooks.ts b/lib/hooks.ts index 72fda69..f24d52c 100644 --- a/lib/hooks.ts +++ b/lib/hooks.ts @@ -48,6 +48,11 @@ export function createEventHandler( return } + if (event.type === "session.compacted") { + logger.info("Session compaction detected - updating state") + state.lastCompaction = Date.now() + } + if (event.type === "session.status" && event.properties.status.type === "idle") { if (!config.strategies.onIdle.enabled) { return diff --git a/lib/messages/prune.ts b/lib/messages/prune.ts index 2ecb2bc..918056e 100644 --- a/lib/messages/prune.ts +++ b/lib/messages/prune.ts @@ -3,7 +3,7 @@ import type { Logger } from "../logger" import type { PluginConfig } from "../config" import { loadPrompt } from "../prompt" import { extractParameterKey, buildToolIdList } from "./utils" -import { getLastUserMessage } from "../shared-utils" +import { getLastUserMessage, isMessageCompacted } from "../shared-utils" import { UserMessage } from "@opencode-ai/sdk" const PRUNED_TOOL_INPUT_REPLACEMENT = '[Input removed to save context]' @@ -17,7 +17,7 @@ const buildPrunableToolsList = ( messages: WithParts[], ): string => { const lines: string[] = [] - const toolIdList: string[] = buildToolIdList(messages) + const toolIdList: string[] = buildToolIdList(state, messages, logger) state.toolParameters.forEach((toolParameterEntry, toolCallId) => { if (state.prune.toolIds.includes(toolCallId)) { @@ -26,9 +26,6 @@ const buildPrunableToolsList = ( if (config.strategies.pruneTool.protectedTools.includes(toolParameterEntry.tool)) { return } - if (toolParameterEntry.compacted) { - return - } const numericId = toolIdList.indexOf(toolCallId) const paramKey = extractParameterKey(toolParameterEntry.tool, toolParameterEntry.parameters) const description = paramKey ? `${toolParameterEntry.tool}, ${paramKey}` : toolParameterEntry.tool @@ -111,6 +108,10 @@ const pruneToolOutputs = ( messages: WithParts[] ): void => { for (const msg of messages) { + if (isMessageCompacted(state, msg)) { + continue + } + for (const part of msg.parts) { if (part.type !== 'tool') { continue diff --git a/lib/messages/utils.ts b/lib/messages/utils.ts index 1ae2c7d..48f453c 100644 --- a/lib/messages/utils.ts +++ b/lib/messages/utils.ts @@ -1,4 +1,6 @@ -import type { WithParts } from "../state" +import { Logger } from "../logger" +import { isMessageCompacted } from "../shared-utils" +import type { SessionState, WithParts } from "../state" /** * Extracts a human-readable key from tool metadata for display purposes. @@ -71,9 +73,16 @@ export const extractParameterKey = (tool: string, parameters: any): string => { return paramStr.substring(0, 50) } -export function buildToolIdList(messages: WithParts[]): string[] { +export function buildToolIdList( + state: SessionState, + messages: WithParts[], + logger: Logger +): string[] { const toolIds: string[] = [] for (const msg of messages) { + if (isMessageCompacted(state, msg)) { + continue + } if (msg.parts) { for (const part of msg.parts) { if (part.type === 'tool' && part.callID && part.tool) { diff --git a/lib/shared-utils.ts b/lib/shared-utils.ts index cdcfb5b..9cb60a1 100644 --- a/lib/shared-utils.ts +++ b/lib/shared-utils.ts @@ -1,4 +1,12 @@ -import { WithParts } from "./state" +import { Logger } from "./logger" +import { SessionState, WithParts } from "./state" + +export const isMessageCompacted = ( + state: SessionState, + msg: WithParts +): boolean => { + return msg.info.time.created < state.lastCompaction +} export const getLastUserMessage = ( messages: WithParts[] @@ -11,3 +19,13 @@ export const getLastUserMessage = ( } return null } + +export const checkForCompaction = ( + state: SessionState, + messages: WithParts[], + logger: Logger +): void => { + for (const msg of messages) { + + } +} diff --git a/lib/state/persistence.ts b/lib/state/persistence.ts index 21f0092..89d6772 100644 --- a/lib/state/persistence.ts +++ b/lib/state/persistence.ts @@ -16,6 +16,7 @@ export interface PersistedSessionState { prune: Prune stats: SessionStats; lastUpdated: string; + lastCompacted: number } const STORAGE_DIR = join( @@ -55,6 +56,7 @@ export async function saveSessionState( prune: sessionState.prune, stats: sessionState.stats, lastUpdated: new Date().toISOString(), + lastCompacted: sessionState.lastCompaction }; const filePath = getSessionFilePath(sessionState.sessionId); @@ -99,8 +101,7 @@ export async function loadSessionState( } logger.info("Loaded session state from disk", { - sessionId: sessionId, - totalTokensSaved: state.stats.totalPruneTokens + sessionId: sessionId }); return state; diff --git a/lib/state/state.ts b/lib/state/state.ts index 19dc854..035f81b 100644 --- a/lib/state/state.ts +++ b/lib/state/state.ts @@ -41,7 +41,8 @@ export function createSessionState(): SessionState { }, toolParameters: new Map(), nudgeCounter: 0, - lastToolPrune: false + lastToolPrune: false, + lastCompaction: 0 } } @@ -58,6 +59,7 @@ export function resetSessionState(state: SessionState): void { state.toolParameters.clear() state.nudgeCounter = 0 state.lastToolPrune = false + state.lastCompaction = 0 } export async function ensureSessionInitialized( @@ -95,4 +97,5 @@ export async function ensureSessionInitialized( pruneTokenCounter: persisted.stats?.pruneTokenCounter || 0, totalPruneTokens: persisted.stats?.totalPruneTokens || 0, } + state.lastCompaction = persisted.lastCompacted || 0 } diff --git a/lib/state/tool-cache.ts b/lib/state/tool-cache.ts index a6140c7..ee2e2dc 100644 --- a/lib/state/tool-cache.ts +++ b/lib/state/tool-cache.ts @@ -1,6 +1,7 @@ import type { SessionState, ToolStatus, WithParts } from "./index" import type { Logger } from "../logger" import { PluginConfig } from "../config" +import { isMessageCompacted } from "../shared-utils" const MAX_TOOL_CACHE_SIZE = 1000 @@ -19,10 +20,17 @@ export async function syncToolCache( state.nudgeCounter = 0 for (const msg of messages) { + if (isMessageCompacted(state, msg)) { + continue + } + for (const part of msg.parts) { if (part.type !== "tool" || !part.callID) { continue } + if (state.toolParameters.has(part.callID)) { + continue + } if (part.tool === "prune") { state.nudgeCounter = 0 @@ -31,10 +39,6 @@ export async function syncToolCache( } state.lastToolPrune = part.tool === "prune" - if (state.toolParameters.has(part.callID)) { - continue - } - state.toolParameters.set( part.callID, { @@ -42,14 +46,12 @@ export async function syncToolCache( parameters: part.state?.input ?? {}, status: part.state.status as ToolStatus | undefined, error: part.state.status === "error" ? part.state.error : undefined, - compacted: part.state.status === "completed" && !!part.state.time.compacted, } ) + logger.info("Cached tool id: " + part.callID) } } - - // logger.info(`nudgeCounter=${state.nudgeCounter}, lastToolPrune=${state.lastToolPrune}`) - + logger.info("Synced cache - size: " + state.toolParameters.size) trimToolParametersCache(state) } catch (error) { logger.warn("Failed to sync tool parameters from OpenCode", { diff --git a/lib/state/types.ts b/lib/state/types.ts index e1b92a7..678bf29 100644 --- a/lib/state/types.ts +++ b/lib/state/types.ts @@ -12,7 +12,6 @@ export interface ToolParameterEntry { parameters: any status?: ToolStatus error?: string - compacted?: boolean } export interface SessionStats { @@ -32,4 +31,5 @@ export interface SessionState { toolParameters: Map nudgeCounter: number lastToolPrune: boolean + lastCompaction: number } diff --git a/lib/strategies/deduplication.ts b/lib/strategies/deduplication.ts index eaa9798..21c4be6 100644 --- a/lib/strategies/deduplication.ts +++ b/lib/strategies/deduplication.ts @@ -20,7 +20,7 @@ export const deduplicate = ( } // Build list of all tool call IDs from messages (chronological order) - const allToolIds = buildToolIdList(messages) + const allToolIds = buildToolIdList(state, messages, logger) if (allToolIds.length === 0) { return } @@ -68,7 +68,7 @@ export const deduplicate = ( } } - state.stats.totalPruneTokens += calculateTokensSaved(messages, newPruneIds) + state.stats.totalPruneTokens += calculateTokensSaved(state, messages, newPruneIds) if (newPruneIds.length > 0) { state.prune.toolIds.push(...newPruneIds) diff --git a/lib/strategies/on-idle.ts b/lib/strategies/on-idle.ts index 50dbb2f..f0870c2 100644 --- a/lib/strategies/on-idle.ts +++ b/lib/strategies/on-idle.ts @@ -7,6 +7,7 @@ import { selectModel, ModelInfo } from "../model-selector" import { saveSessionState } from "../state/persistence" import { sendUnifiedNotification } from "../ui/notification" import { calculateTokensSaved, getCurrentParams } from "./utils" +import { isMessageCompacted } from "../shared-utils" export interface OnIdleResult { prunedCount: number @@ -18,6 +19,7 @@ export interface OnIdleResult { * Parse messages to extract tool information. */ function parseMessages( + state: SessionState, messages: WithParts[], toolParametersCache: Map ): { @@ -28,6 +30,9 @@ function parseMessages( const toolMetadata = new Map() for (const msg of messages) { + if (isMessageCompacted(state, msg)) { + continue + } if (msg.parts) { for (const part of msg.parts) { if (part.type === "tool" && part.callID) { @@ -224,7 +229,7 @@ export async function runOnIdle( } const currentParams = getCurrentParams(messages, logger) - const { toolCallIds, toolMetadata } = parseMessages(messages, state.toolParameters) + const { toolCallIds, toolMetadata } = parseMessages(state, messages, state.toolParameters) const alreadyPrunedIds = state.prune.toolIds const unprunedToolCallIds = toolCallIds.filter(id => !alreadyPrunedIds.includes(id)) @@ -273,7 +278,7 @@ export async function runOnIdle( const allPrunedIds = [...new Set([...alreadyPrunedIds, ...newlyPrunedIds])] state.prune.toolIds = allPrunedIds - state.stats.pruneTokenCounter += calculateTokensSaved(messages, newlyPrunedIds) + state.stats.pruneTokenCounter += calculateTokensSaved(state, messages, newlyPrunedIds) // Build tool metadata map for notification const prunedToolMetadata = new Map() diff --git a/lib/strategies/prune-tool.ts b/lib/strategies/prune-tool.ts index a83694f..e361325 100644 --- a/lib/strategies/prune-tool.ts +++ b/lib/strategies/prune-tool.ts @@ -41,6 +41,9 @@ export function createPruneTool( const { client, state, logger, config, workingDirectory } = ctx const sessionId = toolCtx.sessionID + logger.info("Prune tool invoked") + logger.info(JSON.stringify(args)) + if (!args.ids || args.ids.length === 0) { logger.debug("Prune tool called but args.ids is empty or undefined: " + JSON.stringify(args)) return "No IDs provided. Check the list for available IDs to prune." @@ -72,7 +75,7 @@ export function createPruneTool( const messages: WithParts[] = messagesResponse.data || messagesResponse const currentParams = getCurrentParams(messages, logger) - const toolIdList: string[] = buildToolIdList(messages) + const toolIdList: string[] = buildToolIdList(state, messages, logger) // Validate that all numeric IDs are within bounds if (numericToolIds.some(id => id < 0 || id >= toolIdList.length)) { @@ -102,7 +105,7 @@ export function createPruneTool( } } - state.stats.pruneTokenCounter += calculateTokensSaved(messages, pruneToolIds) + state.stats.pruneTokenCounter += calculateTokensSaved(state, messages, pruneToolIds) await sendUnifiedNotification( client, diff --git a/lib/strategies/utils.ts b/lib/strategies/utils.ts index 126e5e1..3c6a1b1 100644 --- a/lib/strategies/utils.ts +++ b/lib/strategies/utils.ts @@ -1,8 +1,8 @@ -import { WithParts } from "../state" +import { SessionState, WithParts } from "../state" import { UserMessage } from "@opencode-ai/sdk" import { Logger } from "../logger" import { encode } from 'gpt-tokenizer' -import { getLastUserMessage } from "../shared-utils" +import { getLastUserMessage, isMessageCompacted } from "../shared-utils" export function getCurrentParams( messages: WithParts[], @@ -40,12 +40,16 @@ function estimateTokensBatch(texts: string[]): number[] { * TODO: Make it count message content that are not tool outputs. Currently it ONLY covers tool outputs and errors */ export const calculateTokensSaved = ( + state: SessionState, messages: WithParts[], pruneToolIds: string[] ): number => { try { const contents: string[] = [] for (const msg of messages) { + if (isMessageCompacted(state, msg)) { + continue + } for (const part of msg.parts) { if (part.type !== 'tool' || !pruneToolIds.includes(part.callID)) { continue