Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 0 additions & 6 deletions index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -43,12 +43,6 @@ const plugin: Plugin = (async (ctx) => {
// Create tool tracker and load prompts for synthetic instruction injection
const toolTracker = createToolTracker()

// Wire up tool name lookup from the cached tool parameters
toolTracker.getToolName = (callId: string) => {
const entry = state.toolParameters.get(callId.toLowerCase())
return entry?.tool
}

const prompts = {
synthInstruction: loadPrompt("synthetic"),
nudgeInstruction: loadPrompt("nudge")
Expand Down
36 changes: 1 addition & 35 deletions lib/fetch-wrapper/formats/bedrock.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import type { FormatDescriptor, ToolOutput, ToolTracker } from "../types"
import type { FormatDescriptor, ToolOutput } from "../types"
import type { PluginState } from "../../state"

function isNudgeMessage(msg: any, nudgeText: string): boolean {
Expand Down Expand Up @@ -30,36 +30,6 @@ function injectSynth(messages: any[], instruction: string, nudgeText: string): b
return false
}

function trackNewToolResults(messages: any[], tracker: ToolTracker, protectedTools: Set<string>): number {
let newCount = 0
for (const m of messages) {
if (m.role === 'tool' && m.tool_call_id) {
if (!tracker.seenToolResultIds.has(m.tool_call_id)) {
tracker.seenToolResultIds.add(m.tool_call_id)
const toolName = tracker.getToolName?.(m.tool_call_id)
if (!toolName || !protectedTools.has(toolName)) {
tracker.toolResultCount++
newCount++
}
}
} else if (m.role === 'user' && Array.isArray(m.content)) {
for (const part of m.content) {
if (part.type === 'tool_result' && part.tool_use_id) {
if (!tracker.seenToolResultIds.has(part.tool_use_id)) {
tracker.seenToolResultIds.add(part.tool_use_id)
const toolName = tracker.getToolName?.(part.tool_use_id)
if (!toolName || !protectedTools.has(toolName)) {
tracker.toolResultCount++
newCount++
}
}
}
}
}
}
return newCount
}

function injectPrunableList(messages: any[], injection: string): boolean {
if (!injection) return false
messages.push({ role: 'user', content: injection })
Expand Down Expand Up @@ -90,10 +60,6 @@ export const bedrockFormat: FormatDescriptor = {
return injectSynth(data, instruction, nudgeText)
},

trackNewToolResults(data: any[], tracker: ToolTracker, protectedTools: Set<string>): number {
return trackNewToolResults(data, tracker, protectedTools)
},

injectPrunableList(data: any[], injection: string): boolean {
return injectPrunableList(data, injection)
},
Expand Down
29 changes: 1 addition & 28 deletions lib/fetch-wrapper/formats/gemini.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import type { FormatDescriptor, ToolOutput, ToolTracker } from "../types"
import type { FormatDescriptor, ToolOutput } from "../types"
import type { PluginState } from "../../state"

function isNudgeContent(content: any, nudgeText: string): boolean {
Expand Down Expand Up @@ -26,29 +26,6 @@ function injectSynth(contents: any[], instruction: string, nudgeText: string): b
return false
}

function trackNewToolResults(contents: any[], tracker: ToolTracker, protectedTools: Set<string>): number {
let newCount = 0
let positionCounter = 0
for (const content of contents) {
if (!Array.isArray(content.parts)) continue
for (const part of content.parts) {
if (part.functionResponse) {
const positionId = `gemini_pos_${positionCounter}`
positionCounter++
if (!tracker.seenToolResultIds.has(positionId)) {
tracker.seenToolResultIds.add(positionId)
const toolName = part.functionResponse.name
if (!toolName || !protectedTools.has(toolName)) {
tracker.toolResultCount++
newCount++
}
}
}
}
}
return newCount
}

function injectPrunableList(contents: any[], injection: string): boolean {
if (!injection) return false
contents.push({ role: 'user', parts: [{ text: injection }] })
Expand All @@ -75,10 +52,6 @@ export const geminiFormat: FormatDescriptor = {
return injectSynth(data, instruction, nudgeText)
},

trackNewToolResults(data: any[], tracker: ToolTracker, protectedTools: Set<string>): number {
return trackNewToolResults(data, tracker, protectedTools)
},

injectPrunableList(data: any[], injection: string): boolean {
return injectPrunableList(data, injection)
},
Expand Down
36 changes: 1 addition & 35 deletions lib/fetch-wrapper/formats/openai-chat.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import type { FormatDescriptor, ToolOutput, ToolTracker } from "../types"
import type { FormatDescriptor, ToolOutput } from "../types"
import type { PluginState } from "../../state"

function isNudgeMessage(msg: any, nudgeText: string): boolean {
Expand Down Expand Up @@ -30,36 +30,6 @@ function injectSynth(messages: any[], instruction: string, nudgeText: string): b
return false
}

function trackNewToolResults(messages: any[], tracker: ToolTracker, protectedTools: Set<string>): number {
let newCount = 0
for (const m of messages) {
if (m.role === 'tool' && m.tool_call_id) {
if (!tracker.seenToolResultIds.has(m.tool_call_id)) {
tracker.seenToolResultIds.add(m.tool_call_id)
const toolName = tracker.getToolName?.(m.tool_call_id)
if (!toolName || !protectedTools.has(toolName)) {
tracker.toolResultCount++
newCount++
}
}
} else if (m.role === 'user' && Array.isArray(m.content)) {
for (const part of m.content) {
if (part.type === 'tool_result' && part.tool_use_id) {
if (!tracker.seenToolResultIds.has(part.tool_use_id)) {
tracker.seenToolResultIds.add(part.tool_use_id)
const toolName = tracker.getToolName?.(part.tool_use_id)
if (!toolName || !protectedTools.has(toolName)) {
tracker.toolResultCount++
newCount++
}
}
}
}
}
}
return newCount
}

function injectPrunableList(messages: any[], injection: string): boolean {
if (!injection) return false
messages.push({ role: 'user', content: injection })
Expand All @@ -81,10 +51,6 @@ export const openaiChatFormat: FormatDescriptor = {
return injectSynth(data, instruction, nudgeText)
},

trackNewToolResults(data: any[], tracker: ToolTracker, protectedTools: Set<string>): number {
return trackNewToolResults(data, tracker, protectedTools)
},

injectPrunableList(data: any[], injection: string): boolean {
return injectPrunableList(data, injection)
},
Expand Down
23 changes: 1 addition & 22 deletions lib/fetch-wrapper/formats/openai-responses.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import type { FormatDescriptor, ToolOutput, ToolTracker } from "../types"
import type { FormatDescriptor, ToolOutput } from "../types"
import type { PluginState } from "../../state"

function isNudgeItem(item: any, nudgeText: string): boolean {
Expand Down Expand Up @@ -30,23 +30,6 @@ function injectSynth(input: any[], instruction: string, nudgeText: string): bool
return false
}

function trackNewToolResults(input: any[], tracker: ToolTracker, protectedTools: Set<string>): number {
let newCount = 0
for (const item of input) {
if (item.type === 'function_call_output' && item.call_id) {
if (!tracker.seenToolResultIds.has(item.call_id)) {
tracker.seenToolResultIds.add(item.call_id)
const toolName = tracker.getToolName?.(item.call_id)
if (!toolName || !protectedTools.has(toolName)) {
tracker.toolResultCount++
newCount++
}
}
}
}
return newCount
}

function injectPrunableList(input: any[], injection: string): boolean {
if (!injection) return false
input.push({ type: 'message', role: 'user', content: injection })
Expand All @@ -68,10 +51,6 @@ export const openaiResponsesFormat: FormatDescriptor = {
return injectSynth(data, instruction, nudgeText)
},

trackNewToolResults(data: any[], tracker: ToolTracker, protectedTools: Set<string>): number {
return trackNewToolResults(data, tracker, protectedTools)
},

injectPrunableList(data: any[], injection: string): boolean {
return injectPrunableList(data, injection)
},
Expand Down
12 changes: 5 additions & 7 deletions lib/fetch-wrapper/handler.ts
Original file line number Diff line number Diff line change
Expand Up @@ -67,9 +67,11 @@ export async function handleFormat(
let modified = false

// Sync tool parameters from OpenCode's session API (single source of truth)
// Also tracks new tool results for nudge injection
const sessionId = ctx.state.lastSeenSessionId
const protectedSet = new Set(ctx.config.protectedTools)
if (sessionId) {
await syncToolParametersFromOpenCode(ctx.client, sessionId, ctx.state, ctx.logger)
await syncToolParametersFromOpenCode(ctx.client, sessionId, ctx.state, ctx.toolTracker, protectedSet, ctx.logger)
}

if (ctx.config.strategies.onTool.length > 0) {
Expand All @@ -91,8 +93,6 @@ export async function handleFormat(
)

if (prunableList) {
const protectedSet = new Set(ctx.config.protectedTools)
format.trackNewToolResults(data, ctx.toolTracker, protectedSet)
const includeNudge = ctx.config.nudge_freq > 0 && ctx.toolTracker.toolResultCount > ctx.config.nudge_freq

const endInjection = buildEndInjection(prunableList, includeNudge)
Expand All @@ -119,14 +119,12 @@ export async function handleFormat(
}

const toolOutputs = format.extractToolOutputs(data, ctx.state)
const protectedToolsLower = new Set(ctx.config.protectedTools.map(t => t.toLowerCase()))
let replacedCount = 0
let prunableCount = 0

for (const output of toolOutputs) {
if (output.toolName && protectedToolsLower.has(output.toolName.toLowerCase())) {
continue
}
// Skip tools not in cache (protected tools are excluded from cache)
if (!output.toolName) continue
prunableCount++

if (allPrunedIds.has(output.id)) {
Expand Down
1 change: 0 additions & 1 deletion lib/fetch-wrapper/tool-tracker.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ export interface ToolTracker {
seenToolResultIds: Set<string>
toolResultCount: number // Tools since last prune
skipNextIdle: boolean
getToolName?: (callId: string) => string | undefined
}

export function createToolTracker(): ToolTracker {
Expand Down
1 change: 0 additions & 1 deletion lib/fetch-wrapper/types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@ export interface FormatDescriptor {
detect(body: any): boolean
getDataArray(body: any): any[] | undefined
injectSynth(data: any[], instruction: string, nudgeText: string): boolean
trackNewToolResults(data: any[], tracker: ToolTracker, protectedTools: Set<string>): number
injectPrunableList(data: any[], injection: string): boolean
extractToolOutputs(data: any[], state: PluginState): ToolOutput[]
replaceToolOutput(data: any[], toolId: string, prunedMessage: string, state: PluginState): boolean
Expand Down
17 changes: 14 additions & 3 deletions lib/state/tool-cache.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import type { PluginState, ToolStatus } from "./index"
import type { Logger } from "../logger"
import type { ToolTracker } from "../fetch-wrapper/tool-tracker"

/** Maximum number of entries to keep in the tool parameters cache */
const MAX_TOOL_CACHE_SIZE = 500
Expand All @@ -13,6 +14,8 @@ export async function syncToolParametersFromOpenCode(
client: any,
sessionId: string,
state: PluginState,
tracker?: ToolTracker,
protectedTools?: Set<string>,
logger?: Logger
): Promise<void> {
try {
Expand All @@ -36,8 +39,17 @@ export async function syncToolParametersFromOpenCode(

const id = part.callID.toLowerCase()

// Skip if already cached (optimization)
// Track tool results for nudge injection
if (tracker && !tracker.seenToolResultIds.has(id)) {
tracker.seenToolResultIds.add(id)
// Only count non-protected tools toward nudge threshold
if (!part.tool || !protectedTools?.has(part.tool)) {
tracker.toolResultCount++
}
}

if (state.toolParameters.has(id)) continue
if (part.tool && protectedTools?.has(part.tool)) continue

const status = part.state?.status as ToolStatus | undefined
state.toolParameters.set(id, {
Expand All @@ -55,8 +67,7 @@ export async function syncToolParametersFromOpenCode(
if (logger && synced > 0) {
logger.debug("tool-cache", "Synced tool parameters from OpenCode", {
sessionId: sessionId.slice(0, 8),
synced,
totalCached: state.toolParameters.size
synced
})
}
} catch (error) {
Expand Down