Skip to content

Commit 1079522

Browse files
committed
refacotr _map_model_response with context
1 parent 79fbc20 commit 1079522

File tree

2 files changed

+105
-62
lines changed

2 files changed

+105
-62
lines changed

pydantic_ai_slim/pydantic_ai/models/openai.py

Lines changed: 86 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -707,42 +707,105 @@ def _get_web_search_options(self, model_request_parameters: ModelRequestParamete
707707
f'`{tool.__class__.__name__}` is not supported by `OpenAIChatModel`. If it should be, please file an issue.'
708708
)
709709

710+
@dataclass
711+
class _MapModelResposeContext:
712+
"""Context object for mapping a `ModelResponse` to OpenAI chat completion parameters.
713+
714+
This class is designed to be subclassed to add new fields for custom logic,
715+
collecting various parts of the model response (like text and tool calls)
716+
to form a single assistant message.
717+
"""
718+
texts: list[str] = field(default_factory=list)
719+
tool_calls: list[ChatCompletionMessageFunctionToolCallParam] = field(default_factory=list)
720+
721+
def into_message_param(self) -> chat.ChatCompletionAssistantMessageParam:
722+
"""Converts the collected texts and tool calls into a single OpenAI `ChatCompletionAssistantMessageParam`.
723+
724+
This method serves as a hook that can be overridden by subclasses
725+
to implement custom logic for how collected parts are transformed into the final message parameter.
726+
727+
Returns:
728+
An OpenAI `ChatCompletionAssistantMessageParam` object representing the assistant's response.
729+
"""
730+
message_param = chat.ChatCompletionAssistantMessageParam(role='assistant')
731+
if self.texts:
732+
# Note: model responses from this model should only have one text item, so the following
733+
# shouldn't merge multiple texts into one unless you switch models between runs:
734+
message_param['content'] = '\n\n'.join(self.texts)
735+
else:
736+
message_param['content'] = None
737+
if self.tool_calls:
738+
message_param['tool_calls'] = self.tool_calls
739+
return message_param
740+
741+
def _map_response_text_part(self, ctx: _MapModelResposeContext, item: TextPart) -> None:
742+
"""Maps a `TextPart` to the response context.
743+
744+
This method serves as a hook that can be overridden by subclasses
745+
to implement custom logic for handling text parts.
746+
"""
747+
ctx.texts.append(item.content)
748+
749+
def _map_response_thinking_part(self, ctx: _MapModelResposeContext, item: ThinkingPart) -> None:
750+
"""Maps a `ThinkingPart` to the response context.
751+
752+
This method serves as a hook that can be overridden by subclasses
753+
to implement custom logic for handling thinking parts.
754+
"""
755+
# NOTE: DeepSeek `reasoning_content` field should NOT be sent back per https://api-docs.deepseek.com/guides/reasoning_model,
756+
# but we currently just send it in `<think>` tags anyway as we don't want DeepSeek-specific checks here.
757+
# If you need this changed, please file an issue.
758+
start_tag, end_tag = self.profile.thinking_tags
759+
ctx.texts.append('\n'.join([start_tag, item.content, end_tag]))
760+
761+
def _map_response_tool_call_part(self, ctx: _MapModelResposeContext, item: ToolCallPart) -> None:
762+
"""Maps a `ToolCallPart` to the response context.
763+
764+
This method serves as a hook that can be overridden by subclasses
765+
to implement custom logic for handling tool call parts.
766+
"""
767+
ctx.tool_calls.append(self._map_tool_call(item))
768+
769+
def _map_response_builtin_part(
770+
self, ctx: _MapModelResposeContext, item: BuiltinToolCallPart | BuiltinToolReturnPart
771+
) -> None:
772+
"""Maps a built-in tool call or return part to the response context.
773+
774+
This method serves as a hook that can be overridden by subclasses
775+
to implement custom logic for handling built-in tool parts.
776+
"""
777+
# OpenAI doesn't return built-in tool calls
778+
pass
779+
780+
def _map_response_file_part(self, ctx: _MapModelResposeContext, item: FilePart) -> None:
781+
"""Maps a `FilePart` to the response context.
782+
783+
This method serves as a hook that can be overridden by subclasses
784+
to implement custom logic for handling file parts.
785+
"""
786+
# Files generated by models are not sent back to models that don't themselves generate files.
787+
pass
788+
710789
def _map_model_response(self, message: ModelResponse) -> chat.ChatCompletionMessageParam:
711790
"""Hook that determines how `ModelResponse` is mapped into `ChatCompletionMessageParam` objects before sending.
712791
713792
Subclasses of `OpenAIChatModel` may override this method to provide their own mapping logic.
714793
"""
715-
texts: list[str] = []
716-
tool_calls: list[ChatCompletionMessageFunctionToolCallParam] = []
794+
ctx = self._MapModelResposeContext()
717795
for item in message.parts:
718796
if isinstance(item, TextPart):
719-
texts.append(item.content)
797+
self._map_response_text_part(ctx, item)
720798
elif isinstance(item, ThinkingPart):
721-
# NOTE: DeepSeek `reasoning_content` field should NOT be sent back per https://api-docs.deepseek.com/guides/reasoning_model,
722-
# but we currently just send it in `<think>` tags anyway as we don't want DeepSeek-specific checks here.
723-
# If you need this changed, please file an issue.
724-
start_tag, end_tag = self.profile.thinking_tags
725-
texts.append('\n'.join([start_tag, item.content, end_tag]))
799+
self._map_response_thinking_part(ctx, item)
726800
elif isinstance(item, ToolCallPart):
727-
tool_calls.append(self._map_tool_call(item))
728-
# OpenAI doesn't return built-in tool calls
801+
self._map_response_tool_call_part(ctx, item)
729802
elif isinstance(item, BuiltinToolCallPart | BuiltinToolReturnPart): # pragma: no cover
730-
pass
803+
self._map_response_builtin_part(ctx, item)
731804
elif isinstance(item, FilePart): # pragma: no cover
732-
# Files generated by models are not sent back to models that don't themselves generate files.
733-
pass
805+
self._map_response_file_part(ctx, item)
734806
else:
735807
assert_never(item)
736-
message_param = chat.ChatCompletionAssistantMessageParam(role='assistant')
737-
if texts:
738-
# Note: model responses from this model should only have one text item, so the following
739-
# shouldn't merge multiple texts into one unless you switch models between runs:
740-
message_param['content'] = '\n\n'.join(texts)
741-
else:
742-
message_param['content'] = None
743-
if tool_calls:
744-
message_param['tool_calls'] = tool_calls
745-
return message_param
808+
return ctx.into_message_param()
746809

747810
def _map_finish_reason(
748811
self, key: Literal['stop', 'length', 'tool_calls', 'content_filter', 'function_call']

pydantic_ai_slim/pydantic_ai/models/openrouter.py

Lines changed: 19 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -1,23 +1,17 @@
11
from __future__ import annotations as _annotations
22

33
from collections.abc import Iterable
4-
from dataclasses import dataclass
4+
from dataclasses import dataclass, field
55
from typing import Any, Literal, cast
66

77
from pydantic import BaseModel
88
from typing_extensions import TypedDict, assert_never, override
99

1010
from ..exceptions import ModelHTTPError, UnexpectedModelBehavior
1111
from ..messages import (
12-
BuiltinToolCallPart,
13-
BuiltinToolReturnPart,
14-
FilePart,
1512
FinishReason,
16-
ModelResponse,
1713
ModelResponseStreamEvent,
18-
TextPart,
1914
ThinkingPart,
20-
ToolCallPart,
2115
)
2216
from ..profiles import ModelProfileSpec
2317
from ..providers import Provider
@@ -522,40 +516,26 @@ def _process_provider_details(self, response: chat.ChatCompletion) -> dict[str,
522516
provider_details.update(_map_openrouter_provider_details(response))
523517
return provider_details
524518

519+
@dataclass
520+
class _MapModelResposeContext(OpenAIChatModel._MapModelResposeContext): # type: ignore[reportPrivateUsage]
521+
reasoning_details: list[dict[str, Any]] = field(default_factory=list)
522+
523+
def into_message_param(self) -> chat.ChatCompletionAssistantMessageParam:
524+
message_param = super().into_message_param()
525+
if self.reasoning_details:
526+
message_param['reasoning_details'] = self.reasoning_details # type: ignore[reportGeneralTypeIssues]
527+
return message_param
528+
525529
@override
526-
def _map_model_response(self, message: ModelResponse) -> chat.ChatCompletionMessageParam:
527-
texts: list[str] = []
528-
tool_calls: list[chat.ChatCompletionMessageFunctionToolCallParam] = []
529-
reasoning_details: list[dict[str, Any]] = []
530-
for item in message.parts:
531-
if isinstance(item, TextPart):
532-
texts.append(item.content)
533-
elif isinstance(item, ThinkingPart):
534-
if item.provider_name == self.system:
535-
reasoning_details.append(_into_reasoning_detail(item).model_dump())
536-
elif content := item.content: # pragma: lax no cover
537-
start_tag, end_tag = self.profile.thinking_tags
538-
texts.append('\n'.join([start_tag, content, end_tag]))
539-
else:
540-
pass
541-
elif isinstance(item, ToolCallPart):
542-
tool_calls.append(self._map_tool_call(item))
543-
elif isinstance(item, BuiltinToolCallPart | BuiltinToolReturnPart): # pragma: no cover
544-
pass
545-
elif isinstance(item, FilePart): # pragma: no cover
546-
pass
547-
else:
548-
assert_never(item)
549-
message_param = chat.ChatCompletionAssistantMessageParam(role='assistant')
550-
if texts:
551-
message_param['content'] = '\n\n'.join(texts)
530+
def _map_response_thinking_part(self, ctx: OpenAIChatModel._MapModelResposeContext, item: ThinkingPart) -> None:
531+
assert isinstance(ctx, self._MapModelResposeContext)
532+
if item.provider_name == self.system:
533+
ctx.reasoning_details.append(_into_reasoning_detail(item).model_dump())
534+
elif content := item.content: # pragma: lax no cover
535+
start_tag, end_tag = self.profile.thinking_tags
536+
ctx.texts.append('\n'.join([start_tag, content, end_tag]))
552537
else:
553-
message_param['content'] = None
554-
if tool_calls:
555-
message_param['tool_calls'] = tool_calls
556-
if reasoning_details:
557-
message_param['reasoning_details'] = reasoning_details # type: ignore[reportGeneralTypeIssues]
558-
return message_param
538+
pass
559539

560540
@property
561541
@override

0 commit comments

Comments
 (0)