Skip to content

Commit 741bb1f

Browse files
authored
core[patch]: revert change to stream type hint (#31501)
#31286 included an update to the return type for `BaseChatModel.(a)stream`, from `Iterator[BaseMessageChunk]` to `Iterator[BaseMessage]`. This change is correct, because when streaming is disabled, the stream methods return an iterator of `BaseMessage`, and the inheritance is such that an `BaseMessage` is not a `BaseMessageChunk` (but the reverse is true). However, LangChain includes a pattern throughout its docs of [summing BaseMessageChunks](https://python.langchain.com/docs/how_to/streaming/#llms-and-chat-models) to accumulate a chat model stream. This pattern is implemented in tests for most integration packages and appears in application code. So #31286 introduces mypy errors throughout the ecosystem (or maybe more accurately, it reveals that this pattern does not account for use of the `.stream` method when streaming is disabled). Here we revert just the change to the stream return type to unblock things. A fix for this should address docs + integration packages (or if we elect to just force people to update code, be explicit about that).
1 parent b149cce commit 741bb1f

File tree

4 files changed

+30
-12
lines changed

4 files changed

+30
-12
lines changed

libs/core/langchain_core/language_models/chat_models.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@
5151
AIMessage,
5252
AnyMessage,
5353
BaseMessage,
54+
BaseMessageChunk,
5455
HumanMessage,
5556
convert_to_messages,
5657
convert_to_openai_image_block,
@@ -445,10 +446,13 @@ def stream(
445446
*,
446447
stop: Optional[list[str]] = None,
447448
**kwargs: Any,
448-
) -> Iterator[BaseMessage]:
449+
) -> Iterator[BaseMessageChunk]:
449450
if not self._should_stream(async_api=False, **{**kwargs, "stream": True}):
450451
# model doesn't implement streaming, so use default implementation
451-
yield self.invoke(input, config=config, stop=stop, **kwargs)
452+
yield cast(
453+
"BaseMessageChunk",
454+
self.invoke(input, config=config, stop=stop, **kwargs),
455+
)
452456
else:
453457
config = ensure_config(config)
454458
messages = self._convert_input(input).to_messages()
@@ -533,10 +537,13 @@ async def astream(
533537
*,
534538
stop: Optional[list[str]] = None,
535539
**kwargs: Any,
536-
) -> AsyncIterator[BaseMessage]:
540+
) -> AsyncIterator[BaseMessageChunk]:
537541
if not self._should_stream(async_api=True, **{**kwargs, "stream": True}):
538542
# No async or sync stream is implemented, so fall back to ainvoke
539-
yield await self.ainvoke(input, config=config, stop=stop, **kwargs)
543+
yield cast(
544+
"BaseMessageChunk",
545+
await self.ainvoke(input, config=config, stop=stop, **kwargs),
546+
)
540547
return
541548

542549
config = ensure_config(config)

libs/core/tests/unit_tests/fake/test_fake_chat_model.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,5 @@
11
"""Tests for verifying that testing utility code works as expected."""
22

3-
import operator
4-
from functools import reduce
53
from itertools import cycle
64
from typing import Any, Optional, Union
75
from uuid import UUID
@@ -117,7 +115,12 @@ async def test_generic_fake_chat_model_stream() -> None:
117115
]
118116
assert len({chunk.id for chunk in chunks}) == 1
119117

120-
accumulate_chunks = reduce(operator.add, chunks)
118+
accumulate_chunks = None
119+
for chunk in chunks:
120+
if accumulate_chunks is None:
121+
accumulate_chunks = chunk
122+
else:
123+
accumulate_chunks += chunk
121124

122125
assert accumulate_chunks == AIMessageChunk(
123126
content="",

libs/core/tests/unit_tests/language_models/chat_models/test_base.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -163,10 +163,15 @@ def _llm_type(self) -> str:
163163

164164
model = ModelWithGenerate()
165165
chunks = list(model.stream("anything"))
166-
assert chunks == [_any_id_ai_message(content="hello")]
166+
# BaseChatModel.stream is typed to return Iterator[BaseMessageChunk].
167+
# When streaming is disabled, it returns Iterator[BaseMessage], so the type hint
168+
# is not strictly correct.
169+
# LangChain documents a pattern of adding BaseMessageChunks to accumulate a stream.
170+
# This may be better done with `reduce(operator.add, chunks)`.
171+
assert chunks == [_any_id_ai_message(content="hello")] # type: ignore[comparison-overlap]
167172

168173
chunks = [chunk async for chunk in model.astream("anything")]
169-
assert chunks == [_any_id_ai_message(content="hello")]
174+
assert chunks == [_any_id_ai_message(content="hello")] # type: ignore[comparison-overlap]
170175

171176

172177
async def test_astream_implementation_fallback_to_stream() -> None:

libs/langchain/tests/unit_tests/llms/test_fake_chat_model.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,5 @@
11
"""Tests for verifying that testing utility code works as expected."""
22

3-
import operator
4-
from functools import reduce
53
from itertools import cycle
64
from typing import Any, Optional, Union
75
from uuid import UUID
@@ -109,7 +107,12 @@ async def test_generic_fake_chat_model_stream() -> None:
109107
),
110108
]
111109

112-
accumulate_chunks = reduce(operator.add, chunks)
110+
accumulate_chunks = None
111+
for chunk in chunks:
112+
if accumulate_chunks is None:
113+
accumulate_chunks = chunk
114+
else:
115+
accumulate_chunks += chunk
113116

114117
assert accumulate_chunks == AIMessageChunk(
115118
id="a1",

0 commit comments

Comments
 (0)