Skip to content

Commit 56a13a9

Browse files
committed
fix: prevent session state leakage and exclude protected tools from counts
- Add clearAllMappings() to id-mapping.ts for resetting module-level state - Detect session changes in hooks.ts and clear ID mappings + tool cache - Make API format handlers mutually exclusive (else if) to prevent double-processing - Exclude protected tools from 'total' count in replacement logs - Pass protectedTools set to trackNewToolResults functions for nudge frequency
1 parent a29d2de commit 56a13a9

File tree

7 files changed

+119
-30
lines changed

7 files changed

+119
-30
lines changed

lib/api-formats/synth-instruction.ts

Lines changed: 28 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -16,25 +16,33 @@ export function resetToolTrackerCount(tracker: ToolTracker): void {
1616

1717
/**
1818
* Track new tool results in OpenAI/Anthropic messages.
19-
* Increments toolResultCount only for tools not already seen.
19+
* Increments toolResultCount only for tools not already seen and not protected.
2020
* Returns the number of NEW tools found (since last call).
2121
*/
22-
export function trackNewToolResults(messages: any[], tracker: ToolTracker): number {
22+
export function trackNewToolResults(messages: any[], tracker: ToolTracker, protectedTools: Set<string>): number {
2323
let newCount = 0
2424
for (const m of messages) {
2525
if (m.role === 'tool' && m.tool_call_id) {
2626
if (!tracker.seenToolResultIds.has(m.tool_call_id)) {
2727
tracker.seenToolResultIds.add(m.tool_call_id)
28-
tracker.toolResultCount++
29-
newCount++
28+
// Skip protected tools for nudge frequency counting
29+
const toolName = tracker.getToolName?.(m.tool_call_id)
30+
if (!toolName || !protectedTools.has(toolName)) {
31+
tracker.toolResultCount++
32+
newCount++
33+
}
3034
}
3135
} else if (m.role === 'user' && Array.isArray(m.content)) {
3236
for (const part of m.content) {
3337
if (part.type === 'tool_result' && part.tool_use_id) {
3438
if (!tracker.seenToolResultIds.has(part.tool_use_id)) {
3539
tracker.seenToolResultIds.add(part.tool_use_id)
36-
tracker.toolResultCount++
37-
newCount++
40+
// Skip protected tools for nudge frequency counting
41+
const toolName = tracker.getToolName?.(part.tool_use_id)
42+
if (!toolName || !protectedTools.has(toolName)) {
43+
tracker.toolResultCount++
44+
newCount++
45+
}
3846
}
3947
}
4048
}
@@ -48,7 +56,7 @@ export function trackNewToolResults(messages: any[], tracker: ToolTracker): numb
4856
* Uses position-based tracking since Gemini doesn't have tool call IDs.
4957
* Returns the number of NEW tools found (since last call).
5058
*/
51-
export function trackNewToolResultsGemini(contents: any[], tracker: ToolTracker): number {
59+
export function trackNewToolResultsGemini(contents: any[], tracker: ToolTracker, protectedTools: Set<string>): number {
5260
let newCount = 0
5361
let positionCounter = 0
5462
for (const content of contents) {
@@ -60,8 +68,12 @@ export function trackNewToolResultsGemini(contents: any[], tracker: ToolTracker)
6068
positionCounter++
6169
if (!tracker.seenToolResultIds.has(positionId)) {
6270
tracker.seenToolResultIds.add(positionId)
63-
tracker.toolResultCount++
64-
newCount++
71+
// Skip protected tools for nudge frequency counting
72+
const toolName = part.functionResponse.name
73+
if (!toolName || !protectedTools.has(toolName)) {
74+
tracker.toolResultCount++
75+
newCount++
76+
}
6577
}
6678
}
6779
}
@@ -73,14 +85,18 @@ export function trackNewToolResultsGemini(contents: any[], tracker: ToolTracker)
7385
* Track new tool results in OpenAI Responses API input.
7486
* Returns the number of NEW tools found (since last call).
7587
*/
76-
export function trackNewToolResultsResponses(input: any[], tracker: ToolTracker): number {
88+
export function trackNewToolResultsResponses(input: any[], tracker: ToolTracker, protectedTools: Set<string>): number {
7789
let newCount = 0
7890
for (const item of input) {
7991
if (item.type === 'function_call_output' && item.call_id) {
8092
if (!tracker.seenToolResultIds.has(item.call_id)) {
8193
tracker.seenToolResultIds.add(item.call_id)
82-
tracker.toolResultCount++
83-
newCount++
94+
// Skip protected tools for nudge frequency counting
95+
const toolName = tracker.getToolName?.(item.call_id)
96+
if (!toolName || !protectedTools.has(toolName)) {
97+
tracker.toolResultCount++
98+
newCount++
99+
}
84100
}
85101
}
86102
}

lib/fetch-wrapper/gemini.ts

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,8 @@ export async function handleGemini(
4646

4747
if (prunableList) {
4848
// Track new tool results and check if nudge threshold is met
49-
trackNewToolResultsGemini(body.contents, ctx.toolTracker)
49+
const protectedSet = new Set(ctx.config.protectedTools)
50+
trackNewToolResultsGemini(body.contents, ctx.toolTracker, protectedSet)
5051
const includeNudge = ctx.config.nudge_freq > 0 && ctx.toolTracker.toolResultCount > ctx.config.nudge_freq
5152

5253
const endInjection = buildEndInjection(prunableList, includeNudge)
@@ -99,6 +100,8 @@ export async function handleGemini(
99100
const toolPositionCounters = new Map<string, number>()
100101
let replacedCount = 0
101102
let totalFunctionResponses = 0
103+
let prunableFunctionResponses = 0
104+
const protectedToolsLower = new Set(ctx.config.protectedTools.map(t => t.toLowerCase()))
102105

103106
body.contents = body.contents.map((content: any) => {
104107
if (!Array.isArray(content.parts)) return content
@@ -109,6 +112,11 @@ export async function handleGemini(
109112
totalFunctionResponses++
110113
const funcName = part.functionResponse.name?.toLowerCase()
111114

115+
// Count as prunable if not a protected tool
116+
if (!funcName || !protectedToolsLower.has(funcName)) {
117+
prunableFunctionResponses++
118+
}
119+
112120
if (funcName) {
113121
// Get current position for this tool name and increment counter
114122
const currentIndex = toolPositionCounters.get(funcName) || 0
@@ -148,7 +156,7 @@ export async function handleGemini(
148156
if (replacedCount > 0) {
149157
ctx.logger.info("fetch", "Replaced pruned tool outputs (Google/Gemini)", {
150158
replaced: replacedCount,
151-
total: totalFunctionResponses
159+
total: prunableFunctionResponses
152160
})
153161

154162
if (ctx.logger.enabled) {

lib/fetch-wrapper/index.ts

Lines changed: 10 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -58,26 +58,24 @@ export function installFetchWrapper(
5858
// Capture tool IDs before handlers run to track what gets cached this request
5959
const toolIdsBefore = new Set(state.toolParameters.keys())
6060

61-
// Try each format handler in order
62-
// OpenAI Chat Completions & Anthropic style (body.messages)
63-
if (body.messages && Array.isArray(body.messages)) {
64-
const result = await handleOpenAIChatAndAnthropic(body, ctx, inputUrl)
61+
// Try each format handler - mutually exclusive to avoid double-processing
62+
// OpenAI Responses API style (body.input) - check first as it may also have messages
63+
if (body.input && Array.isArray(body.input)) {
64+
const result = await handleOpenAIResponses(body, ctx, inputUrl)
6565
if (result.modified) {
6666
modified = true
6767
}
6868
}
69-
70-
// Google/Gemini style (body.contents)
71-
if (body.contents && Array.isArray(body.contents)) {
72-
const result = await handleGemini(body, ctx, inputUrl)
69+
// OpenAI Chat Completions & Anthropic style (body.messages)
70+
else if (body.messages && Array.isArray(body.messages)) {
71+
const result = await handleOpenAIChatAndAnthropic(body, ctx, inputUrl)
7372
if (result.modified) {
7473
modified = true
7574
}
7675
}
77-
78-
// OpenAI Responses API style (body.input)
79-
if (body.input && Array.isArray(body.input)) {
80-
const result = await handleOpenAIResponses(body, ctx, inputUrl)
76+
// Google/Gemini style (body.contents)
77+
else if (body.contents && Array.isArray(body.contents)) {
78+
const result = await handleGemini(body, ctx, inputUrl)
8179
if (result.modified) {
8280
modified = true
8381
}

lib/fetch-wrapper/openai-chat.ts

Lines changed: 29 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,8 @@ export async function handleOpenAIChatAndAnthropic(
5151

5252
if (prunableList) {
5353
// Track new tool results and check if nudge threshold is met
54-
trackNewToolResults(body.messages, ctx.toolTracker)
54+
const protectedSet = new Set(ctx.config.protectedTools)
55+
trackNewToolResults(body.messages, ctx.toolTracker, protectedSet)
5556
const includeNudge = ctx.config.nudge_freq > 0 && ctx.toolTracker.toolResultCount > ctx.config.nudge_freq
5657

5758
const endInjection = buildEndInjection(prunableList, includeNudge)
@@ -70,6 +71,9 @@ export async function handleOpenAIChatAndAnthropic(
7071
// Check for tool messages in both formats:
7172
// 1. OpenAI style: role === 'tool'
7273
// 2. Anthropic style: role === 'user' with content containing tool_result
74+
const protectedToolsLower = new Set(ctx.config.protectedTools.map(t => t.toLowerCase()))
75+
76+
// Count all tool messages
7377
const toolMessages = body.messages.filter((m: any) => {
7478
if (m.role === 'tool') return true
7579
if (m.role === 'user' && Array.isArray(m.content)) {
@@ -79,6 +83,29 @@ export async function handleOpenAIChatAndAnthropic(
7983
}
8084
return false
8185
})
86+
87+
// Count only prunable (non-protected) tool messages for the total
88+
let prunableToolCount = 0
89+
for (const m of body.messages) {
90+
if (m.role === 'tool') {
91+
// Get tool name from cached metadata
92+
const toolId = m.tool_call_id?.toLowerCase()
93+
const metadata = toolId ? ctx.state.toolParameters.get(toolId) : undefined
94+
if (!metadata || !protectedToolsLower.has(metadata.tool.toLowerCase())) {
95+
prunableToolCount++
96+
}
97+
} else if (m.role === 'user' && Array.isArray(m.content)) {
98+
for (const part of m.content) {
99+
if (part.type === 'tool_result') {
100+
const toolId = part.tool_use_id?.toLowerCase()
101+
const metadata = toolId ? ctx.state.toolParameters.get(toolId) : undefined
102+
if (!metadata || !protectedToolsLower.has(metadata.tool.toLowerCase())) {
103+
prunableToolCount++
104+
}
105+
}
106+
}
107+
}
108+
}
82109

83110
const { allSessions, allPrunedIds } = await getAllPrunedIds(ctx.client, ctx.state, ctx.logger)
84111

@@ -123,7 +150,7 @@ export async function handleOpenAIChatAndAnthropic(
123150
if (replacedCount > 0) {
124151
ctx.logger.info("fetch", "Replaced pruned tool outputs", {
125152
replaced: replacedCount,
126-
total: toolMessages.length
153+
total: prunableToolCount
127154
})
128155

129156
if (ctx.logger.enabled) {

lib/fetch-wrapper/openai-responses.ts

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,8 @@ export async function handleOpenAIResponses(
5151

5252
if (prunableList) {
5353
// Track new tool results and check if nudge threshold is met
54-
trackNewToolResultsResponses(body.input, ctx.toolTracker)
54+
const protectedSet = new Set(ctx.config.protectedTools)
55+
trackNewToolResultsResponses(body.input, ctx.toolTracker, protectedSet)
5556
const includeNudge = ctx.config.nudge_freq > 0 && ctx.toolTracker.toolResultCount > ctx.config.nudge_freq
5657

5758
const endInjection = buildEndInjection(prunableList, includeNudge)
@@ -80,6 +81,16 @@ export async function handleOpenAIResponses(
8081
return { modified, body }
8182
}
8283

84+
// Count only prunable (non-protected) function outputs for the total
85+
const protectedToolsLower = new Set(ctx.config.protectedTools.map(t => t.toLowerCase()))
86+
let prunableFunctionOutputCount = 0
87+
for (const item of functionOutputs) {
88+
const toolName = item.name?.toLowerCase()
89+
if (!toolName || !protectedToolsLower.has(toolName)) {
90+
prunableFunctionOutputCount++
91+
}
92+
}
93+
8394
let replacedCount = 0
8495

8596
body.input = body.input.map((item: any) => {
@@ -96,7 +107,7 @@ export async function handleOpenAIResponses(
96107
if (replacedCount > 0) {
97108
ctx.logger.info("fetch", "Replaced pruned tool outputs (Responses API)", {
98109
replaced: replacedCount,
99-
total: functionOutputs.length
110+
total: prunableFunctionOutputCount
100111
})
101112

102113
if (ctx.logger.enabled) {

lib/hooks.ts

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ import { runOnIdle } from "./core/janitor"
55
import type { PluginConfig, PruningStrategy } from "./config"
66
import type { ToolTracker } from "./api-formats/synth-instruction"
77
import { resetToolTrackerCount } from "./api-formats/synth-instruction"
8+
import { clearAllMappings } from "./state/id-mapping"
89

910
export async function isSubagentSession(client: any, sessionID: string): Promise<boolean> {
1011
try {
@@ -72,6 +73,18 @@ export function createChatParamsHandler(
7273
providerID = input.message.model.providerID
7374
}
7475

76+
// Detect session change and reset per-session state
77+
if (state.lastSeenSessionId && state.lastSeenSessionId !== sessionId) {
78+
logger.info("chat.params", "Session changed, resetting state", {
79+
from: state.lastSeenSessionId.substring(0, 8),
80+
to: sessionId.substring(0, 8)
81+
})
82+
// Clear ID mappings from previous session
83+
clearAllMappings()
84+
// Clear tool parameters cache (not session-scoped, so must be cleared)
85+
state.toolParameters.clear()
86+
}
87+
7588
// Track the last seen session ID for fetch wrapper correlation
7689
state.lastSeenSessionId = sessionId
7790

lib/state/id-mapping.ts

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,22 @@ export function hasMapping(sessionId: string): boolean {
9191
return sessionMappings.has(sessionId)
9292
}
9393

94+
/**
95+
* Clears all ID mappings for a specific session.
96+
* Call this when a session ends or when switching to a new session.
97+
*/
98+
export function clearSessionMapping(sessionId: string): void {
99+
sessionMappings.delete(sessionId)
100+
}
101+
102+
/**
103+
* Clears all session mappings.
104+
* Call this when switching sessions to ensure clean state.
105+
*/
106+
export function clearAllMappings(): void {
107+
sessionMappings.clear()
108+
}
109+
94110
/**
95111
* Gets the next numeric ID that will be assigned (without assigning it).
96112
* Useful for knowing the current state.

0 commit comments

Comments
 (0)