22
33from collections .abc import Callable
44from dataclasses import dataclass
5- from typing import Any , Literal
5+ from typing import Annotated , Any , Literal
66
7- from pydantic import ConfigDict , with_config
7+ from pydantic import ConfigDict , Discriminator , with_config ,
88from temporalio import activity , workflow
99from temporalio .workflow import ActivityConfig
1010from typing_extensions import Self
1111
1212from pydantic_ai import ToolsetTool
13- from pydantic_ai .exceptions import UserError
13+ from pydantic_ai .exceptions import ApprovalRequired , CallDeferred , ModelRetry , UserError
1414from pydantic_ai .mcp import MCPServer , ToolResult
1515from pydantic_ai .tools import AgentDepsT , RunContext , ToolDefinition
1616
@@ -32,6 +32,33 @@ class _CallToolParams:
3232 serialized_run_context : Any
3333 tool_def : ToolDefinition
3434
35+ @dataclass
36+ class _ApprovalRequired :
37+ kind : Literal ['approval_required' ] = 'approval_required'
38+
39+
40+ @dataclass
41+ class _CallDeferred :
42+ kind : Literal ['call_deferred' ] = 'call_deferred'
43+
44+
45+ @dataclass
46+ class _ModelRetry :
47+ message : str
48+ kind : Literal ['model_retry' ] = 'model_retry'
49+
50+
51+ @dataclass
52+ class _ToolReturn :
53+ result : Any
54+ kind : Literal ['tool_return' ] = 'tool_return'
55+
56+
57+ _CallToolResult = Annotated [
58+ _ApprovalRequired | _CallDeferred | _ModelRetry | _ToolReturn ,
59+ Discriminator ('kind' ),
60+ ]
61+
3562
3663class TemporalMCPServer (TemporalWrapperToolset [AgentDepsT ]):
3764 def __init__ (
@@ -72,14 +99,22 @@ async def get_tools_activity(params: _GetToolsParams, deps: AgentDepsT) -> dict[
7299 get_tools_activity
73100 )
74101
75- async def call_tool_activity (params : _CallToolParams , deps : AgentDepsT ) -> ToolResult :
102+ async def call_tool_activity (params : _CallToolParams , deps : AgentDepsT ) -> _CallToolResult :
76103 run_context = self .run_context_type .deserialize_run_context (params .serialized_run_context , deps = deps )
77- return await self .wrapped .call_tool (
78- params .name ,
79- params .tool_args ,
80- run_context ,
81- self .tool_for_tool_def (params .tool_def ),
82- )
104+ try :
105+ result = await self .wrapped .call_tool (
106+ params .name ,
107+ params .tool_args ,
108+ run_context ,
109+ self .tool_for_tool_def (params .tool_def ),
110+ )
111+ return _ToolReturn (result = result )
112+ except ApprovalRequired :
113+ return _ApprovalRequired ()
114+ except CallDeferred :
115+ return _CallDeferred ()
116+ except ModelRetry as e :
117+ return _ModelRetry (message = e .message )
83118
84119 # Set type hint explicitly so that Temporal can take care of serialization and deserialization
85120 call_tool_activity .__annotations__ ['deps' ] = deps_type
0 commit comments