Skip to content

Commit 8f28c34

Browse files
authored
Let Agent be run in a Temporal workflow by moving model requests, tool calls, and MCP to Temporal activities (#2225)
1 parent 3bc8e43 commit 8f28c34

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

43 files changed

+7867
-55
lines changed

docs/api/agent.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,3 +12,4 @@
1212
- RunOutputDataT
1313
- capture_run_messages
1414
- InstrumentationSettings
15+
- EventStreamHandler

docs/api/durable_exec.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
# `pydantic_ai.durable_exec`
2+
3+
::: pydantic_ai.durable_exec.temporal

docs/changelog.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ Pydantic AI is still pre-version 1, so breaking changes will occur, however:
1212
!!! note
1313
Here's a filtered list of the breaking changes for each version to help you upgrade Pydantic AI.
1414

15-
### v0.7.0 (2025-08-08)
15+
### v0.7.0 (2025-08-12)
1616

1717
See [#2458](https://github.com/pydantic/pydantic-ai/pull/2458) - `pydantic_ai.models.StreamedResponse` now yields a `FinalResultEvent` along with the existing `PartStartEvent` and `PartDeltaEvent`. If you're using `pydantic_ai.direct.model_request_stream` or `pydantic_ai.direct.model_request_stream_sync`, you may need to update your code to account for this.
1818

docs/temporal.md

Lines changed: 232 additions & 0 deletions
Large diffs are not rendered by default.

mkdocs.yml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@ nav:
4444
- builtin-tools.md
4545
- common-tools.md
4646
- retries.md
47+
- temporal.md
4748
- MCP:
4849
- mcp/index.md
4950
- mcp/client.md
@@ -75,6 +76,7 @@ nav:
7576
- api/toolsets.md
7677
- api/builtin_tools.md
7778
- api/common_tools.md
79+
- api/durable_exec.md
7880
- api/output.md
7981
- api/result.md
8082
- api/messages.md

pydantic_ai_slim/pydantic_ai/agent/__init__.py

Lines changed: 47 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,7 @@
8282
'InstrumentationSettings',
8383
'WrapperAgent',
8484
'AbstractAgent',
85+
'EventStreamHandler',
8586
)
8687

8788

@@ -401,6 +402,11 @@ def name(self, value: str | None) -> None:
401402
"""Set the name of the agent, used for logging."""
402403
self._name = value
403404

405+
@property
406+
def deps_type(self) -> type:
407+
"""The type of dependencies used by the agent."""
408+
return self._deps_type
409+
404410
@property
405411
def output_type(self) -> OutputSpec[OutputDataT]:
406412
"""The type of data output by agent runs, used to validate the data returned by the model, defaults to `str`."""
@@ -593,12 +599,7 @@ async def main():
593599
run_step=state.run_step,
594600
)
595601

596-
toolset = self._get_toolset(additional=toolsets)
597-
598-
if output_toolset is not None:
599-
if self._prepare_output_tools:
600-
output_toolset = PreparedToolset(output_toolset, self._prepare_output_tools)
601-
toolset = CombinedToolset([output_toolset, toolset])
602+
toolset = self._get_toolset(output_toolset=output_toolset, additional_toolsets=toolsets)
602603

603604
async with toolset:
604605
# This will raise errors for any name conflicts
@@ -1240,48 +1241,64 @@ def _get_deps(self: Agent[T, OutputDataT], deps: T) -> T:
12401241
return deps
12411242

12421243
def _get_toolset(
1243-
self, additional: Sequence[AbstractToolset[AgentDepsT]] | None = None
1244+
self,
1245+
output_toolset: AbstractToolset[AgentDepsT] | None | _utils.Unset = _utils.UNSET,
1246+
additional_toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None,
12441247
) -> AbstractToolset[AgentDepsT]:
1245-
"""Get the combined toolset containing function tools registered directly to the agent and user-provided toolsets including MCP servers.
1248+
"""Get the complete toolset.
12461249
12471250
Args:
1248-
additional: Additional toolsets to add.
1251+
output_toolset: The output toolset to use instead of the one built at agent construction time.
1252+
additional_toolsets: Additional toolsets to add, unless toolsets have been overridden.
12491253
"""
1250-
if some_tools := self._override_tools.get():
1251-
function_toolset = _AgentFunctionToolset(some_tools.value, max_retries=self._max_tool_retries)
1252-
else:
1253-
function_toolset = self._function_toolset
1254+
toolsets = self.toolsets
1255+
# Don't add additional toolsets if the toolsets have been overridden
1256+
if additional_toolsets and self._override_toolsets.get() is None:
1257+
toolsets = [*toolsets, *additional_toolsets]
12541258

1255-
if some_user_toolsets := self._override_toolsets.get():
1256-
user_toolsets = some_user_toolsets.value
1257-
else:
1258-
# Copy the dynamic toolsets to ensure each run has its own instances
1259-
dynamic_toolsets = [dataclasses.replace(toolset) for toolset in self._dynamic_toolsets]
1260-
user_toolsets = [*self._user_toolsets, *dynamic_toolsets, *(additional or [])]
1259+
toolset = CombinedToolset(toolsets)
12611260

1262-
if user_toolsets:
1263-
toolset = CombinedToolset([function_toolset, *user_toolsets])
1264-
else:
1265-
toolset = function_toolset
1261+
# Copy the dynamic toolsets to ensure each run has its own instances
1262+
def copy_dynamic_toolsets(toolset: AbstractToolset[AgentDepsT]) -> AbstractToolset[AgentDepsT]:
1263+
if isinstance(toolset, DynamicToolset):
1264+
return dataclasses.replace(toolset)
1265+
else:
1266+
return toolset
1267+
1268+
toolset = toolset.visit_and_replace(copy_dynamic_toolsets)
12661269

12671270
if self._prepare_tools:
12681271
toolset = PreparedToolset(toolset, self._prepare_tools)
12691272

1273+
output_toolset = output_toolset if _utils.is_set(output_toolset) else self._output_toolset
1274+
if output_toolset is not None:
1275+
if self._prepare_output_tools:
1276+
output_toolset = PreparedToolset(output_toolset, self._prepare_output_tools)
1277+
toolset = CombinedToolset([output_toolset, toolset])
1278+
12701279
return toolset
12711280

12721281
@property
12731282
def toolsets(self) -> Sequence[AbstractToolset[AgentDepsT]]:
12741283
"""All toolsets registered on the agent, including a function toolset holding tools that were registered on the agent directly.
12751284
1276-
If a `prepare_tools` function was configured on the agent, this will contain just a `PreparedToolset` wrapping the original toolsets.
1277-
12781285
Output tools are not included.
12791286
"""
1280-
toolset = self._get_toolset()
1281-
if isinstance(toolset, CombinedToolset):
1282-
return toolset.toolsets
1287+
toolsets: list[AbstractToolset[AgentDepsT]] = []
1288+
1289+
if some_tools := self._override_tools.get():
1290+
function_toolset = _AgentFunctionToolset(some_tools.value, max_retries=self._max_tool_retries)
1291+
else:
1292+
function_toolset = self._function_toolset
1293+
toolsets.append(function_toolset)
1294+
1295+
if some_user_toolsets := self._override_toolsets.get():
1296+
user_toolsets = some_user_toolsets.value
12831297
else:
1284-
return [toolset]
1298+
user_toolsets = [*self._user_toolsets, *self._dynamic_toolsets]
1299+
toolsets.extend(user_toolsets)
1300+
1301+
return toolsets
12851302

12861303
def _prepare_output_schema(
12871304
self, output_type: OutputSpec[RunOutputDataT] | None, model_profile: ModelProfile
@@ -1369,7 +1386,7 @@ async def run_mcp_servers(
13691386
class _AgentFunctionToolset(FunctionToolset[AgentDepsT]):
13701387
@property
13711388
def id(self) -> str:
1372-
return '<agent>' # pragma: no cover
1389+
return '<agent>'
13731390

13741391
@property
13751392
def label(self) -> str:

pydantic_ai_slim/pydantic_ai/agent/abstract.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,12 @@ def name(self, value: str | None) -> None:
9393
"""Set the name of the agent, used for logging."""
9494
raise NotImplementedError
9595

96+
@property
97+
@abstractmethod
98+
def deps_type(self) -> type:
99+
"""The type of dependencies used by the agent."""
100+
raise NotImplementedError
101+
96102
@property
97103
@abstractmethod
98104
def output_type(self) -> OutputSpec[OutputDataT]:

pydantic_ai_slim/pydantic_ai/agent/wrapper.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,10 @@ def name(self) -> str | None:
4343
def name(self, value: str | None) -> None:
4444
self.wrapped.name = value
4545

46+
@property
47+
def deps_type(self) -> type:
48+
return self.wrapped.deps_type
49+
4650
@property
4751
def output_type(self) -> OutputSpec[OutputDataT]:
4852
return self.wrapped.output_type
@@ -196,8 +200,8 @@ async def main():
196200
usage=usage,
197201
infer_name=infer_name,
198202
toolsets=toolsets,
199-
) as result:
200-
yield result
203+
) as run:
204+
yield run
201205

202206
@contextmanager
203207
def override(

pydantic_ai_slim/pydantic_ai/durable_exec/__init__.py

Whitespace-only changes.
Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,83 @@
1+
from __future__ import annotations
2+
3+
import warnings
4+
from collections.abc import Sequence
5+
from dataclasses import replace
6+
from typing import Any, Callable
7+
8+
from pydantic.errors import PydanticUserError
9+
from temporalio.client import ClientConfig, Plugin as ClientPlugin
10+
from temporalio.contrib.pydantic import PydanticPayloadConverter, pydantic_data_converter
11+
from temporalio.converter import DefaultPayloadConverter
12+
from temporalio.worker import Plugin as WorkerPlugin, WorkerConfig
13+
from temporalio.worker.workflow_sandbox import SandboxedWorkflowRunner
14+
15+
from ...exceptions import UserError
16+
from ._agent import TemporalAgent
17+
from ._logfire import LogfirePlugin
18+
from ._run_context import TemporalRunContext
19+
from ._toolset import TemporalWrapperToolset
20+
21+
__all__ = [
22+
'TemporalAgent',
23+
'PydanticAIPlugin',
24+
'LogfirePlugin',
25+
'AgentPlugin',
26+
'TemporalRunContext',
27+
'TemporalWrapperToolset',
28+
]
29+
30+
31+
class PydanticAIPlugin(ClientPlugin, WorkerPlugin):
32+
"""Temporal client and worker plugin for Pydantic AI."""
33+
34+
def configure_client(self, config: ClientConfig) -> ClientConfig:
35+
if (data_converter := config.get('data_converter')) and data_converter.payload_converter_class not in (
36+
DefaultPayloadConverter,
37+
PydanticPayloadConverter,
38+
):
39+
warnings.warn( # pragma: no cover
40+
'A non-default Temporal data converter was used which has been replaced with the Pydantic data converter.'
41+
)
42+
43+
config['data_converter'] = pydantic_data_converter
44+
return super().configure_client(config)
45+
46+
def configure_worker(self, config: WorkerConfig) -> WorkerConfig:
47+
runner = config.get('workflow_runner') # pyright: ignore[reportUnknownMemberType]
48+
if isinstance(runner, SandboxedWorkflowRunner): # pragma: no branch
49+
config['workflow_runner'] = replace(
50+
runner,
51+
restrictions=runner.restrictions.with_passthrough_modules(
52+
'pydantic_ai',
53+
'logfire',
54+
'rich',
55+
'httpx',
56+
# Imported inside `logfire._internal.json_encoder` when running `logfire.info` inside an activity with attributes to serialize
57+
'attrs',
58+
# Imported inside `logfire._internal.json_schema` when running `logfire.info` inside an activity with attributes to serialize
59+
'numpy',
60+
'pandas',
61+
),
62+
)
63+
64+
config['workflow_failure_exception_types'] = [
65+
*config.get('workflow_failure_exception_types', []), # pyright: ignore[reportUnknownMemberType]
66+
UserError,
67+
PydanticUserError,
68+
]
69+
70+
return super().configure_worker(config)
71+
72+
73+
class AgentPlugin(WorkerPlugin):
74+
"""Temporal worker plugin for a specific Pydantic AI agent."""
75+
76+
def __init__(self, agent: TemporalAgent[Any, Any]):
77+
self.agent = agent
78+
79+
def configure_worker(self, config: WorkerConfig) -> WorkerConfig:
80+
activities: Sequence[Callable[..., Any]] = config.get('activities', []) # pyright: ignore[reportUnknownMemberType]
81+
# Activities are checked for name conflicts by Temporal.
82+
config['activities'] = [*activities, *self.agent.temporal_activities]
83+
return super().configure_worker(config)

0 commit comments

Comments
 (0)