Skip to content

Commit e8c3c81

Browse files
committed
fix coverage/linting
1 parent 8d090f0 commit e8c3c81

File tree

2 files changed

+54
-40
lines changed

2 files changed

+54
-40
lines changed

pydantic_ai_slim/pydantic_ai/models/openai.py

Lines changed: 38 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from __future__ import annotations as _annotations
22

33
import base64
4+
import itertools
45
import json
56
import warnings
67
from collections.abc import AsyncIterable, AsyncIterator, Sequence
@@ -62,6 +63,7 @@
6263
ChatCompletionContentPartParam,
6364
ChatCompletionContentPartTextParam,
6465
)
66+
from openai.types.chat.chat_completion_chunk import Choice
6567
from openai.types.chat.chat_completion_content_part_image_param import ImageURL
6668
from openai.types.chat.chat_completion_content_part_input_audio_param import InputAudio
6769
from openai.types.chat.chat_completion_content_part_param import File, FileFile
@@ -530,9 +532,17 @@ async def _completions_create(
530532
raise # pragma: lax no cover
531533

532534
def _validate_completion(self, response: chat.ChatCompletion) -> chat.ChatCompletion:
535+
"""Hook that validates chat completions before processing.
536+
537+
This method may be overridden by subclasses of `OpenAIChatModel` to apply custom completion validations.
538+
"""
533539
return chat.ChatCompletion.model_validate(response.model_dump())
534540

535541
def _process_reasoning(self, response: chat.ChatCompletion) -> list[ThinkingPart]:
542+
"""Hook that maps reasoning tokens to thinking parts.
543+
544+
This method may be overridden by subclasses of `OpenAIChatModel` to apply custom mappings.
545+
"""
536546
message = response.choices[0].message
537547
items: list[ThinkingPart] = []
538548

@@ -550,6 +560,10 @@ def _process_reasoning(self, response: chat.ChatCompletion) -> list[ThinkingPart
550560
return items
551561

552562
def _process_provider_details(self, response: chat.ChatCompletion) -> dict[str, Any]:
563+
"""Hook that response content to provider details.
564+
565+
This method may be overridden by subclasses of `OpenAIChatModel` to apply custom mappings.
566+
"""
553567
choice = response.choices[0]
554568
provider_details: dict[str, Any] = {}
555569

@@ -692,7 +706,7 @@ def _get_web_search_options(self, model_request_parameters: ModelRequestParamete
692706
def _map_model_response(self, message: ModelResponse) -> chat.ChatCompletionMessageParam:
693707
"""Hook that determines how `ModelResponse` is mapped into `ChatCompletionMessageParam` objects before sending.
694708
695-
Subclasses of `OpenAIChatModel` should override this method to provide their own mapping logic.
709+
Subclasses of `OpenAIChatModel` may override this method to provide their own mapping logic.
696710
"""
697711
texts: list[str] = []
698712
tool_calls: list[ChatCompletionMessageFunctionToolCallParam] = []
@@ -1740,19 +1754,28 @@ async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]:
17401754
async def _validate_response(self):
17411755
"""Hook that validates incoming chunks.
17421756
1743-
This method should be overridden by subclasses of `OpenAIStreamedResponse` to apply custom chunk validations.
1757+
This method may be overridden by subclasses of `OpenAIStreamedResponse` to apply custom chunk validations.
17441758
17451759
By default, this is a no-op since `ChatCompletionChunk` is already validated.
17461760
"""
17471761
async for chunk in self._response:
17481762
yield chunk
17491763

17501764
def _map_part_delta(self, chunk: ChatCompletionChunk):
1751-
"""Hook that maps delta content to events.
1765+
"""Hook that determines the sequence of mappings that will be called to produce events.
17521766
1753-
This method should be overridden by subclasses of `OpenAIStreamResponse` to customize the mapping.
1767+
This method may be overridden by subclasses of `OpenAIStreamResponse` to customize the mapping.
17541768
"""
17551769
choice = chunk.choices[0]
1770+
return itertools.chain(
1771+
self._map_thinking_delta(choice), self._map_text_delta(choice), self._map_tool_call_delta(choice)
1772+
)
1773+
1774+
def _map_thinking_delta(self, choice: Choice):
1775+
"""Hook that maps thinking delta content to events.
1776+
1777+
This method may be overridden by subclasses of `OpenAIStreamResponse` to customize the mapping.
1778+
"""
17561779
# The `reasoning_content` field is only present in DeepSeek models.
17571780
# https://api-docs.deepseek.com/guides/reasoning_model
17581781
if reasoning_content := getattr(choice.delta, 'reasoning_content', None):
@@ -1774,6 +1797,11 @@ def _map_part_delta(self, chunk: ChatCompletionChunk):
17741797
provider_name=self.provider_name,
17751798
)
17761799

1800+
def _map_text_delta(self, choice: Choice):
1801+
"""Hook that maps text delta content to events.
1802+
1803+
This method may be overridden by subclasses of `OpenAIStreamResponse` to customize the mapping.
1804+
"""
17771805
# Handle the text part of the response
17781806
content = choice.delta.content
17791807
if content:
@@ -1789,6 +1817,11 @@ def _map_part_delta(self, chunk: ChatCompletionChunk):
17891817
maybe_event.part.provider_name = self.provider_name
17901818
yield maybe_event
17911819

1820+
def _map_tool_call_delta(self, choice: Choice):
1821+
"""Hook that maps tool call delta content to events.
1822+
1823+
This method may be overridden by subclasses of `OpenAIStreamResponse` to customize the mapping.
1824+
"""
17921825
for dtc in choice.delta.tool_calls or []:
17931826
maybe_event = self._parts_manager.handle_tool_call_delta(
17941827
vendor_part_id=dtc.index,
@@ -1802,7 +1835,7 @@ def _map_part_delta(self, chunk: ChatCompletionChunk):
18021835
def _map_provider_details(self, chunk: ChatCompletionChunk) -> dict[str, str] | None:
18031836
"""Hook that generates the provider details from chunk content.
18041837
1805-
This method should be overridden by subclasses of `OpenAIStreamResponse` to customize the provider details.
1838+
This method may be overridden by subclasses of `OpenAIStreamResponse` to customize the provider details.
18061839
"""
18071840
choice = chunk.choices[0]
18081841
if raw_finish_reason := choice.finish_reason:

pydantic_ai_slim/pydantic_ai/models/openrouter.py

Lines changed: 16 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,8 @@
11
from dataclasses import asdict, dataclass
2-
from typing import Any, Literal, cast, override
2+
from typing import Any, Literal, cast
33

4-
from openai import APIError
5-
from openai.types import chat
6-
from openai.types.chat import chat_completion, chat_completion_chunk
74
from pydantic import AliasChoices, BaseModel, Field, TypeAdapter
8-
from typing_extensions import TypedDict, assert_never
5+
from typing_extensions import TypedDict, assert_never, override
96

107
from .. import _utils
118
from ..exceptions import ModelHTTPError
@@ -15,7 +12,6 @@
1512
FilePart,
1613
FinishReason,
1714
ModelResponse,
18-
PartStartEvent,
1915
TextPart,
2016
ThinkingPart,
2117
ToolCallPart,
@@ -25,7 +21,18 @@
2521
from ..settings import ModelSettings
2622
from ..usage import RequestUsage
2723
from . import ModelRequestParameters
28-
from .openai import OpenAIChatModel, OpenAIChatModelSettings, OpenAIStreamedResponse
24+
25+
try:
26+
from openai import APIError
27+
from openai.types import chat
28+
from openai.types.chat import chat_completion, chat_completion_chunk
29+
30+
from .openai import OpenAIChatModel, OpenAIChatModelSettings, OpenAIStreamedResponse
31+
except ImportError as _import_error:
32+
raise ImportError(
33+
'Please install `openai` to use the OpenRouter model, '
34+
'you can use the `openai` optional group — `pip install "pydantic-ai-slim[openai]"`'
35+
) from _import_error
2936

3037
_CHAT_FINISH_REASON_MAP: dict[Literal['stop', 'length', 'tool_calls', 'content_filter', 'error'], FinishReason] = {
3138
'stop': 'stop',
@@ -592,10 +599,9 @@ async def _validate_response(self):
592599
raise ModelHTTPError(status_code=error.code, model_name=self._model_name, body=error.message)
593600

594601
@override
595-
def _map_part_delta(self, chunk: chat.ChatCompletionChunk):
602+
def _map_thinking_delta(self, choice: chat_completion_chunk.Choice):
596603
# We can cast with confidence because chunk was validated in `_validate_response`
597-
chunk = cast(OpenRouterChatCompletionChunk, chunk)
598-
choice = chunk.choices[0]
604+
choice = cast(OpenRouterChunkChoice, choice)
599605

600606
if reasoning_details := choice.delta.reasoning_details:
601607
for detail in reasoning_details:
@@ -607,31 +613,6 @@ def _map_part_delta(self, chunk: chat.ChatCompletionChunk):
607613
provider_name=self._provider_name,
608614
)
609615

610-
# Handle the text part of the response
611-
content = choice.delta.content
612-
if content:
613-
maybe_event = self._parts_manager.handle_text_delta(
614-
vendor_part_id='content',
615-
content=content,
616-
thinking_tags=self._model_profile.thinking_tags,
617-
ignore_leading_whitespace=self._model_profile.ignore_streamed_leading_whitespace,
618-
)
619-
if maybe_event is not None: # pragma: no branch
620-
if isinstance(maybe_event, PartStartEvent) and isinstance(maybe_event.part, ThinkingPart):
621-
maybe_event.part.id = 'content'
622-
maybe_event.part.provider_name = self.provider_name
623-
yield maybe_event
624-
625-
for dtc in choice.delta.tool_calls or []:
626-
maybe_event = self._parts_manager.handle_tool_call_delta(
627-
vendor_part_id=dtc.index,
628-
tool_name=dtc.function and dtc.function.name,
629-
args=dtc.function and dtc.function.arguments,
630-
tool_call_id=dtc.id,
631-
)
632-
if maybe_event is not None:
633-
yield maybe_event
634-
635616
@override
636617
def _map_provider_details(self, chunk: chat.ChatCompletionChunk) -> dict[str, str] | None:
637618
chunk = cast(OpenRouterChatCompletionChunk, chunk)

0 commit comments

Comments
 (0)