Skip to content

Commit 9c31521

Browse files
wreed4DouweM
andauthored
Fix MCPServer error handling with Temporal (#3299)
Co-authored-by: Douwe Maan <[email protected]>
1 parent 1f3b100 commit 9c31521

File tree

3 files changed

+120
-106
lines changed

3 files changed

+120
-106
lines changed
Lines changed: 23 additions & 73 deletions
Original file line numberDiff line numberDiff line change
@@ -1,57 +1,22 @@
11
from __future__ import annotations
22

33
from collections.abc import Callable
4-
from dataclasses import dataclass
5-
from typing import Annotated, Any, Literal
4+
from typing import Any, Literal
65

7-
from pydantic import ConfigDict, Discriminator, with_config
86
from temporalio import activity, workflow
97
from temporalio.workflow import ActivityConfig
10-
from typing_extensions import assert_never
118

129
from pydantic_ai import FunctionToolset, ToolsetTool
13-
from pydantic_ai.exceptions import ApprovalRequired, CallDeferred, ModelRetry, UserError
10+
from pydantic_ai.exceptions import UserError
1411
from pydantic_ai.tools import AgentDepsT, RunContext
1512
from pydantic_ai.toolsets.function import FunctionToolsetTool
1613

1714
from ._run_context import TemporalRunContext
18-
from ._toolset import TemporalWrapperToolset
19-
20-
21-
@dataclass
22-
@with_config(ConfigDict(arbitrary_types_allowed=True))
23-
class _CallToolParams:
24-
name: str
25-
tool_args: dict[str, Any]
26-
serialized_run_context: Any
27-
28-
29-
@dataclass
30-
class _ApprovalRequired:
31-
kind: Literal['approval_required'] = 'approval_required'
32-
33-
34-
@dataclass
35-
class _CallDeferred:
36-
kind: Literal['call_deferred'] = 'call_deferred'
37-
38-
39-
@dataclass
40-
class _ModelRetry:
41-
message: str
42-
kind: Literal['model_retry'] = 'model_retry'
43-
44-
45-
@dataclass
46-
class _ToolReturn:
47-
result: Any
48-
kind: Literal['tool_return'] = 'tool_return'
49-
50-
51-
_CallToolResult = Annotated[
52-
_ApprovalRequired | _CallDeferred | _ModelRetry | _ToolReturn,
53-
Discriminator('kind'),
54-
]
15+
from ._toolset import (
16+
CallToolParams,
17+
CallToolResult,
18+
TemporalWrapperToolset,
19+
)
5520

5621

5722
class TemporalFunctionToolset(TemporalWrapperToolset[AgentDepsT]):
@@ -70,7 +35,7 @@ def __init__(
7035
self.tool_activity_config = tool_activity_config
7136
self.run_context_type = run_context_type
7237

73-
async def call_tool_activity(params: _CallToolParams, deps: AgentDepsT) -> _CallToolResult:
38+
async def call_tool_activity(params: CallToolParams, deps: AgentDepsT) -> CallToolResult:
7439
name = params.name
7540
ctx = self.run_context_type.deserialize_run_context(params.serialized_run_context, deps=deps)
7641
try:
@@ -84,15 +49,7 @@ async def call_tool_activity(params: _CallToolParams, deps: AgentDepsT) -> _Call
8449
# The tool args will already have been validated into their proper types in the `ToolManager`,
8550
# but `execute_activity` would have turned them into simple Python types again, so we need to re-validate them.
8651
args_dict = tool.args_validator.validate_python(params.tool_args)
87-
try:
88-
result = await self.wrapped.call_tool(name, args_dict, ctx, tool)
89-
return _ToolReturn(result=result)
90-
except ApprovalRequired:
91-
return _ApprovalRequired()
92-
except CallDeferred:
93-
return _CallDeferred()
94-
except ModelRetry as e:
95-
return _ModelRetry(message=e.message)
52+
return await self._wrap_call_tool_result(self.wrapped.call_tool(name, args_dict, ctx, tool))
9653

9754
# Set type hint explicitly so that Temporal can take care of serialization and deserialization
9855
call_tool_activity.__annotations__['deps'] = deps_type
@@ -123,25 +80,18 @@ async def call_tool(
12380

12481
tool_activity_config = self.activity_config | tool_activity_config
12582
serialized_run_context = self.run_context_type.serialize_run_context(ctx)
126-
result = await workflow.execute_activity( # pyright: ignore[reportUnknownMemberType]
127-
activity=self.call_tool_activity,
128-
args=[
129-
_CallToolParams(
130-
name=name,
131-
tool_args=tool_args,
132-
serialized_run_context=serialized_run_context,
133-
),
134-
ctx.deps,
135-
],
136-
**tool_activity_config,
83+
return self._unwrap_call_tool_result(
84+
await workflow.execute_activity( # pyright: ignore[reportUnknownMemberType]
85+
activity=self.call_tool_activity,
86+
args=[
87+
CallToolParams(
88+
name=name,
89+
tool_args=tool_args,
90+
serialized_run_context=serialized_run_context,
91+
tool_def=None,
92+
),
93+
ctx.deps,
94+
],
95+
**tool_activity_config,
96+
)
13797
)
138-
if isinstance(result, _ApprovalRequired):
139-
raise ApprovalRequired()
140-
elif isinstance(result, _CallDeferred):
141-
raise CallDeferred()
142-
elif isinstance(result, _ModelRetry):
143-
raise ModelRetry(result.message)
144-
elif isinstance(result, _ToolReturn):
145-
return result.result
146-
else:
147-
assert_never(result)

pydantic_ai_slim/pydantic_ai/durable_exec/temporal/_mcp_server.py

Lines changed: 30 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -11,11 +11,15 @@
1111

1212
from pydantic_ai import ToolsetTool
1313
from pydantic_ai.exceptions import UserError
14-
from pydantic_ai.mcp import MCPServer, ToolResult
14+
from pydantic_ai.mcp import MCPServer
1515
from pydantic_ai.tools import AgentDepsT, RunContext, ToolDefinition
1616

1717
from ._run_context import TemporalRunContext
18-
from ._toolset import TemporalWrapperToolset
18+
from ._toolset import (
19+
CallToolParams,
20+
CallToolResult,
21+
TemporalWrapperToolset,
22+
)
1923

2024

2125
@dataclass
@@ -24,15 +28,6 @@ class _GetToolsParams:
2428
serialized_run_context: Any
2529

2630

27-
@dataclass
28-
@with_config(ConfigDict(arbitrary_types_allowed=True))
29-
class _CallToolParams:
30-
name: str
31-
tool_args: dict[str, Any]
32-
serialized_run_context: Any
33-
tool_def: ToolDefinition
34-
35-
3631
class TemporalMCPServer(TemporalWrapperToolset[AgentDepsT]):
3732
def __init__(
3833
self,
@@ -72,13 +67,16 @@ async def get_tools_activity(params: _GetToolsParams, deps: AgentDepsT) -> dict[
7267
get_tools_activity
7368
)
7469

75-
async def call_tool_activity(params: _CallToolParams, deps: AgentDepsT) -> ToolResult:
70+
async def call_tool_activity(params: CallToolParams, deps: AgentDepsT) -> CallToolResult:
7671
run_context = self.run_context_type.deserialize_run_context(params.serialized_run_context, deps=deps)
77-
return await self.wrapped.call_tool(
78-
params.name,
79-
params.tool_args,
80-
run_context,
81-
self.tool_for_tool_def(params.tool_def),
72+
assert isinstance(params.tool_def, ToolDefinition)
73+
return await self._wrap_call_tool_result(
74+
self.wrapped.call_tool(
75+
params.name,
76+
params.tool_args,
77+
run_context,
78+
self.tool_for_tool_def(params.tool_def),
79+
)
8280
)
8381

8482
# Set type hint explicitly so that Temporal can take care of serialization and deserialization
@@ -125,22 +123,24 @@ async def call_tool(
125123
tool_args: dict[str, Any],
126124
ctx: RunContext[AgentDepsT],
127125
tool: ToolsetTool[AgentDepsT],
128-
) -> ToolResult:
126+
) -> CallToolResult:
129127
if not workflow.in_workflow():
130128
return await super().call_tool(name, tool_args, ctx, tool)
131129

132130
tool_activity_config = self.activity_config | self.tool_activity_config.get(name, {})
133131
serialized_run_context = self.run_context_type.serialize_run_context(ctx)
134-
return await workflow.execute_activity( # pyright: ignore[reportUnknownMemberType]
135-
activity=self.call_tool_activity,
136-
args=[
137-
_CallToolParams(
138-
name=name,
139-
tool_args=tool_args,
140-
serialized_run_context=serialized_run_context,
141-
tool_def=tool.tool_def,
142-
),
143-
ctx.deps,
144-
],
145-
**tool_activity_config,
132+
return self._unwrap_call_tool_result(
133+
await workflow.execute_activity( # pyright: ignore[reportUnknownMemberType]
134+
activity=self.call_tool_activity,
135+
args=[
136+
CallToolParams(
137+
name=name,
138+
tool_args=tool_args,
139+
serialized_run_context=serialized_run_context,
140+
tool_def=tool.tool_def,
141+
),
142+
ctx.deps,
143+
],
144+
**tool_activity_config,
145+
)
146146
)

pydantic_ai_slim/pydantic_ai/durable_exec/temporal/_toolset.py

Lines changed: 67 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,58 @@
11
from __future__ import annotations
22

33
from abc import ABC, abstractmethod
4-
from collections.abc import Callable
5-
from typing import Any, Literal
4+
from collections.abc import Awaitable, Callable
5+
from dataclasses import dataclass
6+
from typing import Annotated, Any, Literal
67

8+
from pydantic import ConfigDict, Discriminator, with_config
79
from temporalio.workflow import ActivityConfig
10+
from typing_extensions import assert_never
811

912
from pydantic_ai import AbstractToolset, FunctionToolset, WrapperToolset
10-
from pydantic_ai.tools import AgentDepsT
13+
from pydantic_ai.exceptions import ApprovalRequired, CallDeferred, ModelRetry
14+
from pydantic_ai.tools import AgentDepsT, ToolDefinition
1115

1216
from ._run_context import TemporalRunContext
1317

1418

19+
@dataclass
20+
@with_config(ConfigDict(arbitrary_types_allowed=True))
21+
class CallToolParams:
22+
name: str
23+
tool_args: dict[str, Any]
24+
serialized_run_context: Any
25+
tool_def: ToolDefinition | None
26+
27+
28+
@dataclass
29+
class _ApprovalRequired:
30+
kind: Literal['approval_required'] = 'approval_required'
31+
32+
33+
@dataclass
34+
class _CallDeferred:
35+
kind: Literal['call_deferred'] = 'call_deferred'
36+
37+
38+
@dataclass
39+
class _ModelRetry:
40+
message: str
41+
kind: Literal['model_retry'] = 'model_retry'
42+
43+
44+
@dataclass
45+
class _ToolReturn:
46+
result: Any
47+
kind: Literal['tool_return'] = 'tool_return'
48+
49+
50+
CallToolResult = Annotated[
51+
_ApprovalRequired | _CallDeferred | _ModelRetry | _ToolReturn,
52+
Discriminator('kind'),
53+
]
54+
55+
1556
class TemporalWrapperToolset(WrapperToolset[AgentDepsT], ABC):
1657
@property
1758
def id(self) -> str:
@@ -30,6 +71,29 @@ def visit_and_replace(
3071
# Temporalized toolsets cannot be swapped out after the fact.
3172
return self
3273

74+
async def _wrap_call_tool_result(self, coro: Awaitable[Any]) -> CallToolResult:
75+
try:
76+
result = await coro
77+
return _ToolReturn(result=result)
78+
except ApprovalRequired:
79+
return _ApprovalRequired()
80+
except CallDeferred:
81+
return _CallDeferred()
82+
except ModelRetry as e:
83+
return _ModelRetry(message=e.message)
84+
85+
def _unwrap_call_tool_result(self, result: CallToolResult) -> Any:
86+
if isinstance(result, _ToolReturn):
87+
return result.result
88+
elif isinstance(result, _ApprovalRequired):
89+
raise ApprovalRequired()
90+
elif isinstance(result, _CallDeferred):
91+
raise CallDeferred()
92+
elif isinstance(result, _ModelRetry):
93+
raise ModelRetry(result.message)
94+
else:
95+
assert_never(result)
96+
3397

3498
def temporalize_toolset(
3599
toolset: AbstractToolset[AgentDepsT],

0 commit comments

Comments
 (0)