Skip to content

Commit 28edfda

Browse files
Fix streaming crash on Bedrock while preserving OpenAI usage stats (#64)
- Remove global default for `stream_options`, which caused crashes on non-OpenAI providers (e.g., Bedrock, Vertex) by routing requests to the wrong endpoint. - Implement smart logic in `_stream` and `_astream` to automatically add `stream_options={"include_usage": True}` *only* when an OpenAI/Azure model is detected. - Maintain backward compatibility for OpenAI users who rely on streaming usage stats. - Allow users to manually override `stream_options` for any provider. Fixes #51.
1 parent ed37959 commit 28edfda

File tree

1 file changed

+18
-6
lines changed

1 file changed

+18
-6
lines changed

langchain_litellm/chat_models/litellm.py

Lines changed: 18 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,6 @@
8383
class ChatLiteLLMException(Exception):
8484
"""Error with the `LiteLLM I/O` library"""
8585

86-
8786
def _create_retry_decorator(
8887
llm: ChatLiteLLM,
8988
run_manager: Optional[
@@ -337,9 +336,7 @@ class ChatLiteLLM(BaseChatModel):
337336
client: Any = None #: :meta private:
338337
model: str = "gpt-3.5-turbo"
339338
model_name: Optional[str] = None
340-
stream_options: Optional[Dict[str, Any]] = Field(
341-
default_factory=lambda: {"include_usage": True}
342-
)
339+
stream_options: Optional[Dict[str, Any]] = None
343340
"""Model name to use."""
344341
openai_api_key: Optional[str] = None
345342
azure_api_key: Optional[str] = None
@@ -566,6 +563,15 @@ def _create_message_dicts(
566563
params["stop"] = stop
567564
message_dicts = [_convert_message_to_dict(m) for m in messages]
568565
return message_dicts, params
566+
567+
def _is_openai(self) -> bool:
568+
"""Check if the current model is OpenAI or Azure."""
569+
model = self.model_name or self.model or ""
570+
if self.custom_llm_provider == "openai" or self.custom_llm_provider == "azure":
571+
return True
572+
if "azure" in model or model in _OPENAI_MODELS:
573+
return True
574+
return False
569575

570576
def _stream(
571577
self,
@@ -576,7 +582,10 @@ def _stream(
576582
) -> Iterator[ChatGenerationChunk]:
577583
message_dicts, params = self._create_message_dicts(messages, stop)
578584
params = {**params, **kwargs, "stream": True}
579-
params["stream_options"] = self.stream_options
585+
if self.stream_options is not None:
586+
params["stream_options"] = self.stream_options
587+
elif self._is_openai():
588+
params["stream_options"] = {"include_usage": True}
580589
default_chunk_class = AIMessageChunk
581590

582591
for chunk in self.completion_with_retry(
@@ -632,7 +641,10 @@ async def _astream(
632641
) -> AsyncIterator[ChatGenerationChunk]:
633642
message_dicts, params = self._create_message_dicts(messages, stop)
634643
params = {**params, **kwargs, "stream": True}
635-
params["stream_options"] = self.stream_options
644+
if self.stream_options is not None:
645+
params["stream_options"] = self.stream_options
646+
elif self._is_openai():
647+
params["stream_options"] = {"include_usage": True}
636648
default_chunk_class = AIMessageChunk
637649

638650
async for chunk in await self.acompletion_with_retry(

0 commit comments

Comments
 (0)