Skip to content

Commit b451460

Browse files
committed
fix(3298): add try catch logic in MCP tool call to mimic Function Tool Call
1 parent 0baf38c commit b451460

File tree

1 file changed

+45
-10
lines changed

1 file changed

+45
-10
lines changed

pydantic_ai_slim/pydantic_ai/durable_exec/temporal/_mcp_server.py

Lines changed: 45 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -2,15 +2,15 @@
22

33
from collections.abc import Callable
44
from dataclasses import dataclass
5-
from typing import Any, Literal
5+
from typing import Annotated, Any, Literal
66

7-
from pydantic import ConfigDict, with_config
7+
from pydantic import ConfigDict, Discriminator, with_config,
88
from temporalio import activity, workflow
99
from temporalio.workflow import ActivityConfig
1010
from typing_extensions import Self
1111

1212
from pydantic_ai import ToolsetTool
13-
from pydantic_ai.exceptions import UserError
13+
from pydantic_ai.exceptions import ApprovalRequired, CallDeferred, ModelRetry, UserError
1414
from pydantic_ai.mcp import MCPServer, ToolResult
1515
from pydantic_ai.tools import AgentDepsT, RunContext, ToolDefinition
1616

@@ -32,6 +32,33 @@ class _CallToolParams:
3232
serialized_run_context: Any
3333
tool_def: ToolDefinition
3434

35+
@dataclass
36+
class _ApprovalRequired:
37+
kind: Literal['approval_required'] = 'approval_required'
38+
39+
40+
@dataclass
41+
class _CallDeferred:
42+
kind: Literal['call_deferred'] = 'call_deferred'
43+
44+
45+
@dataclass
46+
class _ModelRetry:
47+
message: str
48+
kind: Literal['model_retry'] = 'model_retry'
49+
50+
51+
@dataclass
52+
class _ToolReturn:
53+
result: Any
54+
kind: Literal['tool_return'] = 'tool_return'
55+
56+
57+
_CallToolResult = Annotated[
58+
_ApprovalRequired | _CallDeferred | _ModelRetry | _ToolReturn,
59+
Discriminator('kind'),
60+
]
61+
3562

3663
class TemporalMCPServer(TemporalWrapperToolset[AgentDepsT]):
3764
def __init__(
@@ -72,14 +99,22 @@ async def get_tools_activity(params: _GetToolsParams, deps: AgentDepsT) -> dict[
7299
get_tools_activity
73100
)
74101

75-
async def call_tool_activity(params: _CallToolParams, deps: AgentDepsT) -> ToolResult:
102+
async def call_tool_activity(params: _CallToolParams, deps: AgentDepsT) -> _CallToolResult:
76103
run_context = self.run_context_type.deserialize_run_context(params.serialized_run_context, deps=deps)
77-
return await self.wrapped.call_tool(
78-
params.name,
79-
params.tool_args,
80-
run_context,
81-
self.tool_for_tool_def(params.tool_def),
82-
)
104+
try:
105+
result = await self.wrapped.call_tool(
106+
params.name,
107+
params.tool_args,
108+
run_context,
109+
self.tool_for_tool_def(params.tool_def),
110+
)
111+
return _ToolReturn(result=result)
112+
except ApprovalRequired:
113+
return _ApprovalRequired()
114+
except CallDeferred:
115+
return _CallDeferred()
116+
except ModelRetry as e:
117+
return _ModelRetry(message=e.message)
83118

84119
# Set type hint explicitly so that Temporal can take care of serialization and deserialization
85120
call_tool_activity.__annotations__['deps'] = deps_type

0 commit comments

Comments
 (0)