11from __future__ import annotations
22
33from collections .abc import Callable
4- from dataclasses import dataclass
5- from typing import Annotated , Any , Literal
4+ from typing import Any , Literal
65
7- from pydantic import ConfigDict , Discriminator , with_config
86from temporalio import activity , workflow
97from temporalio .workflow import ActivityConfig
10- from typing_extensions import assert_never
118
129from pydantic_ai import FunctionToolset , ToolsetTool
13- from pydantic_ai .exceptions import ApprovalRequired , CallDeferred , ModelRetry , UserError
10+ from pydantic_ai .exceptions import UserError
1411from pydantic_ai .tools import AgentDepsT , RunContext
1512from pydantic_ai .toolsets .function import FunctionToolsetTool
1613
1714from ._run_context import TemporalRunContext
18- from ._toolset import TemporalWrapperToolset
19-
20-
21- @dataclass
22- @with_config (ConfigDict (arbitrary_types_allowed = True ))
23- class _CallToolParams :
24- name : str
25- tool_args : dict [str , Any ]
26- serialized_run_context : Any
27-
28-
29- @dataclass
30- class _ApprovalRequired :
31- kind : Literal ['approval_required' ] = 'approval_required'
32-
33-
34- @dataclass
35- class _CallDeferred :
36- kind : Literal ['call_deferred' ] = 'call_deferred'
37-
38-
39- @dataclass
40- class _ModelRetry :
41- message : str
42- kind : Literal ['model_retry' ] = 'model_retry'
43-
44-
45- @dataclass
46- class _ToolReturn :
47- result : Any
48- kind : Literal ['tool_return' ] = 'tool_return'
49-
50-
51- _CallToolResult = Annotated [
52- _ApprovalRequired | _CallDeferred | _ModelRetry | _ToolReturn ,
53- Discriminator ('kind' ),
54- ]
15+ from ._toolset import (
16+ CallToolParams ,
17+ CallToolResult ,
18+ TemporalWrapperToolset ,
19+ )
5520
5621
5722class TemporalFunctionToolset (TemporalWrapperToolset [AgentDepsT ]):
@@ -70,7 +35,7 @@ def __init__(
7035 self .tool_activity_config = tool_activity_config
7136 self .run_context_type = run_context_type
7237
73- async def call_tool_activity (params : _CallToolParams , deps : AgentDepsT ) -> _CallToolResult :
38+ async def call_tool_activity (params : CallToolParams , deps : AgentDepsT ) -> CallToolResult :
7439 name = params .name
7540 ctx = self .run_context_type .deserialize_run_context (params .serialized_run_context , deps = deps )
7641 try :
@@ -84,15 +49,7 @@ async def call_tool_activity(params: _CallToolParams, deps: AgentDepsT) -> _Call
8449 # The tool args will already have been validated into their proper types in the `ToolManager`,
8550 # but `execute_activity` would have turned them into simple Python types again, so we need to re-validate them.
8651 args_dict = tool .args_validator .validate_python (params .tool_args )
87- try :
88- result = await self .wrapped .call_tool (name , args_dict , ctx , tool )
89- return _ToolReturn (result = result )
90- except ApprovalRequired :
91- return _ApprovalRequired ()
92- except CallDeferred :
93- return _CallDeferred ()
94- except ModelRetry as e :
95- return _ModelRetry (message = e .message )
52+ return await self ._wrap_call_tool_result (self .wrapped .call_tool (name , args_dict , ctx , tool ))
9653
9754 # Set type hint explicitly so that Temporal can take care of serialization and deserialization
9855 call_tool_activity .__annotations__ ['deps' ] = deps_type
@@ -123,25 +80,18 @@ async def call_tool(
12380
12481 tool_activity_config = self .activity_config | tool_activity_config
12582 serialized_run_context = self .run_context_type .serialize_run_context (ctx )
126- result = await workflow .execute_activity ( # pyright: ignore[reportUnknownMemberType]
127- activity = self .call_tool_activity ,
128- args = [
129- _CallToolParams (
130- name = name ,
131- tool_args = tool_args ,
132- serialized_run_context = serialized_run_context ,
133- ),
134- ctx .deps ,
135- ],
136- ** tool_activity_config ,
83+ return self ._unwrap_call_tool_result (
84+ await workflow .execute_activity ( # pyright: ignore[reportUnknownMemberType]
85+ activity = self .call_tool_activity ,
86+ args = [
87+ CallToolParams (
88+ name = name ,
89+ tool_args = tool_args ,
90+ serialized_run_context = serialized_run_context ,
91+ tool_def = None ,
92+ ),
93+ ctx .deps ,
94+ ],
95+ ** tool_activity_config ,
96+ )
13797 )
138- if isinstance (result , _ApprovalRequired ):
139- raise ApprovalRequired ()
140- elif isinstance (result , _CallDeferred ):
141- raise CallDeferred ()
142- elif isinstance (result , _ModelRetry ):
143- raise ModelRetry (result .message )
144- elif isinstance (result , _ToolReturn ):
145- return result .result
146- else :
147- assert_never (result )
0 commit comments