Skip to content
Open
Show file tree
Hide file tree
Changes from 10 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
145 changes: 134 additions & 11 deletions libs/langchain_v1/langchain/agents/middleware/summarization.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
"""Summarization middleware."""

import uuid
from collections.abc import Callable, Iterable
from typing import Any, cast
from collections.abc import Callable, Iterable, Mapping
from typing import Any, Final, cast

from langchain_core.messages import (
AIMessage,
Expand Down Expand Up @@ -59,6 +59,16 @@
_SEARCH_RANGE_FOR_TOOL_PAIRS = 5


class _UnsetType:
"""Sentinel indicating that max tokens should be inferred from the model profile."""

def __repr__(self) -> str:
return "UNSET"


UNSET: Final = _UnsetType()


class SummarizationMiddleware(AgentMiddleware):
"""Summarizes conversation history when token limits are approached.

Expand All @@ -70,22 +80,35 @@ class SummarizationMiddleware(AgentMiddleware):
def __init__(
self,
model: str | BaseChatModel,
max_tokens_before_summary: int | None = None,
max_tokens_before_summary: int | None | _UnsetType = UNSET,
messages_to_keep: int = _DEFAULT_MESSAGES_TO_KEEP,
token_counter: TokenCounter = count_tokens_approximately,
summary_prompt: str = DEFAULT_SUMMARY_PROMPT,
summary_prefix: str = SUMMARY_PREFIX,
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this parameter is unused

*,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let's move this up higher

buffer_tokens: int = 0,
target_retention_frac: float | None = None,
trim_token_limit: int | None = _DEFAULT_TRIM_TOKEN_LIMIT,
) -> None:
"""Initialize the summarization middleware.

Args:
model: The language model to use for generating summaries.
max_tokens_before_summary: Token threshold to trigger summarization.
If `None`, summarization is disabled.
If `None`, summarization is disabled. If `UNSET`, limits are inferred
from the model profile when available.
messages_to_keep: Number of recent messages to preserve after summarization.
Used whenever token-based retention is unavailable or disabled.
token_counter: Function to count tokens in messages.
summary_prompt: Prompt template for generating summaries.
summary_prefix: Prefix added to system message when including summary.
buffer_tokens: Additional buffer to reserve when estimating token usage.
target_retention_frac: Optional fraction (0, 1) of `max_input_tokens` to retain
in context. If the model profile is missing or incomplete, this falls
back to the `messages_to_keep` strategy.
trim_token_limit: Maximum tokens to keep when preparing messages for the
summarization call. Pass `None` to skip trimming entirely (risking
summary model overflows if the history is too long).
"""
super().__init__()

Expand All @@ -98,20 +121,25 @@ def __init__(
self.token_counter = token_counter
self.summary_prompt = summary_prompt
self.summary_prefix = summary_prefix
self.buffer_tokens = buffer_tokens
self.trim_token_limit = trim_token_limit
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure how necessary trim_token_limit is

It is a bit confusing as an argument (all thoughts I had)

  • does this mean the maximum number of tokens that we can trim?
  • does this mean the max number of tokens left after trimming?
  • when does trimming even occur, and why! Are we trimming instead of summarizing?

If someone runs into an error while summarizing, shouldn't they just lower the max_tokens_before_summary?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pass None to skip trimming entirely (risking summary model overflows if the history is too long).

Imo the best solution here is just "compact earlier" instead of "trimming at compaction time". Summarization call is like an extra few thousand tokens max maybe?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I renamed this parameter to trim_tokens_to_summarize. The purpose of my adding this as a parameter is to allow us to circumvent the trimming entirely. The current behavior is to always trim what is sent to the LLM for summarization to 4000 tokens:

def _create_summary(self, messages_to_summarize: list[AnyMessage]) -> str:
"""Generate summary for the given messages."""
if not messages_to_summarize:
return "No previous conversation history."
trimmed_messages = self._trim_messages_for_summary(messages_to_summarize)
if not trimmed_messages:
return "Previous conversation was too long to summarize."
try:
response = self.model.invoke(self.summary_prompt.format(messages=trimmed_messages))
return cast("str", response.content).strip()
except Exception as e: # noqa: BLE001
return f"Error generating summary: {e!s}"
def _trim_messages_for_summary(self, messages: list[AnyMessage]) -> list[AnyMessage]:
"""Trim messages to fit within summary generation limits."""
try:
return trim_messages(
messages,
max_tokens=_DEFAULT_TRIM_TOKEN_LIMIT,
token_counter=self.token_counter,
start_on="human",
strategy="last",
allow_partial=True,
include_system=True,
)
except Exception: # noqa: BLE001
return messages[-_DEFAULT_FALLBACK_MESSAGE_COUNT:]

so this was a minimal change that is not breaking but lets us disable this. lmk if that makes sense or I misunderstood you.


if target_retention_frac is not None and not (0 < target_retention_frac < 1):
error_msg = "target_retention_frac must be between 0 and 1."
raise ValueError(error_msg)

self.target_retention_frac = target_retention_frac

def before_model(self, state: AgentState, runtime: Runtime) -> dict[str, Any] | None: # noqa: ARG002
"""Process messages before model invocation, potentially triggering summarization."""
messages = state["messages"]
self._ensure_message_ids(messages)

total_tokens = self.token_counter(messages)
if (
self.max_tokens_before_summary is not None
and total_tokens < self.max_tokens_before_summary
):
if not self._should_summarize(total_tokens):
return None

cutoff_index = self._find_safe_cutoff(messages)
cutoff_index = self._determine_cutoff_index(messages)

if cutoff_index <= 0:
return None
Expand All @@ -129,6 +157,99 @@ def before_model(self, state: AgentState, runtime: Runtime) -> dict[str, Any] |
]
}

def _should_summarize(self, total_tokens: int) -> bool:
"""Determine whether summarization should run for the current token usage."""
if self.max_tokens_before_summary is UNSET:
return self._should_summarize_with_profile(total_tokens)

if self.max_tokens_before_summary is None:
return False

return total_tokens >= cast("int", self.max_tokens_before_summary)

def _should_summarize_with_profile(self, total_tokens: int) -> bool:
"""Infer summarization threshold from the model profile when available."""
limits = self._get_profile_limits()
if limits is None:
return False

max_input_tokens, max_output_tokens = limits

return total_tokens + max_output_tokens + self.buffer_tokens > max_input_tokens

def _determine_cutoff_index(self, messages: list[AnyMessage]) -> int:
"""Choose cutoff index respecting retention configuration."""
if self.target_retention_frac is not None:
token_based_cutoff = self._find_token_based_cutoff(messages)
if token_based_cutoff is not None:
return token_based_cutoff
return self._find_safe_cutoff(messages)

def _find_token_based_cutoff(self, messages: list[AnyMessage]) -> int | None:
"""Find cutoff index based on target token retention percentage."""
if not messages:
return 0

limits = self._get_profile_limits()
if limits is None:
return None

max_input_tokens, _ = limits
target_token_count = int(max_input_tokens * cast("float", self.target_retention_frac))
if target_token_count <= 0:
target_token_count = 1

if self.token_counter(messages) <= target_token_count:
return 0

# Use binary search to identify the earliest message index that keeps the
# suffix within the token budget.
left, right = 0, len(messages)
cutoff_candidate = len(messages)
max_iterations = len(messages).bit_length() + 1
for _ in range(max_iterations):
if left >= right:
break

mid = (left + right) // 2
if self.token_counter(messages[mid:]) <= target_token_count:
cutoff_candidate = mid
right = mid
else:
left = mid + 1

if cutoff_candidate == len(messages):
cutoff_candidate = left

if cutoff_candidate >= len(messages):
if len(messages) == 1:
return 0
cutoff_candidate = len(messages) - 1

for i in range(cutoff_candidate, -1, -1):
if self._is_safe_cutoff_point(messages, i):
return i

return 0

def _get_profile_limits(self) -> tuple[int, int] | None:
"""Retrieve max input and output token limits from the model profile."""
try:
profile = self.model.profile
except (AttributeError, ImportError):
return None

if not isinstance(profile, Mapping):
return None

max_input_tokens = profile.get("max_input_tokens")
max_output_tokens = profile.get("max_output_tokens")

if not isinstance(max_input_tokens, int) or not isinstance(max_output_tokens, int):
return None

return max_input_tokens, max_output_tokens

def _build_new_messages(self, summary: str) -> list[HumanMessage]:
return [
HumanMessage(content=f"Here is a summary of the conversation to date:\n\n{summary}")
Expand Down Expand Up @@ -229,16 +350,18 @@ def _create_summary(self, messages_to_summarize: list[AnyMessage]) -> str:

try:
response = self.model.invoke(self.summary_prompt.format(messages=trimmed_messages))
return cast("str", response.content).strip()
return response.text.strip()
except Exception as e: # noqa: BLE001
return f"Error generating summary: {e!s}"

def _trim_messages_for_summary(self, messages: list[AnyMessage]) -> list[AnyMessage]:
"""Trim messages to fit within summary generation limits."""
try:
if self.trim_token_limit is None:
return messages
return trim_messages(
messages,
max_tokens=_DEFAULT_TRIM_TOKEN_LIMIT,
max_tokens=self.trim_token_limit,
token_counter=self.token_counter,
start_on="human",
strategy="last",
Expand Down
Loading