diff --git a/libs/core/langchain_core/messages/utils.py b/libs/core/langchain_core/messages/utils.py index 6f681d7a65ea1..181df5f0bf69a 100644 --- a/libs/core/langchain_core/messages/utils.py +++ b/libs/core/langchain_core/messages/utils.py @@ -688,6 +688,7 @@ def trim_messages( *, max_tokens: int, token_counter: Union[ + Literal["approximate"], Callable[[list[BaseMessage]], int], Callable[[BaseMessage], int], BaseLanguageModel, @@ -738,11 +739,16 @@ def trim_messages( BaseMessage. If a BaseLanguageModel is passed in then BaseLanguageModel.get_num_tokens_from_messages() will be used. Set to `len` to count the number of **messages** in the chat history. + You can also use string shortcuts for convenience: + + - ``"approximate"``: Uses `count_tokens_approximately` for fast, approximate + token counts. .. note:: - Use `count_tokens_approximately` to get fast, approximate token counts. - This is recommended for using `trim_messages` on the hot path, where - exact token counting is not necessary. + Use `count_tokens_approximately` (or the shortcut ``"approximate"``) to get + fast, approximate token counts. This is recommended for using + `trim_messages` on the hot path, where exact token counting is not + necessary. strategy: Strategy for trimming. @@ -849,6 +855,35 @@ def trim_messages( HumanMessage(content="what do you call a speechless parrot"), ] + Trim chat history using approximate token counting with the "approximate" shortcut: + + .. code-block:: python + + trim_messages( + messages, + max_tokens=45, + strategy="last", + # Using the "approximate" shortcut for fast approximate token counting + token_counter="approximate", + start_on="human", + include_system=True, + ) + + This is equivalent to using `count_tokens_approximately` directly: + + .. code-block:: python + + from langchain_core.messages.utils import count_tokens_approximately + + trim_messages( + messages, + max_tokens=45, + strategy="last", + token_counter=count_tokens_approximately, + start_on="human", + include_system=True, + ) + Trim chat history based on the message count, keeping the SystemMessage if present, and ensuring that the chat history starts with a HumanMessage ( or a SystemMessage followed by a HumanMessage). @@ -977,24 +1012,43 @@ def dummy_token_counter(messages: list[BaseMessage]) -> int: raise ValueError(msg) messages = convert_to_messages(messages) - if hasattr(token_counter, "get_num_tokens_from_messages"): - list_token_counter = token_counter.get_num_tokens_from_messages - elif callable(token_counter): + + # Handle string shortcuts for token counter + if isinstance(token_counter, str): + if token_counter in _TOKEN_COUNTER_SHORTCUTS: + actual_token_counter = _TOKEN_COUNTER_SHORTCUTS[token_counter] + else: + available_shortcuts = ", ".join( + f"'{key}'" for key in _TOKEN_COUNTER_SHORTCUTS + ) + msg = ( + f"Invalid token_counter shortcut '{token_counter}'. " + f"Available shortcuts: {available_shortcuts}." + ) + raise ValueError(msg) + else: + actual_token_counter = token_counter + + if hasattr(actual_token_counter, "get_num_tokens_from_messages"): + list_token_counter = actual_token_counter.get_num_tokens_from_messages # type: ignore[assignment] + elif callable(actual_token_counter): if ( - next(iter(inspect.signature(token_counter).parameters.values())).annotation + next( + iter(inspect.signature(actual_token_counter).parameters.values()) + ).annotation is BaseMessage ): def list_token_counter(messages: Sequence[BaseMessage]) -> int: - return sum(token_counter(msg) for msg in messages) # type: ignore[arg-type, misc] + return sum(actual_token_counter(msg) for msg in messages) # type: ignore[arg-type, misc] else: - list_token_counter = token_counter + list_token_counter = actual_token_counter # type: ignore[assignment] else: msg = ( f"'token_counter' expected to be a model that implements " f"'get_num_tokens_from_messages()' or a function. Received object of type " - f"{type(token_counter)}." + f"{type(actual_token_counter)}." ) raise ValueError(msg) @@ -1754,3 +1808,14 @@ def count_tokens_approximately( # round up once more time in case extra_tokens_per_message is a float return math.ceil(token_count) + + +# Mapping from string shortcuts to token counter functions +def _approximate_token_counter(messages: Sequence[BaseMessage]) -> int: + """Wrapper for count_tokens_approximately that matches expected signature.""" + return count_tokens_approximately(messages) + + +_TOKEN_COUNTER_SHORTCUTS = { + "approximate": _approximate_token_counter, +} diff --git a/libs/core/tests/unit_tests/messages/test_utils.py b/libs/core/tests/unit_tests/messages/test_utils.py index bedd518589ea0..ac6c3388fa37a 100644 --- a/libs/core/tests/unit_tests/messages/test_utils.py +++ b/libs/core/tests/unit_tests/messages/test_utils.py @@ -660,6 +660,81 @@ def test_trim_messages_start_on_with_allow_partial() -> None: assert messages == messages_copy +def test_trim_messages_token_counter_shortcut_approximate() -> None: + """Test that 'approximate' shortcut works for token_counter.""" + messages = [ + SystemMessage("This is a test message"), + HumanMessage("Another test message", id="first"), + AIMessage("AI response here", id="second"), + ] + messages_copy = [m.model_copy(deep=True) for m in messages] + + # Test using the "approximate" shortcut + result_shortcut = trim_messages( + messages, + max_tokens=50, + token_counter="approximate", + strategy="last", + ) + + # Test using count_tokens_approximately directly + result_direct = trim_messages( + messages, + max_tokens=50, + token_counter=count_tokens_approximately, + strategy="last", + ) + + # Both should produce the same result + assert result_shortcut == result_direct + assert messages == messages_copy + + +def test_trim_messages_token_counter_shortcut_invalid() -> None: + """Test that invalid token_counter shortcut raises ValueError.""" + messages = [ + SystemMessage("This is a test message"), + HumanMessage("Another test message"), + ] + + # Test with invalid shortcut + with pytest.raises(ValueError, match="Invalid token_counter shortcut 'invalid'"): + trim_messages( + messages, + max_tokens=50, + token_counter="invalid", + strategy="last", + ) + + +def test_trim_messages_token_counter_shortcut_with_options() -> None: + """Test that 'approximate' shortcut works with different trim options.""" + messages = [ + SystemMessage("System instructions"), + HumanMessage("First human message", id="first"), + AIMessage("First AI response", id="ai1"), + HumanMessage("Second human message", id="second"), + AIMessage("Second AI response", id="ai2"), + ] + messages_copy = [m.model_copy(deep=True) for m in messages] + + # Test with various options + result = trim_messages( + messages, + max_tokens=100, + token_counter="approximate", + strategy="last", + include_system=True, + start_on="human", + ) + + # Should include system message and start on human + assert len(result) >= 2 + assert isinstance(result[0], SystemMessage) + assert any(isinstance(msg, HumanMessage) for msg in result[1:]) + assert messages == messages_copy + + class FakeTokenCountingModel(FakeChatModel): @override def get_num_tokens_from_messages(