Skip to content

Commit f308139

Browse files
feat(langchain_v1): Implement Context Editing Middleware (#33267)
Brings functionality similar to Anthropic's context editing to all chat models https://docs.claude.com/en/docs/build-with-claude/context-editing --------- Co-authored-by: Sydney Runkle <[email protected]>
1 parent 95a451e commit f308139

File tree

3 files changed

+469
-0
lines changed

3 files changed

+469
-0
lines changed

libs/langchain_v1/langchain/agents/middleware/__init__.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,10 @@
11
"""Middleware plugins for agents."""
22

3+
from .call_tracking import ModelCallLimitMiddleware
4+
from .context_editing import (
5+
ClearToolUsesEdit,
6+
ContextEditingMiddleware,
7+
)
38
from .human_in_the_loop import HumanInTheLoopMiddleware
49
from .model_fallback import ModelFallbackMiddleware
510
from .pii import PIIDetectionError, PIIMiddleware
@@ -24,8 +29,11 @@
2429
"AgentState",
2530
# should move to langchain-anthropic if we decide to keep it
2631
"AnthropicPromptCachingMiddleware",
32+
"ClearToolUsesEdit",
33+
"ContextEditingMiddleware",
2734
"HumanInTheLoopMiddleware",
2835
"LLMToolSelectorMiddleware",
36+
"ModelCallLimitMiddleware",
2937
"ModelFallbackMiddleware",
3038
"ModelRequest",
3139
"PIIDetectionError",
Lines changed: 245 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,245 @@
1+
"""Context editing middleware.
2+
3+
This middleware mirrors Anthropic's context editing capabilities by clearing
4+
older tool results once the conversation grows beyond a configurable token
5+
threshold. The implementation is intentionally model-agnostic so it can be used
6+
with any LangChain chat model.
7+
"""
8+
9+
from __future__ import annotations
10+
11+
from collections.abc import Callable, Iterable, Sequence
12+
from dataclasses import dataclass
13+
from typing import TYPE_CHECKING, Literal
14+
15+
from langchain_core.messages import (
16+
AIMessage,
17+
AnyMessage,
18+
BaseMessage,
19+
SystemMessage,
20+
ToolMessage,
21+
)
22+
from langchain_core.messages.utils import count_tokens_approximately
23+
from typing_extensions import Protocol
24+
25+
from langchain.agents.middleware.types import AgentMiddleware, AgentState, ModelRequest
26+
27+
if TYPE_CHECKING:
28+
from langgraph.runtime import Runtime
29+
30+
DEFAULT_TOOL_PLACEHOLDER = "[cleared]"
31+
32+
33+
TokenCounter = Callable[
34+
[Sequence[BaseMessage]],
35+
int,
36+
]
37+
38+
39+
class ContextEdit(Protocol):
40+
"""Protocol describing a context editing strategy."""
41+
42+
def apply(
43+
self,
44+
messages: list[AnyMessage],
45+
*,
46+
count_tokens: TokenCounter,
47+
) -> None:
48+
"""Apply an edit to the message list in place."""
49+
...
50+
51+
52+
@dataclass(slots=True)
53+
class ClearToolUsesEdit(ContextEdit):
54+
"""Configuration for clearing tool outputs when token limits are exceeded."""
55+
56+
trigger: int = 100_000
57+
"""Token count that triggers the edit."""
58+
59+
clear_at_least: int = 0
60+
"""Minimum number of tokens to reclaim when the edit runs."""
61+
62+
keep: int = 3
63+
"""Number of most recent tool results that must be preserved."""
64+
65+
clear_tool_inputs: bool = False
66+
"""Whether to clear the originating tool call parameters on the AI message."""
67+
68+
exclude_tools: Sequence[str] = ()
69+
"""List of tool names to exclude from clearing."""
70+
71+
placeholder: str = DEFAULT_TOOL_PLACEHOLDER
72+
"""Placeholder text inserted for cleared tool outputs."""
73+
74+
def apply(
75+
self,
76+
messages: list[AnyMessage],
77+
*,
78+
count_tokens: TokenCounter,
79+
) -> None:
80+
"""Apply the clear-tool-uses strategy."""
81+
tokens = count_tokens(messages)
82+
83+
if tokens <= self.trigger:
84+
return
85+
86+
candidates = [
87+
(idx, msg) for idx, msg in enumerate(messages) if isinstance(msg, ToolMessage)
88+
]
89+
90+
if self.keep >= len(candidates):
91+
candidates = []
92+
elif self.keep:
93+
candidates = candidates[: -self.keep]
94+
95+
cleared_tokens = 0
96+
excluded_tools = set(self.exclude_tools)
97+
98+
for idx, tool_message in candidates:
99+
if tool_message.response_metadata.get("context_editing", {}).get("cleared"):
100+
continue
101+
102+
ai_message = next(
103+
(m for m in reversed(messages[:idx]) if isinstance(m, AIMessage)), None
104+
)
105+
106+
if ai_message is None:
107+
continue
108+
109+
tool_call = next(
110+
(
111+
call
112+
for call in ai_message.tool_calls
113+
if call.get("id") == tool_message.tool_call_id
114+
),
115+
None,
116+
)
117+
118+
if tool_call is None:
119+
continue
120+
121+
if (tool_message.name or tool_call["name"]) in excluded_tools:
122+
continue
123+
124+
messages[idx] = tool_message.model_copy(
125+
update={
126+
"artifact": None,
127+
"content": self.placeholder,
128+
"response_metadata": {
129+
**tool_message.response_metadata,
130+
"context_editing": {
131+
"cleared": True,
132+
"strategy": "clear_tool_uses",
133+
},
134+
},
135+
}
136+
)
137+
138+
if self.clear_tool_inputs:
139+
messages[messages.index(ai_message)] = self._build_cleared_tool_input_message(
140+
ai_message,
141+
tool_message.tool_call_id,
142+
)
143+
144+
if self.clear_at_least > 0:
145+
new_token_count = count_tokens(messages)
146+
cleared_tokens = max(0, tokens - new_token_count)
147+
if cleared_tokens >= self.clear_at_least:
148+
break
149+
150+
return
151+
152+
def _build_cleared_tool_input_message(
153+
self,
154+
message: AIMessage,
155+
tool_call_id: str,
156+
) -> AIMessage:
157+
updated_tool_calls = []
158+
cleared_any = False
159+
for tool_call in message.tool_calls:
160+
updated_call = dict(tool_call)
161+
if updated_call.get("id") == tool_call_id:
162+
updated_call["args"] = {}
163+
cleared_any = True
164+
updated_tool_calls.append(updated_call)
165+
166+
metadata = dict(getattr(message, "response_metadata", {}))
167+
context_entry = dict(metadata.get("context_editing", {}))
168+
if cleared_any:
169+
cleared_ids = set(context_entry.get("cleared_tool_inputs", []))
170+
cleared_ids.add(tool_call_id)
171+
context_entry["cleared_tool_inputs"] = sorted(cleared_ids)
172+
metadata["context_editing"] = context_entry
173+
174+
return message.model_copy(
175+
update={
176+
"tool_calls": updated_tool_calls,
177+
"response_metadata": metadata,
178+
}
179+
)
180+
181+
182+
class ContextEditingMiddleware(AgentMiddleware):
183+
"""Middleware that automatically prunes tool results to manage context size.
184+
185+
The middleware applies a sequence of edits when the total input token count
186+
exceeds configured thresholds. Currently the ``ClearToolUsesEdit`` strategy is
187+
supported, aligning with Anthropic's ``clear_tool_uses_20250919`` behaviour.
188+
"""
189+
190+
edits: list[ContextEdit]
191+
token_count_method: Literal["approximate", "model"]
192+
193+
def __init__(
194+
self,
195+
*,
196+
edits: Iterable[ContextEdit] | None = None,
197+
token_count_method: Literal["approximate", "model"] = "approximate", # noqa: S107
198+
) -> None:
199+
"""Initialise a context editing middleware instance.
200+
201+
Args:
202+
edits: Sequence of edit strategies to apply. Defaults to a single
203+
`ClearToolUsesEdit` mirroring Anthropic defaults.
204+
token_count_method: Whether to use approximate token counting
205+
(faster, less accurate) or exact counting implemented by the
206+
chat model (potentially slower, more accurate).
207+
"""
208+
super().__init__()
209+
self.edits = list(edits or (ClearToolUsesEdit(),))
210+
self.token_count_method = token_count_method
211+
212+
def modify_model_request(
213+
self,
214+
request: ModelRequest,
215+
state: AgentState, # noqa: ARG002
216+
runtime: Runtime, # noqa: ARG002
217+
) -> ModelRequest:
218+
"""Modify the model request by applying context edits before invocation."""
219+
if not request.messages:
220+
return request
221+
222+
if self.token_count_method == "approximate": # noqa: S105
223+
224+
def count_tokens(messages: Sequence[BaseMessage]) -> int:
225+
return count_tokens_approximately(messages)
226+
else:
227+
system_msg = (
228+
[SystemMessage(content=request.system_prompt)] if request.system_prompt else []
229+
)
230+
231+
def count_tokens(messages: Sequence[BaseMessage]) -> int:
232+
return request.model.get_num_tokens_from_messages(
233+
system_msg + list(messages), request.tools
234+
)
235+
236+
for edit in self.edits:
237+
edit.apply(request.messages, count_tokens=count_tokens)
238+
239+
return request
240+
241+
242+
__all__ = [
243+
"ClearToolUsesEdit",
244+
"ContextEditingMiddleware",
245+
]

0 commit comments

Comments
 (0)