Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions pydantic_ai_slim/pydantic_ai/ui/_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from abc import ABC, abstractmethod
from collections.abc import AsyncIterator, Sequence
from dataclasses import KW_ONLY, Field, dataclass, replace
from dataclasses import KW_ONLY, Field, dataclass
from functools import cached_property
from http import HTTPStatus
from typing import (
Expand Down Expand Up @@ -238,7 +238,7 @@ def run_stream_native(
else:
state = raw_state

deps = replace(deps, state=state)
deps.state = state
elif self.state:
raise UserError(
f'State is provided but `deps` of type `{type(deps).__name__}` does not implement the `StateHandler` protocol: it needs to be a dataclass with a non-optional `state` field.'
Expand Down
9 changes: 8 additions & 1 deletion pydantic_ai_slim/pydantic_ai/ui/ag_ui/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from __future__ import annotations

from collections.abc import Callable, Mapping, Sequence
from dataclasses import replace
from typing import Any, Generic

from typing_extensions import Self
Expand All @@ -18,7 +19,7 @@
from pydantic_ai.toolsets import AbstractToolset
from pydantic_ai.usage import RunUsage, UsageLimits

from .. import OnCompleteFunc
from .. import OnCompleteFunc, StateHandler
from ._adapter import AGUIAdapter

try:
Expand Down Expand Up @@ -121,6 +122,12 @@ def __init__(

async def run_agent(request: Request) -> Response:
"""Endpoint to run the agent with the provided input data."""
# `dispatch_request` will store the frontend state from the request on `deps.state` (if it implements the `StateHandler` protocol),
# so we need to copy the deps to avoid different requests mutating the same deps object.
nonlocal deps
if isinstance(deps, StateHandler): # pragma: no branch
deps = replace(deps)

return await AGUIAdapter[AgentDepsT, OutputDataT].dispatch_request(
request,
agent=agent,
Expand Down
20 changes: 16 additions & 4 deletions tests/test_ag_ui.py
Original file line number Diff line number Diff line change
Expand Up @@ -1154,15 +1154,21 @@ async def store_state(
),
]

deps = StateDeps(StateInt(value=0))
seen_deps_states: list[int] = []

for run_input in run_inputs:
events = list[dict[str, Any]]()
async for event in run_ag_ui(agent, run_input, deps=deps):
deps = StateDeps(StateInt(value=0))

async def on_complete(result: AgentRunResult[Any]):
seen_deps_states.append(deps.state.value)

async for event in run_ag_ui(agent, run_input, deps=deps, on_complete=on_complete):
events.append(json.loads(event.removeprefix('data: ')))

assert events == simple_result()
assert seen_states == snapshot([41, 0, 0, 42])
assert seen_deps_states == snapshot([42, 1, 1, 43])


async def test_request_with_state_without_handler() -> None:
Expand Down Expand Up @@ -1275,8 +1281,10 @@ async def get_state(ctx: RunContext[StateDeps[StateInt]]) -> int:
async def test_to_ag_ui() -> None:
"""Test the agent.to_ag_ui method."""

agent = Agent(model=FunctionModel(stream_function=simple_stream))
app = agent.to_ag_ui()
agent = Agent(model=FunctionModel(stream_function=simple_stream), deps_type=StateDeps[StateInt])

deps = StateDeps(StateInt(value=0))
app = agent.to_ag_ui(deps=deps)
async with LifespanManager(app):
transport = httpx.ASGITransport(app)
async with httpx.AsyncClient(transport=transport) as client:
Expand All @@ -1286,6 +1294,7 @@ async def test_to_ag_ui() -> None:
id='msg_1',
content='Hello, world!',
),
state=StateInt(value=42),
)
async with client.stream(
'POST',
Expand All @@ -1301,6 +1310,9 @@ async def test_to_ag_ui() -> None:

assert events == simple_result()

# Verify the state was not mutated by the run
assert deps.state.value == 0


async def test_callback_sync() -> None:
"""Test that sync callbacks work correctly."""
Expand Down