Skip to content

Commit bbc3e3b

Browse files
author
Erick Friis
authored
openai: disable streaming for o1 by default (#29147)
Currently 400s https://community.openai.com/t/streaming-support-for-o1-o1-2024-12-17-resulting-in-400-unsupported-value/1085043 o1-mini and o1-preview stream fine
1 parent 62074ba commit bbc3e3b

File tree

2 files changed

+25
-0
lines changed
  • libs/partners/openai

2 files changed

+25
-0
lines changed

libs/partners/openai/langchain_openai/chat_models/base.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -562,6 +562,15 @@ def validate_temperature(cls, values: Dict[str, Any]) -> Any:
562562
values["temperature"] = 1
563563
return values
564564

565+
@model_validator(mode="before")
566+
@classmethod
567+
def validate_disable_streaming(cls, values: Dict[str, Any]) -> Any:
568+
"""Disable streaming if n > 1."""
569+
model = values.get("model_name") or values.get("model") or ""
570+
if model == "o1" and values.get("disable_streaming") is None:
571+
values["disable_streaming"] = True
572+
return values
573+
565574
@model_validator(mode="after")
566575
def validate_environment(self) -> Self:
567576
"""Validate that api key and python package exists in environment."""

libs/partners/openai/tests/integration_tests/chat_models/test_base.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1192,3 +1192,19 @@ def test_o1(use_max_completion_tokens: bool) -> None:
11921192
assert isinstance(response, AIMessage)
11931193
assert isinstance(response.content, str)
11941194
assert response.content.upper() == response.content
1195+
1196+
1197+
@pytest.mark.scheduled
1198+
def test_o1_doesnt_stream() -> None:
1199+
"""
1200+
When this starts failing, remove the `disable_streaming` validator in
1201+
`BaseChatOpenAI`
1202+
"""
1203+
with pytest.raises(openai.BadRequestError):
1204+
list(ChatOpenAI(model="o1", disable_streaming=False).stream("how are you"))
1205+
1206+
1207+
@pytest.mark.scheduled
1208+
def test_o1_stream_default_works() -> None:
1209+
result = list(ChatOpenAI(model="o1").stream("say 'hi'"))
1210+
assert len(result) > 0

0 commit comments

Comments
 (0)