Skip to content

Commit 22d1a7d

Browse files
authored
standard-tests[patch]: require model_name in response_metadata if returns_usage_metadata (#30497)
We are implementing a token-counting callback handler in `langchain-core` that is intended to work with all chat models supporting usage metadata. The callback will aggregate usage metadata by model. This requires responses to include the model name in its metadata. To support this, if a model `returns_usage_metadata`, we check that it includes a string model name in its `response_metadata` in the `"model_name"` key. More context: #30487
1 parent 20f8250 commit 22d1a7d

File tree

9 files changed

+75
-12
lines changed

9 files changed

+75
-12
lines changed

docs/docs/how_to/custom_chat_model.ipynb

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -247,6 +247,7 @@
247247
" additional_kwargs={}, # Used to add additional payload to the message\n",
248248
" response_metadata={ # Use for response metadata\n",
249249
" \"time_in_seconds\": 3,\n",
250+
" \"model_name\": self.model_name,\n",
250251
" },\n",
251252
" usage_metadata={\n",
252253
" \"input_tokens\": ct_input_tokens,\n",
@@ -309,7 +310,10 @@
309310
"\n",
310311
" # Let's add some other information (e.g., response metadata)\n",
311312
" chunk = ChatGenerationChunk(\n",
312-
" message=AIMessageChunk(content=\"\", response_metadata={\"time_in_sec\": 3})\n",
313+
" message=AIMessageChunk(\n",
314+
" content=\"\",\n",
315+
" response_metadata={\"time_in_sec\": 3, \"model_name\": self.model_name},\n",
316+
" )\n",
313317
" )\n",
314318
" if run_manager:\n",
315319
" # This is optional in newer versions of LangChain\n",

libs/cli/langchain_cli/integration_template/integration_template/chat_models.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -329,6 +329,7 @@ def _generate(
329329
additional_kwargs={}, # Used to add additional payload to the message
330330
response_metadata={ # Use for response metadata
331331
"time_in_seconds": 3,
332+
"model_name": self.model_name,
332333
},
333334
usage_metadata={
334335
"input_tokens": ct_input_tokens,
@@ -391,7 +392,10 @@ def _stream(
391392

392393
# Let's add some other information (e.g., response metadata)
393394
chunk = ChatGenerationChunk(
394-
message=AIMessageChunk(content="", response_metadata={"time_in_sec": 3})
395+
message=AIMessageChunk(
396+
content="",
397+
response_metadata={"time_in_sec": 3, "model_name": self.model_name},
398+
)
395399
)
396400
if run_manager:
397401
# This is optional in newer versions of LangChain

libs/partners/fireworks/langchain_fireworks/chat_models.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -471,6 +471,7 @@ def _stream(
471471
generation_info = {}
472472
if finish_reason := choice.get("finish_reason"):
473473
generation_info["finish_reason"] = finish_reason
474+
generation_info["model_name"] = self.model_name
474475
logprobs = choice.get("logprobs")
475476
if logprobs:
476477
generation_info["logprobs"] = logprobs
@@ -565,6 +566,7 @@ async def _astream(
565566
generation_info = {}
566567
if finish_reason := choice.get("finish_reason"):
567568
generation_info["finish_reason"] = finish_reason
569+
generation_info["model_name"] = self.model_name
568570
logprobs = choice.get("logprobs")
569571
if logprobs:
570572
generation_info["logprobs"] = logprobs

libs/partners/fireworks/tests/integration_tests/test_chat_models.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -98,16 +98,19 @@ async def test_astream() -> None:
9898

9999
full: Optional[BaseMessageChunk] = None
100100
chunks_with_token_counts = 0
101+
chunks_with_response_metadata = 0
101102
async for token in llm.astream("I'm Pickle Rick"):
102103
assert isinstance(token, AIMessageChunk)
103104
assert isinstance(token.content, str)
104105
full = token if full is None else full + token
105106
if token.usage_metadata is not None:
106107
chunks_with_token_counts += 1
107-
if chunks_with_token_counts != 1:
108+
if token.response_metadata:
109+
chunks_with_response_metadata += 1
110+
if chunks_with_token_counts != 1 or chunks_with_response_metadata != 1:
108111
raise AssertionError(
109-
"Expected exactly one chunk with token counts. "
110-
"AIMessageChunk aggregation adds counts. Check that "
112+
"Expected exactly one chunk with token counts or response_metadata. "
113+
"AIMessageChunk aggregation adds / appends counts and metadata. Check that "
111114
"this is behaving properly."
112115
)
113116
assert isinstance(full, AIMessageChunk)
@@ -118,6 +121,8 @@ async def test_astream() -> None:
118121
full.usage_metadata["input_tokens"] + full.usage_metadata["output_tokens"]
119122
== full.usage_metadata["total_tokens"]
120123
)
124+
assert isinstance(full.response_metadata["model_name"], str)
125+
assert full.response_metadata["model_name"]
121126

122127

123128
async def test_abatch() -> None:

libs/partners/mistralai/langchain_mistralai/chat_models.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -236,13 +236,15 @@ async def _completion_with_retry(**kwargs: Any) -> Any:
236236
def _convert_chunk_to_message_chunk(
237237
chunk: Dict, default_class: Type[BaseMessageChunk]
238238
) -> BaseMessageChunk:
239-
_delta = chunk["choices"][0]["delta"]
239+
_choice = chunk["choices"][0]
240+
_delta = _choice["delta"]
240241
role = _delta.get("role")
241242
content = _delta.get("content") or ""
242243
if role == "user" or default_class == HumanMessageChunk:
243244
return HumanMessageChunk(content=content)
244245
elif role == "assistant" or default_class == AIMessageChunk:
245246
additional_kwargs: Dict = {}
247+
response_metadata = {}
246248
if raw_tool_calls := _delta.get("tool_calls"):
247249
additional_kwargs["tool_calls"] = raw_tool_calls
248250
try:
@@ -272,11 +274,14 @@ def _convert_chunk_to_message_chunk(
272274
}
273275
else:
274276
usage_metadata = None
277+
if _choice.get("finish_reason") is not None:
278+
response_metadata["model_name"] = chunk.get("model")
275279
return AIMessageChunk(
276280
content=content,
277281
additional_kwargs=additional_kwargs,
278282
tool_call_chunks=tool_call_chunks, # type: ignore[arg-type]
279283
usage_metadata=usage_metadata, # type: ignore[arg-type]
284+
response_metadata=response_metadata,
280285
)
281286
elif role == "system" or default_class == SystemMessageChunk:
282287
return SystemMessageChunk(content=content)

libs/partners/mistralai/tests/integration_tests/test_chat_models.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ def test_stream() -> None:
2020
"""Test streaming tokens from ChatMistralAI."""
2121
llm = ChatMistralAI()
2222

23-
for token in llm.stream("I'm Pickle Rick"):
23+
for token in llm.stream("Hello"):
2424
assert isinstance(token.content, str)
2525

2626

@@ -30,16 +30,19 @@ async def test_astream() -> None:
3030

3131
full: Optional[BaseMessageChunk] = None
3232
chunks_with_token_counts = 0
33-
async for token in llm.astream("I'm Pickle Rick"):
33+
chunks_with_response_metadata = 0
34+
async for token in llm.astream("Hello"):
3435
assert isinstance(token, AIMessageChunk)
3536
assert isinstance(token.content, str)
3637
full = token if full is None else full + token
3738
if token.usage_metadata is not None:
3839
chunks_with_token_counts += 1
39-
if chunks_with_token_counts != 1:
40+
if token.response_metadata:
41+
chunks_with_response_metadata += 1
42+
if chunks_with_token_counts != 1 or chunks_with_response_metadata != 1:
4043
raise AssertionError(
41-
"Expected exactly one chunk with token counts. "
42-
"AIMessageChunk aggregation adds counts. Check that "
44+
"Expected exactly one chunk with token counts or response_metadata. "
45+
"AIMessageChunk aggregation adds / appends counts and metadata. Check that "
4346
"this is behaving properly."
4447
)
4548
assert isinstance(full, AIMessageChunk)
@@ -50,6 +53,8 @@ async def test_astream() -> None:
5053
full.usage_metadata["input_tokens"] + full.usage_metadata["output_tokens"]
5154
== full.usage_metadata["total_tokens"]
5255
)
56+
assert isinstance(full.response_metadata["model_name"], str)
57+
assert full.response_metadata["model_name"]
5358

5459

5560
async def test_abatch() -> None:

libs/standard-tests/langchain_tests/integration_tests/chat_models.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -337,6 +337,9 @@ def supports_image_inputs(self) -> bool:
337337
def returns_usage_metadata(self) -> bool:
338338
return False
339339
340+
Models supporting ``usage_metadata`` should also return the name of the
341+
underlying model in the ``response_metadata`` of the AIMessage.
342+
340343
.. dropdown:: supports_anthropic_inputs
341344
342345
Boolean property indicating whether the chat model supports Anthropic-style
@@ -669,6 +672,11 @@ def test_usage_metadata(self, model: BaseChatModel) -> None:
669672
This test is optional and should be skipped if the model does not return
670673
usage metadata (see Configuration below).
671674
675+
.. versionchanged:: 0.3.17
676+
677+
Additionally check for the presence of `model_name` in the response
678+
metadata, which is needed for usage tracking in callback handlers.
679+
672680
.. dropdown:: Configuration
673681
674682
By default, this test is run.
@@ -739,6 +747,9 @@ def supported_usage_metadata_details(self) -> dict:
739747
)
740748
)]
741749
)
750+
751+
Check also that the response includes a ``"model_name"`` key in its
752+
``usage_metadata``.
742753
"""
743754
if not self.returns_usage_metadata:
744755
pytest.skip("Not implemented.")
@@ -750,6 +761,12 @@ def supported_usage_metadata_details(self) -> dict:
750761
assert isinstance(result.usage_metadata["output_tokens"], int)
751762
assert isinstance(result.usage_metadata["total_tokens"], int)
752763

764+
# Check model_name is in response_metadata
765+
# Needed for langchain_core.callbacks.usage
766+
model_name = result.response_metadata.get("model_name")
767+
assert isinstance(model_name, str)
768+
assert model_name
769+
753770
if "audio_input" in self.supported_usage_metadata_details["invoke"]:
754771
msg = self.invoke_with_audio_input()
755772
assert msg.usage_metadata is not None
@@ -809,6 +826,11 @@ def test_usage_metadata_streaming(self, model: BaseChatModel) -> None:
809826
"""
810827
Test to verify that the model returns correct usage metadata in streaming mode.
811828
829+
.. versionchanged:: 0.3.17
830+
831+
Additionally check for the presence of `model_name` in the response
832+
metadata, which is needed for usage tracking in callback handlers.
833+
812834
.. dropdown:: Configuration
813835
814836
By default, this test is run.
@@ -891,6 +913,9 @@ def supported_usage_metadata_details(self) -> dict:
891913
)
892914
)]
893915
)
916+
917+
Check also that the aggregated response includes a ``"model_name"`` key
918+
in its ``usage_metadata``.
894919
"""
895920
if not self.returns_usage_metadata:
896921
pytest.skip("Not implemented.")
@@ -915,6 +940,12 @@ def supported_usage_metadata_details(self) -> dict:
915940
assert isinstance(full.usage_metadata["output_tokens"], int)
916941
assert isinstance(full.usage_metadata["total_tokens"], int)
917942

943+
# Check model_name is in response_metadata
944+
# Needed for langchain_core.callbacks.usage
945+
model_name = full.response_metadata.get("model_name")
946+
assert isinstance(model_name, str)
947+
assert model_name
948+
918949
if "audio_input" in self.supported_usage_metadata_details["stream"]:
919950
msg = self.invoke_with_audio_input(stream=True)
920951
assert isinstance(msg.usage_metadata["input_token_details"]["audio"], int) # type: ignore[index]

libs/standard-tests/langchain_tests/unit_tests/chat_models.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -412,6 +412,9 @@ def supports_image_inputs(self) -> bool:
412412
def returns_usage_metadata(self) -> bool:
413413
return False
414414
415+
Models supporting ``usage_metadata`` should also return the name of the
416+
underlying model in the ``response_metadata`` of the AIMessage.
417+
415418
.. dropdown:: supports_anthropic_inputs
416419
417420
Boolean property indicating whether the chat model supports Anthropic-style

libs/standard-tests/tests/unit_tests/custom_chat_model.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,7 @@ def _generate(
7676
additional_kwargs={}, # Used to add additional payload to the message
7777
response_metadata={ # Use for response metadata
7878
"time_in_seconds": 3,
79+
"model_name": self.model_name,
7980
},
8081
usage_metadata={
8182
"input_tokens": ct_input_tokens,
@@ -138,7 +139,10 @@ def _stream(
138139

139140
# Let's add some other information (e.g., response metadata)
140141
chunk = ChatGenerationChunk(
141-
message=AIMessageChunk(content="", response_metadata={"time_in_sec": 3})
142+
message=AIMessageChunk(
143+
content="",
144+
response_metadata={"time_in_sec": 3, "model_name": self.model_name},
145+
)
142146
)
143147
if run_manager:
144148
# This is optional in newer versions of LangChain

0 commit comments

Comments
 (0)