Skip to content

Commit 7e62e3a

Browse files
authored
core[patch]: store model names on usage callback handler (#30487)
So we avoid mingling tokens from different models.
1 parent 3282776 commit 7e62e3a

File tree

2 files changed

+60
-17
lines changed

2 files changed

+60
-17
lines changed

libs/core/langchain_core/callbacks/usage.py

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ class UsageMetadataCallbackHandler(BaseCallbackHandler):
3939
def __init__(self) -> None:
4040
super().__init__()
4141
self._lock = threading.Lock()
42-
self.usage_metadata: Optional[UsageMetadata] = None
42+
self.usage_metadata: dict[str, UsageMetadata] = {}
4343

4444
def __repr__(self) -> str:
4545
return str(self.usage_metadata)
@@ -51,21 +51,27 @@ def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
5151
generation = response.generations[0][0]
5252
except IndexError:
5353
generation = None
54+
55+
usage_metadata = None
56+
model_name = None
5457
if isinstance(generation, ChatGeneration):
5558
try:
5659
message = generation.message
5760
if isinstance(message, AIMessage):
5861
usage_metadata = message.usage_metadata
59-
else:
60-
usage_metadata = None
62+
model_name = message.response_metadata.get("model_name")
6163
except AttributeError:
62-
usage_metadata = None
63-
else:
64-
usage_metadata = None
64+
pass
6565

6666
# update shared state behind lock
67-
with self._lock:
68-
self.usage_metadata = add_usage(self.usage_metadata, usage_metadata)
67+
if usage_metadata and model_name:
68+
with self._lock:
69+
if model_name not in self.usage_metadata:
70+
self.usage_metadata[model_name] = usage_metadata
71+
else:
72+
self.usage_metadata[model_name] = add_usage(
73+
self.usage_metadata[model_name], usage_metadata
74+
)
6975

7076

7177
@contextmanager

libs/core/tests/unit_tests/callbacks/test_usage_callback.py

Lines changed: 46 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from itertools import cycle
1+
from typing import Any
22

33
from langchain_core.callbacks import (
44
UsageMetadataCallbackHandler,
@@ -12,6 +12,7 @@
1212
UsageMetadata,
1313
add_usage,
1414
)
15+
from langchain_core.outputs import ChatResult
1516

1617
usage1 = UsageMetadata(
1718
input_tokens=1,
@@ -45,41 +46,77 @@
4546
]
4647

4748

49+
class FakeChatModelWithResponseMetadata(GenericFakeChatModel):
50+
model_name: str
51+
52+
def _generate(self, *args: Any, **kwargs: Any) -> ChatResult:
53+
result = super()._generate(*args, **kwargs)
54+
result.generations[0].message.response_metadata = {
55+
"model_name": self.model_name
56+
}
57+
return result
58+
59+
4860
def test_usage_callback() -> None:
49-
llm = GenericFakeChatModel(messages=cycle(messages))
61+
llm = FakeChatModelWithResponseMetadata(
62+
messages=iter(messages), model_name="test_model"
63+
)
5064

5165
# Test context manager
5266
with get_usage_metadata_callback() as cb:
5367
_ = llm.invoke("Message 1")
5468
_ = llm.invoke("Message 2")
5569
total_1_2 = add_usage(usage1, usage2)
56-
assert cb.usage_metadata == total_1_2
70+
assert cb.usage_metadata == {"test_model": total_1_2}
5771
_ = llm.invoke("Message 3")
5872
_ = llm.invoke("Message 4")
5973
total_3_4 = add_usage(usage3, usage4)
60-
assert cb.usage_metadata == add_usage(total_1_2, total_3_4)
74+
assert cb.usage_metadata == {"test_model": add_usage(total_1_2, total_3_4)}
6175

6276
# Test via config
77+
llm = FakeChatModelWithResponseMetadata(
78+
messages=iter(messages[:2]), model_name="test_model"
79+
)
6380
callback = UsageMetadataCallbackHandler()
6481
_ = llm.batch(["Message 1", "Message 2"], config={"callbacks": [callback]})
65-
assert callback.usage_metadata == total_1_2
82+
assert callback.usage_metadata == {"test_model": total_1_2}
83+
84+
# Test multiple models
85+
llm_1 = FakeChatModelWithResponseMetadata(
86+
messages=iter(messages[:2]), model_name="test_model_1"
87+
)
88+
llm_2 = FakeChatModelWithResponseMetadata(
89+
messages=iter(messages[2:4]), model_name="test_model_2"
90+
)
91+
callback = UsageMetadataCallbackHandler()
92+
_ = llm_1.batch(["Message 1", "Message 2"], config={"callbacks": [callback]})
93+
_ = llm_2.batch(["Message 3", "Message 4"], config={"callbacks": [callback]})
94+
assert callback.usage_metadata == {
95+
"test_model_1": total_1_2,
96+
"test_model_2": total_3_4,
97+
}
6698

6799

68100
async def test_usage_callback_async() -> None:
69-
llm = GenericFakeChatModel(messages=cycle(messages))
101+
llm = FakeChatModelWithResponseMetadata(
102+
messages=iter(messages), model_name="test_model"
103+
)
70104

71105
# Test context manager
72106
with get_usage_metadata_callback() as cb:
73107
_ = await llm.ainvoke("Message 1")
74108
_ = await llm.ainvoke("Message 2")
75109
total_1_2 = add_usage(usage1, usage2)
76-
assert cb.usage_metadata == total_1_2
110+
assert cb.usage_metadata == {"test_model": total_1_2}
77111
_ = await llm.ainvoke("Message 3")
78112
_ = await llm.ainvoke("Message 4")
79113
total_3_4 = add_usage(usage3, usage4)
80-
assert cb.usage_metadata == add_usage(total_1_2, total_3_4)
114+
assert cb.usage_metadata == {"test_model": add_usage(total_1_2, total_3_4)}
81115

82116
# Test via config
117+
llm = FakeChatModelWithResponseMetadata(
118+
messages=iter(messages[:2]), model_name="test_model"
119+
)
83120
callback = UsageMetadataCallbackHandler()
84121
_ = await llm.abatch(["Message 1", "Message 2"], config={"callbacks": [callback]})
85-
assert callback.usage_metadata == total_1_2
122+
assert callback.usage_metadata == {"test_model": total_1_2}

0 commit comments

Comments
 (0)