Skip to content

Commit df06fa8

Browse files
committed
Add turn-based tool protection (protectedTurns config)
- Add protectedTurns config option to protect recent tools from pruning - Track currentTurn in session state using step-start parts - Store turn number on each cached tool parameter entry - Skip caching tools that are within the protected turn window - Exclude turn-protected tools from nudge counter
1 parent 8dd3aec commit df06fa8

File tree

7 files changed

+65
-10
lines changed

7 files changed

+65
-10
lines changed

README.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,8 @@ DCP uses its own config file:
7777
"enabled": true,
7878
// Additional tools to protect from pruning
7979
"protectedTools": [],
80+
// Protect tools from pruning for N turns after they are called (0 = disabled)
81+
"protectedTurns": 4,
8082
// Nudge the LLM to use the prune tool (every <frequency> tool results)
8183
"nudge": {
8284
"enabled": true,

lib/config.ts

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ export interface PruneToolNudge {
2525
export interface PruneTool {
2626
enabled: boolean
2727
protectedTools: string[]
28+
protectedTurns: number
2829
nudge: PruneToolNudge
2930
}
3031

@@ -72,6 +73,7 @@ export const VALID_CONFIG_KEYS = new Set([
7273
'strategies.pruneTool',
7374
'strategies.pruneTool.enabled',
7475
'strategies.pruneTool.protectedTools',
76+
'strategies.pruneTool.protectedTurns',
7577
'strategies.pruneTool.nudge',
7678
'strategies.pruneTool.nudge.enabled',
7779
'strategies.pruneTool.nudge.frequency'
@@ -158,6 +160,9 @@ function validateConfigTypes(config: Record<string, any>): ValidationError[] {
158160
if (strategies.pruneTool.protectedTools !== undefined && !Array.isArray(strategies.pruneTool.protectedTools)) {
159161
errors.push({ key: 'strategies.pruneTool.protectedTools', expected: 'string[]', actual: typeof strategies.pruneTool.protectedTools })
160162
}
163+
if (strategies.pruneTool.protectedTurns !== undefined && typeof strategies.pruneTool.protectedTurns !== 'number') {
164+
errors.push({ key: 'strategies.pruneTool.protectedTurns', expected: 'number', actual: typeof strategies.pruneTool.protectedTurns })
165+
}
161166
if (strategies.pruneTool.nudge) {
162167
if (strategies.pruneTool.nudge.enabled !== undefined && typeof strategies.pruneTool.nudge.enabled !== 'boolean') {
163168
errors.push({ key: 'strategies.pruneTool.nudge.enabled', expected: 'boolean', actual: typeof strategies.pruneTool.nudge.enabled })
@@ -240,6 +245,7 @@ const defaultConfig: PluginConfig = {
240245
pruneTool: {
241246
enabled: true,
242247
protectedTools: [...DEFAULT_PROTECTED_TOOLS],
248+
protectedTurns: 4,
243249
nudge: {
244250
enabled: true,
245251
frequency: 10
@@ -341,6 +347,8 @@ function createDefaultConfig(): void {
341347
"enabled": true,
342348
// Additional tools to protect from pruning
343349
"protectedTools": [],
350+
// Protect tools from pruning for N turns after they are called (0 = disabled)
351+
"protectedTurns": 4,
344352
// Nudge the LLM to use the prune tool (every <frequency> tool results)
345353
"nudge": {
346354
"enabled": true,
@@ -426,6 +434,7 @@ function mergeStrategies(
426434
...(override.pruneTool?.protectedTools ?? [])
427435
])
428436
],
437+
protectedTurns: override.pruneTool?.protectedTurns ?? base.pruneTool.protectedTurns,
429438
nudge: {
430439
enabled: override.pruneTool?.nudge?.enabled ?? base.pruneTool.nudge.enabled,
431440
frequency: override.pruneTool?.nudge?.frequency ?? base.pruneTool.nudge.frequency
@@ -452,6 +461,7 @@ function deepCloneConfig(config: PluginConfig): PluginConfig {
452461
pruneTool: {
453462
...config.strategies.pruneTool,
454463
protectedTools: [...config.strategies.pruneTool.protectedTools],
464+
protectedTurns: config.strategies.pruneTool.protectedTurns,
455465
nudge: { ...config.strategies.pruneTool.nudge }
456466
},
457467
supersedeWrites: {

lib/state/state.ts

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ import type { SessionState, ToolParameterEntry, WithParts } from "./types"
22
import type { Logger } from "../logger"
33
import { loadSessionState } from "./persistence"
44
import { isSubAgentSession } from "./utils"
5-
import { getLastUserMessage } from "../shared-utils"
5+
import { getLastUserMessage, isMessageCompacted } from "../shared-utils"
66

77
export const checkSession = async (
88
client: any,
@@ -34,6 +34,8 @@ export const checkSession = async (
3434
state.prune.toolIds = []
3535
logger.info("Detected compaction from messages - cleared tool cache", { timestamp: lastCompactionTimestamp })
3636
}
37+
38+
state.currentTurn = countTurns(state, messages)
3739
}
3840

3941
export function createSessionState(): SessionState {
@@ -50,7 +52,8 @@ export function createSessionState(): SessionState {
5052
toolParameters: new Map<string, ToolParameterEntry>(),
5153
nudgeCounter: 0,
5254
lastToolPrune: false,
53-
lastCompaction: 0
55+
lastCompaction: 0,
56+
currentTurn: 0
5457
}
5558
}
5659

@@ -68,6 +71,7 @@ export function resetSessionState(state: SessionState): void {
6871
state.nudgeCounter = 0
6972
state.lastToolPrune = false
7073
state.lastCompaction = 0
74+
state.currentTurn = 0
7175
}
7276

7377
export async function ensureSessionInitialized(
@@ -92,6 +96,7 @@ export async function ensureSessionInitialized(
9296
logger.info("isSubAgent = " + isSubAgent)
9397

9498
state.lastCompaction = findLastCompactionTimestamp(messages)
99+
state.currentTurn = countTurns(state, messages)
95100

96101
const persisted = await loadSessionState(sessionId, logger)
97102
if (persisted === null) {
@@ -116,3 +121,18 @@ function findLastCompactionTimestamp(messages: WithParts[]): number {
116121
}
117122
return 0
118123
}
124+
125+
export function countTurns(state: SessionState, messages: WithParts[]): number {
126+
let turnCount = 0
127+
for (const msg of messages) {
128+
if (isMessageCompacted(state, msg)) {
129+
continue
130+
}
131+
for (const part of msg.parts) {
132+
if (part.type === "step-start") {
133+
turnCount++
134+
}
135+
}
136+
}
137+
return turnCount
138+
}

lib/state/tool-cache.ts

Lines changed: 26 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -18,26 +18,44 @@ export async function syncToolCache(
1818
logger.info("Syncing tool parameters from OpenCode messages")
1919

2020
state.nudgeCounter = 0
21+
let turnCounter = 0
2122

2223
for (const msg of messages) {
2324
if (isMessageCompacted(state, msg)) {
2425
continue
2526
}
2627

2728
for (const part of msg.parts) {
28-
if (part.type !== "tool" || !part.callID) {
29+
if (part.type === "step-start") {
30+
turnCounter++
2931
continue
3032
}
31-
if (state.toolParameters.has(part.callID)) {
33+
34+
if (part.type !== "tool" || !part.callID) {
3235
continue
3336
}
3437

38+
const isProtectedByTurn = config.strategies.pruneTool.protectedTurns > 0 &&
39+
(state.currentTurn - turnCounter) < config.strategies.pruneTool.protectedTurns
40+
41+
state.lastToolPrune = part.tool === "prune"
42+
3543
if (part.tool === "prune") {
3644
state.nudgeCounter = 0
37-
} else if (!config.strategies.pruneTool.protectedTools.includes(part.tool)) {
45+
} else if (
46+
!config.strategies.pruneTool.protectedTools.includes(part.tool) &&
47+
!isProtectedByTurn
48+
) {
3849
state.nudgeCounter++
3950
}
40-
state.lastToolPrune = part.tool === "prune"
51+
52+
if (state.toolParameters.has(part.callID)) {
53+
continue
54+
}
55+
56+
if (isProtectedByTurn) {
57+
continue
58+
}
4159

4260
state.toolParameters.set(
4361
part.callID,
@@ -46,12 +64,14 @@ export async function syncToolCache(
4664
parameters: part.state?.input ?? {},
4765
status: part.state.status as ToolStatus | undefined,
4866
error: part.state.status === "error" ? part.state.error : undefined,
67+
turn: turnCounter,
4968
}
5069
)
51-
logger.info("Cached tool id: " + part.callID)
70+
logger.info(`Cached tool id: ${part.callID} (created on turn ${turnCounter})`)
5271
}
5372
}
54-
logger.info("Synced cache - size: " + state.toolParameters.size)
73+
74+
logger.info(`Synced cache - size: ${state.toolParameters.size}, currentTurn: ${state.currentTurn}, nudgeCounter: ${state.nudgeCounter}`)
5575
trimToolParametersCache(state)
5676
} catch (error) {
5777
logger.warn("Failed to sync tool parameters from OpenCode", {

lib/state/types.ts

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ export interface ToolParameterEntry {
1212
parameters: any
1313
status?: ToolStatus
1414
error?: string
15+
turn: number // Which turn (step-start count) this tool was called on
1516
}
1617

1718
export interface SessionStats {
@@ -32,4 +33,5 @@ export interface SessionState {
3233
nudgeCounter: number
3334
lastToolPrune: boolean
3435
lastCompaction: number
36+
currentTurn: number // Current turn count derived from step-start parts
3537
}

lib/strategies/deduplication.ts

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ export const deduplicate = (
4141
for (const id of unprunedIds) {
4242
const metadata = state.toolParameters.get(id)
4343
if (!metadata) {
44-
logger.warn(`Missing metadata for tool call ID: ${id}`)
44+
// logger.warn(`Missing metadata for tool call ID: ${id}`)
4545
continue
4646
}
4747

lib/strategies/on-idle.ts

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,8 @@ function parseMessages(
4545
tool: part.tool,
4646
parameters: parameters,
4747
status: part.state?.status,
48-
error: part.state?.status === "error" ? part.state.error : undefined
48+
error: part.state?.status === "error" ? part.state.error : undefined,
49+
turn: cachedData?.turn ?? 0
4950
})
5051
}
5152
}

0 commit comments

Comments
 (0)