Skip to content

Commit 811f6b2

Browse files
committed
refactor: make batch a protected tool, remove cascade pruning
- Add 'batch' to default protectedTools in config - Remove batchToolChildren tracking from parseMessages() - Remove expandBatchIds() function - Batch children are now pruned individually by LLM decision
1 parent 74431b8 commit 811f6b2

File tree

2 files changed

+9
-36
lines changed

2 files changed

+9
-36
lines changed

lib/config.ts

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ export interface ConfigResult {
3030
const defaultConfig: PluginConfig = {
3131
enabled: true,
3232
debug: false,
33-
protectedTools: ['task', 'todowrite', 'todoread', 'prune'],
33+
protectedTools: ['task', 'todowrite', 'todoread', 'prune', 'batch'],
3434
showModelErrorToasts: true,
3535
strictModelSelection: false,
3636
pruning_summary: 'detailed',

lib/janitor.ts

Lines changed: 8 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -131,7 +131,7 @@ async function runWithStrategies(
131131
}
132132

133133
const currentAgent = findCurrentAgent(messages)
134-
const { toolCallIds, toolOutputs, toolMetadata, batchToolChildren } = parseMessages(messages, state.toolParameters)
134+
const { toolCallIds, toolOutputs, toolMetadata } = parseMessages(messages, state.toolParameters)
135135

136136
const alreadyPrunedIds = state.prunedIds.get(sessionID) ?? []
137137
const unprunedToolCallIds = toolCallIds.filter(id => !alreadyPrunedIds.includes(id))
@@ -161,15 +161,13 @@ async function runWithStrategies(
161161
)
162162
}
163163

164-
// PHASE 2: EXPAND BATCH CHILDREN
165164
if (llmPrunedIds.length === 0) {
166165
return null
167166
}
168167

169-
const expandedPrunedIds = expandBatchIds(llmPrunedIds, batchToolChildren)
170-
const finalNewlyPrunedIds = expandedPrunedIds.filter(id => !alreadyPrunedIds.includes(id))
168+
const finalNewlyPrunedIds = llmPrunedIds.filter(id => !alreadyPrunedIds.includes(id))
171169

172-
// PHASE 3: CALCULATE STATS & NOTIFICATION
170+
// PHASE 2: CALCULATE STATS & NOTIFICATION
173171
const tokensSaved = await calculateTokensSaved(finalNewlyPrunedIds, toolOutputs)
174172

175173
const currentStats = state.stats.get(sessionID) ?? { totalToolsPruned: 0, totalTokensSaved: 0 }
@@ -182,15 +180,15 @@ async function runWithStrategies(
182180
await sendPruningSummary(
183181
ctx.notificationCtx,
184182
sessionID,
185-
expandedPrunedIds,
183+
llmPrunedIds,
186184
toolMetadata,
187185
tokensSaved,
188186
sessionStats,
189187
currentAgent
190188
)
191189

192-
// PHASE 4: STATE UPDATE
193-
const allPrunedIds = [...new Set([...alreadyPrunedIds, ...expandedPrunedIds])]
190+
// PHASE 3: STATE UPDATE
191+
const allPrunedIds = [...new Set([...alreadyPrunedIds, ...llmPrunedIds])]
194192
state.prunedIds.set(sessionID, allPrunedIds)
195193

196194
const sessionName = sessionInfo?.title
@@ -211,7 +209,7 @@ async function runWithStrategies(
211209
return {
212210
prunedCount: finalNewlyPrunedIds.length,
213211
tokensSaved,
214-
llmPrunedIds: expandedPrunedIds,
212+
llmPrunedIds,
215213
toolMetadata,
216214
sessionStats
217215
}
@@ -345,7 +343,6 @@ interface ParsedMessages {
345343
toolCallIds: string[]
346344
toolOutputs: Map<string, string>
347345
toolMetadata: Map<string, { tool: string, parameters?: any }>
348-
batchToolChildren: Map<string, string[]>
349346
}
350347

351348
function parseMessages(
@@ -355,8 +352,6 @@ function parseMessages(
355352
const toolCallIds: string[] = []
356353
const toolOutputs = new Map<string, string>()
357354
const toolMetadata = new Map<string, { tool: string, parameters?: any }>()
358-
const batchToolChildren = new Map<string, string[]>()
359-
let currentBatchId: string | null = null
360355

361356
for (const msg of messages) {
362357
if (msg.parts) {
@@ -376,21 +371,12 @@ function parseMessages(
376371
if (part.state?.status === "completed" && part.state.output) {
377372
toolOutputs.set(normalizedId, part.state.output)
378373
}
379-
380-
if (part.tool === "batch") {
381-
currentBatchId = normalizedId
382-
batchToolChildren.set(normalizedId, [])
383-
} else if (currentBatchId && normalizedId.startsWith('prt_')) {
384-
batchToolChildren.get(currentBatchId)!.push(normalizedId)
385-
} else if (currentBatchId && !normalizedId.startsWith('prt_')) {
386-
currentBatchId = null
387-
}
388374
}
389375
}
390376
}
391377
}
392378

393-
return { toolCallIds, toolOutputs, toolMetadata, batchToolChildren }
379+
return { toolCallIds, toolOutputs, toolMetadata }
394380
}
395381

396382
function findCurrentAgent(messages: any[]): string | undefined {
@@ -408,19 +394,6 @@ function findCurrentAgent(messages: any[]): string | undefined {
408394
// Helpers
409395
// ============================================================================
410396

411-
function expandBatchIds(ids: string[], batchToolChildren: Map<string, string[]>): string[] {
412-
const expanded = new Set<string>()
413-
for (const id of ids) {
414-
const normalizedId = id.toLowerCase()
415-
expanded.add(normalizedId)
416-
const children = batchToolChildren.get(normalizedId)
417-
if (children) {
418-
children.forEach(childId => expanded.add(childId))
419-
}
420-
}
421-
return Array.from(expanded)
422-
}
423-
424397
function replacePrunedToolOutputs(messages: any[], prunedIds: string[]): any[] {
425398
if (prunedIds.length === 0) return messages
426399

0 commit comments

Comments
 (0)