Skip to content

Commit 035285f

Browse files
committed
refactor: consolidate format handlers with FormatDescriptor pattern and fix tool ID case handling
1 parent 0c1017a commit 035285f

File tree

14 files changed

+609
-494
lines changed

14 files changed

+609
-494
lines changed

index.ts

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ const plugin: Plugin = (async (ctx) => {
4545

4646
// Wire up tool name lookup from the cached tool parameters
4747
toolTracker.getToolName = (callId: string) => {
48-
const entry = state.toolParameters.get(callId)
48+
const entry = state.toolParameters.get(callId.toLowerCase())
4949
return entry?.tool
5050
}
5151

Lines changed: 173 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,173 @@
1+
import type { FormatDescriptor, ToolOutput } from "../types"
2+
import { PRUNED_CONTENT_MESSAGE } from "../types"
3+
import type { PluginState } from "../../state"
4+
import type { Logger } from "../../logger"
5+
import type { ToolTracker } from "../../api-formats/synth-instruction"
6+
import { injectSynthGemini, trackNewToolResultsGemini } from "../../api-formats/synth-instruction"
7+
import { injectPrunableListGemini } from "../../api-formats/prunable-list"
8+
9+
/**
10+
* Format descriptor for Google/Gemini API.
11+
*
12+
* Uses body.contents array with:
13+
* - parts[].functionCall for tool invocations
14+
* - parts[].functionResponse for tool results
15+
*
16+
* IMPORTANT: Gemini doesn't include tool call IDs in its native format.
17+
* We use position-based correlation via state.googleToolCallMapping which maps
18+
* "toolName:index" -> "toolCallId" (populated by hooks.ts from message events).
19+
*/
20+
export const geminiFormat: FormatDescriptor = {
21+
name: 'gemini',
22+
23+
detect(body: any): boolean {
24+
return body.contents && Array.isArray(body.contents)
25+
},
26+
27+
getDataArray(body: any): any[] | undefined {
28+
return body.contents
29+
},
30+
31+
cacheToolParameters(_data: any[], _state: PluginState, _logger?: Logger): void {
32+
// Gemini format doesn't include tool parameters in the request body.
33+
// Tool parameters are captured via message events in hooks.ts and stored
34+
// in state.googleToolCallMapping for position-based correlation.
35+
// No-op here.
36+
},
37+
38+
injectSynth(data: any[], instruction: string, nudgeText: string): boolean {
39+
return injectSynthGemini(data, instruction, nudgeText)
40+
},
41+
42+
trackNewToolResults(data: any[], tracker: ToolTracker, protectedTools: Set<string>): number {
43+
return trackNewToolResultsGemini(data, tracker, protectedTools)
44+
},
45+
46+
injectPrunableList(data: any[], injection: string): boolean {
47+
return injectPrunableListGemini(data, injection)
48+
},
49+
50+
extractToolOutputs(data: any[], state: PluginState): ToolOutput[] {
51+
const outputs: ToolOutput[] = []
52+
53+
// We need the position mapping to correlate functionResponses to tool call IDs
54+
// Find the mapping from any active session
55+
let positionMapping: Map<string, string> | undefined
56+
for (const [_sessionId, mapping] of state.googleToolCallMapping) {
57+
if (mapping && mapping.size > 0) {
58+
positionMapping = mapping
59+
break
60+
}
61+
}
62+
63+
if (!positionMapping) {
64+
return outputs
65+
}
66+
67+
// Track position counters per tool name
68+
const toolPositionCounters = new Map<string, number>()
69+
70+
for (const content of data) {
71+
if (!Array.isArray(content.parts)) continue
72+
73+
for (const part of content.parts) {
74+
if (part.functionResponse) {
75+
const funcName = part.functionResponse.name?.toLowerCase()
76+
if (funcName) {
77+
const currentIndex = toolPositionCounters.get(funcName) || 0
78+
toolPositionCounters.set(funcName, currentIndex + 1)
79+
80+
const positionKey = `${funcName}:${currentIndex}`
81+
const toolCallId = positionMapping.get(positionKey)
82+
83+
if (toolCallId) {
84+
outputs.push({
85+
id: toolCallId.toLowerCase(),
86+
toolName: funcName
87+
})
88+
}
89+
}
90+
}
91+
}
92+
}
93+
94+
return outputs
95+
},
96+
97+
replaceToolOutput(data: any[], toolId: string, prunedMessage: string, state: PluginState): boolean {
98+
// Find the position mapping
99+
let positionMapping: Map<string, string> | undefined
100+
for (const [_sessionId, mapping] of state.googleToolCallMapping) {
101+
if (mapping && mapping.size > 0) {
102+
positionMapping = mapping
103+
break
104+
}
105+
}
106+
107+
if (!positionMapping) {
108+
return false
109+
}
110+
111+
const toolIdLower = toolId.toLowerCase()
112+
const toolPositionCounters = new Map<string, number>()
113+
let replaced = false
114+
115+
for (let i = 0; i < data.length; i++) {
116+
const content = data[i]
117+
if (!Array.isArray(content.parts)) continue
118+
119+
let contentModified = false
120+
const newParts = content.parts.map((part: any) => {
121+
if (part.functionResponse) {
122+
const funcName = part.functionResponse.name?.toLowerCase()
123+
if (funcName) {
124+
const currentIndex = toolPositionCounters.get(funcName) || 0
125+
toolPositionCounters.set(funcName, currentIndex + 1)
126+
127+
const positionKey = `${funcName}:${currentIndex}`
128+
const mappedToolId = positionMapping!.get(positionKey)
129+
130+
if (mappedToolId?.toLowerCase() === toolIdLower) {
131+
contentModified = true
132+
replaced = true
133+
// Preserve thoughtSignature if present (required for Gemini 3 Pro)
134+
return {
135+
...part,
136+
functionResponse: {
137+
...part.functionResponse,
138+
response: {
139+
name: part.functionResponse.name,
140+
content: prunedMessage
141+
}
142+
}
143+
}
144+
}
145+
}
146+
}
147+
return part
148+
})
149+
150+
if (contentModified) {
151+
data[i] = { ...content, parts: newParts }
152+
}
153+
}
154+
155+
return replaced
156+
},
157+
158+
hasToolOutputs(data: any[]): boolean {
159+
return data.some((content: any) =>
160+
Array.isArray(content.parts) &&
161+
content.parts.some((part: any) => part.functionResponse)
162+
)
163+
},
164+
165+
getLogMetadata(data: any[], replacedCount: number, inputUrl: string): Record<string, any> {
166+
return {
167+
url: inputUrl,
168+
replacedCount,
169+
totalContents: data.length,
170+
format: 'google-gemini'
171+
}
172+
}
173+
}

lib/fetch-wrapper/formats/index.ts

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
export { openaiChatFormat } from './openai-chat'
2+
export { openaiResponsesFormat } from './openai-responses'
3+
export { geminiFormat } from './gemini'
Lines changed: 130 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,130 @@
1+
import type { FormatDescriptor, ToolOutput } from "../types"
2+
import { PRUNED_CONTENT_MESSAGE } from "../types"
3+
import type { PluginState } from "../../state"
4+
import type { Logger } from "../../logger"
5+
import type { ToolTracker } from "../../api-formats/synth-instruction"
6+
import { cacheToolParametersFromMessages } from "../../state/tool-cache"
7+
import { injectSynth, trackNewToolResults } from "../../api-formats/synth-instruction"
8+
import { injectPrunableList } from "../../api-formats/prunable-list"
9+
10+
/**
11+
* Format descriptor for OpenAI Chat Completions and Anthropic APIs.
12+
*
13+
* OpenAI Chat format:
14+
* - Messages with role='tool' and tool_call_id
15+
* - Assistant messages with tool_calls[] array
16+
*
17+
* Anthropic format:
18+
* - Messages with role='user' containing content[].type='tool_result' and tool_use_id
19+
* - Assistant messages with content[].type='tool_use'
20+
*/
21+
export const openaiChatFormat: FormatDescriptor = {
22+
name: 'openai-chat',
23+
24+
detect(body: any): boolean {
25+
return body.messages && Array.isArray(body.messages)
26+
},
27+
28+
getDataArray(body: any): any[] | undefined {
29+
return body.messages
30+
},
31+
32+
cacheToolParameters(data: any[], state: PluginState, logger?: Logger): void {
33+
cacheToolParametersFromMessages(data, state, logger)
34+
},
35+
36+
injectSynth(data: any[], instruction: string, nudgeText: string): boolean {
37+
return injectSynth(data, instruction, nudgeText)
38+
},
39+
40+
trackNewToolResults(data: any[], tracker: ToolTracker, protectedTools: Set<string>): number {
41+
return trackNewToolResults(data, tracker, protectedTools)
42+
},
43+
44+
injectPrunableList(data: any[], injection: string): boolean {
45+
return injectPrunableList(data, injection)
46+
},
47+
48+
extractToolOutputs(data: any[], state: PluginState): ToolOutput[] {
49+
const outputs: ToolOutput[] = []
50+
51+
for (const m of data) {
52+
// OpenAI Chat format: role='tool' with tool_call_id
53+
if (m.role === 'tool' && m.tool_call_id) {
54+
const metadata = state.toolParameters.get(m.tool_call_id.toLowerCase())
55+
outputs.push({
56+
id: m.tool_call_id.toLowerCase(),
57+
toolName: metadata?.tool
58+
})
59+
}
60+
61+
// Anthropic format: role='user' with content[].type='tool_result'
62+
if (m.role === 'user' && Array.isArray(m.content)) {
63+
for (const part of m.content) {
64+
if (part.type === 'tool_result' && part.tool_use_id) {
65+
const metadata = state.toolParameters.get(part.tool_use_id.toLowerCase())
66+
outputs.push({
67+
id: part.tool_use_id.toLowerCase(),
68+
toolName: metadata?.tool
69+
})
70+
}
71+
}
72+
}
73+
}
74+
75+
return outputs
76+
},
77+
78+
replaceToolOutput(data: any[], toolId: string, prunedMessage: string, _state: PluginState): boolean {
79+
const toolIdLower = toolId.toLowerCase()
80+
let replaced = false
81+
82+
for (let i = 0; i < data.length; i++) {
83+
const m = data[i]
84+
85+
// OpenAI Chat format
86+
if (m.role === 'tool' && m.tool_call_id?.toLowerCase() === toolIdLower) {
87+
data[i] = { ...m, content: prunedMessage }
88+
replaced = true
89+
}
90+
91+
// Anthropic format
92+
if (m.role === 'user' && Array.isArray(m.content)) {
93+
let messageModified = false
94+
const newContent = m.content.map((part: any) => {
95+
if (part.type === 'tool_result' && part.tool_use_id?.toLowerCase() === toolIdLower) {
96+
messageModified = true
97+
return { ...part, content: prunedMessage }
98+
}
99+
return part
100+
})
101+
if (messageModified) {
102+
data[i] = { ...m, content: newContent }
103+
replaced = true
104+
}
105+
}
106+
}
107+
108+
return replaced
109+
},
110+
111+
hasToolOutputs(data: any[]): boolean {
112+
for (const m of data) {
113+
if (m.role === 'tool') return true
114+
if (m.role === 'user' && Array.isArray(m.content)) {
115+
for (const part of m.content) {
116+
if (part.type === 'tool_result') return true
117+
}
118+
}
119+
}
120+
return false
121+
},
122+
123+
getLogMetadata(data: any[], replacedCount: number, inputUrl: string): Record<string, any> {
124+
return {
125+
url: inputUrl,
126+
replacedCount,
127+
totalMessages: data.length
128+
}
129+
}
130+
}

0 commit comments

Comments
 (0)