Skip to content

Commit 5a3e023

Browse files
committed
move some stuff around
1 parent 27bdca9 commit 5a3e023

File tree

3 files changed

+80
-90
lines changed

3 files changed

+80
-90
lines changed

pydantic_ai_slim/pydantic_ai/durable_exec/temporal/_function_toolset.py

Lines changed: 18 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -13,12 +13,9 @@
1313

1414
from ._run_context import TemporalRunContext
1515
from ._toolset import (
16-
CallToolParamsData,
17-
CallToolResultData,
16+
CallToolParams,
17+
CallToolResult,
1818
TemporalWrapperToolset,
19-
ToolReturnData,
20-
remap_dataclass_to_exception,
21-
remap_exception_to_dataclass,
2219
)
2320

2421

@@ -38,7 +35,7 @@ def __init__(
3835
self.tool_activity_config = tool_activity_config
3936
self.run_context_type = run_context_type
4037

41-
async def call_tool_activity(params: CallToolParamsData, deps: AgentDepsT) -> CallToolResultData:
38+
async def call_tool_activity(params: CallToolParams, deps: AgentDepsT) -> CallToolResult:
4239
name = params.name
4340
ctx = self.run_context_type.deserialize_run_context(params.serialized_run_context, deps=deps)
4441
try:
@@ -52,11 +49,7 @@ async def call_tool_activity(params: CallToolParamsData, deps: AgentDepsT) -> Ca
5249
# The tool args will already have been validated into their proper types in the `ToolManager`,
5350
# but `execute_activity` would have turned them into simple Python types again, so we need to re-validate them.
5451
args_dict = tool.args_validator.validate_python(params.tool_args)
55-
try:
56-
result = await self.wrapped.call_tool(name, args_dict, ctx, tool)
57-
return ToolReturnData(result=result)
58-
except Exception as e:
59-
return remap_exception_to_dataclass(e)
52+
return await self._wrap_call_tool_result(self.wrapped.call_tool(name, args_dict, ctx, tool))
6053

6154
# Set type hint explicitly so that Temporal can take care of serialization and deserialization
6255
call_tool_activity.__annotations__['deps'] = deps_type
@@ -87,17 +80,18 @@ async def call_tool(
8780

8881
tool_activity_config = self.activity_config | tool_activity_config
8982
serialized_run_context = self.run_context_type.serialize_run_context(ctx)
90-
result = await workflow.execute_activity( # pyright: ignore[reportUnknownMemberType]
91-
activity=self.call_tool_activity,
92-
args=[
93-
CallToolParamsData(
94-
name=name,
95-
tool_args=tool_args,
96-
serialized_run_context=serialized_run_context,
97-
tool_def=None,
98-
),
99-
ctx.deps,
100-
],
101-
**tool_activity_config,
83+
return self._unwrap_call_tool_result(
84+
await workflow.execute_activity( # pyright: ignore[reportUnknownMemberType]
85+
activity=self.call_tool_activity,
86+
args=[
87+
CallToolParams(
88+
name=name,
89+
tool_args=tool_args,
90+
serialized_run_context=serialized_run_context,
91+
tool_def=None,
92+
),
93+
ctx.deps,
94+
],
95+
**tool_activity_config,
96+
)
10297
)
103-
return remap_dataclass_to_exception(result)

pydantic_ai_slim/pydantic_ai/durable_exec/temporal/_mcp_server.py

Lines changed: 32 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
11
from __future__ import annotations
22

33
from collections.abc import Callable
4+
from dataclasses import dataclass
45
from typing import Any, Literal
56

7+
from pydantic import ConfigDict, with_config
68
from temporalio import activity, workflow
79
from temporalio.workflow import ActivityConfig
810
from typing_extensions import Self
@@ -14,16 +16,18 @@
1416

1517
from ._run_context import TemporalRunContext
1618
from ._toolset import (
17-
CallToolParamsData,
18-
CallToolResultData,
19-
GetToolsParamsData,
19+
CallToolParams,
20+
CallToolResult,
2021
TemporalWrapperToolset,
21-
ToolReturnData,
22-
remap_dataclass_to_exception,
23-
remap_exception_to_dataclass,
2422
)
2523

2624

25+
@dataclass
26+
@with_config(ConfigDict(arbitrary_types_allowed=True))
27+
class _GetToolsParams:
28+
serialized_run_context: Any
29+
30+
2731
class TemporalMCPServer(TemporalWrapperToolset[AgentDepsT]):
2832
def __init__(
2933
self,
@@ -49,7 +53,7 @@ def __init__(
4953

5054
self.run_context_type = run_context_type
5155

52-
async def get_tools_activity(params: GetToolsParamsData, deps: AgentDepsT) -> dict[str, ToolDefinition]:
56+
async def get_tools_activity(params: _GetToolsParams, deps: AgentDepsT) -> dict[str, ToolDefinition]:
5357
run_context = self.run_context_type.deserialize_run_context(params.serialized_run_context, deps=deps)
5458
tools = await self.wrapped.get_tools(run_context)
5559
# ToolsetTool is not serializable as it holds a SchemaValidator (which is also the same for every MCP tool so unnecessary to pass along the wire every time),
@@ -63,19 +67,17 @@ async def get_tools_activity(params: GetToolsParamsData, deps: AgentDepsT) -> di
6367
get_tools_activity
6468
)
6569

66-
async def call_tool_activity(params: CallToolParamsData, deps: AgentDepsT) -> CallToolResultData:
70+
async def call_tool_activity(params: CallToolParams, deps: AgentDepsT) -> CallToolResult:
6771
run_context = self.run_context_type.deserialize_run_context(params.serialized_run_context, deps=deps)
68-
try:
69-
assert isinstance(params.tool_def, ToolDefinition)
70-
result = await self.wrapped.call_tool(
72+
assert isinstance(params.tool_def, ToolDefinition)
73+
return await self._wrap_call_tool_result(
74+
self.wrapped.call_tool(
7175
params.name,
7276
params.tool_args,
7377
run_context,
7478
self.tool_for_tool_def(params.tool_def),
7579
)
76-
return ToolReturnData(result=result)
77-
except Exception as e:
78-
return remap_exception_to_dataclass(e)
80+
)
7981

8082
# Set type hint explicitly so that Temporal can take care of serialization and deserialization
8183
call_tool_activity.__annotations__['deps'] = deps_type
@@ -108,7 +110,7 @@ async def get_tools(self, ctx: RunContext[AgentDepsT]) -> dict[str, ToolsetTool[
108110
tool_defs = await workflow.execute_activity( # pyright: ignore[reportUnknownMemberType]
109111
activity=self.get_tools_activity,
110112
args=[
111-
GetToolsParamsData(serialized_run_context=serialized_run_context),
113+
_GetToolsParams(serialized_run_context=serialized_run_context),
112114
ctx.deps,
113115
],
114116
**self.activity_config,
@@ -121,23 +123,24 @@ async def call_tool(
121123
tool_args: dict[str, Any],
122124
ctx: RunContext[AgentDepsT],
123125
tool: ToolsetTool[AgentDepsT],
124-
) -> CallToolResultData:
126+
) -> CallToolResult:
125127
if not workflow.in_workflow():
126128
return await super().call_tool(name, tool_args, ctx, tool)
127129

128130
tool_activity_config = self.activity_config | self.tool_activity_config.get(name, {})
129131
serialized_run_context = self.run_context_type.serialize_run_context(ctx)
130-
result = await workflow.execute_activity( # pyright: ignore[reportUnknownMemberType]
131-
activity=self.call_tool_activity,
132-
args=[
133-
CallToolParamsData(
134-
name=name,
135-
tool_args=tool_args,
136-
serialized_run_context=serialized_run_context,
137-
tool_def=tool.tool_def,
138-
),
139-
ctx.deps,
140-
],
141-
**tool_activity_config,
132+
return self._unwrap_call_tool_result(
133+
await workflow.execute_activity( # pyright: ignore[reportUnknownMemberType]
134+
activity=self.call_tool_activity,
135+
args=[
136+
CallToolParams(
137+
name=name,
138+
tool_args=tool_args,
139+
serialized_run_context=serialized_run_context,
140+
tool_def=tool.tool_def,
141+
),
142+
ctx.deps,
143+
],
144+
**tool_activity_config,
145+
)
142146
)
143-
return remap_dataclass_to_exception(result)

pydantic_ai_slim/pydantic_ai/durable_exec/temporal/_toolset.py

Lines changed: 30 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from __future__ import annotations
22

33
from abc import ABC, abstractmethod
4-
from collections.abc import Callable
4+
from collections.abc import Awaitable, Callable
55
from dataclasses import dataclass
66
from typing import Annotated, Any, Literal
77

@@ -18,43 +18,37 @@
1818

1919
@dataclass
2020
@with_config(ConfigDict(arbitrary_types_allowed=True))
21-
class GetToolsParamsData:
22-
serialized_run_context: Any
23-
24-
25-
@dataclass
26-
@with_config(ConfigDict(arbitrary_types_allowed=True))
27-
class CallToolParamsData:
21+
class CallToolParams:
2822
name: str
2923
tool_args: dict[str, Any]
3024
serialized_run_context: Any
3125
tool_def: ToolDefinition | None
3226

3327

3428
@dataclass
35-
class ApprovalRequiredData:
29+
class _ApprovalRequired:
3630
kind: Literal['approval_required'] = 'approval_required'
3731

3832

3933
@dataclass
40-
class CallDeferredData:
34+
class _CallDeferred:
4135
kind: Literal['call_deferred'] = 'call_deferred'
4236

4337

4438
@dataclass
45-
class ModelRetryData:
39+
class _ModelRetry:
4640
message: str
4741
kind: Literal['model_retry'] = 'model_retry'
4842

4943

5044
@dataclass
51-
class ToolReturnData:
45+
class _ToolReturn:
5246
result: Any
5347
kind: Literal['tool_return'] = 'tool_return'
5448

5549

56-
CallToolResultData = Annotated[
57-
ApprovalRequiredData | CallDeferredData | ModelRetryData | ToolReturnData,
50+
CallToolResult = Annotated[
51+
_ApprovalRequired | _CallDeferred | _ModelRetry | _ToolReturn,
5852
Discriminator('kind'),
5953
]
6054

@@ -77,29 +71,28 @@ def visit_and_replace(
7771
# Temporalized toolsets cannot be swapped out after the fact.
7872
return self
7973

80-
81-
def remap_exception_to_dataclass(e: Exception) -> CallToolResultData:
82-
try:
83-
raise e
84-
except ApprovalRequired:
85-
return ApprovalRequiredData()
86-
except CallDeferred:
87-
return CallDeferredData()
88-
except ModelRetry as e:
89-
return ModelRetryData(message=e.message)
90-
91-
92-
def remap_dataclass_to_exception(o: CallToolResultData):
93-
if isinstance(o, ApprovalRequiredData):
94-
raise ApprovalRequired()
95-
elif isinstance(o, CallDeferredData):
96-
raise CallDeferred()
97-
elif isinstance(o, ModelRetryData):
98-
raise ModelRetry(o.message)
99-
elif isinstance(o, ToolReturnData):
100-
return o.result
101-
else:
102-
assert_never(o)
74+
async def _wrap_call_tool_result(self, coro: Awaitable[Any]) -> CallToolResult:
75+
try:
76+
result = await coro
77+
return _ToolReturn(result=result)
78+
except ApprovalRequired:
79+
return _ApprovalRequired()
80+
except CallDeferred:
81+
return _CallDeferred()
82+
except ModelRetry as e:
83+
return _ModelRetry(message=e.message)
84+
85+
def _unwrap_call_tool_result(self, result: CallToolResult) -> Any:
86+
if isinstance(result, _ToolReturn):
87+
return result.result
88+
elif isinstance(result, _ApprovalRequired):
89+
raise ApprovalRequired()
90+
elif isinstance(result, _CallDeferred):
91+
raise CallDeferred()
92+
elif isinstance(result, _ModelRetry):
93+
raise ModelRetry(result.message)
94+
else:
95+
assert_never(result)
10396

10497

10598
def temporalize_toolset(

0 commit comments

Comments
 (0)