77from contextlib import asynccontextmanager
88from dataclasses import dataclass , field , replace
99from datetime import datetime
10- from itertools import chain
1110from typing import Any , Literal , cast , overload
1211
1312from pydantic import ValidationError
@@ -650,7 +649,7 @@ async def _process_streamed_response(
650649 # so we set it from a later chunk in `OpenAIChatStreamedResponse`.
651650 model_name = first_chunk .model or self ._model_name
652651
653- return OpenAIStreamedResponse (
652+ return self . _streamed_response_cls (
654653 model_request_parameters = model_request_parameters ,
655654 _model_name = model_name ,
656655 _model_profile = self .profile ,
@@ -660,6 +659,10 @@ async def _process_streamed_response(
660659 _provider_url = self ._provider .base_url ,
661660 )
662661
662+ @property
663+ def _streamed_response_cls (self ):
664+ return OpenAIStreamedResponse
665+
663666 def _get_tools (self , model_request_parameters : ModelRequestParameters ) -> list [chat .ChatCompletionToolParam ]:
664667 return [self ._map_tool_definition (r ) for r in model_request_parameters .tool_defs .values ()]
665668
@@ -687,6 +690,10 @@ def _get_web_search_options(self, model_request_parameters: ModelRequestParamete
687690 )
688691
689692 def _map_model_response (self , message : ModelResponse ) -> chat .ChatCompletionMessageParam :
693+ """Hook that determines how `ModelResponse` is mapped into `ChatCompletionMessageParam` objects before sending.
694+
695+ Subclasses of `OpenAIChatModel` should override this method to provide their own mapping logic.
696+ """
690697 texts : list [str ] = []
691698 tool_calls : list [ChatCompletionMessageFunctionToolCallParam ] = []
692699 for item in message .parts :
@@ -1702,7 +1709,49 @@ class OpenAIStreamedResponse(StreamedResponse):
17021709 _provider_name : str
17031710 _provider_url : str
17041711
1705- def _handle_thinking_delta (self , chunk : ChatCompletionChunk ):
1712+ async def _get_event_iterator (self ) -> AsyncIterator [ModelResponseStreamEvent ]:
1713+ async for chunk in self ._validate_response ():
1714+ self ._usage += self ._map_usage (chunk )
1715+
1716+ if chunk .id : # pragma: no branch
1717+ self .provider_response_id = chunk .id
1718+
1719+ if chunk .model :
1720+ self ._model_name = chunk .model
1721+
1722+ try :
1723+ choice = chunk .choices [0 ]
1724+ except IndexError :
1725+ continue
1726+
1727+ # When using Azure OpenAI and an async content filter is enabled, the openai SDK can return None deltas.
1728+ if choice .delta is None : # pyright: ignore[reportUnnecessaryComparison]
1729+ continue
1730+
1731+ if raw_finish_reason := choice .finish_reason :
1732+ self .finish_reason = self ._map_finish_reason (raw_finish_reason )
1733+
1734+ if provider_details := self ._map_provider_details (chunk ):
1735+ self .provider_details = provider_details
1736+
1737+ for event in self ._map_part_delta (chunk ):
1738+ yield event
1739+
1740+ async def _validate_response (self ):
1741+ """Hook that validates incoming chunks.
1742+
1743+ This method should be overridden by subclasses of `OpenAIStreamedResponse` to apply custom chunk validations.
1744+
1745+ By default, this is a no-op since `ChatCompletionChunk` is already validated.
1746+ """
1747+ async for chunk in self ._response :
1748+ yield chunk
1749+
1750+ def _map_part_delta (self , chunk : ChatCompletionChunk ):
1751+ """Hook that maps delta content to events.
1752+
1753+ This method should be overridden by subclasses of `OpenAIStreamResponse` to customize the mapping.
1754+ """
17061755 choice = chunk .choices [0 ]
17071756 # The `reasoning_content` field is only present in DeepSeek models.
17081757 # https://api-docs.deepseek.com/guides/reasoning_model
@@ -1725,14 +1774,8 @@ def _handle_thinking_delta(self, chunk: ChatCompletionChunk):
17251774 provider_name = self .provider_name ,
17261775 )
17271776
1728- def _handle_provider_details (self , chunk : ChatCompletionChunk ) -> dict [str , str ] | None :
1729- choice = chunk .choices [0 ]
1730- if raw_finish_reason := choice .finish_reason :
1731- return {'finish_reason' : raw_finish_reason }
1732-
1733- def _handle_text_delta (self , chunk : ChatCompletionChunk ):
17341777 # Handle the text part of the response
1735- content = chunk . choices [ 0 ] .delta .content
1778+ content = choice .delta .content
17361779 if content :
17371780 maybe_event = self ._parts_manager .handle_text_delta (
17381781 vendor_part_id = 'content' ,
@@ -1746,8 +1789,6 @@ def _handle_text_delta(self, chunk: ChatCompletionChunk):
17461789 maybe_event .part .provider_name = self .provider_name
17471790 yield maybe_event
17481791
1749- def _handle_tool_delta (self , chunk : ChatCompletionChunk ):
1750- choice = chunk .choices [0 ]
17511792 for dtc in choice .delta .tool_calls or []:
17521793 maybe_event = self ._parts_manager .handle_tool_call_delta (
17531794 vendor_part_id = dtc .index ,
@@ -1758,41 +1799,14 @@ def _handle_tool_delta(self, chunk: ChatCompletionChunk):
17581799 if maybe_event is not None :
17591800 yield maybe_event
17601801
1761- async def _validate_response (self ):
1762- async for chunk in self ._response :
1763- yield chunk
1764-
1765- async def _get_event_iterator (self ) -> AsyncIterator [ModelResponseStreamEvent ]:
1766- async for chunk in self ._validate_response ():
1767- self ._usage += self ._map_usage (chunk )
1768-
1769- if chunk .id : # pragma: no branch
1770- self .provider_response_id = chunk .id
1771-
1772- if chunk .model :
1773- self ._model_name = chunk .model
1774-
1775- try :
1776- choice = chunk .choices [0 ]
1777- except IndexError :
1778- continue
1779-
1780- # When using Azure OpenAI and an async content filter is enabled, the openai SDK can return None deltas.
1781- if choice .delta is None : # pyright: ignore[reportUnnecessaryComparison]
1782- continue
1802+ def _map_provider_details (self , chunk : ChatCompletionChunk ) -> dict [str , str ] | None :
1803+ """Hook that generates the provider details from chunk content.
17831804
1784- if raw_finish_reason := choice .finish_reason :
1785- self .finish_reason = self ._map_finish_reason (raw_finish_reason )
1786-
1787- if provider_details := self ._handle_provider_details (chunk ):
1788- self .provider_details = provider_details
1789-
1790- for event in chain (
1791- self ._handle_thinking_delta (chunk ),
1792- self ._handle_text_delta (chunk ),
1793- self ._handle_tool_delta (chunk ),
1794- ):
1795- yield event
1805+ This method should be overridden by subclasses of `OpenAIStreamResponse` to customize the provider details.
1806+ """
1807+ choice = chunk .choices [0 ]
1808+ if raw_finish_reason := choice .finish_reason :
1809+ return {'finish_reason' : raw_finish_reason }
17961810
17971811 def _map_usage (self , response : ChatCompletionChunk ):
17981812 return _map_usage (response , self ._provider_name , self ._provider_url , self ._model_name )
0 commit comments