Skip to content

Commit 974bac2

Browse files
authored
Merge branch 'main' into clai-chat
2 parents 0673277 + 1a35af5 commit 974bac2

File tree

13 files changed

+1765
-124
lines changed

13 files changed

+1765
-124
lines changed

pydantic_ai_slim/pydantic_ai/durable_exec/dbos/_agent.py

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,21 @@ def dbosify_toolset(toolset: AbstractToolset[AgentDepsT]) -> AbstractToolset[Age
101101
step_config=self._mcp_step_config,
102102
)
103103

104+
# Replace FastMCPToolset with DBOSFastMCPToolset
105+
try:
106+
from pydantic_ai.toolsets.fastmcp import FastMCPToolset
107+
108+
from ._fastmcp_toolset import DBOSFastMCPToolset
109+
except ImportError:
110+
pass
111+
else:
112+
if isinstance(toolset, FastMCPToolset):
113+
return DBOSFastMCPToolset(
114+
wrapped=toolset,
115+
step_name_prefix=dbosagent_name,
116+
step_config=self._mcp_step_config,
117+
)
118+
104119
return toolset
105120

106121
dbos_toolsets = [toolset.visit_and_replace(dbosify_toolset) for toolset in wrapped.toolsets]
@@ -336,6 +351,10 @@ async def main():
336351
Returns:
337352
The result of the run.
338353
"""
354+
if model is not None and not isinstance(model, DBOSModel):
355+
raise UserError(
356+
'Non-DBOS model cannot be set at agent run time inside a DBOS workflow, it must be set at agent creation time.'
357+
)
339358
return await self.dbos_wrapped_run_workflow(
340359
user_prompt,
341360
output_type=output_type,
@@ -449,6 +468,10 @@ def run_sync(
449468
Returns:
450469
The result of the run.
451470
"""
471+
if model is not None and not isinstance(model, DBOSModel): # pragma: lax no cover
472+
raise UserError(
473+
'Non-DBOS model cannot be set at agent run time inside a DBOS workflow, it must be set at agent creation time.'
474+
)
452475
return self.dbos_wrapped_run_sync_workflow(
453476
user_prompt,
454477
output_type=output_type,
@@ -838,7 +861,7 @@ async def main():
838861
Returns:
839862
The result of the run.
840863
"""
841-
if model is not None and not isinstance(model, DBOSModel):
864+
if model is not None and not isinstance(model, DBOSModel): # pragma: lax no cover
842865
raise UserError(
843866
'Non-DBOS model cannot be set at agent run time inside a DBOS workflow, it must be set at agent creation time.'
844867
)
Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
from __future__ import annotations
2+
3+
from pydantic_ai import ToolsetTool
4+
from pydantic_ai.tools import AgentDepsT, ToolDefinition
5+
from pydantic_ai.toolsets.fastmcp import FastMCPToolset
6+
7+
from ._mcp import DBOSMCPToolset
8+
from ._utils import StepConfig
9+
10+
11+
class DBOSFastMCPToolset(DBOSMCPToolset[AgentDepsT]):
12+
"""A wrapper for FastMCPToolset that integrates with DBOS, turning call_tool and get_tools to DBOS steps."""
13+
14+
def __init__(
15+
self,
16+
wrapped: FastMCPToolset[AgentDepsT],
17+
*,
18+
step_name_prefix: str,
19+
step_config: StepConfig,
20+
):
21+
super().__init__(
22+
wrapped,
23+
step_name_prefix=step_name_prefix,
24+
step_config=step_config,
25+
)
26+
27+
def tool_for_tool_def(self, tool_def: ToolDefinition) -> ToolsetTool[AgentDepsT]:
28+
assert isinstance(self.wrapped, FastMCPToolset)
29+
return self.wrapped.tool_for_tool_def(tool_def)
Lines changed: 99 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,99 @@
1+
from __future__ import annotations
2+
3+
from abc import ABC, abstractmethod
4+
from collections.abc import Callable
5+
from typing import TYPE_CHECKING, Any
6+
7+
from dbos import DBOS
8+
from typing_extensions import Self
9+
10+
from pydantic_ai import AbstractToolset, ToolsetTool, WrapperToolset
11+
from pydantic_ai.tools import AgentDepsT, RunContext, ToolDefinition
12+
13+
from ._utils import StepConfig
14+
15+
if TYPE_CHECKING:
16+
from pydantic_ai.mcp import ToolResult
17+
18+
19+
class DBOSMCPToolset(WrapperToolset[AgentDepsT], ABC):
20+
"""A wrapper for MCP toolset that integrates with DBOS, turning call_tool and get_tools to DBOS steps."""
21+
22+
def __init__(
23+
self,
24+
wrapped: AbstractToolset[AgentDepsT],
25+
*,
26+
step_name_prefix: str,
27+
step_config: StepConfig,
28+
):
29+
super().__init__(wrapped)
30+
self._step_config = step_config or {}
31+
self._step_name_prefix = step_name_prefix
32+
id_suffix = f'__{wrapped.id}' if wrapped.id else ''
33+
self._name = f'{step_name_prefix}__mcp_server{id_suffix}'
34+
35+
# Wrap get_tools in a DBOS step.
36+
@DBOS.step(
37+
name=f'{self._name}.get_tools',
38+
**self._step_config,
39+
)
40+
async def wrapped_get_tools_step(
41+
ctx: RunContext[AgentDepsT],
42+
) -> dict[str, ToolDefinition]:
43+
# Need to return a serializable dict, so we cannot return ToolsetTool directly.
44+
tools = await super(DBOSMCPToolset, self).get_tools(ctx)
45+
# ToolsetTool is not serializable as it holds a SchemaValidator (which is also the same for every MCP tool so unnecessary to pass along the wire every time),
46+
# so we just return the ToolDefinitions and wrap them in ToolsetTool outside of the activity.
47+
return {name: tool.tool_def for name, tool in tools.items()}
48+
49+
self._dbos_wrapped_get_tools_step = wrapped_get_tools_step
50+
51+
# Wrap call_tool in a DBOS step.
52+
@DBOS.step(
53+
name=f'{self._name}.call_tool',
54+
**self._step_config,
55+
)
56+
async def wrapped_call_tool_step(
57+
name: str,
58+
tool_args: dict[str, Any],
59+
ctx: RunContext[AgentDepsT],
60+
tool: ToolsetTool[AgentDepsT],
61+
) -> ToolResult:
62+
return await super(DBOSMCPToolset, self).call_tool(name, tool_args, ctx, tool)
63+
64+
self._dbos_wrapped_call_tool_step = wrapped_call_tool_step
65+
66+
@abstractmethod
67+
def tool_for_tool_def(self, tool_def: ToolDefinition) -> ToolsetTool[AgentDepsT]:
68+
raise NotImplementedError
69+
70+
@property
71+
def id(self) -> str | None:
72+
return self.wrapped.id
73+
74+
async def __aenter__(self) -> Self:
75+
# The wrapped MCP toolset enters itself around listing and calling tools
76+
# so we don't need to enter it here (nor could we because we're not inside a DBOS step).
77+
return self
78+
79+
async def __aexit__(self, *args: Any) -> bool | None:
80+
return None
81+
82+
def visit_and_replace(
83+
self, visitor: Callable[[AbstractToolset[AgentDepsT]], AbstractToolset[AgentDepsT]]
84+
) -> AbstractToolset[AgentDepsT]:
85+
# DBOS-ified toolsets cannot be swapped out after the fact.
86+
return self
87+
88+
async def get_tools(self, ctx: RunContext[AgentDepsT]) -> dict[str, ToolsetTool[AgentDepsT]]:
89+
tool_defs = await self._dbos_wrapped_get_tools_step(ctx)
90+
return {name: self.tool_for_tool_def(tool_def) for name, tool_def in tool_defs.items()}
91+
92+
async def call_tool(
93+
self,
94+
name: str,
95+
tool_args: dict[str, Any],
96+
ctx: RunContext[AgentDepsT],
97+
tool: ToolsetTool[AgentDepsT],
98+
) -> ToolResult:
99+
return await self._dbos_wrapped_call_tool_step(name, tool_args, ctx, tool)
Lines changed: 12 additions & 73 deletions
Original file line numberDiff line numberDiff line change
@@ -1,22 +1,14 @@
11
from __future__ import annotations
22

3-
from abc import ABC
4-
from collections.abc import Callable
5-
from typing import TYPE_CHECKING, Any
6-
7-
from dbos import DBOS
8-
from typing_extensions import Self
9-
10-
from pydantic_ai import AbstractToolset, ToolsetTool, WrapperToolset
11-
from pydantic_ai.tools import AgentDepsT, RunContext
3+
from pydantic_ai import ToolsetTool
4+
from pydantic_ai.mcp import MCPServer
5+
from pydantic_ai.tools import AgentDepsT, ToolDefinition
126

7+
from ._mcp import DBOSMCPToolset
138
from ._utils import StepConfig
149

15-
if TYPE_CHECKING:
16-
from pydantic_ai.mcp import MCPServer, ToolResult
1710

18-
19-
class DBOSMCPServer(WrapperToolset[AgentDepsT], ABC):
11+
class DBOSMCPServer(DBOSMCPToolset[AgentDepsT]):
2012
"""A wrapper for MCPServer that integrates with DBOS, turning call_tool and get_tools to DBOS steps."""
2113

2214
def __init__(
@@ -26,65 +18,12 @@ def __init__(
2618
step_name_prefix: str,
2719
step_config: StepConfig,
2820
):
29-
super().__init__(wrapped)
30-
self._step_config = step_config or {}
31-
self._step_name_prefix = step_name_prefix
32-
id_suffix = f'__{wrapped.id}' if wrapped.id else ''
33-
self._name = f'{step_name_prefix}__mcp_server{id_suffix}'
34-
35-
# Wrap get_tools in a DBOS step.
36-
@DBOS.step(
37-
name=f'{self._name}.get_tools',
38-
**self._step_config,
39-
)
40-
async def wrapped_get_tools_step(
41-
ctx: RunContext[AgentDepsT],
42-
) -> dict[str, ToolsetTool[AgentDepsT]]:
43-
return await super(DBOSMCPServer, self).get_tools(ctx)
44-
45-
self._dbos_wrapped_get_tools_step = wrapped_get_tools_step
46-
47-
# Wrap call_tool in a DBOS step.
48-
@DBOS.step(
49-
name=f'{self._name}.call_tool',
50-
**self._step_config,
21+
super().__init__(
22+
wrapped,
23+
step_name_prefix=step_name_prefix,
24+
step_config=step_config,
5125
)
52-
async def wrapped_call_tool_step(
53-
name: str,
54-
tool_args: dict[str, Any],
55-
ctx: RunContext[AgentDepsT],
56-
tool: ToolsetTool[AgentDepsT],
57-
) -> ToolResult:
58-
return await super(DBOSMCPServer, self).call_tool(name, tool_args, ctx, tool)
5926

60-
self._dbos_wrapped_call_tool_step = wrapped_call_tool_step
61-
62-
@property
63-
def id(self) -> str | None:
64-
return self.wrapped.id
65-
66-
async def __aenter__(self) -> Self:
67-
# The wrapped MCPServer enters itself around listing and calling tools
68-
# so we don't need to enter it here (nor could we because we're not inside a DBOS step).
69-
return self
70-
71-
async def __aexit__(self, *args: Any) -> bool | None:
72-
return None
73-
74-
def visit_and_replace(
75-
self, visitor: Callable[[AbstractToolset[AgentDepsT]], AbstractToolset[AgentDepsT]]
76-
) -> AbstractToolset[AgentDepsT]:
77-
# DBOS-ified toolsets cannot be swapped out after the fact.
78-
return self
79-
80-
async def get_tools(self, ctx: RunContext[AgentDepsT]) -> dict[str, ToolsetTool[AgentDepsT]]:
81-
return await self._dbos_wrapped_get_tools_step(ctx)
82-
83-
async def call_tool(
84-
self,
85-
name: str,
86-
tool_args: dict[str, Any],
87-
ctx: RunContext[AgentDepsT],
88-
tool: ToolsetTool[AgentDepsT],
89-
) -> ToolResult:
90-
return await self._dbos_wrapped_call_tool_step(name, tool_args, ctx, tool)
27+
def tool_for_tool_def(self, tool_def: ToolDefinition) -> ToolsetTool[AgentDepsT]:
28+
assert isinstance(self.wrapped, MCPServer)
29+
return self.wrapped.tool_for_tool_def(tool_def)

pydantic_ai_slim/pydantic_ai/models/google.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -570,7 +570,7 @@ async def _map_messages(
570570
{
571571
'function_response': {
572572
'name': part.tool_name,
573-
'response': {'call_error': part.model_response()},
573+
'response': {'error': part.model_response()},
574574
'id': part.tool_call_id,
575575
}
576576
}

pydantic_ai_slim/pydantic_ai/models/openrouter.py

Lines changed: 33 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,9 @@
22

33
from collections.abc import Iterable
44
from dataclasses import dataclass, field
5-
from typing import Any, Literal, cast
5+
from typing import Annotated, Any, Literal, TypeAlias, cast
66

7-
from pydantic import BaseModel
7+
from pydantic import BaseModel, Discriminator
88
from typing_extensions import TypedDict, assert_never, override
99

1010
from ..exceptions import ModelHTTPError
@@ -22,9 +22,13 @@
2222
try:
2323
from openai import APIError, AsyncOpenAI
2424
from openai.types import chat, completion_usage
25-
from openai.types.chat import chat_completion, chat_completion_chunk
25+
from openai.types.chat import chat_completion, chat_completion_chunk, chat_completion_message_function_tool_call
2626

27-
from .openai import OpenAIChatModel, OpenAIChatModelSettings, OpenAIStreamedResponse
27+
from .openai import (
28+
OpenAIChatModel,
29+
OpenAIChatModelSettings,
30+
OpenAIStreamedResponse,
31+
)
2832
except ImportError as _import_error:
2933
raise ImportError(
3034
'Please install `openai` to use the OpenRouter model, '
@@ -341,6 +345,27 @@ def _into_reasoning_detail(thinking_part: ThinkingPart) -> _OpenRouterReasoningD
341345
assert_never(data.type)
342346

343347

348+
class _OpenRouterFunction(chat_completion_message_function_tool_call.Function):
349+
arguments: str | None # type: ignore[reportIncompatibleVariableOverride]
350+
"""
351+
The arguments to call the function with, as generated by the model in JSON
352+
format. Note that the model does not always generate valid JSON, and may
353+
hallucinate parameters not defined by your function schema. Validate the
354+
arguments in your code before calling your function.
355+
"""
356+
357+
358+
class _OpenRouterChatCompletionMessageFunctionToolCall(chat.ChatCompletionMessageFunctionToolCall):
359+
function: _OpenRouterFunction # type: ignore[reportIncompatibleVariableOverride]
360+
"""The function that the model called."""
361+
362+
363+
_OpenRouterChatCompletionMessageToolCallUnion: TypeAlias = Annotated[
364+
_OpenRouterChatCompletionMessageFunctionToolCall | chat.ChatCompletionMessageCustomToolCall,
365+
Discriminator(discriminator='type'),
366+
]
367+
368+
344369
class _OpenRouterCompletionMessage(chat.ChatCompletionMessage):
345370
"""Wrapped chat completion message with OpenRouter specific attributes."""
346371

@@ -350,11 +375,14 @@ class _OpenRouterCompletionMessage(chat.ChatCompletionMessage):
350375
reasoning_details: list[_OpenRouterReasoningDetail] | None = None
351376
"""The reasoning details associated with the message, if any."""
352377

378+
tool_calls: list[_OpenRouterChatCompletionMessageToolCallUnion] | None = None # type: ignore[reportIncompatibleVariableOverride]
379+
"""The tool calls generated by the model, such as function calls."""
380+
353381

354382
class _OpenRouterChoice(chat_completion.Choice):
355383
"""Wraps OpenAI chat completion choice with OpenRouter specific attributes."""
356384

357-
native_finish_reason: str
385+
native_finish_reason: str | None
358386
"""The provided finish reason by the downstream provider from OpenRouter."""
359387

360388
finish_reason: Literal['stop', 'length', 'tool_calls', 'content_filter', 'error'] # type: ignore[reportIncompatibleVariableOverride]

0 commit comments

Comments
 (0)