Skip to content

Commit 7404338

Browse files
ccurmemdrxy
andauthored
fix(core): fix string content when streaming output_version="v1" (#33261)
Co-authored-by: Mason Daugherty <[email protected]>
1 parent f308139 commit 7404338

File tree

2 files changed

+201
-4
lines changed

2 files changed

+201
-4
lines changed

libs/core/langchain_core/language_models/chat_models.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@
4343
is_data_content_block,
4444
message_chunk_to_message,
4545
)
46+
from langchain_core.messages import content as types
4647
from langchain_core.messages.block_translators.openai import (
4748
convert_to_openai_image_block,
4849
)
@@ -533,6 +534,8 @@ def stream(
533534
input_messages = _normalize_messages(messages)
534535
run_id = "-".join((LC_ID_PREFIX, str(run_manager.run_id)))
535536
yielded = False
537+
index = -1
538+
index_type = ""
536539
for chunk in self._stream(input_messages, stop=stop, **kwargs):
537540
if chunk.message.id is None:
538541
chunk.message.id = run_id
@@ -542,6 +545,14 @@ def stream(
542545
chunk.message = _update_message_content_to_blocks(
543546
chunk.message, "v1"
544547
)
548+
for block in cast(
549+
"list[types.ContentBlock]", chunk.message.content
550+
):
551+
if block["type"] != index_type:
552+
index_type = block["type"]
553+
index = index + 1
554+
if "index" not in block:
555+
block["index"] = index
545556
run_manager.on_llm_new_token(
546557
cast("str", chunk.message.content), chunk=chunk
547558
)
@@ -651,6 +662,8 @@ async def astream(
651662
input_messages = _normalize_messages(messages)
652663
run_id = "-".join((LC_ID_PREFIX, str(run_manager.run_id)))
653664
yielded = False
665+
index = -1
666+
index_type = ""
654667
async for chunk in self._astream(
655668
input_messages,
656669
stop=stop,
@@ -664,6 +677,14 @@ async def astream(
664677
chunk.message = _update_message_content_to_blocks(
665678
chunk.message, "v1"
666679
)
680+
for block in cast(
681+
"list[types.ContentBlock]", chunk.message.content
682+
):
683+
if block["type"] != index_type:
684+
index_type = block["type"]
685+
index = index + 1
686+
if "index" not in block:
687+
block["index"] = index
667688
await run_manager.on_llm_new_token(
668689
cast("str", chunk.message.content), chunk=chunk
669690
)
@@ -1145,13 +1166,23 @@ def _generate_with_cache(
11451166
f"{LC_ID_PREFIX}-{run_manager.run_id}" if run_manager else None
11461167
)
11471168
yielded = False
1169+
index = -1
1170+
index_type = ""
11481171
for chunk in self._stream(messages, stop=stop, **kwargs):
11491172
chunk.message.response_metadata = _gen_info_and_msg_metadata(chunk)
11501173
if self.output_version == "v1":
11511174
# Overwrite .content with .content_blocks
11521175
chunk.message = _update_message_content_to_blocks(
11531176
chunk.message, "v1"
11541177
)
1178+
for block in cast(
1179+
"list[types.ContentBlock]", chunk.message.content
1180+
):
1181+
if block["type"] != index_type:
1182+
index_type = block["type"]
1183+
index = index + 1
1184+
if "index" not in block:
1185+
block["index"] = index
11551186
if run_manager:
11561187
if chunk.message.id is None:
11571188
chunk.message.id = run_id
@@ -1253,13 +1284,23 @@ async def _agenerate_with_cache(
12531284
f"{LC_ID_PREFIX}-{run_manager.run_id}" if run_manager else None
12541285
)
12551286
yielded = False
1287+
index = -1
1288+
index_type = ""
12561289
async for chunk in self._astream(messages, stop=stop, **kwargs):
12571290
chunk.message.response_metadata = _gen_info_and_msg_metadata(chunk)
12581291
if self.output_version == "v1":
12591292
# Overwrite .content with .content_blocks
12601293
chunk.message = _update_message_content_to_blocks(
12611294
chunk.message, "v1"
12621295
)
1296+
for block in cast(
1297+
"list[types.ContentBlock]", chunk.message.content
1298+
):
1299+
if block["type"] != index_type:
1300+
index_type = block["type"]
1301+
index = index + 1
1302+
if "index" not in block:
1303+
block["index"] = index
12631304
if run_manager:
12641305
if chunk.message.id is None:
12651306
chunk.message.id = run_id

libs/core/tests/unit_tests/language_models/chat_models/test_base.py

Lines changed: 160 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,10 @@
88
import pytest
99
from typing_extensions import override
1010

11-
from langchain_core.callbacks import CallbackManagerForLLMRun
11+
from langchain_core.callbacks import (
12+
AsyncCallbackManagerForLLMRun,
13+
CallbackManagerForLLMRun,
14+
)
1215
from langchain_core.language_models import (
1316
BaseChatModel,
1417
FakeListChatModel,
@@ -23,7 +26,6 @@
2326
AIMessage,
2427
AIMessageChunk,
2528
BaseMessage,
26-
BaseMessageChunk,
2729
HumanMessage,
2830
SystemMessage,
2931
)
@@ -907,6 +909,56 @@ async def test_output_version_ainvoke(monkeypatch: Any) -> None:
907909
assert response.response_metadata["output_version"] == "v1"
908910

909911

912+
class _AnotherFakeChatModel(BaseChatModel):
913+
responses: Iterator[AIMessage]
914+
"""Responses for _generate."""
915+
916+
chunks: Iterator[AIMessageChunk]
917+
"""Responses for _stream."""
918+
919+
@property
920+
def _llm_type(self) -> str:
921+
return "another-fake-chat-model"
922+
923+
def _generate(
924+
self,
925+
messages: list[BaseMessage], # noqa: ARG002
926+
stop: list[str] | None = None, # noqa: ARG002
927+
run_manager: CallbackManagerForLLMRun | None = None, # noqa: ARG002
928+
**kwargs: Any, # noqa: ARG002
929+
) -> ChatResult:
930+
return ChatResult(generations=[ChatGeneration(message=next(self.responses))])
931+
932+
async def _agenerate(
933+
self,
934+
messages: list[BaseMessage], # noqa: ARG002
935+
stop: list[str] | None = None, # noqa: ARG002
936+
run_manager: AsyncCallbackManagerForLLMRun | None = None, # noqa: ARG002
937+
**kwargs: Any, # noqa: ARG002
938+
) -> ChatResult:
939+
return ChatResult(generations=[ChatGeneration(message=next(self.responses))])
940+
941+
def _stream(
942+
self,
943+
messages: list[BaseMessage], # noqa: ARG002
944+
stop: list[str] | None = None, # noqa: ARG002
945+
run_manager: CallbackManagerForLLMRun | None = None, # noqa: ARG002
946+
**kwargs: Any, # noqa: ARG002
947+
) -> Iterator[ChatGenerationChunk]:
948+
for chunk in self.chunks:
949+
yield ChatGenerationChunk(message=chunk)
950+
951+
async def _astream(
952+
self,
953+
messages: list[BaseMessage], # noqa: ARG002
954+
stop: list[str] | None = None, # noqa: ARG002
955+
run_manager: AsyncCallbackManagerForLLMRun | None = None, # noqa: ARG002
956+
**kwargs: Any, # noqa: ARG002
957+
) -> AsyncIterator[ChatGenerationChunk]:
958+
for chunk in self.chunks:
959+
yield ChatGenerationChunk(message=chunk)
960+
961+
910962
def test_output_version_stream(monkeypatch: Any) -> None:
911963
messages = [AIMessage("foo bar")]
912964

@@ -923,7 +975,7 @@ def test_output_version_stream(monkeypatch: Any) -> None:
923975

924976
# v1
925977
llm = GenericFakeChatModel(messages=iter(messages), output_version="v1")
926-
full_v1: BaseMessageChunk | None = None
978+
full_v1: AIMessageChunk | None = None
927979
for chunk in llm.stream("hello"):
928980
assert isinstance(chunk, AIMessageChunk)
929981
assert isinstance(chunk.content, list)
@@ -936,6 +988,58 @@ def test_output_version_stream(monkeypatch: Any) -> None:
936988
assert isinstance(full_v1, AIMessageChunk)
937989
assert full_v1.response_metadata["output_version"] == "v1"
938990

991+
assert full_v1.content == [{"type": "text", "text": "foo bar", "index": 0}]
992+
993+
# Test text blocks
994+
llm_with_rich_content = _AnotherFakeChatModel(
995+
responses=iter([]),
996+
chunks=iter(
997+
[
998+
AIMessageChunk(content="foo "),
999+
AIMessageChunk(content="bar"),
1000+
]
1001+
),
1002+
output_version="v1",
1003+
)
1004+
full_v1 = None
1005+
for chunk in llm_with_rich_content.stream("hello"):
1006+
full_v1 = chunk if full_v1 is None else full_v1 + chunk
1007+
assert isinstance(full_v1, AIMessageChunk)
1008+
assert full_v1.content_blocks == [{"type": "text", "text": "foo bar", "index": 0}]
1009+
1010+
# Test content blocks of different types
1011+
chunks = [
1012+
AIMessageChunk(content="", additional_kwargs={"reasoning_content": "<rea"}),
1013+
AIMessageChunk(content="", additional_kwargs={"reasoning_content": "soning>"}),
1014+
AIMessageChunk(content="<some "),
1015+
AIMessageChunk(content="text>"),
1016+
]
1017+
llm_with_rich_content = _AnotherFakeChatModel(
1018+
responses=iter([]),
1019+
chunks=iter(chunks),
1020+
output_version="v1",
1021+
)
1022+
full_v1 = None
1023+
for chunk in llm_with_rich_content.stream("hello"):
1024+
full_v1 = chunk if full_v1 is None else full_v1 + chunk
1025+
assert isinstance(full_v1, AIMessageChunk)
1026+
assert full_v1.content_blocks == [
1027+
{"type": "reasoning", "reasoning": "<reasoning>", "index": 0},
1028+
{"type": "text", "text": "<some text>", "index": 1},
1029+
]
1030+
1031+
# Test invoke with stream=True
1032+
llm_with_rich_content = _AnotherFakeChatModel(
1033+
responses=iter([]),
1034+
chunks=iter(chunks),
1035+
output_version="v1",
1036+
)
1037+
response_v1 = llm_with_rich_content.invoke("hello", stream=True)
1038+
assert response_v1.content_blocks == [
1039+
{"type": "reasoning", "reasoning": "<reasoning>", "index": 0},
1040+
{"type": "text", "text": "<some text>", "index": 1},
1041+
]
1042+
9391043
# v1 from env var
9401044
monkeypatch.setenv("LC_OUTPUT_VERSION", "v1")
9411045
llm = GenericFakeChatModel(messages=iter(messages))
@@ -969,7 +1073,7 @@ async def test_output_version_astream(monkeypatch: Any) -> None:
9691073

9701074
# v1
9711075
llm = GenericFakeChatModel(messages=iter(messages), output_version="v1")
972-
full_v1: BaseMessageChunk | None = None
1076+
full_v1: AIMessageChunk | None = None
9731077
async for chunk in llm.astream("hello"):
9741078
assert isinstance(chunk, AIMessageChunk)
9751079
assert isinstance(chunk.content, list)
@@ -982,6 +1086,58 @@ async def test_output_version_astream(monkeypatch: Any) -> None:
9821086
assert isinstance(full_v1, AIMessageChunk)
9831087
assert full_v1.response_metadata["output_version"] == "v1"
9841088

1089+
assert full_v1.content == [{"type": "text", "text": "foo bar", "index": 0}]
1090+
1091+
# Test text blocks
1092+
llm_with_rich_content = _AnotherFakeChatModel(
1093+
responses=iter([]),
1094+
chunks=iter(
1095+
[
1096+
AIMessageChunk(content="foo "),
1097+
AIMessageChunk(content="bar"),
1098+
]
1099+
),
1100+
output_version="v1",
1101+
)
1102+
full_v1 = None
1103+
async for chunk in llm_with_rich_content.astream("hello"):
1104+
full_v1 = chunk if full_v1 is None else full_v1 + chunk
1105+
assert isinstance(full_v1, AIMessageChunk)
1106+
assert full_v1.content_blocks == [{"type": "text", "text": "foo bar", "index": 0}]
1107+
1108+
# Test content blocks of different types
1109+
chunks = [
1110+
AIMessageChunk(content="", additional_kwargs={"reasoning_content": "<rea"}),
1111+
AIMessageChunk(content="", additional_kwargs={"reasoning_content": "soning>"}),
1112+
AIMessageChunk(content="<some "),
1113+
AIMessageChunk(content="text>"),
1114+
]
1115+
llm_with_rich_content = _AnotherFakeChatModel(
1116+
responses=iter([]),
1117+
chunks=iter(chunks),
1118+
output_version="v1",
1119+
)
1120+
full_v1 = None
1121+
async for chunk in llm_with_rich_content.astream("hello"):
1122+
full_v1 = chunk if full_v1 is None else full_v1 + chunk
1123+
assert isinstance(full_v1, AIMessageChunk)
1124+
assert full_v1.content_blocks == [
1125+
{"type": "reasoning", "reasoning": "<reasoning>", "index": 0},
1126+
{"type": "text", "text": "<some text>", "index": 1},
1127+
]
1128+
1129+
# Test invoke with stream=True
1130+
llm_with_rich_content = _AnotherFakeChatModel(
1131+
responses=iter([]),
1132+
chunks=iter(chunks),
1133+
output_version="v1",
1134+
)
1135+
response_v1 = await llm_with_rich_content.ainvoke("hello", stream=True)
1136+
assert response_v1.content_blocks == [
1137+
{"type": "reasoning", "reasoning": "<reasoning>", "index": 0},
1138+
{"type": "text", "text": "<some text>", "index": 1},
1139+
]
1140+
9851141
# v1 from env var
9861142
monkeypatch.setenv("LC_OUTPUT_VERSION", "v1")
9871143
llm = GenericFakeChatModel(messages=iter(messages))

0 commit comments

Comments
 (0)