Skip to content

Commit de48e10

Browse files
authored
fix(core,openai,anthropic): delegate to core implementation on invoke when streaming=True (#33308)
1 parent 08bf8f3 commit de48e10

File tree

5 files changed

+15
-30
lines changed

5 files changed

+15
-30
lines changed

libs/core/langchain_core/language_models/chat_models.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -471,6 +471,9 @@ def _should_stream(
471471
if "stream" in kwargs:
472472
return kwargs["stream"]
473473

474+
if getattr(self, "streaming", False):
475+
return True
476+
474477
# Check if any streaming callback handlers have been passed in.
475478
handlers = run_manager.handlers if run_manager else []
476479
return any(isinstance(h, _StreamingCallbackHandler) for h in handlers)

libs/partners/anthropic/langchain_anthropic/chat_models.py

Lines changed: 0 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,6 @@
2121
from langchain_core.language_models.chat_models import (
2222
BaseChatModel,
2323
LangSmithParams,
24-
agenerate_from_stream,
25-
generate_from_stream,
2624
)
2725
from langchain_core.messages import (
2826
AIMessage,
@@ -1845,14 +1843,6 @@ def _generate(
18451843
run_manager: Optional[CallbackManagerForLLMRun] = None,
18461844
**kwargs: Any,
18471845
) -> ChatResult:
1848-
if self.streaming:
1849-
stream_iter = self._stream(
1850-
messages,
1851-
stop=stop,
1852-
run_manager=run_manager,
1853-
**kwargs,
1854-
)
1855-
return generate_from_stream(stream_iter)
18561846
payload = self._get_request_payload(messages, stop=stop, **kwargs)
18571847
try:
18581848
data = self._create(payload)
@@ -1867,14 +1857,6 @@ async def _agenerate(
18671857
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
18681858
**kwargs: Any,
18691859
) -> ChatResult:
1870-
if self.streaming:
1871-
stream_iter = self._astream(
1872-
messages,
1873-
stop=stop,
1874-
run_manager=run_manager,
1875-
**kwargs,
1876-
)
1877-
return await agenerate_from_stream(stream_iter)
18781860
payload = self._get_request_payload(messages, stop=stop, **kwargs)
18791861
try:
18801862
data = await self._acreate(payload)

libs/partners/anthropic/tests/unit_tests/test_chat_models.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,12 @@ def test_initialization() -> None:
4747
assert model.anthropic_api_url == "https://api.anthropic.com"
4848

4949

50+
@pytest.mark.parametrize("async_api", [True, False])
51+
def test_streaming_attribute_should_stream(async_api: bool) -> None: # noqa: FBT001
52+
llm = ChatAnthropic(model="foo", streaming=True)
53+
assert llm._should_stream(async_api=async_api)
54+
55+
5056
def test_anthropic_client_caching() -> None:
5157
"""Test that the OpenAI client is cached."""
5258
llm1 = ChatAnthropic(model="claude-3-5-sonnet-latest")

libs/partners/openai/langchain_openai/chat_models/base.py

Lines changed: 0 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -38,8 +38,6 @@
3838
from langchain_core.language_models.chat_models import (
3939
BaseChatModel,
4040
LangSmithParams,
41-
agenerate_from_stream,
42-
generate_from_stream,
4341
)
4442
from langchain_core.messages import (
4543
AIMessage,
@@ -1187,11 +1185,6 @@ def _generate(
11871185
run_manager: Optional[CallbackManagerForLLMRun] = None,
11881186
**kwargs: Any,
11891187
) -> ChatResult:
1190-
if self.streaming:
1191-
stream_iter = self._stream(
1192-
messages, stop=stop, run_manager=run_manager, **kwargs
1193-
)
1194-
return generate_from_stream(stream_iter)
11951188
payload = self._get_request_payload(messages, stop=stop, **kwargs)
11961189
generation_info = None
11971190
raw_response = None
@@ -1432,11 +1425,6 @@ async def _agenerate(
14321425
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
14331426
**kwargs: Any,
14341427
) -> ChatResult:
1435-
if self.streaming:
1436-
stream_iter = self._astream(
1437-
messages, stop=stop, run_manager=run_manager, **kwargs
1438-
)
1439-
return await agenerate_from_stream(stream_iter)
14401428
payload = self._get_request_payload(messages, stop=stop, **kwargs)
14411429
generation_info = None
14421430
raw_response = None

libs/partners/openai/tests/unit_tests/chat_models/test_base.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,12 @@ def test_openai_model_param() -> None:
8989
assert llm.max_tokens == 10
9090

9191

92+
@pytest.mark.parametrize("async_api", [True, False])
93+
def test_streaming_attribute_should_stream(async_api: bool) -> None:
94+
llm = ChatOpenAI(model="foo", streaming=True)
95+
assert llm._should_stream(async_api=async_api)
96+
97+
9298
def test_openai_client_caching() -> None:
9399
"""Test that the OpenAI client is cached."""
94100
llm1 = ChatOpenAI(model="gpt-4.1-mini")

0 commit comments

Comments
 (0)