Skip to content

Commit dacb19d

Browse files
committed
refactor shared code into shared location
1 parent 5fb966c commit dacb19d

File tree

3 files changed

+104
-116
lines changed

3 files changed

+104
-116
lines changed
Lines changed: 14 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -1,57 +1,25 @@
11
from __future__ import annotations
22

33
from collections.abc import Callable
4-
from dataclasses import dataclass
5-
from typing import Annotated, Any, Literal
4+
from typing import Any, Literal
65

7-
from pydantic import ConfigDict, Discriminator, with_config
86
from temporalio import activity, workflow
97
from temporalio.workflow import ActivityConfig
10-
from typing_extensions import assert_never
118

129
from pydantic_ai import FunctionToolset, ToolsetTool
13-
from pydantic_ai.exceptions import ApprovalRequired, CallDeferred, ModelRetry, UserError
10+
from pydantic_ai.exceptions import UserError
1411
from pydantic_ai.tools import AgentDepsT, RunContext
1512
from pydantic_ai.toolsets.function import FunctionToolsetTool
1613

1714
from ._run_context import TemporalRunContext
18-
from ._toolset import TemporalWrapperToolset
19-
20-
21-
@dataclass
22-
@with_config(ConfigDict(arbitrary_types_allowed=True))
23-
class _CallToolParams:
24-
name: str
25-
tool_args: dict[str, Any]
26-
serialized_run_context: Any
27-
28-
29-
@dataclass
30-
class _ApprovalRequired:
31-
kind: Literal['approval_required'] = 'approval_required'
32-
33-
34-
@dataclass
35-
class _CallDeferred:
36-
kind: Literal['call_deferred'] = 'call_deferred'
37-
38-
39-
@dataclass
40-
class _ModelRetry:
41-
message: str
42-
kind: Literal['model_retry'] = 'model_retry'
43-
44-
45-
@dataclass
46-
class _ToolReturn:
47-
result: Any
48-
kind: Literal['tool_return'] = 'tool_return'
49-
50-
51-
_CallToolResult = Annotated[
52-
_ApprovalRequired | _CallDeferred | _ModelRetry | _ToolReturn,
53-
Discriminator('kind'),
54-
]
15+
from ._toolset import (
16+
TemporalWrapperToolset,
17+
_CallToolParams,
18+
_CallToolResult,
19+
_ToolReturn,
20+
remap_dataclass_to_exception,
21+
remap_exception_to_dataclass,
22+
)
5523

5624

5725
class TemporalFunctionToolset(TemporalWrapperToolset[AgentDepsT]):
@@ -87,12 +55,8 @@ async def call_tool_activity(params: _CallToolParams, deps: AgentDepsT) -> _Call
8755
try:
8856
result = await self.wrapped.call_tool(name, args_dict, ctx, tool)
8957
return _ToolReturn(result=result)
90-
except ApprovalRequired:
91-
return _ApprovalRequired()
92-
except CallDeferred:
93-
return _CallDeferred()
94-
except ModelRetry as e:
95-
return _ModelRetry(message=e.message)
58+
except Exception as e:
59+
return remap_exception_to_dataclass(e)
9660

9761
# Set type hint explicitly so that Temporal can take care of serialization and deserialization
9862
call_tool_activity.__annotations__['deps'] = deps_type
@@ -130,18 +94,10 @@ async def call_tool(
13094
name=name,
13195
tool_args=tool_args,
13296
serialized_run_context=serialized_run_context,
97+
tool_def=None,
13398
),
13499
ctx.deps,
135100
],
136101
**tool_activity_config,
137102
)
138-
if isinstance(result, _ApprovalRequired):
139-
raise ApprovalRequired()
140-
elif isinstance(result, _CallDeferred):
141-
raise CallDeferred()
142-
elif isinstance(result, _ModelRetry):
143-
raise ModelRetry(result.message)
144-
elif isinstance(result, _ToolReturn):
145-
return result.result
146-
else:
147-
assert_never(result)
103+
return remap_dataclass_to_exception(result)

pydantic_ai_slim/pydantic_ai/durable_exec/temporal/_mcp_server.py

Lines changed: 17 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -1,64 +1,27 @@
11
from __future__ import annotations
22

33
from collections.abc import Callable
4-
from dataclasses import dataclass
5-
from typing import Annotated, Any, Literal
4+
from typing import Any, Literal
65

7-
from pydantic import ConfigDict, Discriminator, with_config
86
from temporalio import activity, workflow
97
from temporalio.workflow import ActivityConfig
108
from typing_extensions import Self
119

1210
from pydantic_ai import ToolsetTool
13-
from pydantic_ai.exceptions import ApprovalRequired, CallDeferred, ModelRetry, UserError
11+
from pydantic_ai.exceptions import UserError
1412
from pydantic_ai.mcp import MCPServer, ToolResult
1513
from pydantic_ai.tools import AgentDepsT, RunContext, ToolDefinition
1614

1715
from ._run_context import TemporalRunContext
18-
from ._toolset import TemporalWrapperToolset
19-
20-
21-
@dataclass
22-
@with_config(ConfigDict(arbitrary_types_allowed=True))
23-
class _GetToolsParams:
24-
serialized_run_context: Any
25-
26-
27-
@dataclass
28-
@with_config(ConfigDict(arbitrary_types_allowed=True))
29-
class _CallToolParams:
30-
name: str
31-
tool_args: dict[str, Any]
32-
serialized_run_context: Any
33-
tool_def: ToolDefinition
34-
35-
36-
@dataclass
37-
class _ApprovalRequired:
38-
kind: Literal['approval_required'] = 'approval_required'
39-
40-
41-
@dataclass
42-
class _CallDeferred:
43-
kind: Literal['call_deferred'] = 'call_deferred'
44-
45-
46-
@dataclass
47-
class _ModelRetry:
48-
message: str
49-
kind: Literal['model_retry'] = 'model_retry'
50-
51-
52-
@dataclass
53-
class _ToolReturn:
54-
result: Any
55-
kind: Literal['tool_return'] = 'tool_return'
56-
57-
58-
_CallToolResult = Annotated[
59-
_ApprovalRequired | _CallDeferred | _ModelRetry | _ToolReturn,
60-
Discriminator('kind'),
61-
]
16+
from ._toolset import (
17+
TemporalWrapperToolset,
18+
_CallToolParams,
19+
_CallToolResult,
20+
_GetToolsParams,
21+
_ToolReturn,
22+
remap_exception_to_dataclass,
23+
remap_dataclass_to_exception,
24+
)
6225

6326

6427
class TemporalMCPServer(TemporalWrapperToolset[AgentDepsT]):
@@ -103,19 +66,16 @@ async def get_tools_activity(params: _GetToolsParams, deps: AgentDepsT) -> dict[
10366
async def call_tool_activity(params: _CallToolParams, deps: AgentDepsT) -> _CallToolResult:
10467
run_context = self.run_context_type.deserialize_run_context(params.serialized_run_context, deps=deps)
10568
try:
69+
assert isinstance(params.tool_def, ToolDefinition)
10670
result = await self.wrapped.call_tool(
10771
params.name,
10872
params.tool_args,
10973
run_context,
11074
self.tool_for_tool_def(params.tool_def),
11175
)
11276
return _ToolReturn(result=result)
113-
except ApprovalRequired:
114-
return _ApprovalRequired()
115-
except CallDeferred:
116-
return _CallDeferred()
117-
except ModelRetry as e:
118-
return _ModelRetry(message=e.message)
77+
except Exception as e:
78+
return remap_exception_to_dataclass(e)
11979

12080
# Set type hint explicitly so that Temporal can take care of serialization and deserialization
12181
call_tool_activity.__annotations__['deps'] = deps_type
@@ -161,13 +121,13 @@ async def call_tool(
161121
tool_args: dict[str, Any],
162122
ctx: RunContext[AgentDepsT],
163123
tool: ToolsetTool[AgentDepsT],
164-
) -> ToolResult:
124+
) -> _CallToolResult:
165125
if not workflow.in_workflow():
166126
return await super().call_tool(name, tool_args, ctx, tool)
167127

168128
tool_activity_config = self.activity_config | self.tool_activity_config.get(name, {})
169129
serialized_run_context = self.run_context_type.serialize_run_context(ctx)
170-
return await workflow.execute_activity( # pyright: ignore[reportUnknownMemberType]
130+
result = await workflow.execute_activity( # pyright: ignore[reportUnknownMemberType]
171131
activity=self.call_tool_activity,
172132
args=[
173133
_CallToolParams(
@@ -180,3 +140,4 @@ async def call_tool(
180140
],
181141
**tool_activity_config,
182142
)
143+
return remap_dataclass_to_exception(result)

pydantic_ai_slim/pydantic_ai/durable_exec/temporal/_toolset.py

Lines changed: 73 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,16 +2,63 @@
22

33
from abc import ABC, abstractmethod
44
from collections.abc import Callable
5-
from typing import Any, Literal
5+
from dataclasses import dataclass
6+
from typing import Annotated, Any, Literal
67

8+
from pydantic import ConfigDict, Discriminator, with_config
79
from temporalio.workflow import ActivityConfig
10+
from typing_extensions import assert_never
811

912
from pydantic_ai import AbstractToolset, FunctionToolset, WrapperToolset
10-
from pydantic_ai.tools import AgentDepsT
13+
from pydantic_ai.exceptions import ApprovalRequired, CallDeferred, ModelRetry, UserError
14+
from pydantic_ai.tools import AgentDepsT, ToolDefinition
1115

1216
from ._run_context import TemporalRunContext
1317

1418

19+
@dataclass
20+
@with_config(ConfigDict(arbitrary_types_allowed=True))
21+
class _GetToolsParams:
22+
serialized_run_context: Any
23+
24+
25+
@dataclass
26+
@with_config(ConfigDict(arbitrary_types_allowed=True))
27+
class _CallToolParams:
28+
name: str
29+
tool_args: dict[str, Any]
30+
serialized_run_context: Any
31+
tool_def: ToolDefinition | None
32+
33+
34+
@dataclass
35+
class _ApprovalRequired:
36+
kind: Literal['approval_required'] = 'approval_required'
37+
38+
39+
@dataclass
40+
class _CallDeferred:
41+
kind: Literal['call_deferred'] = 'call_deferred'
42+
43+
44+
@dataclass
45+
class _ModelRetry:
46+
message: str
47+
kind: Literal['model_retry'] = 'model_retry'
48+
49+
50+
@dataclass
51+
class _ToolReturn:
52+
result: Any
53+
kind: Literal['tool_return'] = 'tool_return'
54+
55+
56+
_CallToolResult = Annotated[
57+
_ApprovalRequired | _CallDeferred | _ModelRetry | _ToolReturn,
58+
Discriminator('kind'),
59+
]
60+
61+
1562
class TemporalWrapperToolset(WrapperToolset[AgentDepsT], ABC):
1663
@property
1764
def id(self) -> str:
@@ -31,6 +78,30 @@ def visit_and_replace(
3178
return self
3279

3380

81+
def remap_exception_to_dataclass(e: Exception) -> _CallToolResult:
82+
try:
83+
raise e
84+
except ApprovalRequired:
85+
return _ApprovalRequired()
86+
except CallDeferred:
87+
return _CallDeferred()
88+
except ModelRetry as e:
89+
return _ModelRetry(message=e.message)
90+
91+
92+
def remap_dataclass_to_exception(o: _CallToolResult):
93+
if isinstance(o, _ApprovalRequired):
94+
raise ApprovalRequired()
95+
elif isinstance(o, _CallDeferred):
96+
raise CallDeferred()
97+
elif isinstance(o, _ModelRetry):
98+
raise ModelRetry(o.message)
99+
elif isinstance(o, _ToolReturn):
100+
return o.result
101+
else:
102+
assert_never(o)
103+
104+
34105
def temporalize_toolset(
35106
toolset: AbstractToolset[AgentDepsT],
36107
activity_name_prefix: str,

0 commit comments

Comments
 (0)