Skip to content

Commit 2fca506

Browse files
authored
Ensure AG-UI state is isolated between requests. (#2343)
1 parent 168680a commit 2fca506

File tree

3 files changed

+152
-56
lines changed

3 files changed

+152
-56
lines changed

docs/ag-ui.md

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -82,10 +82,15 @@ The adapter provides full support for
8282
real-time synchronization between agents and frontend applications.
8383

8484
In the example below we have document state which is shared between the UI and
85-
server using the [`StateDeps`][pydantic_ai.ag_ui.StateDeps] which implements the
86-
[`StateHandler`][pydantic_ai.ag_ui.StateHandler] protocol that can be used to automatically
87-
decode state contained in [`RunAgentInput.state`](https://docs.ag-ui.com/sdk/js/core/types#runagentinput)
88-
when processing requests.
85+
server using the [`StateDeps`][pydantic_ai.ag_ui.StateDeps] [dependencies type](./dependencies.md) that can be used to automatically
86+
validate state contained in [`RunAgentInput.state`](https://docs.ag-ui.com/sdk/js/core/types#runagentinput) using a Pydantic `BaseModel` specified as a generic parameter.
87+
88+
!!! note "Custom dependencies type with AG-UI state"
89+
If you want to use your own dependencies type to hold AG-UI state as well as other things, it needs to implements the
90+
[`StateHandler`][pydantic_ai.ag_ui.StateHandler] protocol, meaning it needs to be a [dataclass](https://docs.python.org/3/library/dataclasses.html) with a non-optional `state` field. This lets Pydantic AI ensure that state is properly isolated between requests by building a new dependencies object each time.
91+
92+
If the `state` field's type is a Pydantic `BaseModel` subclass, the raw state dictionary on the request is automatically validated. If not, you can validate the raw value yourself in your dependencies dataclass's `__post_init__` method.
93+
8994

9095
```python {title="ag_ui_state.py" py="3.10"}
9196
from pydantic import BaseModel

pydantic_ai_slim/pydantic_ai/ag_ui.py

Lines changed: 33 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -9,18 +9,25 @@
99
import json
1010
import uuid
1111
from collections.abc import Iterable, Mapping, Sequence
12-
from dataclasses import dataclass, field
12+
from dataclasses import Field, dataclass, field, replace
1313
from http import HTTPStatus
1414
from typing import (
15+
TYPE_CHECKING,
1516
Any,
1617
Callable,
18+
ClassVar,
1719
Final,
1820
Generic,
1921
Protocol,
2022
TypeVar,
2123
runtime_checkable,
2224
)
2325

26+
from pydantic_ai.exceptions import UserError
27+
28+
if TYPE_CHECKING:
29+
pass
30+
2431
try:
2532
from ag_ui.core import (
2633
AssistantMessage,
@@ -288,8 +295,24 @@ async def run(
288295
if not run_input.messages:
289296
raise _NoMessagesError
290297

298+
raw_state: dict[str, Any] = run_input.state or {}
291299
if isinstance(deps, StateHandler):
292-
deps.state = run_input.state
300+
if isinstance(deps.state, BaseModel):
301+
try:
302+
state = type(deps.state).model_validate(raw_state)
303+
except ValidationError as e: # pragma: no cover
304+
raise _InvalidStateError from e
305+
else:
306+
state = raw_state
307+
308+
deps = replace(deps, state=state)
309+
elif raw_state:
310+
raise UserError(
311+
f'AG-UI state is provided but `deps` of type `{type(deps).__name__}` does not implement the `StateHandler` protocol: it needs to be a dataclass with a non-optional `state` field.'
312+
)
313+
else:
314+
# `deps` not being a `StateHandler` is OK if there is no state.
315+
pass
293316

294317
messages = _messages_from_ag_ui(run_input.messages)
295318

@@ -311,7 +334,7 @@ async def run(
311334
yield encoder.encode(
312335
RunErrorEvent(message=e.message, code=e.code),
313336
)
314-
except Exception as e: # pragma: no cover
337+
except Exception as e:
315338
yield encoder.encode(
316339
RunErrorEvent(message=str(e)),
317340
)
@@ -531,7 +554,11 @@ def _messages_from_ag_ui(messages: list[Message]) -> list[ModelMessage]:
531554

532555
@runtime_checkable
533556
class StateHandler(Protocol):
534-
"""Protocol for state handlers in agent runs."""
557+
"""Protocol for state handlers in agent runs. Requires the class to be a dataclass with a `state` field."""
558+
559+
# Has to be a dataclass so we can use `replace` to update the state.
560+
# From https://github.com/python/typeshed/blob/9ab7fde0a0cd24ed7a72837fcb21093b811b80d8/stdlib/_typeshed/__init__.pyi#L352
561+
__dataclass_fields__: ClassVar[dict[str, Field[Any]]]
535562

536563
@property
537564
def state(self) -> State:
@@ -558,6 +585,7 @@ def state(self, state: State) -> None:
558585
"""Type variable for the state type, which must be a subclass of `BaseModel`."""
559586

560587

588+
@dataclass
561589
class StateDeps(Generic[StateT]):
562590
"""Provides AG-UI state management.
563591
@@ -570,42 +598,7 @@ class StateDeps(Generic[StateT]):
570598
Implements the `StateHandler` protocol.
571599
"""
572600

573-
def __init__(self, default: StateT) -> None:
574-
"""Initialize the state with the provided state type."""
575-
self._state = default
576-
577-
@property
578-
def state(self) -> StateT:
579-
"""Get the current state of the agent run.
580-
581-
Returns:
582-
The current run state.
583-
"""
584-
return self._state
585-
586-
@state.setter
587-
def state(self, state: State) -> None:
588-
"""Set the state of the agent run.
589-
590-
This method is called to update the state of the agent run with the
591-
provided state.
592-
593-
Implements the `StateHandler` protocol.
594-
595-
Args:
596-
state: The run state, which must be `None` or model validate for the state type.
597-
598-
Raises:
599-
InvalidStateError: If `state` does not validate.
600-
"""
601-
if state is None:
602-
# If state is None, we keep the current state, which will be the default state.
603-
return
604-
605-
try:
606-
self._state = type(self._state).model_validate(state)
607-
except ValidationError as e: # pragma: no cover
608-
raise _InvalidStateError from e
601+
state: StateT
609602

610603

611604
@dataclass(repr=False)

tests/test_ag_ui.py

Lines changed: 110 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
import json
88
import uuid
99
from collections.abc import AsyncIterator
10+
from dataclasses import dataclass
1011
from http import HTTPStatus
1112
from typing import Any
1213

@@ -17,7 +18,9 @@
1718
from inline_snapshot import snapshot
1819
from pydantic import BaseModel
1920

21+
from pydantic_ai._run_context import RunContext
2022
from pydantic_ai.agent import Agent
23+
from pydantic_ai.exceptions import UserError
2124
from pydantic_ai.messages import ModelMessage
2225
from pydantic_ai.models.function import (
2326
AgentInfo,
@@ -27,8 +30,9 @@
2730
DeltaToolCalls,
2831
FunctionModel,
2932
)
33+
from pydantic_ai.models.test import TestModel
3034
from pydantic_ai.output import OutputDataT
31-
from pydantic_ai.tools import AgentDepsT
35+
from pydantic_ai.tools import AgentDepsT, ToolDefinition
3236

3337
from .conftest import IsSameStr
3438

@@ -180,7 +184,7 @@ def create_input(
180184
thread_id=thread_id,
181185
run_id=uuid_str(),
182186
messages=list(messages),
183-
state=state,
187+
state=dict(state) if state else {},
184188
context=[],
185189
tools=tools or [],
186190
forwarded_props=None,
@@ -1050,9 +1054,19 @@ async def stream_function(
10501054
async def test_request_with_state() -> None:
10511055
"""Test request with state modification."""
10521056

1057+
seen_states: list[int] = []
1058+
1059+
async def store_state(
1060+
ctx: RunContext[StateDeps[StateInt]], tool_defs: list[ToolDefinition]
1061+
) -> list[ToolDefinition]:
1062+
seen_states.append(ctx.deps.state.value)
1063+
ctx.deps.state.value += 1
1064+
return tool_defs
1065+
10531066
agent: Agent[StateDeps[StateInt], str] = Agent(
10541067
model=FunctionModel(stream_function=simple_stream),
10551068
deps_type=StateDeps[StateInt], # type: ignore[reportUnknownArgumentType]
1069+
prepare_tools=store_state,
10561070
)
10571071
adapter = _Adapter(agent=agent)
10581072
run_inputs = [
@@ -1074,32 +1088,101 @@ async def test_request_with_state() -> None:
10741088
id='msg_3',
10751089
content='Hello, how are you?',
10761090
),
1091+
),
1092+
create_input(
1093+
UserMessage(
1094+
id='msg_4',
1095+
content='Hello, how are you?',
1096+
),
10771097
state=StateInt(value=42),
10781098
),
10791099
]
10801100

1081-
deps = StateDeps(StateInt())
1101+
deps = StateDeps(StateInt(value=0))
10821102

1083-
last_value = deps.state.value
10841103
for run_input in run_inputs:
10851104
events = list[dict[str, Any]]()
10861105
async for event in adapter.run(run_input, deps=deps):
10871106
events.append(json.loads(event.removeprefix('data: ')))
10881107

10891108
assert events == simple_result()
1090-
assert deps.state.value == run_input.state.value if run_input.state is not None else last_value
1091-
last_value = deps.state.value
1109+
assert seen_states == snapshot(
1110+
[
1111+
41, # run msg_1, prepare_tools call 1
1112+
42, # run msg_1, prepare_tools call 2
1113+
0, # run msg_2, prepare_tools call 1
1114+
1, # run msg_2, prepare_tools call 2
1115+
0, # run msg_3, prepare_tools call 1
1116+
1, # run msg_3, prepare_tools call 2
1117+
42, # run msg_4, prepare_tools call 1
1118+
43, # run msg_4, prepare_tools call 2
1119+
]
1120+
)
1121+
1122+
1123+
async def test_request_with_state_without_handler() -> None:
1124+
agent = Agent(model=FunctionModel(stream_function=simple_stream))
1125+
adapter = _Adapter(agent=agent)
1126+
run_input = create_input(
1127+
UserMessage(
1128+
id='msg_1',
1129+
content='Hello, how are you?',
1130+
),
1131+
state=StateInt(value=41),
1132+
)
1133+
1134+
with pytest.raises(
1135+
UserError,
1136+
match='AG-UI state is provided but `deps` of type `NoneType` does not implement the `StateHandler` protocol: it needs to be a dataclass with a non-optional `state` field.',
1137+
):
1138+
async for _ in adapter.run(run_input):
1139+
pass
1140+
1141+
1142+
async def test_request_with_state_with_custom_handler() -> None:
1143+
@dataclass
1144+
class CustomStateDeps:
1145+
state: dict[str, Any]
1146+
1147+
seen_states: list[dict[str, Any]] = []
1148+
1149+
async def store_state(ctx: RunContext[CustomStateDeps], tool_defs: list[ToolDefinition]) -> list[ToolDefinition]:
1150+
seen_states.append(ctx.deps.state)
1151+
return tool_defs
1152+
1153+
agent: Agent[CustomStateDeps, str] = Agent(
1154+
model=FunctionModel(stream_function=simple_stream),
1155+
deps_type=CustomStateDeps,
1156+
prepare_tools=store_state,
1157+
)
1158+
adapter = _Adapter(agent=agent)
1159+
run_input = create_input(
1160+
UserMessage(
1161+
id='msg_1',
1162+
content='Hello, how are you?',
1163+
),
1164+
state={'value': 42},
1165+
)
1166+
1167+
async for _ in adapter.run(run_input, deps=CustomStateDeps(state={'value': 0})):
1168+
pass
10921169

1093-
assert deps.state.value == 42
1170+
assert seen_states[-1] == {'value': 42}
10941171

10951172

10961173
async def test_concurrent_runs() -> None:
10971174
"""Test concurrent execution of multiple runs."""
10981175
import asyncio
10991176

1100-
agent = Agent(
1101-
model=FunctionModel(stream_function=simple_stream),
1177+
agent: Agent[StateDeps[StateInt], str] = Agent(
1178+
model=TestModel(),
1179+
deps_type=StateDeps[StateInt], # type: ignore[reportUnknownArgumentType]
11021180
)
1181+
1182+
@agent.tool
1183+
async def get_state(ctx: RunContext[StateDeps[StateInt]]) -> int:
1184+
return ctx.deps.state.value
1185+
11031186
adapter = _Adapter(agent=agent)
11041187
concurrent_tasks: list[asyncio.Task[list[dict[str, Any]]]] = []
11051188

@@ -1109,10 +1192,11 @@ async def test_concurrent_runs() -> None:
11091192
id=f'msg_{i}',
11101193
content=f'Message {i}',
11111194
),
1195+
state=StateInt(value=i),
11121196
thread_id=f'test_thread_{i}',
11131197
)
11141198

1115-
task = asyncio.create_task(collect_events_from_adapter(adapter, run_input))
1199+
task = asyncio.create_task(collect_events_from_adapter(adapter, run_input, deps=StateDeps(StateInt())))
11161200
concurrent_tasks.append(task)
11171201

11181202
results = await asyncio.gather(*concurrent_tasks)
@@ -1121,9 +1205,23 @@ async def test_concurrent_runs() -> None:
11211205
for i, events in enumerate(results):
11221206
assert events == [
11231207
{'type': 'RUN_STARTED', 'threadId': f'test_thread_{i}', 'runId': (run_id := IsSameStr())},
1208+
{
1209+
'type': 'TOOL_CALL_START',
1210+
'toolCallId': (tool_call_id := IsSameStr()),
1211+
'toolCallName': 'get_state',
1212+
'parentMessageId': IsStr(),
1213+
},
1214+
{'type': 'TOOL_CALL_END', 'toolCallId': tool_call_id},
1215+
{
1216+
'type': 'TOOL_CALL_RESULT',
1217+
'messageId': IsStr(),
1218+
'toolCallId': tool_call_id,
1219+
'content': str(i),
1220+
'role': 'tool',
1221+
},
11241222
{'type': 'TEXT_MESSAGE_START', 'messageId': (message_id := IsSameStr()), 'role': 'assistant'},
1125-
{'type': 'TEXT_MESSAGE_CONTENT', 'messageId': message_id, 'delta': 'success '},
1126-
{'type': 'TEXT_MESSAGE_CONTENT', 'messageId': message_id, 'delta': '(no tool calls)'},
1223+
{'type': 'TEXT_MESSAGE_CONTENT', 'messageId': message_id, 'delta': '{"get_s'},
1224+
{'type': 'TEXT_MESSAGE_CONTENT', 'messageId': message_id, 'delta': 'tate":' + str(i) + '}'},
11271225
{'type': 'TEXT_MESSAGE_END', 'messageId': message_id},
11281226
{'type': 'RUN_FINISHED', 'threadId': f'test_thread_{i}', 'runId': run_id},
11291227
]

0 commit comments

Comments
 (0)