@@ -688,6 +688,7 @@ def trim_messages(
688
688
* ,
689
689
max_tokens : int ,
690
690
token_counter : Union [
691
+ str ,
691
692
Callable [[list [BaseMessage ]], int ],
692
693
Callable [[BaseMessage ], int ],
693
694
BaseLanguageModel ,
@@ -738,11 +739,16 @@ def trim_messages(
738
739
BaseMessage. If a BaseLanguageModel is passed in then
739
740
BaseLanguageModel.get_num_tokens_from_messages() will be used.
740
741
Set to `len` to count the number of **messages** in the chat history.
742
+ You can also use string shortcuts for convenience:
743
+
744
+ - ``"approx"``: Uses `count_tokens_approximately` for fast, approximate
745
+ token counts.
741
746
742
747
.. note::
743
- Use `count_tokens_approximately` to get fast, approximate token counts.
744
- This is recommended for using `trim_messages` on the hot path, where
745
- exact token counting is not necessary.
748
+ Use `count_tokens_approximately` (or the shortcut ``"approx"``) to get
749
+ fast, approximate token counts. This is recommended for using
750
+ `trim_messages` on the hot path, where exact token counting is not
751
+ necessary.
746
752
747
753
strategy: Strategy for trimming.
748
754
@@ -849,6 +855,35 @@ def trim_messages(
849
855
HumanMessage(content="what do you call a speechless parrot"),
850
856
]
851
857
858
+ Trim chat history using approximate token counting with the "approx" shortcut:
859
+
860
+ .. code-block:: python
861
+
862
+ trim_messages(
863
+ messages,
864
+ max_tokens=45,
865
+ strategy="last",
866
+ # Using the "approx" shortcut for fast approximate token counting
867
+ token_counter="approx",
868
+ start_on="human",
869
+ include_system=True,
870
+ )
871
+
872
+ This is equivalent to using `count_tokens_approximately` directly:
873
+
874
+ .. code-block:: python
875
+
876
+ from langchain_core.messages.utils import count_tokens_approximately
877
+
878
+ trim_messages(
879
+ messages,
880
+ max_tokens=45,
881
+ strategy="last",
882
+ token_counter=count_tokens_approximately,
883
+ start_on="human",
884
+ include_system=True,
885
+ )
886
+
852
887
Trim chat history based on the message count, keeping the SystemMessage if
853
888
present, and ensuring that the chat history starts with a HumanMessage (
854
889
or a SystemMessage followed by a HumanMessage).
@@ -977,24 +1012,43 @@ def dummy_token_counter(messages: list[BaseMessage]) -> int:
977
1012
raise ValueError (msg )
978
1013
979
1014
messages = convert_to_messages (messages )
980
- if hasattr (token_counter , "get_num_tokens_from_messages" ):
981
- list_token_counter = token_counter .get_num_tokens_from_messages
982
- elif callable (token_counter ):
1015
+
1016
+ # Handle string shortcuts for token counter
1017
+ if isinstance (token_counter , str ):
1018
+ if token_counter in _TOKEN_COUNTER_SHORTCUTS :
1019
+ actual_token_counter = _TOKEN_COUNTER_SHORTCUTS [token_counter ]
1020
+ else :
1021
+ available_shortcuts = ", " .join (
1022
+ f"'{ key } '" for key in _TOKEN_COUNTER_SHORTCUTS
1023
+ )
1024
+ msg = (
1025
+ f"Invalid token_counter shortcut '{ token_counter } '. "
1026
+ f"Available shortcuts: { available_shortcuts } ."
1027
+ )
1028
+ raise ValueError (msg )
1029
+ else :
1030
+ actual_token_counter = token_counter
1031
+
1032
+ if hasattr (actual_token_counter , "get_num_tokens_from_messages" ):
1033
+ list_token_counter = actual_token_counter .get_num_tokens_from_messages # type: ignore[assignment]
1034
+ elif callable (actual_token_counter ):
983
1035
if (
984
- next (iter (inspect .signature (token_counter ).parameters .values ())).annotation
1036
+ next (
1037
+ iter (inspect .signature (actual_token_counter ).parameters .values ())
1038
+ ).annotation
985
1039
is BaseMessage
986
1040
):
987
1041
988
1042
def list_token_counter (messages : Sequence [BaseMessage ]) -> int :
989
- return sum (token_counter (msg ) for msg in messages ) # type: ignore[arg-type, misc]
1043
+ return sum (actual_token_counter (msg ) for msg in messages ) # type: ignore[arg-type, misc]
990
1044
991
1045
else :
992
- list_token_counter = token_counter
1046
+ list_token_counter = actual_token_counter # type: ignore[assignment]
993
1047
else :
994
1048
msg = (
995
1049
f"'token_counter' expected to be a model that implements "
996
1050
f"'get_num_tokens_from_messages()' or a function. Received object of type "
997
- f"{ type (token_counter )} ."
1051
+ f"{ type (actual_token_counter )} ."
998
1052
)
999
1053
raise ValueError (msg )
1000
1054
@@ -1754,3 +1808,14 @@ def count_tokens_approximately(
1754
1808
1755
1809
# round up once more time in case extra_tokens_per_message is a float
1756
1810
return math .ceil (token_count )
1811
+
1812
+
1813
+ # Mapping from string shortcuts to token counter functions
1814
+ def _approx_token_counter (messages : Sequence [BaseMessage ]) -> int :
1815
+ """Wrapper for count_tokens_approximately that matches expected signature."""
1816
+ return count_tokens_approximately (messages )
1817
+
1818
+
1819
+ _TOKEN_COUNTER_SHORTCUTS = {
1820
+ "approx" : _approx_token_counter ,
1821
+ }
0 commit comments