Skip to content

Commit 8d090f0

Browse files
committed
simplify hooks
1 parent 0b37792 commit 8d090f0

File tree

4 files changed

+334
-199
lines changed

4 files changed

+334
-199
lines changed

pydantic_ai_slim/pydantic_ai/models/openai.py

Lines changed: 60 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77
from contextlib import asynccontextmanager
88
from dataclasses import dataclass, field, replace
99
from datetime import datetime
10-
from itertools import chain
1110
from typing import Any, Literal, cast, overload
1211

1312
from pydantic import ValidationError
@@ -650,7 +649,7 @@ async def _process_streamed_response(
650649
# so we set it from a later chunk in `OpenAIChatStreamedResponse`.
651650
model_name = first_chunk.model or self._model_name
652651

653-
return OpenAIStreamedResponse(
652+
return self._streamed_response_cls(
654653
model_request_parameters=model_request_parameters,
655654
_model_name=model_name,
656655
_model_profile=self.profile,
@@ -660,6 +659,10 @@ async def _process_streamed_response(
660659
_provider_url=self._provider.base_url,
661660
)
662661

662+
@property
663+
def _streamed_response_cls(self):
664+
return OpenAIStreamedResponse
665+
663666
def _get_tools(self, model_request_parameters: ModelRequestParameters) -> list[chat.ChatCompletionToolParam]:
664667
return [self._map_tool_definition(r) for r in model_request_parameters.tool_defs.values()]
665668

@@ -687,6 +690,10 @@ def _get_web_search_options(self, model_request_parameters: ModelRequestParamete
687690
)
688691

689692
def _map_model_response(self, message: ModelResponse) -> chat.ChatCompletionMessageParam:
693+
"""Hook that determines how `ModelResponse` is mapped into `ChatCompletionMessageParam` objects before sending.
694+
695+
Subclasses of `OpenAIChatModel` should override this method to provide their own mapping logic.
696+
"""
690697
texts: list[str] = []
691698
tool_calls: list[ChatCompletionMessageFunctionToolCallParam] = []
692699
for item in message.parts:
@@ -1702,7 +1709,49 @@ class OpenAIStreamedResponse(StreamedResponse):
17021709
_provider_name: str
17031710
_provider_url: str
17041711

1705-
def _handle_thinking_delta(self, chunk: ChatCompletionChunk):
1712+
async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]:
1713+
async for chunk in self._validate_response():
1714+
self._usage += self._map_usage(chunk)
1715+
1716+
if chunk.id: # pragma: no branch
1717+
self.provider_response_id = chunk.id
1718+
1719+
if chunk.model:
1720+
self._model_name = chunk.model
1721+
1722+
try:
1723+
choice = chunk.choices[0]
1724+
except IndexError:
1725+
continue
1726+
1727+
# When using Azure OpenAI and an async content filter is enabled, the openai SDK can return None deltas.
1728+
if choice.delta is None: # pyright: ignore[reportUnnecessaryComparison]
1729+
continue
1730+
1731+
if raw_finish_reason := choice.finish_reason:
1732+
self.finish_reason = self._map_finish_reason(raw_finish_reason)
1733+
1734+
if provider_details := self._map_provider_details(chunk):
1735+
self.provider_details = provider_details
1736+
1737+
for event in self._map_part_delta(chunk):
1738+
yield event
1739+
1740+
async def _validate_response(self):
1741+
"""Hook that validates incoming chunks.
1742+
1743+
This method should be overridden by subclasses of `OpenAIStreamedResponse` to apply custom chunk validations.
1744+
1745+
By default, this is a no-op since `ChatCompletionChunk` is already validated.
1746+
"""
1747+
async for chunk in self._response:
1748+
yield chunk
1749+
1750+
def _map_part_delta(self, chunk: ChatCompletionChunk):
1751+
"""Hook that maps delta content to events.
1752+
1753+
This method should be overridden by subclasses of `OpenAIStreamResponse` to customize the mapping.
1754+
"""
17061755
choice = chunk.choices[0]
17071756
# The `reasoning_content` field is only present in DeepSeek models.
17081757
# https://api-docs.deepseek.com/guides/reasoning_model
@@ -1725,14 +1774,8 @@ def _handle_thinking_delta(self, chunk: ChatCompletionChunk):
17251774
provider_name=self.provider_name,
17261775
)
17271776

1728-
def _handle_provider_details(self, chunk: ChatCompletionChunk) -> dict[str, str] | None:
1729-
choice = chunk.choices[0]
1730-
if raw_finish_reason := choice.finish_reason:
1731-
return {'finish_reason': raw_finish_reason}
1732-
1733-
def _handle_text_delta(self, chunk: ChatCompletionChunk):
17341777
# Handle the text part of the response
1735-
content = chunk.choices[0].delta.content
1778+
content = choice.delta.content
17361779
if content:
17371780
maybe_event = self._parts_manager.handle_text_delta(
17381781
vendor_part_id='content',
@@ -1746,8 +1789,6 @@ def _handle_text_delta(self, chunk: ChatCompletionChunk):
17461789
maybe_event.part.provider_name = self.provider_name
17471790
yield maybe_event
17481791

1749-
def _handle_tool_delta(self, chunk: ChatCompletionChunk):
1750-
choice = chunk.choices[0]
17511792
for dtc in choice.delta.tool_calls or []:
17521793
maybe_event = self._parts_manager.handle_tool_call_delta(
17531794
vendor_part_id=dtc.index,
@@ -1758,41 +1799,14 @@ def _handle_tool_delta(self, chunk: ChatCompletionChunk):
17581799
if maybe_event is not None:
17591800
yield maybe_event
17601801

1761-
async def _validate_response(self):
1762-
async for chunk in self._response:
1763-
yield chunk
1764-
1765-
async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]:
1766-
async for chunk in self._validate_response():
1767-
self._usage += self._map_usage(chunk)
1768-
1769-
if chunk.id: # pragma: no branch
1770-
self.provider_response_id = chunk.id
1771-
1772-
if chunk.model:
1773-
self._model_name = chunk.model
1774-
1775-
try:
1776-
choice = chunk.choices[0]
1777-
except IndexError:
1778-
continue
1779-
1780-
# When using Azure OpenAI and an async content filter is enabled, the openai SDK can return None deltas.
1781-
if choice.delta is None: # pyright: ignore[reportUnnecessaryComparison]
1782-
continue
1802+
def _map_provider_details(self, chunk: ChatCompletionChunk) -> dict[str, str] | None:
1803+
"""Hook that generates the provider details from chunk content.
17831804
1784-
if raw_finish_reason := choice.finish_reason:
1785-
self.finish_reason = self._map_finish_reason(raw_finish_reason)
1786-
1787-
if provider_details := self._handle_provider_details(chunk):
1788-
self.provider_details = provider_details
1789-
1790-
for event in chain(
1791-
self._handle_thinking_delta(chunk),
1792-
self._handle_text_delta(chunk),
1793-
self._handle_tool_delta(chunk),
1794-
):
1795-
yield event
1805+
This method should be overridden by subclasses of `OpenAIStreamResponse` to customize the provider details.
1806+
"""
1807+
choice = chunk.choices[0]
1808+
if raw_finish_reason := choice.finish_reason:
1809+
return {'finish_reason': raw_finish_reason}
17961810

17971811
def _map_usage(self, response: ChatCompletionChunk):
17981812
return _map_usage(response, self._provider_name, self._provider_url, self._model_name)

0 commit comments

Comments
 (0)