Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
85 changes: 75 additions & 10 deletions libs/core/langchain_core/messages/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -688,6 +688,7 @@ def trim_messages(
*,
max_tokens: int,
token_counter: Union[
Literal["approximate"],
Callable[[list[BaseMessage]], int],
Callable[[BaseMessage], int],
BaseLanguageModel,
Expand Down Expand Up @@ -738,11 +739,16 @@ def trim_messages(
BaseMessage. If a BaseLanguageModel is passed in then
BaseLanguageModel.get_num_tokens_from_messages() will be used.
Set to `len` to count the number of **messages** in the chat history.
You can also use string shortcuts for convenience:

- ``"approximate"``: Uses `count_tokens_approximately` for fast, approximate
token counts.

.. note::
Use `count_tokens_approximately` to get fast, approximate token counts.
This is recommended for using `trim_messages` on the hot path, where
exact token counting is not necessary.
Use `count_tokens_approximately` (or the shortcut ``"approximate"``) to get
fast, approximate token counts. This is recommended for using
`trim_messages` on the hot path, where exact token counting is not
necessary.

strategy: Strategy for trimming.

Expand Down Expand Up @@ -849,6 +855,35 @@ def trim_messages(
HumanMessage(content="what do you call a speechless parrot"),
]

Trim chat history using approximate token counting with the "approximate" shortcut:

.. code-block:: python

trim_messages(
messages,
max_tokens=45,
strategy="last",
# Using the "approximate" shortcut for fast approximate token counting
token_counter="approximate",
start_on="human",
include_system=True,
)

This is equivalent to using `count_tokens_approximately` directly:

.. code-block:: python

from langchain_core.messages.utils import count_tokens_approximately

trim_messages(
messages,
max_tokens=45,
strategy="last",
token_counter=count_tokens_approximately,
start_on="human",
include_system=True,
)

Trim chat history based on the message count, keeping the SystemMessage if
present, and ensuring that the chat history starts with a HumanMessage (
or a SystemMessage followed by a HumanMessage).
Expand Down Expand Up @@ -977,24 +1012,43 @@ def dummy_token_counter(messages: list[BaseMessage]) -> int:
raise ValueError(msg)

messages = convert_to_messages(messages)
if hasattr(token_counter, "get_num_tokens_from_messages"):
list_token_counter = token_counter.get_num_tokens_from_messages
elif callable(token_counter):

# Handle string shortcuts for token counter
if isinstance(token_counter, str):
if token_counter in _TOKEN_COUNTER_SHORTCUTS:
actual_token_counter = _TOKEN_COUNTER_SHORTCUTS[token_counter]
else:
available_shortcuts = ", ".join(
f"'{key}'" for key in _TOKEN_COUNTER_SHORTCUTS
)
msg = (
f"Invalid token_counter shortcut '{token_counter}'. "
f"Available shortcuts: {available_shortcuts}."
)
raise ValueError(msg)
else:
actual_token_counter = token_counter

if hasattr(actual_token_counter, "get_num_tokens_from_messages"):
list_token_counter = actual_token_counter.get_num_tokens_from_messages # type: ignore[assignment]
elif callable(actual_token_counter):
if (
next(iter(inspect.signature(token_counter).parameters.values())).annotation
next(
iter(inspect.signature(actual_token_counter).parameters.values())
).annotation
is BaseMessage
):

def list_token_counter(messages: Sequence[BaseMessage]) -> int:
return sum(token_counter(msg) for msg in messages) # type: ignore[arg-type, misc]
return sum(actual_token_counter(msg) for msg in messages) # type: ignore[arg-type, misc]

else:
list_token_counter = token_counter
list_token_counter = actual_token_counter # type: ignore[assignment]
else:
msg = (
f"'token_counter' expected to be a model that implements "
f"'get_num_tokens_from_messages()' or a function. Received object of type "
f"{type(token_counter)}."
f"{type(actual_token_counter)}."
)
raise ValueError(msg)

Expand Down Expand Up @@ -1754,3 +1808,14 @@ def count_tokens_approximately(

# round up once more time in case extra_tokens_per_message is a float
return math.ceil(token_count)


# Mapping from string shortcuts to token counter functions
def _approximate_token_counter(messages: Sequence[BaseMessage]) -> int:
"""Wrapper for count_tokens_approximately that matches expected signature."""
return count_tokens_approximately(messages)


_TOKEN_COUNTER_SHORTCUTS = {
"approximate": _approximate_token_counter,
}
75 changes: 75 additions & 0 deletions libs/core/tests/unit_tests/messages/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -660,6 +660,81 @@ def test_trim_messages_start_on_with_allow_partial() -> None:
assert messages == messages_copy


def test_trim_messages_token_counter_shortcut_approximate() -> None:
"""Test that 'approximate' shortcut works for token_counter."""
messages = [
SystemMessage("This is a test message"),
HumanMessage("Another test message", id="first"),
AIMessage("AI response here", id="second"),
]
messages_copy = [m.model_copy(deep=True) for m in messages]

# Test using the "approximate" shortcut
result_shortcut = trim_messages(
messages,
max_tokens=50,
token_counter="approximate",
strategy="last",
)

# Test using count_tokens_approximately directly
result_direct = trim_messages(
messages,
max_tokens=50,
token_counter=count_tokens_approximately,
strategy="last",
)

# Both should produce the same result
assert result_shortcut == result_direct
assert messages == messages_copy


def test_trim_messages_token_counter_shortcut_invalid() -> None:
"""Test that invalid token_counter shortcut raises ValueError."""
messages = [
SystemMessage("This is a test message"),
HumanMessage("Another test message"),
]

# Test with invalid shortcut
with pytest.raises(ValueError, match="Invalid token_counter shortcut 'invalid'"):
trim_messages(
messages,
max_tokens=50,
token_counter="invalid",
strategy="last",
)


def test_trim_messages_token_counter_shortcut_with_options() -> None:
"""Test that 'approximate' shortcut works with different trim options."""
messages = [
SystemMessage("System instructions"),
HumanMessage("First human message", id="first"),
AIMessage("First AI response", id="ai1"),
HumanMessage("Second human message", id="second"),
AIMessage("Second AI response", id="ai2"),
]
messages_copy = [m.model_copy(deep=True) for m in messages]

# Test with various options
result = trim_messages(
messages,
max_tokens=100,
token_counter="approximate",
strategy="last",
include_system=True,
start_on="human",
)

# Should include system message and start on human
assert len(result) >= 2
assert isinstance(result[0], SystemMessage)
assert any(isinstance(msg, HumanMessage) for msg in result[1:])
assert messages == messages_copy


class FakeTokenCountingModel(FakeChatModel):
@override
def get_num_tokens_from_messages(
Expand Down
Loading