Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
222 changes: 198 additions & 24 deletions libs/langchain_v1/langchain/agents/middleware/summarization.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
"""Summarization middleware."""

import uuid
from collections.abc import Callable, Iterable
from typing import Any, cast
import warnings
from collections.abc import Callable, Iterable, Mapping
from typing import Any, Literal, cast

from langchain_core.messages import (
AIMessage,
Expand Down Expand Up @@ -51,13 +52,17 @@
{messages}
</messages>""" # noqa: E501

SUMMARY_PREFIX = "## Previous conversation summary:"

_DEFAULT_MESSAGES_TO_KEEP = 20
_DEFAULT_TRIM_TOKEN_LIMIT = 4000
_DEFAULT_FALLBACK_MESSAGE_COUNT = 15
_SEARCH_RANGE_FOR_TOOL_PAIRS = 5

ContextFraction = tuple[Literal["fraction"], float]
ContextTokens = tuple[Literal["tokens"], int]
ContextMessages = tuple[Literal["messages"], int]

ContextSize = ContextFraction | ContextTokens | ContextMessages


class SummarizationMiddleware(AgentMiddleware):
"""Summarizes conversation history when token limits are approached.
Expand All @@ -70,48 +75,97 @@ class SummarizationMiddleware(AgentMiddleware):
def __init__(
self,
model: str | BaseChatModel,
max_tokens_before_summary: int | None = None,
messages_to_keep: int = _DEFAULT_MESSAGES_TO_KEEP,
*,
trigger: ContextSize | list[ContextSize] | None = None,
keep: ContextSize = ("messages", _DEFAULT_MESSAGES_TO_KEEP),
token_counter: TokenCounter = count_tokens_approximately,
summary_prompt: str = DEFAULT_SUMMARY_PROMPT,
summary_prefix: str = SUMMARY_PREFIX,
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this parameter is unused

trim_tokens_to_summarize: int | None = _DEFAULT_TRIM_TOKEN_LIMIT,
**deprecated_kwargs: Any,
) -> None:
"""Initialize the summarization middleware.

Args:
model: The language model to use for generating summaries.
max_tokens_before_summary: Token threshold to trigger summarization.
If `None`, summarization is disabled.
messages_to_keep: Number of recent messages to preserve after summarization.
trigger: One or more thresholds that trigger summarization. Provide a single
`ContextSize` tuple or a list of tuples, in which case summarization runs
when any threshold is breached. Examples: `("messages", 50)`, `("tokens", 3000)`,
`[("fraction", 0.8), ("messages", 100)]`.
keep: Context retention policy applied after summarization. Provide a
`ContextSize` tuple to specify how much history to preserve. Defaults to
keeping the most recent 20 messages. Examples: `("messages", 20)`,
`("tokens", 3000)`, or `("fraction", 0.3)`.
token_counter: Function to count tokens in messages.
summary_prompt: Prompt template for generating summaries.
summary_prefix: Prefix added to system message when including summary.
trim_tokens_to_summarize: Maximum tokens to keep when preparing messages for the
summarization call. Pass `None` to skip trimming entirely.
"""
# Handle deprecated parameters
if "max_tokens_before_summary" in deprecated_kwargs:
value = deprecated_kwargs["max_tokens_before_summary"]
warnings.warn(
"max_tokens_before_summary is deprecated. Use trigger=('tokens', value) instead.",
DeprecationWarning,
stacklevel=2,
)
if trigger is None and value is not None:
trigger = ("tokens", value)

if "messages_to_keep" in deprecated_kwargs:
value = deprecated_kwargs["messages_to_keep"]
warnings.warn(
"messages_to_keep is deprecated. Use keep=('messages', value) instead.",
DeprecationWarning,
stacklevel=2,
)
if keep == ("messages", _DEFAULT_MESSAGES_TO_KEEP):
keep = ("messages", value)

super().__init__()

if isinstance(model, str):
model = init_chat_model(model)

self.model = model
self.max_tokens_before_summary = max_tokens_before_summary
self.messages_to_keep = messages_to_keep
if trigger is None:
self.trigger: ContextSize | list[ContextSize] | None = None
trigger_conditions: list[ContextSize] = []
elif isinstance(trigger, list):
validated_list = [self._validate_context_size(item, "trigger") for item in trigger]
self.trigger = validated_list
trigger_conditions = validated_list
else:
validated = self._validate_context_size(trigger, "trigger")
self.trigger = validated
trigger_conditions = [validated]
self._trigger_conditions = trigger_conditions

self.keep = self._validate_context_size(keep, "keep")
self.token_counter = token_counter
self.summary_prompt = summary_prompt
self.summary_prefix = summary_prefix
self.trim_tokens_to_summarize = trim_tokens_to_summarize

requires_profile = any(condition[0] == "fraction" for condition in self._trigger_conditions)
if self.keep[0] == "fraction":
requires_profile = True
if requires_profile and self._get_profile_limits() is None:
msg = (
"Model profile information is required to use fractional token limits. "
'pip install "langchain[model-profiles]" or use absolute token counts '
"instead."
)
raise ValueError(msg)

def before_model(self, state: AgentState, runtime: Runtime) -> dict[str, Any] | None: # noqa: ARG002
"""Process messages before model invocation, potentially triggering summarization."""
messages = state["messages"]
self._ensure_message_ids(messages)

total_tokens = self.token_counter(messages)
if (
self.max_tokens_before_summary is not None
and total_tokens < self.max_tokens_before_summary
):
if not self._should_summarize(messages, total_tokens):
return None

cutoff_index = self._find_safe_cutoff(messages)
cutoff_index = self._determine_cutoff_index(messages)

if cutoff_index <= 0:
return None
Expand All @@ -129,6 +183,124 @@ def before_model(self, state: AgentState, runtime: Runtime) -> dict[str, Any] |
]
}

def _should_summarize(self, messages: list[AnyMessage], total_tokens: int) -> bool:
"""Determine whether summarization should run for the current token usage."""
if not self._trigger_conditions:
return False

for kind, value in self._trigger_conditions:
if kind == "messages" and len(messages) >= value:
return True
if kind == "tokens" and total_tokens >= value:
return True
if kind == "fraction":
max_input_tokens = self._get_profile_limits()
if max_input_tokens is None:
continue
threshold = int(max_input_tokens * value)
if threshold <= 0:
threshold = 1
if total_tokens >= threshold:
return True
return False

def _determine_cutoff_index(self, messages: list[AnyMessage]) -> int:
"""Choose cutoff index respecting retention configuration."""
kind, value = self.keep
if kind in {"tokens", "fraction"}:
token_based_cutoff = self._find_token_based_cutoff(messages)
if token_based_cutoff is not None:
return token_based_cutoff
# None cutoff -> model profile data not available (caught in __init__ but
# here for safety), fallback to message count
return self._find_safe_cutoff(messages, _DEFAULT_MESSAGES_TO_KEEP)
return self._find_safe_cutoff(messages, cast("int", value))

def _find_token_based_cutoff(self, messages: list[AnyMessage]) -> int | None:
"""Find cutoff index based on target token retention."""
if not messages:
return 0

kind, value = self.keep
if kind == "fraction":
max_input_tokens = self._get_profile_limits()
if max_input_tokens is None:
return None
target_token_count = int(max_input_tokens * value)
elif kind == "tokens":
target_token_count = int(value)
else:
return None

if target_token_count <= 0:
target_token_count = 1

if self.token_counter(messages) <= target_token_count:
return 0

# Use binary search to identify the earliest message index that keeps the
# suffix within the token budget.
left, right = 0, len(messages)
cutoff_candidate = len(messages)
max_iterations = len(messages).bit_length() + 1
for _ in range(max_iterations):
if left >= right:
break

mid = (left + right) // 2
if self.token_counter(messages[mid:]) <= target_token_count:
cutoff_candidate = mid
right = mid
else:
left = mid + 1

if cutoff_candidate == len(messages):
cutoff_candidate = left

if cutoff_candidate >= len(messages):
if len(messages) == 1:
return 0
cutoff_candidate = len(messages) - 1

for i in range(cutoff_candidate, -1, -1):
if self._is_safe_cutoff_point(messages, i):
return i

return 0

def _get_profile_limits(self) -> int | None:
"""Retrieve max input token limit from the model profile."""
try:
profile = self.model.profile
except (AttributeError, ImportError):
return None

if not isinstance(profile, Mapping):
return None

max_input_tokens = profile.get("max_input_tokens")

if not isinstance(max_input_tokens, int):
return None

return max_input_tokens

def _validate_context_size(self, context: ContextSize, parameter_name: str) -> ContextSize:
"""Validate context configuration tuples."""
kind, value = context
if kind == "fraction":
if not 0 < value <= 1:
msg = f"Fractional {parameter_name} values must be between 0 and 1, got {value}."
raise ValueError(msg)
elif kind in {"tokens", "messages"}:
if value <= 0:
msg = f"{parameter_name} thresholds must be greater than 0, got {value}."
raise ValueError(msg)
else:
msg = f"Unsupported context size type {kind} for {parameter_name}."
raise ValueError(msg)
return context

def _build_new_messages(self, summary: str) -> list[HumanMessage]:
return [
HumanMessage(content=f"Here is a summary of the conversation to date:\n\n{summary}")
Expand All @@ -151,16 +323,16 @@ def _partition_messages(

return messages_to_summarize, preserved_messages

def _find_safe_cutoff(self, messages: list[AnyMessage]) -> int:
def _find_safe_cutoff(self, messages: list[AnyMessage], messages_to_keep: int) -> int:
"""Find safe cutoff point that preserves AI/Tool message pairs.

Returns the index where messages can be safely cut without separating
related AI and Tool messages. Returns 0 if no safe cutoff is found.
"""
if len(messages) <= self.messages_to_keep:
if len(messages) <= messages_to_keep:
return 0

target_cutoff = len(messages) - self.messages_to_keep
target_cutoff = len(messages) - messages_to_keep

for i in range(target_cutoff, -1, -1):
if self._is_safe_cutoff_point(messages, i):
Expand Down Expand Up @@ -229,16 +401,18 @@ def _create_summary(self, messages_to_summarize: list[AnyMessage]) -> str:

try:
response = self.model.invoke(self.summary_prompt.format(messages=trimmed_messages))
return cast("str", response.content).strip()
return response.text.strip()
except Exception as e: # noqa: BLE001
return f"Error generating summary: {e!s}"

def _trim_messages_for_summary(self, messages: list[AnyMessage]) -> list[AnyMessage]:
"""Trim messages to fit within summary generation limits."""
try:
if self.trim_tokens_to_summarize is None:
return messages
return trim_messages(
messages,
max_tokens=_DEFAULT_TRIM_TOKEN_LIMIT,
max_tokens=self.trim_tokens_to_summarize,
token_counter=self.token_counter,
start_on="human",
strategy="last",
Expand Down
Loading