diff --git a/pydantic_ai_slim/pydantic_ai/ui/_adapter.py b/pydantic_ai_slim/pydantic_ai/ui/_adapter.py index bb03147b71..58576941fe 100644 --- a/pydantic_ai_slim/pydantic_ai/ui/_adapter.py +++ b/pydantic_ai_slim/pydantic_ai/ui/_adapter.py @@ -2,7 +2,7 @@ from abc import ABC, abstractmethod from collections.abc import AsyncIterator, Sequence -from dataclasses import KW_ONLY, Field, dataclass, replace +from dataclasses import KW_ONLY, Field, dataclass from functools import cached_property from http import HTTPStatus from typing import ( @@ -238,7 +238,7 @@ def run_stream_native( else: state = raw_state - deps = replace(deps, state=state) + deps.state = state elif self.state: raise UserError( f'State is provided but `deps` of type `{type(deps).__name__}` does not implement the `StateHandler` protocol: it needs to be a dataclass with a non-optional `state` field.' diff --git a/pydantic_ai_slim/pydantic_ai/ui/ag_ui/app.py b/pydantic_ai_slim/pydantic_ai/ui/ag_ui/app.py index 2d9cb434f0..1f0fbe5262 100644 --- a/pydantic_ai_slim/pydantic_ai/ui/ag_ui/app.py +++ b/pydantic_ai_slim/pydantic_ai/ui/ag_ui/app.py @@ -3,6 +3,7 @@ from __future__ import annotations from collections.abc import Callable, Mapping, Sequence +from dataclasses import replace from typing import Any, Generic from typing_extensions import Self @@ -18,7 +19,7 @@ from pydantic_ai.toolsets import AbstractToolset from pydantic_ai.usage import RunUsage, UsageLimits -from .. import OnCompleteFunc +from .. import OnCompleteFunc, StateHandler from ._adapter import AGUIAdapter try: @@ -121,6 +122,12 @@ def __init__( async def run_agent(request: Request) -> Response: """Endpoint to run the agent with the provided input data.""" + # `dispatch_request` will store the frontend state from the request on `deps.state` (if it implements the `StateHandler` protocol), + # so we need to copy the deps to avoid different requests mutating the same deps object. + nonlocal deps + if isinstance(deps, StateHandler): # pragma: no branch + deps = replace(deps) + return await AGUIAdapter[AgentDepsT, OutputDataT].dispatch_request( request, agent=agent, diff --git a/tests/test_ag_ui.py b/tests/test_ag_ui.py index 07d18bcfac..0ca9dcc3aa 100644 --- a/tests/test_ag_ui.py +++ b/tests/test_ag_ui.py @@ -1154,15 +1154,21 @@ async def store_state( ), ] - deps = StateDeps(StateInt(value=0)) + seen_deps_states: list[int] = [] for run_input in run_inputs: events = list[dict[str, Any]]() - async for event in run_ag_ui(agent, run_input, deps=deps): + deps = StateDeps(StateInt(value=0)) + + async def on_complete(result: AgentRunResult[Any]): + seen_deps_states.append(deps.state.value) + + async for event in run_ag_ui(agent, run_input, deps=deps, on_complete=on_complete): events.append(json.loads(event.removeprefix('data: '))) assert events == simple_result() assert seen_states == snapshot([41, 0, 0, 42]) + assert seen_deps_states == snapshot([42, 1, 1, 43]) async def test_request_with_state_without_handler() -> None: @@ -1275,8 +1281,10 @@ async def get_state(ctx: RunContext[StateDeps[StateInt]]) -> int: async def test_to_ag_ui() -> None: """Test the agent.to_ag_ui method.""" - agent = Agent(model=FunctionModel(stream_function=simple_stream)) - app = agent.to_ag_ui() + agent = Agent(model=FunctionModel(stream_function=simple_stream), deps_type=StateDeps[StateInt]) + + deps = StateDeps(StateInt(value=0)) + app = agent.to_ag_ui(deps=deps) async with LifespanManager(app): transport = httpx.ASGITransport(app) async with httpx.AsyncClient(transport=transport) as client: @@ -1286,6 +1294,7 @@ async def test_to_ag_ui() -> None: id='msg_1', content='Hello, world!', ), + state=StateInt(value=42), ) async with client.stream( 'POST', @@ -1301,6 +1310,9 @@ async def test_to_ag_ui() -> None: assert events == simple_result() + # Verify the state was not mutated by the run + assert deps.state.value == 0 + async def test_callback_sync() -> None: """Test that sync callbacks work correctly."""