Skip to content

Commit a941da4

Browse files
committed
Address review comments
1 parent 054c270 commit a941da4

File tree

3 files changed

+24
-23
lines changed

3 files changed

+24
-23
lines changed

docs/durable_execution/prefect.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ Any agent can be wrapped in a [`PrefectAgent`][pydantic_ai.durable_exec.prefect.
6161
* Wraps [tool calls](../tools.md) as Prefect tasks (configurable per-tool).
6262
* Wraps [MCP communication](../mcp/client.md) as Prefect tasks.
6363

64-
Event stream handlers are **not automatically wrapped** by Prefect. If they involve I/O or non-deterministic behavior, you can explicitly decorate them with `@task` from Prefect.
64+
Event stream handlers are **not automatically wrapped** by Prefect. If they involve I/O or non-deterministic behavior, you can explicitly decorate them with `@task` from Prefect. For examples, see the [streaming docs](../agents.md#streaming-all-events)
6565

6666
The original agent, model, and MCP server can still be used as normal outside the Prefect flow.
6767

pydantic_ai_slim/pydantic_ai/durable_exec/prefect/_agent.py

Lines changed: 19 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from __future__ import annotations
22

3-
from collections.abc import AsyncIterable, AsyncIterator, Iterator, Sequence
3+
from collections.abc import AsyncIterable, AsyncIterator, Callable, Iterator, Sequence
44
from contextlib import AbstractAsyncContextManager, asynccontextmanager, contextmanager
55
from contextvars import ContextVar
66
from typing import Any, overload
@@ -50,6 +50,10 @@ def __init__(
5050
model_task_config: TaskConfig | None = None,
5151
tool_task_config: TaskConfig | None = None,
5252
tool_task_config_by_name: dict[str, TaskConfig | None] | None = None,
53+
prefectify_toolset_func: Callable[
54+
[AbstractToolset[AgentDepsT], TaskConfig, TaskConfig, dict[str, TaskConfig | None]],
55+
AbstractToolset[AgentDepsT],
56+
] = prefectify_toolset,
5357
):
5458
"""Wrap an agent to enable it with Prefect durable flows, by automatically offloading model requests, tool calls, and MCP server communication to Prefect tasks.
5559
@@ -63,6 +67,9 @@ def __init__(
6367
model_task_config: The Prefect task config to use for model request tasks. If no config is provided, use the default settings of Prefect.
6468
tool_task_config: The default Prefect task config to use for tool calls. If no config is provided, use the default settings of Prefect.
6569
tool_task_config_by_name: Per-tool task configuration. Keys are tool names, values are TaskConfig or None (None disables task wrapping for that tool).
70+
prefectify_toolset_func: Optional function to use to prepare toolsets for Prefect by wrapping them in a `PrefectWrapperToolset` that moves methods that require IO to Prefect tasks.
71+
If not provided, only `FunctionToolset` and `MCPServer` will be prepared for Prefect.
72+
The function takes the toolset, the task config, the tool-specific task config, and the tool-specific task config by name.
6673
"""
6774
super().__init__(wrapped)
6875

@@ -91,23 +98,23 @@ def __init__(
9198
)
9299
self._model = prefect_model
93100

94-
prefect_toolsets = [toolset.visit_and_replace(self._prefectify_toolset) for toolset in wrapped.toolsets]
101+
def _prefectify_toolset(toolset: AbstractToolset[AgentDepsT]) -> AbstractToolset[AgentDepsT]:
102+
"""Convert a toolset to its Prefect equivalent."""
103+
return prefectify_toolset_func(
104+
toolset,
105+
self._mcp_task_config,
106+
self._tool_task_config,
107+
self._tool_task_config_by_name,
108+
)
109+
110+
prefect_toolsets = [toolset.visit_and_replace(_prefectify_toolset) for toolset in wrapped.toolsets]
95111
self._toolsets = prefect_toolsets
96112

97113
# Context variable to track when we're inside this agent's Prefect flow
98114
self._in_prefect_agent_flow: ContextVar[bool] = ContextVar(
99115
f'_in_prefect_agent_flow_{self._name}', default=False
100116
)
101117

102-
def _prefectify_toolset(self, toolset: AbstractToolset[AgentDepsT]) -> AbstractToolset[AgentDepsT]:
103-
"""Convert a toolset to its Prefect equivalent."""
104-
return prefectify_toolset(
105-
toolset,
106-
mcp_task_config=self._mcp_task_config,
107-
tool_task_config=self._tool_task_config,
108-
tool_task_config_by_name=self._tool_task_config_by_name,
109-
)
110-
111118
@property
112119
def name(self) -> str | None:
113120
return self._name
@@ -812,7 +819,7 @@ def override(
812819
tools: The tools to use instead of the tools registered with the agent.
813820
instructions: The instructions to use instead of the instructions registered with the agent.
814821
"""
815-
if _utils.is_set(model) and not isinstance(model, (PrefectModel)):
822+
if _utils.is_set(model) and not isinstance(model, PrefectModel):
816823
raise UserError(
817824
'Non-Prefect model cannot be contextually overridden inside a Prefect flow, it must be set at agent creation time.'
818825
)

pydantic_ai_slim/pydantic_ai/durable_exec/prefect/_model.py

Lines changed: 4 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -117,12 +117,9 @@ async def request(
117117
model_request_parameters: ModelRequestParameters,
118118
) -> ModelResponse:
119119
"""Make a model request, wrapped as a Prefect task when in a flow."""
120-
# Get model name for task description
121-
model_name = getattr(self.wrapped, 'model_name', 'unknown')
122-
123-
return await self._wrapped_request.with_options(name=f'Model Request: {model_name}', **self.task_config)(
124-
messages, model_settings, model_request_parameters
125-
)
120+
return await self._wrapped_request.with_options(
121+
name=f'Model Request: {self.wrapped.model_name}', **self.task_config
122+
)(messages, model_settings, model_request_parameters)
126123

127124
@asynccontextmanager
128125
async def request_stream(
@@ -149,10 +146,7 @@ async def request_stream(
149146
return
150147

151148
# If in a flow, consume the stream in a task and return the final response
152-
# Get model name for task description
153-
model_name = getattr(self.wrapped, 'model_name', 'unknown')
154-
155149
response = await self._wrapped_request_stream.with_options(
156-
name=f'Model Request (Streaming): {model_name}', **self.task_config
150+
name=f'Model Request (Streaming): {self.wrapped.model_name}', **self.task_config
157151
)(messages, model_settings, model_request_parameters, run_context)
158152
yield PrefectStreamedResponse(model_request_parameters, response)

0 commit comments

Comments
 (0)