Skip to content

Commit 4b9e2e5

Browse files
authored
core[patch]: add token counting callback handler (#30481)
Stripped-down version of [OpenAICallbackHandler](https://github.com/langchain-ai/langchain/blob/master/libs/community/langchain_community/callbacks/openai_info.py) that just tracks `AIMessage.usage_metadata`. ```python from langchain_core.callbacks import get_usage_metadata_callback from langgraph.prebuilt import create_react_agent def get_weather(location: str) -> str: """Get the weather at a location.""" return "It's sunny." tools = [get_weather] agent = create_react_agent("openai:gpt-4o-mini", tools) with get_usage_metadata_callback() as cb: result = await agent.ainvoke({"messages": "What's the weather in Boston?"}) print(cb.usage_metadata) ```
1 parent 1d2b1d8 commit 4b9e2e5

File tree

4 files changed

+201
-0
lines changed

4 files changed

+201
-0
lines changed

libs/core/langchain_core/callbacks/__init__.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,10 @@
4343
)
4444
from langchain_core.callbacks.stdout import StdOutCallbackHandler
4545
from langchain_core.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
46+
from langchain_core.callbacks.usage import (
47+
UsageMetadataCallbackHandler,
48+
get_usage_metadata_callback,
49+
)
4650

4751
__all__ = [
4852
"dispatch_custom_event",
@@ -77,4 +81,6 @@
7781
"StdOutCallbackHandler",
7882
"StreamingStdOutCallbackHandler",
7983
"FileCallbackHandler",
84+
"UsageMetadataCallbackHandler",
85+
"get_usage_metadata_callback",
8086
]
Lines changed: 108 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,108 @@
1+
"""Callback Handler that tracks AIMessage.usage_metadata."""
2+
3+
import threading
4+
from collections.abc import Generator
5+
from contextlib import contextmanager
6+
from contextvars import ContextVar
7+
from typing import Any, Optional
8+
9+
from langchain_core.callbacks import BaseCallbackHandler
10+
from langchain_core.messages import AIMessage
11+
from langchain_core.messages.ai import UsageMetadata, add_usage
12+
from langchain_core.outputs import ChatGeneration, LLMResult
13+
14+
15+
class UsageMetadataCallbackHandler(BaseCallbackHandler):
16+
"""Callback Handler that tracks AIMessage.usage_metadata.
17+
18+
Example:
19+
.. code-block:: python
20+
21+
from langchain.chat_models import init_chat_model
22+
from langchain_core.callbacks import UsageMetadataCallbackHandler
23+
24+
llm = init_chat_model(model="openai:gpt-4o-mini")
25+
26+
callback = UsageMetadataCallbackHandler()
27+
results = llm.batch(["Hello", "Goodbye"], config={"callbacks": [callback]})
28+
print(callback.usage_metadata)
29+
30+
.. code-block:: none
31+
32+
{'output_token_details': {'audio': 0, 'reasoning': 0}, 'input_tokens': 17, 'output_tokens': 31, 'total_tokens': 48, 'input_token_details': {'cache_read': 0, 'audio': 0}}
33+
34+
.. versionadded:: 0.3.49
35+
""" # noqa: E501
36+
37+
def __init__(self) -> None:
38+
super().__init__()
39+
self._lock = threading.Lock()
40+
self.usage_metadata: Optional[UsageMetadata] = None
41+
42+
def __repr__(self) -> str:
43+
return str(self.usage_metadata)
44+
45+
def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
46+
"""Collect token usage."""
47+
# Check for usage_metadata (langchain-core >= 0.2.2)
48+
try:
49+
generation = response.generations[0][0]
50+
except IndexError:
51+
generation = None
52+
if isinstance(generation, ChatGeneration):
53+
try:
54+
message = generation.message
55+
if isinstance(message, AIMessage):
56+
usage_metadata = message.usage_metadata
57+
else:
58+
usage_metadata = None
59+
except AttributeError:
60+
usage_metadata = None
61+
else:
62+
usage_metadata = None
63+
64+
# update shared state behind lock
65+
with self._lock:
66+
self.usage_metadata = add_usage(self.usage_metadata, usage_metadata)
67+
68+
69+
@contextmanager
70+
def get_usage_metadata_callback(
71+
name: str = "usage_metadata_callback",
72+
) -> Generator[UsageMetadataCallbackHandler, None, None]:
73+
"""Get context manager for tracking usage metadata across chat model calls using
74+
``AIMessage.usage_metadata``.
75+
76+
Args:
77+
name (str): The name of the context variable. Defaults to
78+
``"usage_metadata_callback"``.
79+
80+
Example:
81+
.. code-block:: python
82+
83+
from langchain.chat_models import init_chat_model
84+
from langchain_core.callbacks import get_usage_metadata_callback
85+
86+
llm = init_chat_model(model="openai:gpt-4o-mini")
87+
88+
with get_usage_metadata_callback() as cb:
89+
llm.invoke("Hello")
90+
llm.invoke("Goodbye")
91+
print(cb.usage_metadata)
92+
93+
.. code-block:: none
94+
95+
{'output_token_details': {'audio': 0, 'reasoning': 0}, 'input_tokens': 17, 'output_tokens': 31, 'total_tokens': 48, 'input_token_details': {'cache_read': 0, 'audio': 0}}
96+
97+
.. versionadded:: 0.3.49
98+
""" # noqa: E501
99+
from langchain_core.tracers.context import register_configure_hook
100+
101+
usage_metadata_callback_var: ContextVar[Optional[UsageMetadataCallbackHandler]] = (
102+
ContextVar(name, default=None)
103+
)
104+
register_configure_hook(usage_metadata_callback_var, True)
105+
cb = UsageMetadataCallbackHandler()
106+
usage_metadata_callback_var.set(cb)
107+
yield cb
108+
usage_metadata_callback_var.set(None)

libs/core/tests/unit_tests/callbacks/test_imports.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,8 @@
3333
"FileCallbackHandler",
3434
"adispatch_custom_event",
3535
"dispatch_custom_event",
36+
"UsageMetadataCallbackHandler",
37+
"get_usage_metadata_callback",
3638
]
3739

3840

Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,85 @@
1+
from itertools import cycle
2+
3+
from langchain_core.callbacks import (
4+
UsageMetadataCallbackHandler,
5+
get_usage_metadata_callback,
6+
)
7+
from langchain_core.language_models import GenericFakeChatModel
8+
from langchain_core.messages import AIMessage
9+
from langchain_core.messages.ai import (
10+
InputTokenDetails,
11+
OutputTokenDetails,
12+
UsageMetadata,
13+
add_usage,
14+
)
15+
16+
usage1 = UsageMetadata(
17+
input_tokens=1,
18+
output_tokens=2,
19+
total_tokens=3,
20+
)
21+
usage2 = UsageMetadata(
22+
input_tokens=4,
23+
output_tokens=5,
24+
total_tokens=9,
25+
)
26+
usage3 = UsageMetadata(
27+
input_tokens=10,
28+
output_tokens=20,
29+
total_tokens=30,
30+
input_token_details=InputTokenDetails(audio=5),
31+
output_token_details=OutputTokenDetails(reasoning=10),
32+
)
33+
usage4 = UsageMetadata(
34+
input_tokens=5,
35+
output_tokens=10,
36+
total_tokens=15,
37+
input_token_details=InputTokenDetails(audio=3),
38+
output_token_details=OutputTokenDetails(reasoning=5),
39+
)
40+
messages = [
41+
AIMessage("Response 1", usage_metadata=usage1),
42+
AIMessage("Response 2", usage_metadata=usage2),
43+
AIMessage("Response 3", usage_metadata=usage3),
44+
AIMessage("Response 4", usage_metadata=usage4),
45+
]
46+
47+
48+
def test_usage_callback() -> None:
49+
llm = GenericFakeChatModel(messages=cycle(messages))
50+
51+
# Test context manager
52+
with get_usage_metadata_callback() as cb:
53+
_ = llm.invoke("Message 1")
54+
_ = llm.invoke("Message 2")
55+
total_1_2 = add_usage(usage1, usage2)
56+
assert cb.usage_metadata == total_1_2
57+
_ = llm.invoke("Message 3")
58+
_ = llm.invoke("Message 4")
59+
total_3_4 = add_usage(usage3, usage4)
60+
assert cb.usage_metadata == add_usage(total_1_2, total_3_4)
61+
62+
# Test via config
63+
callback = UsageMetadataCallbackHandler()
64+
_ = llm.batch(["Message 1", "Message 2"], config={"callbacks": [callback]})
65+
assert callback.usage_metadata == total_1_2
66+
67+
68+
async def test_usage_callback_async() -> None:
69+
llm = GenericFakeChatModel(messages=cycle(messages))
70+
71+
# Test context manager
72+
with get_usage_metadata_callback() as cb:
73+
_ = await llm.ainvoke("Message 1")
74+
_ = await llm.ainvoke("Message 2")
75+
total_1_2 = add_usage(usage1, usage2)
76+
assert cb.usage_metadata == total_1_2
77+
_ = await llm.ainvoke("Message 3")
78+
_ = await llm.ainvoke("Message 4")
79+
total_3_4 = add_usage(usage3, usage4)
80+
assert cb.usage_metadata == add_usage(total_1_2, total_3_4)
81+
82+
# Test via config
83+
callback = UsageMetadataCallbackHandler()
84+
_ = await llm.abatch(["Message 1", "Message 2"], config={"callbacks": [callback]})
85+
assert callback.usage_metadata == total_1_2

0 commit comments

Comments
 (0)