Skip to content

Commit 13b662c

Browse files
chore: summarize_when_long with no tool call response
1 parent a3ebb7e commit 13b662c

File tree

2 files changed

+160
-20
lines changed

2 files changed

+160
-20
lines changed

dreadnode/agent/hooks/summarize.py

Lines changed: 75 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -44,12 +44,56 @@ def _get_last_input_tokens(event: AgentEvent) -> int:
4444
return last_generation_event.usage.input_tokens if last_generation_event.usage else 0
4545

4646

47+
def _find_tool_aware_boundary(
48+
messages: list[rg.Message],
49+
min_messages_to_keep: int,
50+
) -> int:
51+
"""
52+
Find the best summarization boundary while preserving tool call/response pairs.
53+
54+
This prevents breaking tool messages that would cause API errors with strict models
55+
(OpenAI, Anthropic) that require every tool_call_id to have a matching response.
56+
57+
Args:
58+
messages: List of messages to analyze (excluding system message)
59+
min_messages_to_keep: Minimum messages that must be kept after boundary
60+
61+
Returns:
62+
Index where to split (messages[:idx] summarized, messages[idx:] kept)
63+
Returns 0 if no valid boundary found
64+
"""
65+
# Build tool_call_id -> assistant message index mapping
66+
tool_call_map: dict[str, int] = {}
67+
for i, msg in enumerate(messages):
68+
if msg.role == "assistant" and hasattr(msg, "tool_calls"):
69+
for tc in getattr(msg, "tool_calls", None) or []:
70+
if hasattr(tc, "id"):
71+
tool_call_map[tc.id] = i
72+
73+
# Walk backward from desired split point to find first valid boundary
74+
for boundary in range(len(messages) - min_messages_to_keep, -1, -1):
75+
# Check if this boundary would orphan any tool responses
76+
has_orphan = False
77+
for msg in messages[boundary:]:
78+
if msg.role == "tool" and hasattr(msg, "tool_call_id"):
79+
call_idx = tool_call_map.get(msg.tool_call_id)
80+
if call_idx is not None and call_idx < boundary:
81+
has_orphan = True
82+
break
83+
84+
if not has_orphan:
85+
return boundary
86+
87+
return 0 # No valid boundary found
88+
89+
4790
@component
4891
def summarize_when_long(
4992
model: str | rg.Generator | None = None,
5093
max_tokens: int = 100_000,
5194
min_messages_to_keep: int = 5,
5295
guidance: str = "",
96+
preserve_tool_pairs: bool = True,
5397
) -> "Hook":
5498
"""
5599
Creates a hook to manage the agent's context window by summarizing the conversation history.
@@ -66,6 +110,9 @@ def summarize_when_long(
66110
(default is None, meaning no proactive summarization).
67111
min_messages_to_keep: The minimum number of messages to retain after summarization (default is 5).
68112
guidance: Additional guidance for the summarization process (default is "").
113+
preserve_tool_pairs: If True, ensures tool call/response pairs stay together to avoid breaking
114+
strict API requirements (OpenAI, Anthropic). Defaults to True. Set to False to use legacy
115+
behavior that may break tool pairs but allows more aggressive summarization.
69116
"""
70117

71118
if min_messages_to_keep < 2:
@@ -91,6 +138,10 @@ async def summarize_when_long( # noqa: PLR0912
91138
guidance,
92139
help="Additional guidance for the summarization process",
93140
),
141+
preserve_tool_pairs: bool = Config(
142+
preserve_tool_pairs,
143+
help="Preserve tool call/response pairs to avoid breaking strict API requirements",
144+
),
94145
) -> Reaction | None:
95146
should_summarize = False
96147

@@ -123,26 +174,30 @@ async def summarize_when_long( # noqa: PLR0912
123174
messages.pop(0) if messages and messages[0].role == "system" else None
124175
)
125176

126-
# Find the best point to summarize by walking the message list once.
127-
# A boundary is valid after a simple assistant message or a finished tool block.
128-
best_summarize_boundary = 0
129-
for i, message in enumerate(messages):
130-
# If the remaining messages are less than or equal to our minimum, we can't slice any further.
131-
if len(messages) - i <= min_messages_to_keep:
132-
break
133-
134-
# Condition 1: The message is an assistant response without tool calls.
135-
is_simple_assistant = message.role == "assistant" and not getattr(
136-
message, "tool_calls", None
137-
)
138-
139-
# Condition 2: The message is the last in a block of tool responses.
140-
is_last_tool_in_block = message.role == "tool" and (
141-
i + 1 == len(messages) or messages[i + 1].role != "tool"
142-
)
143-
144-
if is_simple_assistant or is_last_tool_in_block:
145-
best_summarize_boundary = i + 1
177+
# Find the best point to summarize
178+
if preserve_tool_pairs:
179+
# Use tool-aware boundary finding to prevent breaking tool call/response pairs
180+
best_summarize_boundary = _find_tool_aware_boundary(messages, min_messages_to_keep)
181+
else:
182+
# Legacy behavior: walk the message list once looking for simple boundaries
183+
best_summarize_boundary = 0
184+
for i, message in enumerate(messages):
185+
# If the remaining messages are less than or equal to our minimum, we can't slice any further.
186+
if len(messages) - i <= min_messages_to_keep:
187+
break
188+
189+
# Condition 1: The message is an assistant response without tool calls.
190+
is_simple_assistant = message.role == "assistant" and not getattr(
191+
message, "tool_calls", None
192+
)
193+
194+
# Condition 2: The message is the last in a block of tool responses.
195+
is_last_tool_in_block = message.role == "tool" and (
196+
i + 1 == len(messages) or messages[i + 1].role != "tool"
197+
)
198+
199+
if is_simple_assistant or is_last_tool_in_block:
200+
best_summarize_boundary = i + 1
146201

147202
if best_summarize_boundary == 0:
148203
return None # No valid slice point was found.

tests/test_preserve_tool_pairs.py

Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,85 @@
1+
"""Tests for preserve_tool_pairs functionality in summarize_when_long hook."""
2+
3+
import rigging as rg
4+
from dreadnode.agent.hooks.summarize import _find_tool_aware_boundary
5+
6+
7+
class ToolCall:
8+
"""Minimal tool call representation for testing."""
9+
def __init__(self, call_id: str):
10+
self.id = call_id
11+
12+
13+
class ToolMessage(rg.Message):
14+
"""Tool response message for testing."""
15+
def __init__(self, call_id: str, content: str):
16+
super().__init__("tool", content)
17+
self.tool_call_id = call_id
18+
19+
20+
def test_preserves_tool_pairs():
21+
"""Tool call and response stay together when split."""
22+
messages = [
23+
rg.Message("user", "Hello"),
24+
rg.Message("assistant", "Let me check", tool_calls=[ToolCall("call_1")]),
25+
ToolMessage("call_1", "Result"),
26+
rg.Message("assistant", "Done"),
27+
rg.Message("user", "Thanks"),
28+
]
29+
30+
boundary = _find_tool_aware_boundary(messages, min_messages_to_keep=2)
31+
32+
# Should keep tool pair together by moving boundary earlier
33+
assert boundary <= 1, "Boundary should preserve tool call/response pair"
34+
35+
36+
def test_no_tools():
37+
"""Works correctly without any tool messages."""
38+
messages = [
39+
rg.Message("user", "Hello"),
40+
rg.Message("assistant", "Hi"),
41+
rg.Message("user", "How are you"),
42+
rg.Message("assistant", "Good"),
43+
]
44+
45+
boundary = _find_tool_aware_boundary(messages, min_messages_to_keep=2)
46+
assert boundary == 2, "Should split at natural boundary"
47+
48+
49+
def test_multiple_tool_pairs():
50+
"""Handles multiple tool call/response pairs correctly."""
51+
messages = [
52+
rg.Message("user", "Do A and B"),
53+
rg.Message("assistant", "Running A", tool_calls=[ToolCall("a")]),
54+
ToolMessage("a", "A done"),
55+
rg.Message("assistant", "Running B", tool_calls=[ToolCall("b")]),
56+
ToolMessage("b", "B done"),
57+
rg.Message("user", "Thanks"),
58+
]
59+
60+
boundary = _find_tool_aware_boundary(messages, min_messages_to_keep=2)
61+
62+
# Should not split between any tool pairs
63+
kept = messages[boundary:]
64+
assert len(kept) >= 2, "Should keep minimum messages"
65+
66+
67+
def test_no_valid_boundary():
68+
"""Returns 0 when entire conversation is tool chain."""
69+
messages = [
70+
rg.Message("assistant", "Start", tool_calls=[ToolCall("1")]),
71+
ToolMessage("1", "Result 1"),
72+
rg.Message("assistant", "Continue", tool_calls=[ToolCall("2")]),
73+
ToolMessage("2", "Result 2"),
74+
]
75+
76+
boundary = _find_tool_aware_boundary(messages, min_messages_to_keep=2)
77+
assert boundary == 0, "Should keep everything when no valid split exists"
78+
79+
80+
if __name__ == "__main__":
81+
test_preserves_tool_pairs()
82+
test_no_tools()
83+
test_multiple_tool_pairs()
84+
test_no_valid_boundary()
85+
print("All tests passed")

0 commit comments

Comments
 (0)