Skip to content

Commit a5ec2c6

Browse files
committed
added feature for using alias approx alongside count_tokens_approximately
1 parent cdae9e4 commit a5ec2c6

File tree

2 files changed

+150
-10
lines changed

2 files changed

+150
-10
lines changed

libs/core/langchain_core/messages/utils.py

Lines changed: 75 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -688,6 +688,7 @@ def trim_messages(
688688
*,
689689
max_tokens: int,
690690
token_counter: Union[
691+
str,
691692
Callable[[list[BaseMessage]], int],
692693
Callable[[BaseMessage], int],
693694
BaseLanguageModel,
@@ -738,11 +739,16 @@ def trim_messages(
738739
BaseMessage. If a BaseLanguageModel is passed in then
739740
BaseLanguageModel.get_num_tokens_from_messages() will be used.
740741
Set to `len` to count the number of **messages** in the chat history.
742+
You can also use string shortcuts for convenience:
743+
744+
- ``"approx"``: Uses `count_tokens_approximately` for fast, approximate
745+
token counts.
741746
742747
.. note::
743-
Use `count_tokens_approximately` to get fast, approximate token counts.
744-
This is recommended for using `trim_messages` on the hot path, where
745-
exact token counting is not necessary.
748+
Use `count_tokens_approximately` (or the shortcut ``"approx"``) to get
749+
fast, approximate token counts. This is recommended for using
750+
`trim_messages` on the hot path, where exact token counting is not
751+
necessary.
746752
747753
strategy: Strategy for trimming.
748754
@@ -849,6 +855,35 @@ def trim_messages(
849855
HumanMessage(content="what do you call a speechless parrot"),
850856
]
851857
858+
Trim chat history using approximate token counting with the "approx" shortcut:
859+
860+
.. code-block:: python
861+
862+
trim_messages(
863+
messages,
864+
max_tokens=45,
865+
strategy="last",
866+
# Using the "approx" shortcut for fast approximate token counting
867+
token_counter="approx",
868+
start_on="human",
869+
include_system=True,
870+
)
871+
872+
This is equivalent to using `count_tokens_approximately` directly:
873+
874+
.. code-block:: python
875+
876+
from langchain_core.messages.utils import count_tokens_approximately
877+
878+
trim_messages(
879+
messages,
880+
max_tokens=45,
881+
strategy="last",
882+
token_counter=count_tokens_approximately,
883+
start_on="human",
884+
include_system=True,
885+
)
886+
852887
Trim chat history based on the message count, keeping the SystemMessage if
853888
present, and ensuring that the chat history starts with a HumanMessage (
854889
or a SystemMessage followed by a HumanMessage).
@@ -977,24 +1012,43 @@ def dummy_token_counter(messages: list[BaseMessage]) -> int:
9771012
raise ValueError(msg)
9781013

9791014
messages = convert_to_messages(messages)
980-
if hasattr(token_counter, "get_num_tokens_from_messages"):
981-
list_token_counter = token_counter.get_num_tokens_from_messages
982-
elif callable(token_counter):
1015+
1016+
# Handle string shortcuts for token counter
1017+
if isinstance(token_counter, str):
1018+
if token_counter in _TOKEN_COUNTER_SHORTCUTS:
1019+
actual_token_counter = _TOKEN_COUNTER_SHORTCUTS[token_counter]
1020+
else:
1021+
available_shortcuts = ", ".join(
1022+
f"'{key}'" for key in _TOKEN_COUNTER_SHORTCUTS
1023+
)
1024+
msg = (
1025+
f"Invalid token_counter shortcut '{token_counter}'. "
1026+
f"Available shortcuts: {available_shortcuts}."
1027+
)
1028+
raise ValueError(msg)
1029+
else:
1030+
actual_token_counter = token_counter
1031+
1032+
if hasattr(actual_token_counter, "get_num_tokens_from_messages"):
1033+
list_token_counter = actual_token_counter.get_num_tokens_from_messages # type: ignore[assignment]
1034+
elif callable(actual_token_counter):
9831035
if (
984-
next(iter(inspect.signature(token_counter).parameters.values())).annotation
1036+
next(
1037+
iter(inspect.signature(actual_token_counter).parameters.values())
1038+
).annotation
9851039
is BaseMessage
9861040
):
9871041

9881042
def list_token_counter(messages: Sequence[BaseMessage]) -> int:
989-
return sum(token_counter(msg) for msg in messages) # type: ignore[arg-type, misc]
1043+
return sum(actual_token_counter(msg) for msg in messages) # type: ignore[arg-type, misc]
9901044

9911045
else:
992-
list_token_counter = token_counter
1046+
list_token_counter = actual_token_counter # type: ignore[assignment]
9931047
else:
9941048
msg = (
9951049
f"'token_counter' expected to be a model that implements "
9961050
f"'get_num_tokens_from_messages()' or a function. Received object of type "
997-
f"{type(token_counter)}."
1051+
f"{type(actual_token_counter)}."
9981052
)
9991053
raise ValueError(msg)
10001054

@@ -1754,3 +1808,14 @@ def count_tokens_approximately(
17541808

17551809
# round up once more time in case extra_tokens_per_message is a float
17561810
return math.ceil(token_count)
1811+
1812+
1813+
# Mapping from string shortcuts to token counter functions
1814+
def _approx_token_counter(messages: Sequence[BaseMessage]) -> int:
1815+
"""Wrapper for count_tokens_approximately that matches expected signature."""
1816+
return count_tokens_approximately(messages)
1817+
1818+
1819+
_TOKEN_COUNTER_SHORTCUTS = {
1820+
"approx": _approx_token_counter,
1821+
}

libs/core/tests/unit_tests/messages/test_utils.py

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -660,6 +660,81 @@ def test_trim_messages_start_on_with_allow_partial() -> None:
660660
assert messages == messages_copy
661661

662662

663+
def test_trim_messages_token_counter_shortcut_approx() -> None:
664+
"""Test that 'approx' shortcut works for token_counter."""
665+
messages = [
666+
SystemMessage("This is a test message"),
667+
HumanMessage("Another test message", id="first"),
668+
AIMessage("AI response here", id="second"),
669+
]
670+
messages_copy = [m.model_copy(deep=True) for m in messages]
671+
672+
# Test using the "approx" shortcut
673+
result_shortcut = trim_messages(
674+
messages,
675+
max_tokens=50,
676+
token_counter="approx",
677+
strategy="last",
678+
)
679+
680+
# Test using count_tokens_approximately directly
681+
result_direct = trim_messages(
682+
messages,
683+
max_tokens=50,
684+
token_counter=count_tokens_approximately,
685+
strategy="last",
686+
)
687+
688+
# Both should produce the same result
689+
assert result_shortcut == result_direct
690+
assert messages == messages_copy
691+
692+
693+
def test_trim_messages_token_counter_shortcut_invalid() -> None:
694+
"""Test that invalid token_counter shortcut raises ValueError."""
695+
messages = [
696+
SystemMessage("This is a test message"),
697+
HumanMessage("Another test message"),
698+
]
699+
700+
# Test with invalid shortcut
701+
with pytest.raises(ValueError, match="Invalid token_counter shortcut 'invalid'"):
702+
trim_messages(
703+
messages,
704+
max_tokens=50,
705+
token_counter="invalid",
706+
strategy="last",
707+
)
708+
709+
710+
def test_trim_messages_token_counter_shortcut_with_options() -> None:
711+
"""Test that 'approx' shortcut works with different trim options."""
712+
messages = [
713+
SystemMessage("System instructions"),
714+
HumanMessage("First human message", id="first"),
715+
AIMessage("First AI response", id="ai1"),
716+
HumanMessage("Second human message", id="second"),
717+
AIMessage("Second AI response", id="ai2"),
718+
]
719+
messages_copy = [m.model_copy(deep=True) for m in messages]
720+
721+
# Test with various options
722+
result = trim_messages(
723+
messages,
724+
max_tokens=100,
725+
token_counter="approx",
726+
strategy="last",
727+
include_system=True,
728+
start_on="human",
729+
)
730+
731+
# Should include system message and start on human
732+
assert len(result) >= 2
733+
assert isinstance(result[0], SystemMessage)
734+
assert any(isinstance(msg, HumanMessage) for msg in result[1:])
735+
assert messages == messages_copy
736+
737+
663738
class FakeTokenCountingModel(FakeChatModel):
664739
@override
665740
def get_num_tokens_from_messages(

0 commit comments

Comments
 (0)