diff --git a/pydantic_ai_slim/pydantic_ai/durable_exec/temporal/_function_toolset.py b/pydantic_ai_slim/pydantic_ai/durable_exec/temporal/_function_toolset.py index f67b18170d..dd1f8c1ee3 100644 --- a/pydantic_ai_slim/pydantic_ai/durable_exec/temporal/_function_toolset.py +++ b/pydantic_ai_slim/pydantic_ai/durable_exec/temporal/_function_toolset.py @@ -1,57 +1,22 @@ from __future__ import annotations from collections.abc import Callable -from dataclasses import dataclass -from typing import Annotated, Any, Literal +from typing import Any, Literal -from pydantic import ConfigDict, Discriminator, with_config from temporalio import activity, workflow from temporalio.workflow import ActivityConfig -from typing_extensions import assert_never from pydantic_ai import FunctionToolset, ToolsetTool -from pydantic_ai.exceptions import ApprovalRequired, CallDeferred, ModelRetry, UserError +from pydantic_ai.exceptions import UserError from pydantic_ai.tools import AgentDepsT, RunContext from pydantic_ai.toolsets.function import FunctionToolsetTool from ._run_context import TemporalRunContext -from ._toolset import TemporalWrapperToolset - - -@dataclass -@with_config(ConfigDict(arbitrary_types_allowed=True)) -class _CallToolParams: - name: str - tool_args: dict[str, Any] - serialized_run_context: Any - - -@dataclass -class _ApprovalRequired: - kind: Literal['approval_required'] = 'approval_required' - - -@dataclass -class _CallDeferred: - kind: Literal['call_deferred'] = 'call_deferred' - - -@dataclass -class _ModelRetry: - message: str - kind: Literal['model_retry'] = 'model_retry' - - -@dataclass -class _ToolReturn: - result: Any - kind: Literal['tool_return'] = 'tool_return' - - -_CallToolResult = Annotated[ - _ApprovalRequired | _CallDeferred | _ModelRetry | _ToolReturn, - Discriminator('kind'), -] +from ._toolset import ( + CallToolParams, + CallToolResult, + TemporalWrapperToolset, +) class TemporalFunctionToolset(TemporalWrapperToolset[AgentDepsT]): @@ -70,7 +35,7 @@ def __init__( self.tool_activity_config = tool_activity_config self.run_context_type = run_context_type - async def call_tool_activity(params: _CallToolParams, deps: AgentDepsT) -> _CallToolResult: + async def call_tool_activity(params: CallToolParams, deps: AgentDepsT) -> CallToolResult: name = params.name ctx = self.run_context_type.deserialize_run_context(params.serialized_run_context, deps=deps) try: @@ -84,15 +49,7 @@ async def call_tool_activity(params: _CallToolParams, deps: AgentDepsT) -> _Call # The tool args will already have been validated into their proper types in the `ToolManager`, # but `execute_activity` would have turned them into simple Python types again, so we need to re-validate them. args_dict = tool.args_validator.validate_python(params.tool_args) - try: - result = await self.wrapped.call_tool(name, args_dict, ctx, tool) - return _ToolReturn(result=result) - except ApprovalRequired: - return _ApprovalRequired() - except CallDeferred: - return _CallDeferred() - except ModelRetry as e: - return _ModelRetry(message=e.message) + return await self._wrap_call_tool_result(self.wrapped.call_tool(name, args_dict, ctx, tool)) # Set type hint explicitly so that Temporal can take care of serialization and deserialization call_tool_activity.__annotations__['deps'] = deps_type @@ -123,25 +80,18 @@ async def call_tool( tool_activity_config = self.activity_config | tool_activity_config serialized_run_context = self.run_context_type.serialize_run_context(ctx) - result = await workflow.execute_activity( # pyright: ignore[reportUnknownMemberType] - activity=self.call_tool_activity, - args=[ - _CallToolParams( - name=name, - tool_args=tool_args, - serialized_run_context=serialized_run_context, - ), - ctx.deps, - ], - **tool_activity_config, + return self._unwrap_call_tool_result( + await workflow.execute_activity( # pyright: ignore[reportUnknownMemberType] + activity=self.call_tool_activity, + args=[ + CallToolParams( + name=name, + tool_args=tool_args, + serialized_run_context=serialized_run_context, + tool_def=None, + ), + ctx.deps, + ], + **tool_activity_config, + ) ) - if isinstance(result, _ApprovalRequired): - raise ApprovalRequired() - elif isinstance(result, _CallDeferred): - raise CallDeferred() - elif isinstance(result, _ModelRetry): - raise ModelRetry(result.message) - elif isinstance(result, _ToolReturn): - return result.result - else: - assert_never(result) diff --git a/pydantic_ai_slim/pydantic_ai/durable_exec/temporal/_mcp_server.py b/pydantic_ai_slim/pydantic_ai/durable_exec/temporal/_mcp_server.py index fedba5ae8c..3a36494468 100644 --- a/pydantic_ai_slim/pydantic_ai/durable_exec/temporal/_mcp_server.py +++ b/pydantic_ai_slim/pydantic_ai/durable_exec/temporal/_mcp_server.py @@ -11,11 +11,15 @@ from pydantic_ai import ToolsetTool from pydantic_ai.exceptions import UserError -from pydantic_ai.mcp import MCPServer, ToolResult +from pydantic_ai.mcp import MCPServer from pydantic_ai.tools import AgentDepsT, RunContext, ToolDefinition from ._run_context import TemporalRunContext -from ._toolset import TemporalWrapperToolset +from ._toolset import ( + CallToolParams, + CallToolResult, + TemporalWrapperToolset, +) @dataclass @@ -24,15 +28,6 @@ class _GetToolsParams: serialized_run_context: Any -@dataclass -@with_config(ConfigDict(arbitrary_types_allowed=True)) -class _CallToolParams: - name: str - tool_args: dict[str, Any] - serialized_run_context: Any - tool_def: ToolDefinition - - class TemporalMCPServer(TemporalWrapperToolset[AgentDepsT]): def __init__( self, @@ -72,13 +67,16 @@ async def get_tools_activity(params: _GetToolsParams, deps: AgentDepsT) -> dict[ get_tools_activity ) - async def call_tool_activity(params: _CallToolParams, deps: AgentDepsT) -> ToolResult: + async def call_tool_activity(params: CallToolParams, deps: AgentDepsT) -> CallToolResult: run_context = self.run_context_type.deserialize_run_context(params.serialized_run_context, deps=deps) - return await self.wrapped.call_tool( - params.name, - params.tool_args, - run_context, - self.tool_for_tool_def(params.tool_def), + assert isinstance(params.tool_def, ToolDefinition) + return await self._wrap_call_tool_result( + self.wrapped.call_tool( + params.name, + params.tool_args, + run_context, + self.tool_for_tool_def(params.tool_def), + ) ) # Set type hint explicitly so that Temporal can take care of serialization and deserialization @@ -125,22 +123,24 @@ async def call_tool( tool_args: dict[str, Any], ctx: RunContext[AgentDepsT], tool: ToolsetTool[AgentDepsT], - ) -> ToolResult: + ) -> CallToolResult: if not workflow.in_workflow(): return await super().call_tool(name, tool_args, ctx, tool) tool_activity_config = self.activity_config | self.tool_activity_config.get(name, {}) serialized_run_context = self.run_context_type.serialize_run_context(ctx) - return await workflow.execute_activity( # pyright: ignore[reportUnknownMemberType] - activity=self.call_tool_activity, - args=[ - _CallToolParams( - name=name, - tool_args=tool_args, - serialized_run_context=serialized_run_context, - tool_def=tool.tool_def, - ), - ctx.deps, - ], - **tool_activity_config, + return self._unwrap_call_tool_result( + await workflow.execute_activity( # pyright: ignore[reportUnknownMemberType] + activity=self.call_tool_activity, + args=[ + CallToolParams( + name=name, + tool_args=tool_args, + serialized_run_context=serialized_run_context, + tool_def=tool.tool_def, + ), + ctx.deps, + ], + **tool_activity_config, + ) ) diff --git a/pydantic_ai_slim/pydantic_ai/durable_exec/temporal/_toolset.py b/pydantic_ai_slim/pydantic_ai/durable_exec/temporal/_toolset.py index a8744dbca3..d4adb4b6a7 100644 --- a/pydantic_ai_slim/pydantic_ai/durable_exec/temporal/_toolset.py +++ b/pydantic_ai_slim/pydantic_ai/durable_exec/temporal/_toolset.py @@ -1,17 +1,58 @@ from __future__ import annotations from abc import ABC, abstractmethod -from collections.abc import Callable -from typing import Any, Literal +from collections.abc import Awaitable, Callable +from dataclasses import dataclass +from typing import Annotated, Any, Literal +from pydantic import ConfigDict, Discriminator, with_config from temporalio.workflow import ActivityConfig +from typing_extensions import assert_never from pydantic_ai import AbstractToolset, FunctionToolset, WrapperToolset -from pydantic_ai.tools import AgentDepsT +from pydantic_ai.exceptions import ApprovalRequired, CallDeferred, ModelRetry +from pydantic_ai.tools import AgentDepsT, ToolDefinition from ._run_context import TemporalRunContext +@dataclass +@with_config(ConfigDict(arbitrary_types_allowed=True)) +class CallToolParams: + name: str + tool_args: dict[str, Any] + serialized_run_context: Any + tool_def: ToolDefinition | None + + +@dataclass +class _ApprovalRequired: + kind: Literal['approval_required'] = 'approval_required' + + +@dataclass +class _CallDeferred: + kind: Literal['call_deferred'] = 'call_deferred' + + +@dataclass +class _ModelRetry: + message: str + kind: Literal['model_retry'] = 'model_retry' + + +@dataclass +class _ToolReturn: + result: Any + kind: Literal['tool_return'] = 'tool_return' + + +CallToolResult = Annotated[ + _ApprovalRequired | _CallDeferred | _ModelRetry | _ToolReturn, + Discriminator('kind'), +] + + class TemporalWrapperToolset(WrapperToolset[AgentDepsT], ABC): @property def id(self) -> str: @@ -30,6 +71,29 @@ def visit_and_replace( # Temporalized toolsets cannot be swapped out after the fact. return self + async def _wrap_call_tool_result(self, coro: Awaitable[Any]) -> CallToolResult: + try: + result = await coro + return _ToolReturn(result=result) + except ApprovalRequired: + return _ApprovalRequired() + except CallDeferred: + return _CallDeferred() + except ModelRetry as e: + return _ModelRetry(message=e.message) + + def _unwrap_call_tool_result(self, result: CallToolResult) -> Any: + if isinstance(result, _ToolReturn): + return result.result + elif isinstance(result, _ApprovalRequired): + raise ApprovalRequired() + elif isinstance(result, _CallDeferred): + raise CallDeferred() + elif isinstance(result, _ModelRetry): + raise ModelRetry(result.message) + else: + assert_never(result) + def temporalize_toolset( toolset: AbstractToolset[AgentDepsT],