diff --git a/index.ts b/index.ts index a4536c4..6a97e05 100644 --- a/index.ts +++ b/index.ts @@ -43,12 +43,6 @@ const plugin: Plugin = (async (ctx) => { // Create tool tracker and load prompts for synthetic instruction injection const toolTracker = createToolTracker() - // Wire up tool name lookup from the cached tool parameters - toolTracker.getToolName = (callId: string) => { - const entry = state.toolParameters.get(callId.toLowerCase()) - return entry?.tool - } - const prompts = { synthInstruction: loadPrompt("synthetic"), nudgeInstruction: loadPrompt("nudge") diff --git a/lib/fetch-wrapper/formats/bedrock.ts b/lib/fetch-wrapper/formats/bedrock.ts index ea1396b..2aaedc6 100644 --- a/lib/fetch-wrapper/formats/bedrock.ts +++ b/lib/fetch-wrapper/formats/bedrock.ts @@ -1,4 +1,4 @@ -import type { FormatDescriptor, ToolOutput, ToolTracker } from "../types" +import type { FormatDescriptor, ToolOutput } from "../types" import type { PluginState } from "../../state" function isNudgeMessage(msg: any, nudgeText: string): boolean { @@ -30,36 +30,6 @@ function injectSynth(messages: any[], instruction: string, nudgeText: string): b return false } -function trackNewToolResults(messages: any[], tracker: ToolTracker, protectedTools: Set): number { - let newCount = 0 - for (const m of messages) { - if (m.role === 'tool' && m.tool_call_id) { - if (!tracker.seenToolResultIds.has(m.tool_call_id)) { - tracker.seenToolResultIds.add(m.tool_call_id) - const toolName = tracker.getToolName?.(m.tool_call_id) - if (!toolName || !protectedTools.has(toolName)) { - tracker.toolResultCount++ - newCount++ - } - } - } else if (m.role === 'user' && Array.isArray(m.content)) { - for (const part of m.content) { - if (part.type === 'tool_result' && part.tool_use_id) { - if (!tracker.seenToolResultIds.has(part.tool_use_id)) { - tracker.seenToolResultIds.add(part.tool_use_id) - const toolName = tracker.getToolName?.(part.tool_use_id) - if (!toolName || !protectedTools.has(toolName)) { - tracker.toolResultCount++ - newCount++ - } - } - } - } - } - } - return newCount -} - function injectPrunableList(messages: any[], injection: string): boolean { if (!injection) return false messages.push({ role: 'user', content: injection }) @@ -90,10 +60,6 @@ export const bedrockFormat: FormatDescriptor = { return injectSynth(data, instruction, nudgeText) }, - trackNewToolResults(data: any[], tracker: ToolTracker, protectedTools: Set): number { - return trackNewToolResults(data, tracker, protectedTools) - }, - injectPrunableList(data: any[], injection: string): boolean { return injectPrunableList(data, injection) }, diff --git a/lib/fetch-wrapper/formats/gemini.ts b/lib/fetch-wrapper/formats/gemini.ts index 8e2f569..c1c0feb 100644 --- a/lib/fetch-wrapper/formats/gemini.ts +++ b/lib/fetch-wrapper/formats/gemini.ts @@ -1,4 +1,4 @@ -import type { FormatDescriptor, ToolOutput, ToolTracker } from "../types" +import type { FormatDescriptor, ToolOutput } from "../types" import type { PluginState } from "../../state" function isNudgeContent(content: any, nudgeText: string): boolean { @@ -26,29 +26,6 @@ function injectSynth(contents: any[], instruction: string, nudgeText: string): b return false } -function trackNewToolResults(contents: any[], tracker: ToolTracker, protectedTools: Set): number { - let newCount = 0 - let positionCounter = 0 - for (const content of contents) { - if (!Array.isArray(content.parts)) continue - for (const part of content.parts) { - if (part.functionResponse) { - const positionId = `gemini_pos_${positionCounter}` - positionCounter++ - if (!tracker.seenToolResultIds.has(positionId)) { - tracker.seenToolResultIds.add(positionId) - const toolName = part.functionResponse.name - if (!toolName || !protectedTools.has(toolName)) { - tracker.toolResultCount++ - newCount++ - } - } - } - } - } - return newCount -} - function injectPrunableList(contents: any[], injection: string): boolean { if (!injection) return false contents.push({ role: 'user', parts: [{ text: injection }] }) @@ -75,10 +52,6 @@ export const geminiFormat: FormatDescriptor = { return injectSynth(data, instruction, nudgeText) }, - trackNewToolResults(data: any[], tracker: ToolTracker, protectedTools: Set): number { - return trackNewToolResults(data, tracker, protectedTools) - }, - injectPrunableList(data: any[], injection: string): boolean { return injectPrunableList(data, injection) }, diff --git a/lib/fetch-wrapper/formats/openai-chat.ts b/lib/fetch-wrapper/formats/openai-chat.ts index 141f03f..2ac3793 100644 --- a/lib/fetch-wrapper/formats/openai-chat.ts +++ b/lib/fetch-wrapper/formats/openai-chat.ts @@ -1,4 +1,4 @@ -import type { FormatDescriptor, ToolOutput, ToolTracker } from "../types" +import type { FormatDescriptor, ToolOutput } from "../types" import type { PluginState } from "../../state" function isNudgeMessage(msg: any, nudgeText: string): boolean { @@ -30,36 +30,6 @@ function injectSynth(messages: any[], instruction: string, nudgeText: string): b return false } -function trackNewToolResults(messages: any[], tracker: ToolTracker, protectedTools: Set): number { - let newCount = 0 - for (const m of messages) { - if (m.role === 'tool' && m.tool_call_id) { - if (!tracker.seenToolResultIds.has(m.tool_call_id)) { - tracker.seenToolResultIds.add(m.tool_call_id) - const toolName = tracker.getToolName?.(m.tool_call_id) - if (!toolName || !protectedTools.has(toolName)) { - tracker.toolResultCount++ - newCount++ - } - } - } else if (m.role === 'user' && Array.isArray(m.content)) { - for (const part of m.content) { - if (part.type === 'tool_result' && part.tool_use_id) { - if (!tracker.seenToolResultIds.has(part.tool_use_id)) { - tracker.seenToolResultIds.add(part.tool_use_id) - const toolName = tracker.getToolName?.(part.tool_use_id) - if (!toolName || !protectedTools.has(toolName)) { - tracker.toolResultCount++ - newCount++ - } - } - } - } - } - } - return newCount -} - function injectPrunableList(messages: any[], injection: string): boolean { if (!injection) return false messages.push({ role: 'user', content: injection }) @@ -81,10 +51,6 @@ export const openaiChatFormat: FormatDescriptor = { return injectSynth(data, instruction, nudgeText) }, - trackNewToolResults(data: any[], tracker: ToolTracker, protectedTools: Set): number { - return trackNewToolResults(data, tracker, protectedTools) - }, - injectPrunableList(data: any[], injection: string): boolean { return injectPrunableList(data, injection) }, diff --git a/lib/fetch-wrapper/formats/openai-responses.ts b/lib/fetch-wrapper/formats/openai-responses.ts index 549c56b..6b84891 100644 --- a/lib/fetch-wrapper/formats/openai-responses.ts +++ b/lib/fetch-wrapper/formats/openai-responses.ts @@ -1,4 +1,4 @@ -import type { FormatDescriptor, ToolOutput, ToolTracker } from "../types" +import type { FormatDescriptor, ToolOutput } from "../types" import type { PluginState } from "../../state" function isNudgeItem(item: any, nudgeText: string): boolean { @@ -30,23 +30,6 @@ function injectSynth(input: any[], instruction: string, nudgeText: string): bool return false } -function trackNewToolResults(input: any[], tracker: ToolTracker, protectedTools: Set): number { - let newCount = 0 - for (const item of input) { - if (item.type === 'function_call_output' && item.call_id) { - if (!tracker.seenToolResultIds.has(item.call_id)) { - tracker.seenToolResultIds.add(item.call_id) - const toolName = tracker.getToolName?.(item.call_id) - if (!toolName || !protectedTools.has(toolName)) { - tracker.toolResultCount++ - newCount++ - } - } - } - } - return newCount -} - function injectPrunableList(input: any[], injection: string): boolean { if (!injection) return false input.push({ type: 'message', role: 'user', content: injection }) @@ -68,10 +51,6 @@ export const openaiResponsesFormat: FormatDescriptor = { return injectSynth(data, instruction, nudgeText) }, - trackNewToolResults(data: any[], tracker: ToolTracker, protectedTools: Set): number { - return trackNewToolResults(data, tracker, protectedTools) - }, - injectPrunableList(data: any[], injection: string): boolean { return injectPrunableList(data, injection) }, diff --git a/lib/fetch-wrapper/handler.ts b/lib/fetch-wrapper/handler.ts index cd9b683..8874e32 100644 --- a/lib/fetch-wrapper/handler.ts +++ b/lib/fetch-wrapper/handler.ts @@ -67,9 +67,11 @@ export async function handleFormat( let modified = false // Sync tool parameters from OpenCode's session API (single source of truth) + // Also tracks new tool results for nudge injection const sessionId = ctx.state.lastSeenSessionId + const protectedSet = new Set(ctx.config.protectedTools) if (sessionId) { - await syncToolParametersFromOpenCode(ctx.client, sessionId, ctx.state, ctx.logger) + await syncToolParametersFromOpenCode(ctx.client, sessionId, ctx.state, ctx.toolTracker, protectedSet, ctx.logger) } if (ctx.config.strategies.onTool.length > 0) { @@ -91,8 +93,6 @@ export async function handleFormat( ) if (prunableList) { - const protectedSet = new Set(ctx.config.protectedTools) - format.trackNewToolResults(data, ctx.toolTracker, protectedSet) const includeNudge = ctx.config.nudge_freq > 0 && ctx.toolTracker.toolResultCount > ctx.config.nudge_freq const endInjection = buildEndInjection(prunableList, includeNudge) @@ -119,14 +119,12 @@ export async function handleFormat( } const toolOutputs = format.extractToolOutputs(data, ctx.state) - const protectedToolsLower = new Set(ctx.config.protectedTools.map(t => t.toLowerCase())) let replacedCount = 0 let prunableCount = 0 for (const output of toolOutputs) { - if (output.toolName && protectedToolsLower.has(output.toolName.toLowerCase())) { - continue - } + // Skip tools not in cache (protected tools are excluded from cache) + if (!output.toolName) continue prunableCount++ if (allPrunedIds.has(output.id)) { diff --git a/lib/fetch-wrapper/tool-tracker.ts b/lib/fetch-wrapper/tool-tracker.ts index 4048925..639b99b 100644 --- a/lib/fetch-wrapper/tool-tracker.ts +++ b/lib/fetch-wrapper/tool-tracker.ts @@ -2,7 +2,6 @@ export interface ToolTracker { seenToolResultIds: Set toolResultCount: number // Tools since last prune skipNextIdle: boolean - getToolName?: (callId: string) => string | undefined } export function createToolTracker(): ToolTracker { diff --git a/lib/fetch-wrapper/types.ts b/lib/fetch-wrapper/types.ts index c7ebc68..7ea1f83 100644 --- a/lib/fetch-wrapper/types.ts +++ b/lib/fetch-wrapper/types.ts @@ -14,7 +14,6 @@ export interface FormatDescriptor { detect(body: any): boolean getDataArray(body: any): any[] | undefined injectSynth(data: any[], instruction: string, nudgeText: string): boolean - trackNewToolResults(data: any[], tracker: ToolTracker, protectedTools: Set): number injectPrunableList(data: any[], injection: string): boolean extractToolOutputs(data: any[], state: PluginState): ToolOutput[] replaceToolOutput(data: any[], toolId: string, prunedMessage: string, state: PluginState): boolean diff --git a/lib/state/tool-cache.ts b/lib/state/tool-cache.ts index 8d2f8b2..b29d0d6 100644 --- a/lib/state/tool-cache.ts +++ b/lib/state/tool-cache.ts @@ -1,5 +1,6 @@ import type { PluginState, ToolStatus } from "./index" import type { Logger } from "../logger" +import type { ToolTracker } from "../fetch-wrapper/tool-tracker" /** Maximum number of entries to keep in the tool parameters cache */ const MAX_TOOL_CACHE_SIZE = 500 @@ -13,6 +14,8 @@ export async function syncToolParametersFromOpenCode( client: any, sessionId: string, state: PluginState, + tracker?: ToolTracker, + protectedTools?: Set, logger?: Logger ): Promise { try { @@ -36,8 +39,17 @@ export async function syncToolParametersFromOpenCode( const id = part.callID.toLowerCase() - // Skip if already cached (optimization) + // Track tool results for nudge injection + if (tracker && !tracker.seenToolResultIds.has(id)) { + tracker.seenToolResultIds.add(id) + // Only count non-protected tools toward nudge threshold + if (!part.tool || !protectedTools?.has(part.tool)) { + tracker.toolResultCount++ + } + } + if (state.toolParameters.has(id)) continue + if (part.tool && protectedTools?.has(part.tool)) continue const status = part.state?.status as ToolStatus | undefined state.toolParameters.set(id, { @@ -55,8 +67,7 @@ export async function syncToolParametersFromOpenCode( if (logger && synced > 0) { logger.debug("tool-cache", "Synced tool parameters from OpenCode", { sessionId: sessionId.slice(0, 8), - synced, - totalCached: state.toolParameters.size + synced }) } } catch (error) {