Skip to content

Commit 6e14eb4

Browse files
committed
fix tests failure and allow event stream task cusomization
1 parent 32e93a7 commit 6e14eb4

File tree

2 files changed

+178
-228
lines changed

2 files changed

+178
-228
lines changed

pydantic_ai_slim/pydantic_ai/durable_exec/prefect/_agent.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@ def __init__(
5050
model_task_config: TaskConfig | None = None,
5151
tool_task_config: TaskConfig | None = None,
5252
tool_task_config_by_name: dict[str, TaskConfig | None] | None = None,
53+
event_stream_handler_task_config: TaskConfig | None = None,
5354
prefectify_toolset_func: Callable[
5455
[AbstractToolset[AgentDepsT], TaskConfig, TaskConfig, dict[str, TaskConfig | None]],
5556
AbstractToolset[AgentDepsT],
@@ -67,6 +68,7 @@ def __init__(
6768
model_task_config: The Prefect task config to use for model request tasks. If no config is provided, use the default settings of Prefect.
6869
tool_task_config: The default Prefect task config to use for tool calls. If no config is provided, use the default settings of Prefect.
6970
tool_task_config_by_name: Per-tool task configuration. Keys are tool names, values are TaskConfig or None (None disables task wrapping for that tool).
71+
event_stream_handler_task_config: The Prefect task config to use for the event stream handler task. If no config is provided, use the default settings of Prefect.
7072
prefectify_toolset_func: Optional function to use to prepare toolsets for Prefect by wrapping them in a `PrefectWrapperToolset` that moves methods that require IO to Prefect tasks.
7173
If not provided, only `FunctionToolset` and `MCPServer` will be prepared for Prefect.
7274
The function takes the toolset, the task config, the tool-specific task config, and the tool-specific task config by name.
@@ -85,6 +87,7 @@ def __init__(
8587
self._model_task_config = default_task_config | (model_task_config or {})
8688
self._tool_task_config = default_task_config | (tool_task_config or {})
8789
self._tool_task_config_by_name = tool_task_config_by_name or {}
90+
self._event_stream_handler_task_config = default_task_config | (event_stream_handler_task_config or {})
8891

8992
if not isinstance(wrapped.model, Model):
9093
raise UserError(
@@ -147,7 +150,7 @@ async def _call_event_stream_handler_in_flow(
147150
assert handler is not None
148151

149152
# Create a task to handle each event
150-
@task(name='Handle Stream Event')
153+
@task(name='Handle Stream Event', **self._event_stream_handler_task_config)
151154
async def event_stream_handler_task(event: _messages.AgentStreamEvent) -> None:
152155
async def streamed_response():
153156
yield event

0 commit comments

Comments
 (0)