Skip to content

Commit 0b37792

Browse files
committed
add stream hooks
1 parent 21a78e4 commit 0b37792

File tree

4 files changed

+430
-45
lines changed

4 files changed

+430
-45
lines changed

pydantic_ai_slim/pydantic_ai/models/openai.py

Lines changed: 45 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,9 @@
77
from contextlib import asynccontextmanager
88
from dataclasses import dataclass, field, replace
99
from datetime import datetime
10+
from itertools import chain
1011
from typing import Any, Literal, cast, overload
1112

12-
from openai.types.chat.chat_completion_chunk import Choice
1313
from pydantic import ValidationError
1414
from pydantic_core import to_json
1515
from typing_extensions import assert_never, deprecated
@@ -1702,7 +1702,8 @@ class OpenAIStreamedResponse(StreamedResponse):
17021702
_provider_name: str
17031703
_provider_url: str
17041704

1705-
def _handle_thinking_delta(self, choice: Choice):
1705+
def _handle_thinking_delta(self, chunk: ChatCompletionChunk):
1706+
choice = chunk.choices[0]
17061707
# The `reasoning_content` field is only present in DeepSeek models.
17071708
# https://api-docs.deepseek.com/guides/reasoning_model
17081709
if reasoning_content := getattr(choice.delta, 'reasoning_content', None):
@@ -1724,12 +1725,45 @@ def _handle_thinking_delta(self, choice: Choice):
17241725
provider_name=self.provider_name,
17251726
)
17261727

1727-
def _handle_provider_details(self, choice: Choice) -> dict[str, str] | None:
1728+
def _handle_provider_details(self, chunk: ChatCompletionChunk) -> dict[str, str] | None:
1729+
choice = chunk.choices[0]
17281730
if raw_finish_reason := choice.finish_reason:
17291731
return {'finish_reason': raw_finish_reason}
17301732

1731-
async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]:
1733+
def _handle_text_delta(self, chunk: ChatCompletionChunk):
1734+
# Handle the text part of the response
1735+
content = chunk.choices[0].delta.content
1736+
if content:
1737+
maybe_event = self._parts_manager.handle_text_delta(
1738+
vendor_part_id='content',
1739+
content=content,
1740+
thinking_tags=self._model_profile.thinking_tags,
1741+
ignore_leading_whitespace=self._model_profile.ignore_streamed_leading_whitespace,
1742+
)
1743+
if maybe_event is not None: # pragma: no branch
1744+
if isinstance(maybe_event, PartStartEvent) and isinstance(maybe_event.part, ThinkingPart):
1745+
maybe_event.part.id = 'content'
1746+
maybe_event.part.provider_name = self.provider_name
1747+
yield maybe_event
1748+
1749+
def _handle_tool_delta(self, chunk: ChatCompletionChunk):
1750+
choice = chunk.choices[0]
1751+
for dtc in choice.delta.tool_calls or []:
1752+
maybe_event = self._parts_manager.handle_tool_call_delta(
1753+
vendor_part_id=dtc.index,
1754+
tool_name=dtc.function and dtc.function.name,
1755+
args=dtc.function and dtc.function.arguments,
1756+
tool_call_id=dtc.id,
1757+
)
1758+
if maybe_event is not None:
1759+
yield maybe_event
1760+
1761+
async def _validate_response(self):
17321762
async for chunk in self._response:
1763+
yield chunk
1764+
1765+
async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]:
1766+
async for chunk in self._validate_response():
17331767
self._usage += self._map_usage(chunk)
17341768

17351769
if chunk.id: # pragma: no branch
@@ -1750,36 +1784,15 @@ async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]:
17501784
if raw_finish_reason := choice.finish_reason:
17511785
self.finish_reason = self._map_finish_reason(raw_finish_reason)
17521786

1753-
if provider_details := self._handle_provider_details(choice):
1787+
if provider_details := self._handle_provider_details(chunk):
17541788
self.provider_details = provider_details
17551789

1756-
for thinking_part in self._handle_thinking_delta(choice):
1757-
yield thinking_part
1758-
1759-
# Handle the text part of the response
1760-
content = choice.delta.content
1761-
if content:
1762-
maybe_event = self._parts_manager.handle_text_delta(
1763-
vendor_part_id='content',
1764-
content=content,
1765-
thinking_tags=self._model_profile.thinking_tags,
1766-
ignore_leading_whitespace=self._model_profile.ignore_streamed_leading_whitespace,
1767-
)
1768-
if maybe_event is not None: # pragma: no branch
1769-
if isinstance(maybe_event, PartStartEvent) and isinstance(maybe_event.part, ThinkingPart):
1770-
maybe_event.part.id = 'content'
1771-
maybe_event.part.provider_name = self.provider_name
1772-
yield maybe_event
1773-
1774-
for dtc in choice.delta.tool_calls or []:
1775-
maybe_event = self._parts_manager.handle_tool_call_delta(
1776-
vendor_part_id=dtc.index,
1777-
tool_name=dtc.function and dtc.function.name,
1778-
args=dtc.function and dtc.function.arguments,
1779-
tool_call_id=dtc.id,
1780-
)
1781-
if maybe_event is not None:
1782-
yield maybe_event
1790+
for event in chain(
1791+
self._handle_thinking_delta(chunk),
1792+
self._handle_text_delta(chunk),
1793+
self._handle_tool_delta(chunk),
1794+
):
1795+
yield event
17831796

17841797
def _map_usage(self, response: ChatCompletionChunk):
17851798
return _map_usage(response, self._provider_name, self._provider_url, self._model_name)

pydantic_ai_slim/pydantic_ai/models/openrouter.py

Lines changed: 85 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33

44
from openai import AsyncStream
55
from openai.types import chat
6-
from openai.types.chat.chat_completion import Choice
6+
from openai.types.chat import chat_completion, chat_completion_chunk
77
from pydantic import AliasChoices, BaseModel, Field, TypeAdapter
88
from typing_extensions import TypedDict, assert_never
99

@@ -346,7 +346,7 @@ class OpenRouterCompletionMessage(chat.ChatCompletionMessage):
346346
"""The reasoning details associated with the message, if any."""
347347

348348

349-
class OpenRouterChoice(Choice):
349+
class OpenRouterChoice(chat_completion.Choice):
350350
"""Wraps OpenAI chat completion choice with OpenRouter specific attributes."""
351351

352352
native_finish_reason: str
@@ -375,14 +375,40 @@ class OpenRouterChatCompletion(chat.ChatCompletion):
375375
"""OpenRouter specific error attribute."""
376376

377377

378+
class OpenRouterChoiceDelta(chat_completion_chunk.ChoiceDelta):
379+
"""Wrapped chat completion message with OpenRouter specific attributes."""
380+
381+
reasoning: str | None = None
382+
"""The reasoning text associated with the message, if any."""
383+
384+
reasoning_details: list[OpenRouterReasoningDetail] | None = None
385+
"""The reasoning details associated with the message, if any."""
386+
387+
388+
class OpenRouterChunkChoice(chat_completion_chunk.Choice):
389+
"""Wraps OpenAI chat completion chunk choice with OpenRouter specific attributes."""
390+
391+
native_finish_reason: str | None
392+
"""The provided finish reason by the downstream provider from OpenRouter."""
393+
394+
finish_reason: Literal['stop', 'length', 'tool_calls', 'content_filter', 'error'] | None # type: ignore[reportIncompatibleVariableOverride]
395+
"""OpenRouter specific finish reasons for streaming chunks.
396+
397+
Notably, removes 'function_call' and adds 'error' finish reasons.
398+
"""
399+
400+
delta: OpenRouterChoiceDelta # type: ignore[reportIncompatibleVariableOverride]
401+
"""A wrapped chat completion delta with OpenRouter specific attributes."""
402+
403+
378404
class OpenRouterChatCompletionChunk(chat.ChatCompletionChunk):
379405
"""Wraps OpenAI chat completion with OpenRouter specific attributes."""
380406

381407
provider: str
382408
"""The downstream provider that was used by OpenRouter."""
383409

384-
choices: list[OpenRouterChoice] # type: ignore[reportIncompatibleVariableOverride]
385-
"""A list of chat completion choices modified with OpenRouter specific attributes."""
410+
choices: list[OpenRouterChunkChoice] # type: ignore[reportIncompatibleVariableOverride]
411+
"""A list of chat completion chunk choices modified with OpenRouter specific attributes."""
386412

387413
error: OpenRouterError | None = None
388414
"""OpenRouter specific error attribute."""
@@ -428,6 +454,48 @@ class OpenRouterStreamedResponse(OpenAIStreamedResponse):
428454
def _map_usage(self, response: chat.ChatCompletionChunk):
429455
return _map_usage(response, self._provider_name, self._provider_url, self._model_name)
430456

457+
@override
458+
def _map_finish_reason( # type: ignore[reportIncompatibleMethodOverride]
459+
self, key: Literal['stop', 'length', 'tool_calls', 'content_filter', 'error']
460+
) -> FinishReason | None:
461+
return _CHAT_FINISH_REASON_MAP.get(key)
462+
463+
@override
464+
def _handle_thinking_delta(self, chunk: OpenRouterChatCompletionChunk): # type: ignore[reportIncompatibleMethodOverride]
465+
delta = chunk.choices[0].delta
466+
if reasoning_details := delta.reasoning_details:
467+
for detail in reasoning_details:
468+
thinking_part = OpenRouterThinkingPart.from_reasoning_detail(detail)
469+
yield self._parts_manager.handle_thinking_delta(
470+
vendor_part_id='reasoning_detail',
471+
id=thinking_part.id,
472+
content=thinking_part.content,
473+
provider_name=self._provider_name,
474+
)
475+
476+
@override
477+
def _handle_provider_details(self, chunk: chat.ChatCompletionChunk) -> dict[str, str] | None:
478+
native_chunk = OpenRouterChatCompletionChunk.model_validate(chunk.model_dump())
479+
480+
if provider_details := super()._handle_provider_details(chunk):
481+
if provider := native_chunk.provider:
482+
provider_details['downstream_provider'] = provider
483+
484+
if native_finish_reason := native_chunk.choices[0].native_finish_reason:
485+
provider_details['native_finish_reason'] = native_finish_reason
486+
487+
return provider_details
488+
489+
@override
490+
async def _validate_response(self):
491+
async for chunk in self._response:
492+
chunk = OpenRouterChatCompletionChunk.model_validate(chunk.model_dump())
493+
494+
if error := chunk.error:
495+
raise ModelHTTPError(status_code=error.code, model_name=chunk.model, body=error.message)
496+
497+
yield chunk
498+
431499

432500
def _openrouter_settings_to_openai_settings(model_settings: OpenRouterModelSettings) -> OpenAIChatModelSettings:
433501
"""Transforms a 'OpenRouterModelSettings' object into an 'OpenAIChatModelSettings' object.
@@ -475,6 +543,7 @@ def __init__(
475543
"""
476544
super().__init__(model_name, provider=provider or OpenRouterProvider(), profile=profile, settings=settings)
477545

546+
@override
478547
def prepare_request(
479548
self,
480549
model_settings: ModelSettings | None,
@@ -485,13 +554,13 @@ def prepare_request(
485554
return new_settings, customized_parameters
486555

487556
@override
488-
def _map_finish_reason(
557+
def _map_finish_reason( # type: ignore[reportIncompatibleMethodOverride]
489558
self, key: Literal['stop', 'length', 'tool_calls', 'content_filter', 'error']
490-
) -> FinishReason | None: # type: ignore[reportIncompatibleMethodOverride]
559+
) -> FinishReason | None:
491560
return _CHAT_FINISH_REASON_MAP.get(key)
492561

493562
@override
494-
def _process_reasoning(self, response: OpenRouterChatCompletion) -> list[ThinkingPart]:
563+
def _process_reasoning(self, response: OpenRouterChatCompletion) -> list[ThinkingPart]: # type: ignore[reportIncompatibleMethodOverride]
495564
message = response.choices[0].message
496565
items: list[ThinkingPart] = []
497566

@@ -502,10 +571,7 @@ def _process_reasoning(self, response: OpenRouterChatCompletion) -> list[Thinkin
502571
return items
503572

504573
@override
505-
def _process_provider_details(self, response: OpenRouterChatCompletion) -> dict[str, Any]:
506-
if error := response.error:
507-
raise ModelHTTPError(status_code=error.code, model_name=response.model, body=error.message)
508-
574+
def _process_provider_details(self, response: OpenRouterChatCompletion) -> dict[str, Any]: # type: ignore[reportIncompatibleMethodOverride]
509575
provider_details = super()._process_provider_details(response)
510576

511577
provider_details['downstream_provider'] = response.provider
@@ -515,8 +581,14 @@ def _process_provider_details(self, response: OpenRouterChatCompletion) -> dict[
515581

516582
@override
517583
def _validate_completion(self, response: chat.ChatCompletion) -> chat.ChatCompletion:
518-
return OpenRouterChatCompletion.model_validate(response.model_dump())
584+
response = OpenRouterChatCompletion.model_validate(response.model_dump())
519585

586+
if error := response.error:
587+
raise ModelHTTPError(status_code=error.code, model_name=response.model, body=error.message)
588+
589+
return response
590+
591+
@override
520592
async def _process_streamed_response(
521593
self, response: AsyncStream[chat.ChatCompletionChunk], model_request_parameters: ModelRequestParameters
522594
) -> OpenRouterStreamedResponse:
@@ -538,6 +610,7 @@ async def _process_streamed_response(
538610
_provider_url=self._provider.base_url,
539611
)
540612

613+
@override
541614
def _map_model_response(self, message: ModelResponse) -> chat.ChatCompletionMessageParam:
542615
texts: list[str] = []
543616
tool_calls: list[chat.ChatCompletionMessageFunctionToolCallParam] = []

0 commit comments

Comments
 (0)